PeterPinetree commited on
Commit
561f65d
·
verified ·
1 Parent(s): 49621ce

Update app.py

Browse files

Added semantic neighborhood and adjusted color scheme.

Files changed (1) hide show
  1. app.py +204 -51
app.py CHANGED
@@ -1,61 +1,214 @@
1
- #from fastapi import FastAPI
2
- import solara
3
  import random
 
 
 
 
4
  import torch
5
  import torch.nn.functional as F
6
- import pandas as pd
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
 
9
- #app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- #@app.get("/")
12
- #def greet_json():
13
- #return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-0.6B')
16
- model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen3-0.6B')
17
- text1 = solara.reactive("Never gonna give you up, never gonna let you")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  @solara.component
19
  def Page():
20
- with solara.Column(margin=10):
21
- solara.Markdown("#Next token prediction visualization")
22
- solara.Markdown("I built this tool to help me understand autoregressive language models. For any given text, it gives the top 10 candidates to be the next token with their respective probabilities. The language model I'm using is the smallest version of Qwen: Qwen3-0.6B.")
23
- def on_action_cell(column, row_index):
24
- text1.value += tokenizer.decode(top_10.indices[0][row_index])
25
- cell_actions = [solara.CellAction(icon="mdi-thumb-up", name="Select", on_click=on_action_cell)]
26
- solara.InputText("Enter text:", value=text1, continuous_update=True)
27
- if text1.value != "":
28
- tokens = tokenizer.encode(text1.value, return_tensors="pt")
29
- spans1 = ""
30
- spans2 = ""
31
- for i, token in enumerate(tokens[0]):
32
- random.seed(i)
33
- random_color = ''.join([random.choice('0123456789ABCDEF') for k in range(6)])
34
- spans1 += " " + f"<span style='font-family: helvetica; color: #{random_color}'>{token}</span>"
35
- spans2 += " " + f"""<span style="
36
- padding: 6px;
37
- border-right: 3px solid white;
38
- line-height: 3em;
39
- font-family: courier;
40
- background-color: #{random_color};
41
- color: white;
42
- position: relative;
43
- "><span style="
44
- position: absolute;
45
- top: 5.5ch;
46
- line-height: 1em;
47
- left: -0.5px;
48
- font-size: 0.45em"> {token}</span>{tokenizer.decode([token])}</span>"""
49
- solara.Markdown(f'{spans2}')
50
- solara.Markdown(f'{spans1}')
51
- outputs = model.generate(tokens, max_new_tokens=1, output_scores=True, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id)
52
- scores = F.softmax(outputs.scores[0], dim=-1)
53
- top_10 = torch.topk(scores, 10)
54
- df = pd.DataFrame()
55
- df["probs"] = top_10.values[0]
56
- df["probs"] = [f"{value:.2%}" for value in df["probs"].values]
57
- df["next token ID"] = [top_10.indices[0][i].numpy() for i in range(10)]
58
- df["predicted next token"] = [tokenizer.decode(top_10.indices[0][i]) for i in range(10)]
59
- solara.Markdown("###Prediction")
60
- solara.DataFrame(df, items_per_page=10, cell_actions=cell_actions)
61
  Page()
 
1
+ # app.py
2
+ import json
3
  import random
4
+ from pathlib import Path
5
+
6
+ import solara
7
+ import pandas as pd
8
  import torch
9
  import torch.nn.functional as F
 
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
 
