|
|
import marimo |
|
|
|
|
|
__generated_with = "0.8.22" |
|
|
app = marimo.App(width="medium") |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def __(): |
|
|
import marimo as mo |
|
|
return (mo,) |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def __(): |
|
|
import pandas as pd |
|
|
df = pd.read_csv("our_visualization/datasets/test_set.csv") |
|
|
df.head() |
|
|
return df, pd |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def __(): |
|
|
import pickle |
|
|
from utils import ChessBoard |
|
|
import onnxruntime as ort |
|
|
from leela_board import _idx_to_move_bn, _idx_to_move_wn |
|
|
import numpy as np |
|
|
from onnx2torch import convert |
|
|
import onnx |
|
|
import torch |
|
|
import os |
|
|
|
|
|
def get_models(root="/Users/sereda/Documents/chessXAI/our_visualization/models"): |
|
|
paths = os.listdir(root) |
|
|
model_paths = [] |
|
|
for path in paths: |
|
|
if ".onnx" in path: model_paths.append(os.path.join(root, path)) |
|
|
return model_paths |
|
|
|
|
|
def get_activations_from_model(model_path, pattern, fen): |
|
|
|
|
|
def register_hooks_for_capture(model, pattern): |
|
|
activations = {} |
|
|
def get_activation(name): |
|
|
def hook(module, input, output): |
|
|
activations[name] = output.detach().numpy() |
|
|
return hook |
|
|
|
|
|
handles = [] |
|
|
for n, m in model.named_modules(): |
|
|
if pattern in n: |
|
|
handle = m.register_forward_hook(get_activation(n)) |
|
|
handles.append(handle) |
|
|
return activations, handles |
|
|
|
|
|
|
|
|
model = convert(onnx.load(model_path)) |
|
|
act, handles = register_hooks_for_capture(model, pattern) |
|
|
|
|
|
|
|
|
board = ChessBoard(fen) |
|
|
inputs = board.t |
|
|
_, _, _ = model(inputs.unsqueeze(dim=0)) |
|
|
|
|
|
|
|
|
[h.remove() for h in handles] |
|
|
return act |
|
|
return ( |
|
|
ChessBoard, |
|
|
convert, |
|
|
get_activations_from_model, |
|
|
get_models, |
|
|
np, |
|
|
onnx, |
|
|
ort, |
|
|
os, |
|
|
pickle, |
|
|
torch, |
|
|
) |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def __(df, mo): |
|
|
min_elo, max_elo = df["Rating"].min() // 100 * 100, df["Rating"].max() // 100 * 100 |
|
|
elo_list = [f"{elo}" for elo in range(min_elo, max_elo + 100, 100)] |
|
|
dropdown_elo = mo.ui.dropdown(value = "1000", options=elo_list, label=f"Select rating in range of {min_elo} - {max_elo}") |
|
|
dropdown_elo |
|
|
return dropdown_elo, elo_list, max_elo, min_elo |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def __(df, dropdown_elo, mo): |
|
|
unique_themes = set() |
|
|
df_rated = df[(df["Rating"] >= int(dropdown_elo.value)) & (df["Rating"] <= int(dropdown_elo.value) + 100)] |
|
|
for i in range(len(df_rated)): |
|
|
themes = df_rated.iloc[i]["Themes"].split(" ") |
|
|
for theme in themes: unique_themes.add(theme) |
|
|
unique_themes_list = list(unique_themes) |
|
|
unique_themes_list.sort() |
|
|
|
|
|
dropdown_themes = mo.ui.dropdown(value=unique_themes_list[0], options=unique_themes_list, label=f"Select puzzle theme") |
|
|
dropdown_themes |
|
|
return ( |
|
|
df_rated, |
|
|
dropdown_themes, |
|
|
i, |
|
|
theme, |
|
|
themes, |
|
|
unique_themes, |
|
|
unique_themes_list, |
|
|
) |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def __(df_rated, dropdown_themes): |
|
|
themes_mask = [] |
|
|
def _(themes_mask): |
|
|
for i in range(len(df_rated)): |
|
|
themes_new = df_rated.iloc[i]["Themes"].split(" ") |
|
|
if dropdown_themes.value in themes_new: themes_mask.append(i) |
|
|
_(themes_mask) |
|
|
fens = list(df_rated.iloc[themes_mask]["FEN"]) |
|
|
df_rated.iloc[themes_mask][["FEN", "Moves", "Themes", "Rating"]] |
|
|
return fens, themes_mask |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def __(fens, mo): |
|
|
dropdown_fen = mo.ui.dropdown(value = fens[0], options=fens, label="Select FEN") |
|
|
dropdown_fen |
|
|
return (dropdown_fen,) |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def __(df_rated, dropdown_fen, mo): |
|
|
moves = df_rated[df_rated["FEN"] == dropdown_fen.value]["Moves"].iloc[0].split(" ") |
|
|
player_moves = moves[1::2] |
|
|
board_moves = [] |
|
|
def _(board_moves): |
|
|
for i in range(len(player_moves)): |
|
|
board_moves.append(moves[:2 * i + 1]) |
|
|
_(board_moves) |
|
|
moves_dict = {pm: om for pm, om in zip(player_moves, board_moves)} |
|
|
dropdown_moves = mo.ui.dropdown(options=moves_dict, value=player_moves[0], label="Select which player move to look at") |
|
|
|
|
|
dropdown_moves |
|
|
return board_moves, dropdown_moves, moves, moves_dict, player_moves |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def __(dropdown_moves, mo): |
|
|
dropdown_layer = mo.ui.dropdown(value="0", options=[f"{i}" for i in range(15)], label="Select layer (smaller - closer to input)") |
|
|
focus_square = mo.ui.text_area(value=dropdown_moves.selected_key[:2], placeholder="Input square to look at (e.g. a1, b8, ...") |
|
|
mo.vstack([dropdown_layer, focus_square]) |
|
|
return dropdown_layer, focus_square |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def __(ChessBoard, dropdown_fen, dropdown_moves): |
|
|
def _(): |
|
|
board = ChessBoard(dropdown_fen.value) |
|
|
for move in dropdown_moves.value: |
|
|
print(move) |
|
|
|
|
|
return board.board.pc_board.fen() |
|
|
FEN = _() |
|
|
return (FEN,) |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def __(focus_square): |
|
|
import chess |
|
|
from global_data import global_data |
|
|
|
|
|
focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a") |
|
|
|
|
|
def set_plotting_parameters(act, layer_number, fen): |
|
|
layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}") |
|
|
print(act.keys()) |
|
|
global_data.model = 'test' |
|
|
global_data.activations = act[layer_key][0, :, ::-1 , :] |
|
|
print(global_data.activations.shape) |
|
|
global_data.subplot_rows = 8 |
|
|
global_data.subplot_cols = 4 |
|
|
global_data.board = chess.Board(fen) |
|
|
global_data.show_all_heads = True |
|
|
|
|
|
global_data.visualization_mode = 'ROW' |
|
|
global_data.focused_square_ind = focus_square_ind |
|
|
|
|
|
|
|
|
global_data.visualization_mode_is_64x64 = False |
|
|
global_data.colorscale_mode = "mode1" |
|
|
global_data.show_colorscale = False |
|
|
return chess, focus_square_ind, global_data, set_plotting_parameters |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def __( |
|
|
FEN, |
|
|
dropdown_layer, |
|
|
get_activations_from_model, |
|
|
get_models, |
|
|
set_plotting_parameters, |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PATTERN = "mha/QK/softmax" |
|
|
|
|
|
MODEL = get_models()[-1] |
|
|
ACTIVATIONS = get_activations_from_model(MODEL, PATTERN, FEN) |
|
|
set_plotting_parameters(ACTIVATIONS, int(dropdown_layer.value), FEN) |
|
|
from activation_heatmap import heatmap_figure |
|
|
fig = heatmap_figure() |
|
|
fig.update_layout(height=1500, width=1200) |
|
|
fig |
|
|
return ACTIVATIONS, MODEL, PATTERN, fig, heatmap_figure |
|
|
|
|
|
|
|
|
@app.cell |
|
|
def __(): |
|
|
|
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.run() |
|
|
|