chessXAI / visualization_demo.py
seredapj's picture
Upload 19 files
0c536cc verified
import marimo
__generated_with = "0.8.22"
app = marimo.App(width="medium")
@app.cell
def __():
import marimo as mo
return (mo,)
@app.cell
def __():
import pandas as pd
df = pd.read_csv("our_visualization/datasets/test_set.csv")
df.head()
return df, pd
@app.cell
def __():
import pickle
from utils import ChessBoard
import onnxruntime as ort
from leela_board import _idx_to_move_bn, _idx_to_move_wn
import numpy as np
from onnx2torch import convert
import onnx
import torch
import os
def get_models(root="/Users/sereda/Documents/chessXAI/our_visualization/models"):
paths = os.listdir(root)
model_paths = []
for path in paths:
if ".onnx" in path: model_paths.append(os.path.join(root, path))
return model_paths
def get_activations_from_model(model_path, pattern, fen):
# Write hooks for selected model path
def register_hooks_for_capture(model, pattern):
activations = {}
def get_activation(name):
def hook(module, input, output):
activations[name] = output.detach().numpy()
return hook
handles = []
for n, m in model.named_modules():
if pattern in n:
handle = m.register_forward_hook(get_activation(n))
handles.append(handle)
return activations, handles
# Load model and register hooks for it
model = convert(onnx.load(model_path))
act, handles = register_hooks_for_capture(model, pattern)
# Get fen and pass it through model to generate activations
board = ChessBoard(fen)
inputs = board.t
_, _, _ = model(inputs.unsqueeze(dim=0))
# Remove handles
[h.remove() for h in handles]
return act
return (
ChessBoard,
convert,
get_activations_from_model,
get_models,
np,
onnx,
ort,
os,
pickle,
torch,
)
@app.cell
def __(df, mo):
min_elo, max_elo = df["Rating"].min() // 100 * 100, df["Rating"].max() // 100 * 100
elo_list = [f"{elo}" for elo in range(min_elo, max_elo + 100, 100)]
dropdown_elo = mo.ui.dropdown(value = "1000", options=elo_list, label=f"Select rating in range of {min_elo} - {max_elo}")
dropdown_elo
return dropdown_elo, elo_list, max_elo, min_elo
@app.cell
def __(df, dropdown_elo, mo):
unique_themes = set()
df_rated = df[(df["Rating"] >= int(dropdown_elo.value)) & (df["Rating"] <= int(dropdown_elo.value) + 100)]
for i in range(len(df_rated)):
themes = df_rated.iloc[i]["Themes"].split(" ")
for theme in themes: unique_themes.add(theme)
unique_themes_list = list(unique_themes)
unique_themes_list.sort()
dropdown_themes = mo.ui.dropdown(value=unique_themes_list[0], options=unique_themes_list, label=f"Select puzzle theme")
dropdown_themes
return (
df_rated,
dropdown_themes,
i,
theme,
themes,
unique_themes,
unique_themes_list,
)
@app.cell
def __(df_rated, dropdown_themes):
themes_mask = []
def _(themes_mask):
for i in range(len(df_rated)):
themes_new = df_rated.iloc[i]["Themes"].split(" ")
if dropdown_themes.value in themes_new: themes_mask.append(i)
_(themes_mask)
fens = list(df_rated.iloc[themes_mask]["FEN"])
df_rated.iloc[themes_mask][["FEN", "Moves", "Themes", "Rating"]]
return fens, themes_mask
@app.cell
def __(fens, mo):
dropdown_fen = mo.ui.dropdown(value = fens[0], options=fens, label="Select FEN")
dropdown_fen
return (dropdown_fen,)
@app.cell
def __(df_rated, dropdown_fen, mo):
moves = df_rated[df_rated["FEN"] == dropdown_fen.value]["Moves"].iloc[0].split(" ")
player_moves = moves[1::2]
board_moves = []
def _(board_moves):
for i in range(len(player_moves)):
board_moves.append(moves[:2 * i + 1])
_(board_moves)
moves_dict = {pm: om for pm, om in zip(player_moves, board_moves)}
dropdown_moves = mo.ui.dropdown(options=moves_dict, value=player_moves[0], label="Select which player move to look at")
# print(moves)
dropdown_moves
return board_moves, dropdown_moves, moves, moves_dict, player_moves
@app.cell
def __(dropdown_moves, mo):
dropdown_layer = mo.ui.dropdown(value="0", options=[f"{i}" for i in range(15)], label="Select layer (smaller - closer to input)")
focus_square = mo.ui.text_area(value=dropdown_moves.selected_key[:2], placeholder="Input square to look at (e.g. a1, b8, ...")
mo.vstack([dropdown_layer, focus_square])
return dropdown_layer, focus_square
@app.cell
def __(ChessBoard, dropdown_fen, dropdown_moves):
def _():
board = ChessBoard(dropdown_fen.value)
for move in dropdown_moves.value:
print(move)
# board.move(move)
return board.board.pc_board.fen()
FEN = _()
return (FEN,)
@app.cell
def __(focus_square):
import chess
from global_data import global_data
focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
def set_plotting_parameters(act, layer_number, fen):
layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}")
print(act.keys())
global_data.model = 'test'
global_data.activations = act[layer_key][0, :, ::-1 , :]
print(global_data.activations.shape)
global_data.subplot_rows = 8
global_data.subplot_cols = 4
global_data.board = chess.Board(fen)
global_data.show_all_heads = True
# global_data.selected_head = 1
global_data.visualization_mode = 'ROW'
global_data.focused_square_ind = focus_square_ind
# global_data.heatmap_horizontal_gap = 0.001
global_data.visualization_mode_is_64x64 = False
global_data.colorscale_mode = "mode1"
global_data.show_colorscale = False
return chess, focus_square_ind, global_data, set_plotting_parameters
@app.cell
def __(
FEN,
dropdown_layer,
get_activations_from_model,
get_models,
set_plotting_parameters,
):
# FEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
# board = ChessBoard("r1b2rk1/pp2pp1p/6p1/3Qb2q/1P4n1/2P1BN2/P2N1PPP/R4RK1 w - - 0 14")
# board.move("f3e5")
# FEN = board.board.pc_board.fen()
PATTERN = "mha/QK/softmax"
# PATTERN = "smolgen_weights"
MODEL = get_models()[-1]
ACTIVATIONS = get_activations_from_model(MODEL, PATTERN, FEN)
set_plotting_parameters(ACTIVATIONS, int(dropdown_layer.value), FEN)
from activation_heatmap import heatmap_figure
fig = heatmap_figure()
fig.update_layout(height=1500, width=1200)
fig
return ACTIVATIONS, MODEL, PATTERN, fig, heatmap_figure
@app.cell
def __():
# Add fens after opponents moves
# Default squares of interest
return
if __name__ == "__main__":
app.run()