12
+ # ---- Model (same as original Space) -----------------------------------------
13
+ MODEL_ID = "Qwen/Qwen3-0.6B"
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
16
+
17
+ # ---- App state ---------------------------------------------------------------
18
+ text_rx = solara.reactive("twinkle, twinkle, little ")
19
+ top10_rx = solara.reactive(pd.DataFrame(columns=["probs", "next token ID", "predicted next token"]))
20
+ selected_token_id_rx = solara.reactive(None) # for neighborhood focus
21
+ notice_rx = solara.reactive("Enter text to see predictions.")
22
+ theme_css = solara.reactive("""
23
+ <style>
24
+ :root {
25
+ --primary: #38bdf8; /* light blue */
26
+ --bg: #ffffff; /* white */
27
+ --text: #000000; /* black */
28
+ --muted: #6b7280; /* gray-500 */
29
+ --border: #e5e7eb; /* gray-200 */
30
+ }
31
+ body { background: var(--bg); color: var(--text); }
32
+ h1, h2, h3 { color: var(--text); }
33
+ table td, table th { border-color: var(--border) !important; }
34
+ .solara-dataframe .MuiTableCell-root { font-size: 14px; }
35
+ .btn-primary { background: var(--primary); color: #000; border: 1px solid var(--primary); padding: 6px 10px; border-radius: 8px; }
36
+ .badge { display:inline-block; padding:2px 8px; border:1px solid var(--border); border-radius:999px; color:var(--text); }
37
+ </style>
38
+ """)
39
+
40
+ # ---- Load embedding assets (your files) --------------------------------------
41
+ ASSETS = Path("assets/embeddings")
42
+ COORDS_PATH = ASSETS / "pca_top5k_coords.json"
43
+ NEIGH_PATH = ASSETS / "neighbors_top5k_k40.json"
44
+
45
+ coords = {}
46
+ neighbors = {}
47
+ ids_set = set()
48
+
49
+ if COORDS_PATH.exists() and NEIGH_PATH.exists():
50
+ with COORDS_PATH.open("r", encoding="utf-8") as f:
51
+ coords = json.load(f) # {token_id: [x, y], ...}
52
+ with NEIGH_PATH.open("r", encoding="utf-8") as f:
53
+ neighbors = json.load(f) # {"neighbors": {token_id: [[nid, sim], ...]}}
54
+ ids_set = set(map(int, coords.keys()))
55
+ else:
56
+ notice_rx.set("Embedding files not found. Add assets/embeddings/*.json to enable the map.")
57
+
58
+ # ---- Helpers -----------------------------------------------------------------
59
+ def predict_top10(prompt: str) -> pd.DataFrame:
60
+ if not prompt:
61
+ return pd.DataFrame(columns=["probs", "next token ID", "predicted next token"])
62
+
63
+ tokens = tokenizer.encode(prompt, return_tensors="pt")
64
+ out = model.generate(
65
+ tokens,
66
+ max_new_tokens=1,
67
+ output_scores=True,
68
+ return_dict_in_generate=True,
69
+ pad_token_id=tokenizer.eos_token_id,
70
+ do_sample=False, # greedy (deterministic)
71
+ temperature=0.0,
72
+ top_k=1,
73
+ top_p=1.0,
74
+ )
75
+ scores = F.softmax(out.scores[0], dim=-1) # [1, vocab]
76
+ top_10 = torch.topk(scores, 10)
77
+
78
+ df = pd.DataFrame()
79
+ df["probs"] = top_10.values[0].detach().cpu().numpy()
80
+ df["probs"] = [f"{p:.2%}" for p in df["probs"]]
81
+ ids = [int(top_10.indices[0][i].detach().cpu().item()) for i in range(10)]
82
+ df["next token ID"] = ids
83
+ df["predicted next token"] = [tokenizer.decode([i]) for i in ids]
84
+ return df
85
+
86
+ def get_neighbor_list(token_id: int, k: int = 18):
87
+ if not ids_set or token_id not in ids_set:
88
+ return []
89
+ raw = neighbors.get("neighbors", {}).get(str(token_id), [])
90
+ # raw item is [nid, sim]; keep top k
91
+ return raw[:k]
92
+
93
+ # ---- Plot (Plotly scatter) ---------------------------------------------------
94
+ # We generate a static "all points" scatter once, then reuse it with highlights.
95
+ import plotly.graph_objects as go
96
 
97
+ def base_scatter():
98
+ if not coords:
99
+ return go.Figure().update_layout(
100
+ height=440, margin=dict(l=10, r=10, t=10, b=10),
101
+ paper_bgcolor="white", plot_bgcolor="white",
102
+ )
103
+ # unpack coordinates
104
+ xs, ys, tids = [], [], []
105
+ for tid_str, pt in coords.items():
106
+ xs.append(pt[0]); ys.append(pt[1]); tids.append(int(tid_str))
107
+ fig = go.Figure()
108
+ fig.add_trace(go.Scattergl(
109
+ x=xs, y=ys, mode="markers",
110
+ marker=dict(size=3, opacity=0.85),
111
+ text=[f"id {t}" for t in tids],
112
+ hoverinfo="skip", # keep hover minimal; we’ll show neighbors explicitly
113
+ ))
114
+ fig.update_layout(
115
+ height=440, margin=dict(l=10, r=10, t=10, b=10),
116
+ paper_bgcolor="white", plot_bgcolor="white",
117
+ xaxis=dict(visible=False), yaxis=dict(visible=False),
118
+ )
119
+ return fig
120
 
