Spaces:
Sleeping
Sleeping
| import os | |
| import math | |
| import json | |
| import warnings | |
| from dataclasses import dataclass, asdict | |
| from typing import Dict, List, Tuple, Optional | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from torch import nn | |
| import networkx as nx | |
| import streamlit as st | |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
| import umap | |
| from sklearn.neighbors import NearestNeighbors, KernelDensity | |
| from sklearn.cluster import KMeans, DBSCAN | |
| from sklearn.metrics import pairwise_distances | |
| from scipy.spatial import procrustes | |
| from scipy.linalg import orthogonal_procrustes | |
| import plotly.graph_objects as go | |
| # Optional libs (use if present) | |
| try: | |
| import hdbscan # Robust density-based clustering | |
| HAS_HDBSCAN = True | |
| except Exception: | |
| HAS_HDBSCAN = False | |
| try: | |
| import igraph as ig | |
| import leidenalg as la | |
| HAS_IGRAPH_LEIDEN = True | |
| except Exception: | |
| HAS_IGRAPH_LEIDEN = False | |
| try: | |
| import pyvista as pv # Volume & isosurfaces (VTK) | |
| HAS_PYVISTA = True | |
| except Exception: | |
| HAS_PYVISTA = False | |
| # ====== Configuration ========================================================================= | |
| class Config: | |
| # Model | |
| model_name: str = "Qwen/Qwen1.5-1.8B" | |
| max_length: int = 64 | |
| # Data | |
| corpus: List[str] = None | |
| # Graph & Clustering | |
| graph_mode: str = "threshold" | |
| knn_k: int = 8 | |
| sim_threshold: float = 0.05 # Percentile of edges shown 0.05 = Show top 5% of edges | |
| use_cosine: bool = True | |
| # Anchors / LoT-style features (global) | |
| anchor_k: int = 16 # number of global prototypes (KMeans on pooled states) | |
| anchor_temp: float = 0.7 # softmax temperature for converting distances to probs | |
| # Clustering per layer | |
| cluster_method: str = "auto" # {"auto","leiden","hdbscan","dbscan","kmeans"} | |
| n_clusters_kmeans: int = 6 # fallback for kmeans | |
| hdbscan_min_cluster_size: int = 4 | |
| # UMAP & alignment | |
| umap_n_neighbors: int = 30 | |
| umap_min_dist: float = 0.05 | |
| umap_metric: str = "cosine" | |
| fit_pool_per_layer: int = 512 # number of states sampled per layer to fit UMAP | |
| align_layers: bool = True # aligning procrustes to layers | |
| # Visualization | |
| color_by: str = "pos" # "cluster" or "pos" (Part of Speech) | |
| # Output | |
| out_dir: str = "qwen_mri3d_outputs" | |
| plotly_html: str = "qwen_layers_3d.html" | |
| # Default corpus (small and diverse; adjust freely) | |
| DEFAULT_CORPUS = [ | |
| "Is a Universal Basic Income (UBI) a viable solution to poverty, or does it simply discourage people from working?", | |
| "Explain the arguments for and against the independence of Taiwan from the perspective of both the US and China.", | |
| "What are the ethical arguments surrounding the use of CRISPR technology to edit human embryos for non-medical enhancements?", | |
| "Analyze the effectiveness of strict lockdowns versus herd immunity strategies during the COVID-19 pandemic.", | |
| "Why is nuclear energy controversial despite being a low-carbon power source? Present both the safety concerns and the environmental benefits.", | |
| "Does the existence of evil in the world disprove the existence of a benevolent God? Summarize the philosophical debate.", | |
| "Summarize the main arguments used by gun rights advocates against stricter background checks in the United States.", | |
| "Should autonomous weapons systems (killer robots) be banned internationally, even if they could reduce soldier casualties?", | |
| "Was the dropping of the atomic bombs on Hiroshima and Nagasaki militarily necessary to end World War II?", | |
| "What are the competing arguments regarding transgender women participating in biological women's sports categories?" | |
| ] | |
| #Select from 4 different models | |
| MODELS = ["Qwen/Qwen1.5-0.5B", "deepseek-ai/deepseek-coder-1.3b-instruct", "openai-community/gpt2", "prem-research/MiniGuard-v0.1"] | |
| # ====== Utilities ========================================================================= | |
| def seed_everything(seed: int = 42): | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| def cosine_similarity_matrix(X: np.ndarray) -> np.ndarray: | |
| norms = np.linalg.norm(X, axis=1, keepdims=True) + 1e-8 | |
| Xn = X / norms | |
| return Xn @ Xn.T | |
| def orthogonal_align(A_ref: np.ndarray, B: np.ndarray) -> np.ndarray: | |
| """ | |
| Align B to A_ref using Procrustes analysis (rotation/reflection only). | |
| Preserves local geometry of B, but aligns global orientation to A. | |
| """ | |
| # Center both | |
| mu_a = A_ref.mean(0) | |
| mu_b = B.mean(0) | |
| A0 = A_ref - mu_a | |
| B0 = B - mu_b | |
| # Solve for Rotation R that minimizes ||A0 - B0 @ R|| | |
| # M = B0.T @ A0 | |
| # U, S, Vt = svd(M) | |
| # R = U @ Vt | |
| R, _ = orthogonal_procrustes(B0, A0) | |
| # B_aligned = (B - mu_b) @ R + mu_a | |
| # We essentially rotate B to match A's orientation, then shift to A's center | |
| return B0 @ R + mu_a | |
| def get_pos_tags(text: str, tokenizer, tokens: List[str]) -> List[str]: | |
| """ | |
| Map LLM tokens to Spacy POS tags. | |
| Heuristic: Reconstruct text, run Spacy, align based on char overlap. | |
| """ | |
| try: | |
| nlp = spacy.load("en_core_web_sm") | |
| except: | |
| # Fallback if model not downloaded | |
| return ["UNK"] * len(tokens) | |
| doc = nlp(text) | |
| # This is a simplified mapping. Real alignment is complex due to subwords. | |
| # We will approximate: Find which word the subword belongs to. | |
| pos_tags = [] | |
| # Re-build offsets for tokens (simplified) | |
| # Ideally, we use tokenizer(return_offsets_mapping=True) | |
| # Here we will just iterate and approximate for the demo. | |
| # Fast approximation: tag the token string itself | |
| # (Not perfect for subwords like "ing", but visually useful) | |
| for t_str in tokens: | |
| clean_t = t_str.replace("Δ ", "").replace("β", "").strip() | |
| if not clean_t: | |
| pos_tags.append("SYM") # likely special char | |
| continue | |
| # Tag the single token fragment | |
| sub_doc = nlp(clean_t) | |
| if len(sub_doc) > 0: | |
| pos_tags.append(sub_doc[0].pos_) | |
| else: | |
| pos_tags.append("UNK") | |
| return pos_tags | |
| def build_knn_graph(coords: np.ndarray, k: int, metric: str = "cosine") -> nx.Graph: | |
| nbrs = NearestNeighbors(n_neighbors=min(k+1, len(coords)), metric=metric) | |
| nbrs.fit(coords) | |
| distances, indices = nbrs.kneighbors(coords) | |
| G = nx.Graph() | |
| G.add_nodes_from(range(len(coords))) | |
| for i in range(len(coords)): | |
| for j in indices[i, 1:]: | |
| G.add_edge(int(i), int(j)) | |
| return G | |
| def build_threshold_graph(H: np.ndarray, top_pct: float = 0.05, use_cosine: bool = True, include_ties: bool = True,) -> nx.Graph: | |
| if use_cosine: | |
| S = cosine_similarity_matrix(H) | |
| else: | |
| S = H @ H.T | |
| N = S.shape[0] | |
| iu = np.triu_indices(N, k=1) | |
| vals = S[iu] | |
| # threshold at (1 - top_pct) quantile | |
| q = 1.0 - top_pct | |
| thr = float(np.quantile(vals, q)) | |
| G = nx.Graph() | |
| G.add_nodes_from(range(N)) | |
| if include_ties: | |
| mask = vals >= thr | |
| else: | |
| # strictly greater than threshold reduces tie-inflation | |
| mask = vals > thr | |
| rows = iu[0][mask] | |
| cols = iu[1][mask] | |
| wts = vals[mask] | |
| for r, c, w in zip(rows, cols, wts): | |
| G.add_edge(int(r), int(c), weight=float(w)) | |
| return G | |
| def percolation_stats(G: nx.Graph) -> Dict[str, float]: | |
| """ | |
| Compute percolation observables (Ο, #clusters, Ο) as in your notebook. | |
| Ο : fraction of nodes in the Giant Connected Component (GCC) | |
| Ο : mean size of components excluding GCC | |
| """ | |
| n = G.number_of_nodes() | |
| if n == 0: | |
| return dict(phi=0.0, num_clusters=0, chi=0.0, largest_component_size=0, component_sizes=[]) | |
| comps = list(nx.connected_components(G)) | |
| sizes = [len(c) for c in comps] | |
| if not sizes: | |
| return dict(phi=0.0, num_clusters=0, chi=0.0, largest_component_size=0, component_sizes=[]) | |
| largest = max(sizes) | |
| phi = largest / n | |
| non_gcc_sizes = [s for s in sizes if s != largest] | |
| chi = float(np.mean(non_gcc_sizes)) if non_gcc_sizes else 0.0 | |
| return dict(phi=float(phi), | |
| num_clusters=len(comps), | |
| chi=float(chi), | |
| largest_component_size=largest, | |
| component_sizes=sorted(sizes, reverse=True)) | |
| def cluster_layer(features: np.ndarray, G: Optional[nx.Graph], method: str, | |
| n_clusters_kmeans: int=6, hdbscan_min_cluster_size: int=4) -> np.ndarray: | |
| # (Same as original) | |
| method = method.lower() | |
| N = len(features) | |
| if method == "auto": | |
| if HAS_IGRAPH_LEIDEN and G and G.number_of_edges() > 0: return leiden_communities(G) | |
| elif HAS_HDBSCAN: return hdbscan.HDBSCAN(min_cluster_size=hdbscan_min_cluster_size).fit_predict(features) | |
| else: return KMeans(n_clusters=min(n_clusters_kmeans, N), n_init="auto").fit_predict(features) | |
| # ... (rest of method dispatch unchanged) | |
| return KMeans(n_clusters=min(n_clusters_kmeans, N), n_init="auto").fit_predict(features) | |
| # Helper for Leiden (from original) | |
| def leiden_communities(G: nx.Graph) -> np.ndarray: | |
| if not HAS_IGRAPH_LEIDEN: raise RuntimeError("Missing igraph") | |
| mapping = {n: i for i, n in enumerate(G.nodes())} | |
| edges = [(mapping[u], mapping[v]) for u, v in G.edges()] | |
| ig_g = ig.Graph(n=len(mapping), edges=edges, directed=False) | |
| part = la.find_partition(ig_g, la.RBConfigurationVertexPartition) | |
| labels = np.zeros(len(mapping), dtype=int) | |
| for cid, comm in enumerate(part): | |
| for node in comm: labels[node] = cid | |
| return labels | |
| def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0): | |
| dists = pairwise_distances(H, anchors, metric="euclidean") | |
| logits = -dists / max(temperature, 1e-6) | |
| logits = logits - logits.max(axis=1, keepdims=True) | |
| P = np.exp(logits) | |
| P /= P.sum(axis=1, keepdims=True) + 1e-12 | |
| # Entropy calculation | |
| H_unc = -np.sum(P * np.log(P + 1e-12), axis=1) | |
| return dists, P, H_unc | |
| def fit_global_anchors(pool: np.ndarray, K: int) -> np.ndarray: | |
| km = KMeans(n_clusters=K, n_init="auto", random_state=42) | |
| km.fit(pool) | |
| return km.cluster_centers_ | |
| # ====== Model I/O (hidden states) ============================================================= | |
| class HiddenStatesBundle: | |
| """ | |
| Encapsulates a single input's hidden states and metadata. | |
| hidden_layers: list of np.ndarray of shape (T, D), length = num_layers+1 (incl. embedding) | |
| tokens : list of token strings of length T | |
| """ | |
| hidden_layers: List[np.ndarray] | |
| tokens: List[str] | |
| def load_qwen(model_name: str, device: str, dtype: torch.dtype): | |
| """ | |
| Load Qwen with output_hidden_states=True. We use AutoTokenizer for broader compatibility. | |
| """ | |
| print(f"[Load] {model_name} on {device} ({dtype})") | |
| config = AutoConfig.from_pretrained(model_name, output_hidden_states=True) | |
| tok = AutoTokenizer.from_pretrained(model_name, use_fast=True) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, config=config) | |
| model.eval().to(device) | |
| if device == "cuda" and dtype == torch.float16: | |
| model = model.half() | |
| return model, tok | |
| def extract_hidden_states(model, tokenizer, text: str, max_length: int, device: str) -> HiddenStatesBundle: | |
| """ | |
| Run a single forward pass to collect all hidden states (incl. embedding layer). | |
| Returns CPU numpy arrays to keep GPU memory low. | |
| """ | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device) | |
| out = model(**inputs) | |
| # Tuple length = num_layers + 1 (embedding) | |
| hs = [h[0].detach().float().cpu().numpy() for h in out.hidden_states] # shapes: (T, D) | |
| tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) | |
| return HiddenStatesBundle(hidden_layers=hs, tokens=tokens) | |
| # ====== LoT-style anchors & features ========================================================== | |
| def fit_global_anchors(all_states_sampled: np.ndarray, K: int, random_state: int = 42) -> np.ndarray: | |
| """ | |
| Fit KMeans cluster centroids on a pooled set of states (from many layers/texts). | |
| These centroids are "anchors" (LoT-like choices) to build low-dim features: | |
| f(state) = [dist(state, anchor_j)]_{j=1..K} | |
| """ | |
| print(f"[Anchors] Fitting {K} global centroids on {len(all_states_sampled)} states ...") | |
| kmeans = KMeans(n_clusters=K, n_init="auto", random_state=random_state) | |
| kmeans.fit(all_states_sampled) | |
| return kmeans.cluster_centers_ # (K, D) | |
| def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| For states H (N,D) and anchors A (K,D): | |
| - Compute Euclidean distances to each anchor β Dists (N,K) | |
| - Convert to soft probabilities with exp(-Dist/T), normalize row-wise β P (N,K) | |
| - Uncertainty = entropy(P) (cf. LoT Eq. (6)) | |
| - Top-anchor argmin distance for "consistency"-style comparisons (cf. Eq. (5)) | |
| Returns (Dists, P, entropy) | |
| """ | |
| # Distances (N, K) | |
| dists = pairwise_distances(H, anchors, metric="euclidean") # (N,K) | |
| # Soft assignments | |
| logits = -dists / max(temperature, 1e-6) | |
| # Stable softmax | |
| logits = logits - logits.max(axis=1, keepdims=True) | |
| P = np.exp(logits) | |
| P /= P.sum(axis=1, keepdims=True) + 1e-12 | |
| # Uncertainty (entropy) | |
| H_unc = -np.sum(P * np.log(P + 1e-12), axis=1) | |
| return dists, P, H_unc | |
| # ====== Dimensionality reduction / embeddings ================================================ | |
| def fit_umap_2d(pool: np.ndarray, | |
| n_neighbors: int = 30, | |
| min_dist: float = 0.05, | |
| metric: str = "cosine", | |
| random_state: int = 42) -> umap.UMAP: | |
| """ | |
| Fit UMAP once on a diverse pool across layers to preserve orientation. | |
| Later layers call .transform() to embed into the SAME 2D space β "MRI stack". | |
| """ | |
| reducer = umap.UMAP(n_components=2, n_neighbors=n_neighbors, min_dist=min_dist, | |
| metric=metric, random_state=random_state) | |
| reducer.fit(pool) | |
| return reducer | |
| def fit_umap_3d(all_states: np.ndarray, | |
| n_neighbors: int = 30, | |
| min_dist: float = 0.05, | |
| metric: str = "cosine", | |
| random_state: int = 42) -> np.ndarray: | |
| """ | |
| Fit a global 3D UMAP embedding for all states at once (alternative to slice stack). | |
| Returns coords_3d (N,3) for the concatenated states passed in. | |
| """ | |
| reducer = umap.UMAP(n_components=3, n_neighbors=n_neighbors, min_dist=min_dist, | |
| metric=metric, random_state=random_state) | |
| return reducer.fit_transform(all_states) | |
| # ====== Visualization ======================================================================== | |
| def plotly_3d_layers(xy_layers: List[np.ndarray], | |
| layer_tokens: List[List[str]], | |
| layer_cluster_labels: List[np.ndarray], | |
| layer_pos_tags: List[List[str]], | |
| layer_uncertainty: List[np.ndarray], | |
| layer_graphs: List[nx.Graph], | |
| color_by: str = "cluster", | |
| title: str = "3D Cluster Formation", | |
| prompt: str = None,) -> go.Figure: | |
| fig_data = [] | |
| # Define categorical colormap for POS | |
| pos_map = { | |
| "NOUN": "#1f77b4", "VERB": "#d62728", "ADJ": "#2ca02c", | |
| "ADV": "#ff7f0e", "PRON": "#9467bd", "DET": "#8c564b", | |
| "ADP": "#e377c2", "NUM": "#7f7f7f", "PUNCT": "#bcbd22", | |
| "SYM": "#17becf", "UNK": "#bababa" | |
| } | |
| L = len(xy_layers) | |
| for l, (xy, tokens, labels, pos, unc, G) in enumerate(zip(xy_layers, layer_tokens, layer_cluster_labels, layer_pos_tags, layer_uncertainty, layer_graphs)): | |
| if len(xy) == 0: continue | |
| x, y = xy[:, 0], xy[:, 1] | |
| z = np.full_like(x, l, dtype=float) | |
| # Color Logic | |
| if color_by == "pos": | |
| # Map POS strings to colors | |
| node_colors = [pos_map.get(p, "#333333") for p in pos] | |
| show_scale = False | |
| colorscale = None | |
| else: | |
| # Cluster ID | |
| node_colors = labels | |
| show_scale = (l == 0) | |
| colorscale = 'Viridis' | |
| # Hover Text | |
| node_text = [ | |
| f"L{l} | {tok}<br>POS: {p}<br>Cluster: {c}<br>Unc: {u:.2f}" | |
| for tok, p, c, u in zip(tokens, pos, labels, unc) | |
| ] | |
| node_trace = go.Scatter3d( | |
| x=x, y=y, z=z, | |
| mode='markers', | |
| name=f"Layer {l}", | |
| showlegend=False, | |
| marker=dict( | |
| size=3, | |
| opacity=1, | |
| color=node_colors, | |
| colorscale=colorscale, | |
| showscale=show_scale, | |
| colorbar=dict(title="Cluster ID") if show_scale else None | |
| ), | |
| text=node_text, | |
| hovertemplate="%{text}<extra></extra>" | |
| ) | |
| fig_data.append(node_trace) | |
| # Edges | |
| if G is not None and G.number_of_edges() > 0: | |
| edge_x, edge_y, edge_z = [], [], [] | |
| for u, v in G.edges(): | |
| edge_x += [x[u], x[v], None] | |
| edge_y += [y[u], y[v], None] | |
| edge_z += [z[u], z[v], None] | |
| edge_trace = go.Scatter3d( | |
| x=edge_x, y=edge_y, z=edge_z, | |
| mode='lines', | |
| line=dict(width=2, color='red'), | |
| opacity=0.6, | |
| hoverinfo='skip', | |
| showlegend=False | |
| ) | |
| fig_data.append(edge_trace) | |
| # Trajectories (connect same token across layers) | |
| if L > 1: | |
| T = len(xy_layers[0]) | |
| # Sample trajectories to avoid lag if T is huge | |
| step = max(1, T // 100) | |
| for i in range(0, T, step): | |
| xs = [xy_layers[l][i, 0] for l in range(L)] | |
| ys = [xy_layers[l][i, 1] for l in range(L)] | |
| zs = list(range(L)) | |
| traj = go.Scatter3d( | |
| x=xs, y=ys, z=zs, | |
| mode='lines', | |
| line=dict(width=3, color='rgba(50,50,50,0.5)'), | |
| hoverinfo='skip', | |
| showlegend=False | |
| ) | |
| fig_data.append(traj) | |
| if color_by == "pos": | |
| # Add legend-only traces for POS categories actually present | |
| present_pos = sorted({p for layer in layer_pos_tags for p in layer}) | |
| for p in present_pos: | |
| fig_data.append( | |
| go.Scatter3d( | |
| x=[None], y=[None], z=[None], # legend-only | |
| mode="markers", | |
| name=p, | |
| marker=dict(size=8, color=pos_map.get(p, "#333333")), | |
| showlegend=True, | |
| hoverinfo="skip" | |
| ) | |
| ) | |
| fig = go.Figure(data=fig_data) | |
| fig.update_layout( | |
| title=dict( | |
| text=title, | |
| x=0.5, | |
| xanchor="center", | |
| ), | |
| annotations=[ | |
| dict( | |
| text=f"<b>Prompt:</b> {prompt}", | |
| x=0.5, | |
| y=1.02, | |
| xref="paper", | |
| yref="paper", | |
| showarrow=False, | |
| font=dict(size=13), | |
| align="center" | |
| ) | |
| ] if prompt else [], | |
| scene=dict( | |
| xaxis_title="UMAP X", | |
| yaxis_title="UMAP Y", | |
| zaxis_title="Layer Depth", | |
| aspectratio=dict(x=1, y=1, z=1.5) | |
| ), | |
| height=900, | |
| margin=dict(l=0, r=0, b=0, t=40) | |
| ) | |
| return fig | |
| def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts: bool = False): | |
| seed_everything(42) | |
| # 1. Extract Hidden States | |
| from transformers import logging | |
| logging.set_verbosity_error() | |
| # Extract | |
| main_bundle = extract_hidden_states(model, tok, main_text, cfg.max_length, device) | |
| layers_np = main_bundle.hidden_layers | |
| tokens = main_bundle.tokens | |
| L_all = len(layers_np) | |
| # 2. Get POS Tags | |
| pos_tags = get_pos_tags(main_text, tok, tokens) | |
| # 3. Pooling & Anchors (LoT) | |
| # (Simplified: just pool from the main text for speed in demo) | |
| pool_states = np.vstack([layers_np[l] for l in range(0, L_all, 2)]) | |
| idx = np.random.choice(len(pool_states), min(len(pool_states), 2000), replace=False) | |
| anchors = fit_global_anchors(pool_states[idx], cfg.anchor_k) | |
| # 4. Process Layers | |
| layer_features = [] | |
| layer_uncertainties = [] | |
| layer_graphs = [] | |
| layer_cluster_labels = [] | |
| percolation = [] | |
| for l in range(L_all): | |
| H = layers_np[l] | |
| # Features & Uncertainty | |
| dists, P, H_unc = anchor_features(H, anchors, cfg.anchor_temp) | |
| layer_features.append(dists) | |
| layer_uncertainties.append(H_unc) | |
| # Graphs | |
| if cfg.graph_mode == "knn": | |
| G = build_knn_graph(dists, cfg.knn_k, metric="euclidean") | |
| else: | |
| G = build_threshold_graph(H, cfg.sim_threshold, use_cosine=cfg.use_cosine) | |
| layer_graphs.append(G) | |
| # Clusters | |
| labels = cluster_layer(dists, G, cfg.cluster_method, | |
| cfg.n_clusters_kmeans, cfg.hdbscan_min_cluster_size) | |
| layer_cluster_labels.append(labels) | |
| # Percolation | |
| percolation.append(percolation_stats(G)) | |
| # 5. UMAP & Alignment | |
| # Fit UMAP on the pool to establish a coordinate system | |
| reducer = umap.UMAP(n_components=2, n_neighbors=cfg.umap_n_neighbors, | |
| min_dist=cfg.umap_min_dist, metric=cfg.umap_metric, random_state=42) | |
| reducer.fit(pool_states[idx]) | |
| xy_by_layer = [] | |
| for l in range(L_all): | |
| # Transform into 2D | |
| xy = reducer.transform(layers_np[l]) | |
| # Procrustes Alignment: Align layer L to L-1 | |
| if cfg.align_layers and l > 0: | |
| xy = orthogonal_align(xy_by_layer[l-1], xy) | |
| xy_by_layer.append(xy) | |
| # 6. Plot | |
| fig = plotly_3d_layers( | |
| xy_layers=xy_by_layer, | |
| layer_tokens=[tokens] * L_all, | |
| layer_cluster_labels=layer_cluster_labels, | |
| layer_pos_tags=[pos_tags] * L_all, | |
| layer_uncertainty=layer_uncertainties, | |
| layer_graphs=layer_graphs, | |
| color_by=cfg.color_by, | |
| title=f"{cfg.model_name.rsplit("/", 1)[-1]} 3D MRI | Color: {cfg.color_by.upper()} | Aligned: {cfg.align_layers}", | |
| prompt=main_text | |
| ) | |
| # 7. Save Artifacts (This is the missing part) | |
| if save_artifacts: | |
| import os | |
| # Create the directory if it doesn't exist | |
| os.makedirs(cfg.out_dir, exist_ok=True) | |
| # Construct the full path | |
| out_path = os.path.join(cfg.out_dir, cfg.plotly_html) | |
| # Write the HTML file | |
| fig.write_html(out_path) | |
| print(f"Successfully saved 3D plot to: {out_path}") | |
| return fig, {"percolation": percolation, "tokens": tokens} | |
| def get_model_and_tok(model_name: str): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| config = AutoConfig.from_pretrained(model_name, output_hidden_states=True, trust_remote_code=True) | |
| tok = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True) | |
| if tok.pad_token_id is None: | |
| tok.pad_token = tok.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| config=config, | |
| torch_dtype=dtype if device == "cuda" else None, | |
| device_map="auto" if device == "cuda" else None | |
| ) | |
| model.eval() | |
| if device != "cuda": | |
| model = model.to(device) | |
| return model, tok, device, dtype | |
| def main(): | |
| st.set_page_config(page_title="LLM Hidden Layer Explorer", layout="wide") | |
| st.title("Token Embedding Explorer (Live Hidden States)") | |
| with st.sidebar: | |
| st.header("Model / Input") | |
| model_name = st.selectbox("Model", MODELS, index=1) | |
| max_length = st.slider("Max tokens", 16, 256, 64, step=16) | |
| st.header("Graph") | |
| graph_mode = st.selectbox("Graph mode", ["knn", "threshold"], index=0) | |
| knn_k = st.slider("k (kNN)", 2, 50, 8) if graph_mode == "knn" else 8 | |
| sim_threshold = st.slider("Similarity threshold", 0.0, 0.99, 0.70, step=0.01) if graph_mode == "threshold" else 0.70 | |
| use_cosine = st.checkbox("Use cosine similarity", value=True) | |
| st.header("Anchors / LoT") | |
| anchor_k = st.slider("anchor_k", 4, 64, 16, step=1) | |
| anchor_temp = st.slider("anchor_temp", 0.05, 2.0, 0.7, step=0.05) | |
| st.header("UMAP") | |
| umap_n_neighbors = st.slider("n_neighbors", 5, 100, 30, step=1) | |
| umap_min_dist = st.slider("min_dist", 0.0, 0.99, 0.05, step=0.01) | |
| umap_metric = st.selectbox("metric", ["cosine", "euclidean"], index=0) | |
| st.header("Performance") | |
| fit_pool_per_layer = st.slider("fit_pool_per_layer", 64, 2048, 512, step=64) | |
| st.header("Outputs") | |
| save_artifacts = st.checkbox("Save artifacts to disk (HTML/CSV/NPZ)", value=False) | |
| prompt_col, run_col = st.columns([4, 1]) | |
| with prompt_col: | |
| main_text = st.selectbox( | |
| "Prompt to visualize (hidden states computed on this text)", | |
| options=DEFAULT_CORPUS, | |
| index=0, | |
| help="Select a predefined prompt for analysis" | |
| ) | |
| with run_col: | |
| st.write("") | |
| st.write("") | |
| run_btn = st.button("Run", type="primary") | |
| cfg = Config( | |
| model_name=model_name, | |
| max_length=max_length, | |
| corpus=None, # keep using DEFAULT_CORPUS for pooling unless you expose it | |
| graph_mode=graph_mode, | |
| knn_k=knn_k, | |
| sim_threshold=sim_threshold, | |
| use_cosine=use_cosine, | |
| anchor_k=anchor_k, | |
| anchor_temp=anchor_temp, | |
| umap_n_neighbors=umap_n_neighbors, | |
| umap_min_dist=umap_min_dist, | |
| umap_metric=umap_metric, | |
| fit_pool_per_layer=fit_pool_per_layer, | |
| # keep other defaults | |
| ) | |
| if run_btn: | |
| if not main_text.strip(): | |
| st.error("Please enter some text.") | |
| return | |
| with st.spinner("Loading model (cached after first run)..."): | |
| model, tok, device, dtype = get_model_and_tok(cfg.model_name) | |
| # optionally pass compute_volume to pipeline (recommended) | |
| # e.g., run_pipeline(..., compute_volume=compute_volume) | |
| with st.spinner("Running pipeline (hidden states β features β UMAP β Plotly)..."): | |
| fig, outputs = run_pipeline( | |
| cfg=cfg, | |
| model=model, | |
| tok=tok, | |
| device=device, | |
| main_text=main_text, | |
| save_artifacts=save_artifacts, | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.success(f"Loaded {cfg.model_name} on {device} ({dtype})") | |
| with st.expander("Percolation summary"): | |
| percolation = outputs.get("percolation", []) | |
| for l, stt in enumerate(percolation): | |
| st.write(f"L={l:02d} | Ο={stt['phi']:.3f} | #C={stt['num_clusters']} | Ο={stt['chi']:.2f}") | |
| with st.expander("Debug: config"): | |
| st.json(asdict(cfg)) | |
| # ====== 9. Main ================================================================================= | |
| if __name__ == "__main__": | |
| torch.set_grad_enabled(False) | |
| main() |