File size: 5,854 Bytes
193fd12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# =========================
# model.py
# =========================
import re
import torch
import numpy as np
import pandas as pd
from transformers import (
    AutoModelForPreTraining,
    AutoTokenizer,
    pipeline,
)
import streamlit as st # Přidáme pro cachování

# =========================
# CONFIG (stejné jako u vás)
# =========================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
ELECTRA_MODEL = "Seznam/small-e-czech"
CLF_MODEL = "Stremie/xlm-roberta-base-clickbait"

RTD_CLICKBAIT_TH = 0.20
RTD_BORDERLINE_TH = 0.15
CLF_CLICK_TH = 0.65
CLF_NOT_TH   = 0.35
COMB_CLICK_TH = 0.60
COMB_NOT_TH   = 0.40

# =========================
# LOAD MODELS (s cachováním)
# =========================
# Použijeme @st.cache_resource, aby se modely načetly jen jednou
@st.cache_resource
def load_models():
    """Načte a vrátí oba modely a tokenizer."""
    print("Načítám modely...")
    disc = AutoModelForPreTraining.from_pretrained(ELECTRA_MODEL).to(DEVICE).eval()
    tok  = AutoTokenizer.from_pretrained(ELECTRA_MODEL)

    clf = pipeline(
        "text-classification",
        model=CLF_MODEL,
        device=0 if DEVICE == "cuda" else -1
    )

    # ---- Robust label mapping pro klasifikátor ----
    id2label = getattr(clf.model.config, "id2label", {}) or {}
    label_values_upper = {str(v).upper() for v in id2label.values()}
    if not ({"CLICKBAIT", "NOT"} <= label_values_upper):
        clf.model.config.id2label = {0: "NOT", 1: "CLICKBAIT"}
        clf.model.config.label2id = {"NOT": 0, "CLICKBAIT": 1}

    print("Modely načteny.")
    return disc, tok, clf

# Všechny vaše ostatní funkce (rtd_token_scores_batch, classify_supervised, atd.)
# zde zkopírujte BEZE ZMĚN.
# ... (vložte sem zbytek funkcí z vašeho skriptu) ...
@torch.no_grad()
def rtd_token_scores_batch(texts, disc, tok, batch_size=32):
    all_scores = []
    for i in range(0, len(texts), batch_size):
        enc = tok(texts[i:i+batch_size], return_tensors="pt", padding=True, truncation=True).to(DEVICE)
        out = disc(**enc)
        probs = torch.sigmoid(out.logits).detach().cpu().numpy()
        all_scores.extend(probs)
    return all_scores

def clickbait_score_rtd_from_probs(probs, k_top: int = 5) -> float:
    core = probs[1:-1] if len(probs) >= 2 else probs
    if core.size == 0: return 0.0
    k = min(k_top, core.size)
    topk = np.partition(core, -k)[-k:]
    score = float(np.mean(topk))
    return max(0.0, min(1.0, score))

def rtd_label_from_score(p: float) -> str:
    if p >= RTD_CLICKBAIT_TH: return "CLICKBAIT"
    if p >= RTD_BORDERLINE_TH: return "BORDERLINE"
    return "NOT"

def _normalize_label_to_index(lbl, LABEL2ID):
    if isinstance(lbl, int): return lbl
    s = str(lbl)
    if s in LABEL2ID: return LABEL2ID[s]
    m = re.search(r"(\d+)$", s)
    if m: return int(m.group(1))
    return None

def classify_supervised(texts, clf):
    ID2LABEL = clf.model.config.id2label
    LABEL2ID = clf.model.config.label2id
    sanitized = [str(t).strip() if pd.notna(t) else "" for t in texts]
    outs = clf(sanitized, top_k=None, truncation=True, max_length=256)
    results = []
    for scores in outs:
        prob_click, prob_not = 0.0, 0.0
        for s in scores:
            idx = _normalize_label_to_index(s["label"], LABEL2ID)
            if idx is None: continue
            name = ID2LABEL.get(idx, str(s["label"])).upper()
            if name == "CLICKBAIT": prob_click = float(s["score"])
            elif name == "NOT": prob_not = float(s["score"])

        binary_label = "CLICKBAIT" if prob_click >= prob_not else "NOT"
        if prob_click >= CLF_CLICK_TH: tri_label = "CLICKBAIT"
        elif prob_click <= CLF_NOT_TH: tri_label = "NOT"
        else: tri_label = "BORDERLINE"
        clf_margin = abs(prob_click - prob_not)
        results.append({
            "clf_prob_clickbait": prob_click, "clf_prob_not": prob_not,
            "clf_label": binary_label, "clf_label_3way": tri_label,
            "clf_margin": clf_margin,
        })
    return results


# =========================
# HLAVNÍ FUNKCE PRO ZPRACOVÁNÍ
# =========================
def process_headlines(headlines: list[str], k_top: int = 5) -> pd.DataFrame:
    """Zpracuje seznam titulků a vrátí DataFrame s výsledky."""
    if not headlines or all(s.isspace() for s in headlines):
        return pd.DataFrame()

    disc, tok, clf = load_models()
    df = pd.DataFrame({"Titulek": headlines})

    # RTD
    rtd_probs_all = rtd_token_scores_batch(headlines, disc, tok, batch_size=32)
    rtd_scores = [clickbait_score_rtd_from_probs(p, k_top=k_top) for p in rtd_probs_all]
    rtd_labels = [rtd_label_from_score(p) for p in rtd_scores]

    # Supervised
    sup_rows = classify_supervised(headlines, clf)
    df_sup = pd.DataFrame(sup_rows)

    # Sestavení výsledků
    df_out = df.copy()
    df_out["rtd_score"] = rtd_scores
    df_out["rtd_label"] = rtd_labels
    df_out = pd.concat([df_out, df_sup], axis=1)

    df_out["combined_score"] = (0.85 * df_out["clf_prob_clickbait"] + 0.15 * df_out["rtd_score"])

    final_labels = []
    for s in df_out["combined_score"]:
        if s >= COMB_CLICK_TH: final_labels.append("CLICKBAIT")
        elif s <= COMB_NOT_TH: final_labels.append("NOT")
        else: final_labels.append("BORDERLINE")
    df_out["final_label"] = final_labels

    # Vybereme a přejmenujeme sloupce pro přehlednost
    final_cols = {
        "Titulek": "Titulek",
        "final_label": "Výsledek",
        "combined_score": "Kombinované skóre",
        "clf_prob_clickbait": "Pravděpodobnost clickbaitu",
        "rtd_score": "RTD skóre",
    }
    return df_out[final_cols.keys()].rename(columns=final_cols)