Spaces:
Sleeping
Sleeping
| # app.py | |
| import json | |
| import random | |
| from pathlib import Path | |
| import solara | |
| import pandas as pd | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # ---- Model (same as original Space) ----------------------------------------- | |
| MODEL_ID = "Qwen/Qwen3-0.6B" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID) | |
| # ---- App state --------------------------------------------------------------- | |
| text_rx = solara.reactive("twinkle, twinkle, little ") | |
| top10_rx = solara.reactive(pd.DataFrame(columns=["probs", "next token ID", "predicted next token"])) | |
| selected_token_id_rx = solara.reactive(None) # for neighborhood focus | |
| notice_rx = solara.reactive("Enter text to see predictions.") | |
| theme_css = solara.reactive(""" | |
| <style> | |
| :root { | |
| --primary: #38bdf8; /* light blue */ | |
| --bg: #ffffff; /* white */ | |
| --text: #000000; /* black */ | |
| --muted: #6b7280; /* gray-500 */ | |
| --border: #e5e7eb; /* gray-200 */ | |
| } | |
| body { background: var(--bg); color: var(--text); } | |
| h1, h2, h3 { color: var(--text); } | |
| table td, table th { border-color: var(--border) !important; } | |
| .solara-dataframe .MuiTableCell-root { font-size: 14px; } | |
| .btn-primary { background: var(--primary); color: #000; border: 1px solid var(--primary); padding: 6px 10px; border-radius: 8px; } | |
| .badge { display:inline-block; padding:2px 8px; border:1px solid var(--border); border-radius:999px; color:var(--text); } | |
| </style> | |
| """) | |
| # ---- Load embedding assets (your files) -------------------------------------- | |
| ASSETS = Path("assets/embeddings") | |
| COORDS_PATH = ASSETS / "pca_top5k_coords.json" | |
| NEIGH_PATH = ASSETS / "neighbors_top5k_k40.json" | |
| coords = {} | |
| neighbors = {} | |
| ids_set = set() | |
| if COORDS_PATH.exists() and NEIGH_PATH.exists(): | |
| with COORDS_PATH.open("r", encoding="utf-8") as f: | |
| coords = json.load(f) # {token_id: [x, y], ...} | |
| with NEIGH_PATH.open("r", encoding="utf-8") as f: | |
| neighbors = json.load(f) # {"neighbors": {token_id: [[nid, sim], ...]}} | |
| ids_set = set(map(int, coords.keys())) | |
| else: | |
| notice_rx.set("Embedding files not found. Add assets/embeddings/*.json to enable the map.") | |
| # ---- Helpers ----------------------------------------------------------------- | |
| def predict_top10(prompt: str) -> pd.DataFrame: | |
| if not prompt: | |
| return pd.DataFrame(columns=["probs", "next token ID", "predicted next token"]) | |
| tokens = tokenizer.encode(prompt, return_tensors="pt") | |
| out = model.generate( | |
| tokens, | |
| max_new_tokens=1, | |
| output_scores=True, | |
| return_dict_in_generate=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=False, # greedy (deterministic) | |
| temperature=0.0, | |
| top_k=1, | |
| top_p=1.0, | |
| ) | |
| scores = F.softmax(out.scores[0], dim=-1) # [1, vocab] | |
| top_10 = torch.topk(scores, 10) | |
| df = pd.DataFrame() | |
| df["probs"] = top_10.values[0].detach().cpu().numpy() | |
| df["probs"] = [f"{p:.2%}" for p in df["probs"]] | |
| ids = [int(top_10.indices[0][i].detach().cpu().item()) for i in range(10)] | |
| df["next token ID"] = ids | |
| df["predicted next token"] = [tokenizer.decode([i]) for i in ids] | |
| return df | |
| def get_neighbor_list(token_id: int, k: int = 18): | |
| if not ids_set or token_id not in ids_set: | |
| return [] | |
| raw = neighbors.get("neighbors", {}).get(str(token_id), []) | |
| # raw item is [nid, sim]; keep top k | |
| return raw[:k] | |
| # ---- Plot (Plotly scatter) --------------------------------------------------- | |
| # We generate a static "all points" scatter once, then reuse it with highlights. | |
| import plotly.graph_objects as go | |
| def base_scatter(): | |
| if not coords: | |
| return go.Figure().update_layout( | |
| height=440, margin=dict(l=10, r=10, t=10, b=10), | |
| paper_bgcolor="white", plot_bgcolor="white", | |
| ) | |
| # unpack coordinates | |
| xs, ys, tids = [], [], [] | |
| for tid_str, pt in coords.items(): | |
| xs.append(pt[0]); ys.append(pt[1]); tids.append(int(tid_str)) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scattergl( | |
| x=xs, y=ys, mode="markers", | |
| marker=dict(size=3, opacity=0.85), | |
| text=[f"id {t}" for t in tids], | |
| hoverinfo="skip", # keep hover minimal; we’ll show neighbors explicitly | |
| )) | |
| fig.update_layout( | |
| height=440, margin=dict(l=10, r=10, t=10, b=10), | |
| paper_bgcolor="white", plot_bgcolor="white", | |
| xaxis=dict(visible=False), yaxis=dict(visible=False), | |
| ) | |
| return fig | |
| base_fig = base_scatter() | |
| fig_rx = solara.reactive(base_fig) | |
| def highlight(token_id: int): | |
| """Return a new figure with neighbors + target highlighted.""" | |
| fig = base_fig.to_dict() # detach copy | |
| fig = go.Figure(fig) | |
| if not coords or token_id not in ids_set: | |
| return fig | |
| # Target | |
| tx, ty = coords[str(token_id)] | |
| fig.add_trace(go.Scattergl( | |
| x=[tx], y=[ty], mode="markers", | |
| marker=dict(size=8, line=dict(width=1), symbol="circle"), | |
| name="target", | |
| )) | |
| # Neighbors | |
| nbrs = get_neighbor_list(token_id) | |
| if nbrs: | |
| nx = [coords[str(nid)][0] for nid, _ in nbrs] | |
| ny = [coords[str(nid)][1] for nid, _ in nbrs] | |
| fig.add_trace(go.Scattergl( | |
| x=nx, y=ny, mode="markers", | |
| marker=dict(size=6, symbol="circle-open"), | |
| name="neighbors", | |
| )) | |
| fig.update_layout(showlegend=False) | |
| return fig | |
| # ---- UI actions -------------------------------------------------------------- | |
| def on_append_cell(column, row_index): | |
| # append chosen next token to the text input | |
| df = top10_rx.value | |
| if row_index < len(df): | |
| token_id = int(df.iloc[row_index]["next token ID"]) | |
| decoded = tokenizer.decode([token_id]) | |
| text_rx.set(text_rx.value + decoded) | |
| selected_token_id_rx.set(token_id) | |
| # Update plot | |
| fig_rx.set(highlight(token_id)) | |
| cell_actions = [solara.CellAction(icon="mdi-plus", name="Append & highlight", on_click=on_append_cell)] | |
| def on_predict(): | |
| df = predict_top10(text_rx.value) | |
| top10_rx.set(df) | |
| notice_rx.set("Click a candidate to append it and highlight its neighborhood.") | |
| # also set selected to the top-1 for convenience | |
| if len(df) > 0: | |
| tid = int(df.iloc[0]["next token ID"]) | |
| selected_token_id_rx.set(tid) | |
| fig_rx.set(highlight(tid)) | |
| def on_show_neighborhood(): | |
| # take last token in the prompt (if any), otherwise do nothing | |
| ids = tokenizer.encode(text_rx.value) | |
| if ids: | |
| token_id = int(ids[-1]) | |
| selected_token_id_rx.set(token_id) | |
| fig_rx.set(highlight(token_id)) | |
| # ---- Page -------------------------------------------------------------------- | |
| def Page(): | |
| solara.HTML(tag="div", unsafe_inner_html=theme_css.value) # inject CSS theme | |
| with solara.Column(margin=12, gap="16px"): | |
| solara.Markdown("# Next-Token Predictor + Semantic Neighborhood") | |
| solara.Markdown( | |
| "Type text, then **Predict** to see the next-token distribution. " | |
| "Click a candidate to append it and highlight its **semantic neighborhood**." | |
| ) | |
| with solara.Row(gap="8px"): | |
| solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth": "520px"}) | |
| solara.Button("Predict", on_click=on_predict, classes=["btn-primary"]) | |
| solara.Button("Show neighborhood of last token", on_click=on_show_neighborhood) | |
| solara.Markdown(f"*{notice_rx.value}*") | |
| # Top-10 table | |
| solara.Markdown("### Prediction") | |
| solara.DataFrame(top10_rx.value, items_per_page=10, cell_actions=cell_actions) | |
| # Neighborhood panel | |
| solara.Markdown("### Semantic Neighborhood") | |
| if not coords: | |
| solara.Markdown("> Embedding map unavailable – add `assets/embeddings/*.json`.") | |
| else: | |
| solara.FigurePlotly(fig_rx.value) | |
| Page() | |