PeterPinetree's picture
Update app.py
561f65d verified
raw
history blame
7.97 kB
# app.py
import json
import random
from pathlib import Path
import solara
import pandas as pd
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
# ---- Model (same as original Space) -----------------------------------------
MODEL_ID = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
# ---- App state ---------------------------------------------------------------
text_rx = solara.reactive("twinkle, twinkle, little ")
top10_rx = solara.reactive(pd.DataFrame(columns=["probs", "next token ID", "predicted next token"]))
selected_token_id_rx = solara.reactive(None) # for neighborhood focus
notice_rx = solara.reactive("Enter text to see predictions.")
theme_css = solara.reactive("""
<style>
:root {
--primary: #38bdf8; /* light blue */
--bg: #ffffff; /* white */
--text: #000000; /* black */
--muted: #6b7280; /* gray-500 */
--border: #e5e7eb; /* gray-200 */
}
body { background: var(--bg); color: var(--text); }
h1, h2, h3 { color: var(--text); }
table td, table th { border-color: var(--border) !important; }
.solara-dataframe .MuiTableCell-root { font-size: 14px; }
.btn-primary { background: var(--primary); color: #000; border: 1px solid var(--primary); padding: 6px 10px; border-radius: 8px; }
.badge { display:inline-block; padding:2px 8px; border:1px solid var(--border); border-radius:999px; color:var(--text); }
</style>
""")
# ---- Load embedding assets (your files) --------------------------------------
ASSETS = Path("assets/embeddings")
COORDS_PATH = ASSETS / "pca_top5k_coords.json"
NEIGH_PATH = ASSETS / "neighbors_top5k_k40.json"
coords = {}
neighbors = {}
ids_set = set()
if COORDS_PATH.exists() and NEIGH_PATH.exists():
with COORDS_PATH.open("r", encoding="utf-8") as f:
coords = json.load(f) # {token_id: [x, y], ...}
with NEIGH_PATH.open("r", encoding="utf-8") as f:
neighbors = json.load(f) # {"neighbors": {token_id: [[nid, sim], ...]}}
ids_set = set(map(int, coords.keys()))
else:
notice_rx.set("Embedding files not found. Add assets/embeddings/*.json to enable the map.")
# ---- Helpers -----------------------------------------------------------------
def predict_top10(prompt: str) -> pd.DataFrame:
if not prompt:
return pd.DataFrame(columns=["probs", "next token ID", "predicted next token"])
tokens = tokenizer.encode(prompt, return_tensors="pt")
out = model.generate(
tokens,
max_new_tokens=1,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id,
do_sample=False, # greedy (deterministic)
temperature=0.0,
top_k=1,
top_p=1.0,
)
scores = F.softmax(out.scores[0], dim=-1) # [1, vocab]
top_10 = torch.topk(scores, 10)
df = pd.DataFrame()
df["probs"] = top_10.values[0].detach().cpu().numpy()
df["probs"] = [f"{p:.2%}" for p in df["probs"]]
ids = [int(top_10.indices[0][i].detach().cpu().item()) for i in range(10)]
df["next token ID"] = ids
df["predicted next token"] = [tokenizer.decode([i]) for i in ids]
return df
def get_neighbor_list(token_id: int, k: int = 18):
if not ids_set or token_id not in ids_set:
return []
raw = neighbors.get("neighbors", {}).get(str(token_id), [])
# raw item is [nid, sim]; keep top k
return raw[:k]
# ---- Plot (Plotly scatter) ---------------------------------------------------
# We generate a static "all points" scatter once, then reuse it with highlights.
import plotly.graph_objects as go
def base_scatter():
if not coords:
return go.Figure().update_layout(
height=440, margin=dict(l=10, r=10, t=10, b=10),
paper_bgcolor="white", plot_bgcolor="white",
)
# unpack coordinates
xs, ys, tids = [], [], []
for tid_str, pt in coords.items():
xs.append(pt[0]); ys.append(pt[1]); tids.append(int(tid_str))
fig = go.Figure()
fig.add_trace(go.Scattergl(
x=xs, y=ys, mode="markers",
marker=dict(size=3, opacity=0.85),
text=[f"id {t}" for t in tids],
hoverinfo="skip", # keep hover minimal; we’ll show neighbors explicitly
))
fig.update_layout(
height=440, margin=dict(l=10, r=10, t=10, b=10),
paper_bgcolor="white", plot_bgcolor="white",
xaxis=dict(visible=False), yaxis=dict(visible=False),
)
return fig
base_fig = base_scatter()
fig_rx = solara.reactive(base_fig)
def highlight(token_id: int):
"""Return a new figure with neighbors + target highlighted."""
fig = base_fig.to_dict() # detach copy
fig = go.Figure(fig)
if not coords or token_id not in ids_set:
return fig
# Target
tx, ty = coords[str(token_id)]
fig.add_trace(go.Scattergl(
x=[tx], y=[ty], mode="markers",
marker=dict(size=8, line=dict(width=1), symbol="circle"),
name="target",
))
# Neighbors
nbrs = get_neighbor_list(token_id)
if nbrs:
nx = [coords[str(nid)][0] for nid, _ in nbrs]
ny = [coords[str(nid)][1] for nid, _ in nbrs]
fig.add_trace(go.Scattergl(
x=nx, y=ny, mode="markers",
marker=dict(size=6, symbol="circle-open"),
name="neighbors",
))
fig.update_layout(showlegend=False)
return fig
# ---- UI actions --------------------------------------------------------------
def on_append_cell(column, row_index):
# append chosen next token to the text input
df = top10_rx.value
if row_index < len(df):
token_id = int(df.iloc[row_index]["next token ID"])
decoded = tokenizer.decode([token_id])
text_rx.set(text_rx.value + decoded)
selected_token_id_rx.set(token_id)
# Update plot
fig_rx.set(highlight(token_id))
cell_actions = [solara.CellAction(icon="mdi-plus", name="Append & highlight", on_click=on_append_cell)]
def on_predict():
df = predict_top10(text_rx.value)
top10_rx.set(df)
notice_rx.set("Click a candidate to append it and highlight its neighborhood.")
# also set selected to the top-1 for convenience
if len(df) > 0:
tid = int(df.iloc[0]["next token ID"])
selected_token_id_rx.set(tid)
fig_rx.set(highlight(tid))
def on_show_neighborhood():
# take last token in the prompt (if any), otherwise do nothing
ids = tokenizer.encode(text_rx.value)
if ids:
token_id = int(ids[-1])
selected_token_id_rx.set(token_id)
fig_rx.set(highlight(token_id))
# ---- Page --------------------------------------------------------------------
@solara.component
def Page():
solara.HTML(tag="div", unsafe_inner_html=theme_css.value) # inject CSS theme
with solara.Column(margin=12, gap="16px"):
solara.Markdown("# Next-Token Predictor + Semantic Neighborhood")
solara.Markdown(
"Type text, then **Predict** to see the next-token distribution. "
"Click a candidate to append it and highlight its **semantic neighborhood**."
)
with solara.Row(gap="8px"):
solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth": "520px"})
solara.Button("Predict", on_click=on_predict, classes=["btn-primary"])
solara.Button("Show neighborhood of last token", on_click=on_show_neighborhood)
solara.Markdown(f"*{notice_rx.value}*")
# Top-10 table
solara.Markdown("### Prediction")
solara.DataFrame(top10_rx.value, items_per_page=10, cell_actions=cell_actions)
# Neighborhood panel
solara.Markdown("### Semantic Neighborhood")
if not coords:
solara.Markdown("> Embedding map unavailable – add `assets/embeddings/*.json`.")
else:
solara.FigurePlotly(fig_rx.value)
Page()