121
+ base_fig = base_scatter()
122
+ fig_rx = solara.reactive(base_fig)
123
+
124
+ def highlight(token_id: int):
125
+ """Return a new figure with neighbors + target highlighted."""
126
+ fig = base_fig.to_dict() # detach copy
127
+ fig = go.Figure(fig)
128
+
129
+ if not coords or token_id not in ids_set:
130
+ return fig
131
+
132
+ # Target
133
+ tx, ty = coords[str(token_id)]
134
+ fig.add_trace(go.Scattergl(
135
+ x=[tx], y=[ty], mode="markers",
136
+ marker=dict(size=8, line=dict(width=1), symbol="circle"),
137
+ name="target",
138
+ ))
139
+
140
+ # Neighbors
141
+ nbrs = get_neighbor_list(token_id)
142
+ if nbrs:
143
+ nx = [coords[str(nid)][0] for nid, _ in nbrs]
144
+ ny = [coords[str(nid)][1] for nid, _ in nbrs]
145
+ fig.add_trace(go.Scattergl(
146
+ x=nx, y=ny, mode="markers",
147
+ marker=dict(size=6, symbol="circle-open"),
148
+ name="neighbors",
149
+ ))
150
+ fig.update_layout(showlegend=False)
151
+ return fig
152
+
153
+ # ---- UI actions --------------------------------------------------------------
154
+ def on_append_cell(column, row_index):
155
+ # append chosen next token to the text input
156
+ df = top10_rx.value
157
+ if row_index < len(df):
158
+ token_id = int(df.iloc[row_index]["next token ID"])
159
+ decoded = tokenizer.decode([token_id])
160
+ text_rx.set(text_rx.value + decoded)
161
+ selected_token_id_rx.set(token_id)
162
+ # Update plot
163
+ fig_rx.set(highlight(token_id))
164
+
165
+ cell_actions = [solara.CellAction(icon="mdi-plus", name="Append & highlight", on_click=on_append_cell)]
166
+
167
+ def on_predict():
168
+ df = predict_top10(text_rx.value)
169
+ top10_rx.set(df)
170
+ notice_rx.set("Click a candidate to append it and highlight its neighborhood.")
171
+ # also set selected to the top-1 for convenience
172
+ if len(df) > 0:
173
+ tid = int(df.iloc[0]["next token ID"])
174
+ selected_token_id_rx.set(tid)
175
+ fig_rx.set(highlight(tid))
176
+
177
+ def on_show_neighborhood():
178
+ # take last token in the prompt (if any), otherwise do nothing
179
+ ids = tokenizer.encode(text_rx.value)
180
+ if ids:
181
+ token_id = int(ids[-1])
182
+ selected_token_id_rx.set(token_id)
183
+ fig_rx.set(highlight(token_id))
184
+
185
+ # ---- Page --------------------------------------------------------------------
186
  @solara.component
187
  def Page():
188
+ solara.HTML(tag="div", unsafe_inner_html=theme_css.value) # inject CSS theme
189
+
190
+ with solara.Column(margin=12, gap="16px"):
191
+ solara.Markdown("# Next-Token Predictor + Semantic Neighborhood")
192
+ solara.Markdown(
193
+ "Type text, then **Predict** to see the next-token distribution. "
194
+ "Click a candidate to append it and highlight its **semantic neighborhood**."
195
+ )
196
+ with solara.Row(gap="8px"):
197
+ solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth": "520px"})
198
+ solara.Button("Predict", on_click=on_predict, classes=["btn-primary"])
199
+ solara.Button("Show neighborhood of last token", on_click=on_show_neighborhood)
200
+
201
+ solara.Markdown(f"*{notice_rx.value}*")
202
+
203
+ # Top-10 table
204
+ solara.Markdown("### Prediction")
205
+ solara.DataFrame(top10_rx.value, items_per_page=10, cell_actions=cell_actions)
206
+
207
+ # Neighborhood panel
208
+ solara.Markdown("### Semantic Neighborhood")
209
+ if not coords:
210
+ solara.Markdown("> Embedding map unavailable – add `assets/embeddings/*.json`.")
211
+ else:
212
+ solara.FigurePlotly(fig_rx.value)
213
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  Page()