import chess.engine from constants import ROOT_DIR, CONTENT_HEIGHT, LEFT_PANE_WIDTH, EXPORT_FORMAT, EXPORT_SCALE from time import sleep # from test_array import activations_array from copy import deepcopy from board2planes import board2planes import yaml import os from os.path import isdir, join import sys import numpy as np SIMULATE_TF = False #TODO: Remove this option, deprecated # turn off tensorflow importing and generate random data to speed up development DEV_MODE = False SIMULATED_LAYERS = 6 SIMULATED_HEADS = 64 FIXED_ROW = None # 1 #None to disable FIXED_COL = None # 5 #None to disable if DEV_MODE: class DummyModel: def __init__(self, layers, heads): self.layers = layers self.heads = heads def __call__(self, *args, **kwargs): data = [np.random.rand(1, self.heads, 64, 64) for i in range(self.layers)] return [None, None, None, data] else: import tensorflow as tf from tensorflow.compat.v1 import ConfigProto from tensorflow.compat.v1 import InteractiveSession # class to hold data, state and configurations # Dash is stateless and in general it is very bad idea to store data in global variables on server side # However, this application is ment to be run by single user on local machine, so it is safe to store data and state # information on global object class GlobalData: def __init__(self): import os if not DEV_MODE: # import tensorflow as tf os.environ["CUDA_VISIBLE_DEVICES"] = "0" # from tensorflow.compat.v1 import ConfigProto # from tensorflow.compat.v1 import InteractiveSession # import chess # import matplotlib.patheffects as path_effects #config = ConfigProto() #config.gpu_options.allow_growth = True #session = InteractiveSession(config=config) #tf.keras.backend.clear_session() self.tmp = 0 self.export_format = EXPORT_FORMAT self.export_scale = EXPORT_SCALE self.fen = 'rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1' # '2kr3r/ppp2b2/2n4p/4p3/Q2Pq1pP/2P1N3/PP3PP1/R1B1KB1R w KQ - 3 18'#'6n1/1p1k4/3p4/pNp5/P1P4p/7P/1P4KP/r7 w - - 2 121'# self.board = chess.Board(fen=self.fen) self.focused_square_ind = 0 self.active_move_table_cell = None # tuple (row_ind, col_id), e.g. (12, 'White') self.activations = None # activations_array self.visualization_mode = 'ROW' self.visualization_mode_is_64x64 = False self.subplot_mode = 'big' #'fit' # big'#'fit'#, 'big' self.subplot_cols = 0 self.subplot_rows = 0 self.number_of_heads = 0 self.selected_head = None self.show_all_heads = True self.show_colorscale = False self.colorscale_mode = 'mode1' self.figure_container_height = '100%' # '100%' self.running_counter = 0 # used to pass new values to hidden indicator elements which will trigger follow-up callback self.grid_has_changed = False # self.has_subplot_grid_changed = True # self.figure_layout_images = None #store layout and only recalculate when subplot grid has changed # self.figure_layout_annotations = None # self.need_update_axis = True self.screen_w = 0 self.screen_h = 0 self.figure_w = 0 self.figure_h = 0 self.heatmap_w = 0 self.heatmap_h = 0 self.heatmap_fig_w = 0 self.heatmap_fig_h = 0 self.heatmap_gap = 0 self.colorscale_x_offset = 0 self.heatmap_horizontal_gap = 0.275 self.figure_cache = {} self.update_grid_shape() self.pgn_data = [] # list of boards in pgn self.move_table_boards = {} # dict of boards in pgn, key is (move_table.row_ind, move_table.column_id) if not SIMULATE_TF: self.selected_layer = None else: self.selected_layer = 0 self.nr_of_layers_in_body = -1 self.has_attention_policy = False self.model_paths = [] self.model_names = [] self.model_yamls = {} #key = model path, value = yaml of that model self.model_cache = {} self.find_models2() self.model_path = None#self.model_paths[0] # '/home/jusufe/PycharmProjects/lc0-attention-visualizer/T12_saved_model_1M' self.model = None self.tfp = None #TensorflowProcess if not SIMULATE_TF: self.load_model() self.activations_data = None if self.model is not None or SIMULATE_TF: self.update_activations_data() if self.selected_layer is not None: self.set_layer(self.selected_layer) self.move_table_active_cell = None self.force_update_graph = False def set_subplot_mode(self, fit_to_page): if fit_to_page == [True]: self.subplot_mode = 'fit' else: self.subplot_mode = 'big' self.update_grid_shape() def set_screen_size(self, w, h): self.screen_w = w self.screen_h = h self.figure_w = w*LEFT_PANE_WIDTH/100 self.figure_h = h*CONTENT_HEIGHT/100 print('GRAPH AREA', self.figure_w, self.figure_h) def set_heatmap_size(self, size): if size != '1': #print('-----------------------HEATMAP SIZE', size) # w, h = size # print('TYETETETETEU', global_data.screen_w) # global_data.set_screen_size(w, h) #print('>>>>>: HEATMAP WIDTH', size[0]) #print('>>>>>: HEATMAP HEIGHT', size[1]) #print('>>>>>: FIG WIDTH', size[2]) #print('>>>>>: FIG HEIGHT', size[3]) #print('>>>>>: HEATMAP GAP', size[4]) self.heatmap_w = float(size[0]) self.heatmap_h = float(size[1]) self.heatmap_fig_w = float(size[2]) self.heatmap_fig_h = float(size[3]) self.heatmap_gap = round(float(size[4]), 2) self.colorscale_x_offset = float(size[5])/self.heatmap_fig_w if size[6] == 1: self.force_update_graph = True else: self.force_update_graph = False #if self.heatmap_gap < 30: # self.heatmap_horizontal_gap += 0.025 # self.heatmap_horizontal_gap = min(0.25, self.heatmap_horizontal_gap) #if self.heatmap_gap < 200: # self.heatmap_horizontal_gap += -0.025 # self.heatmap_horizontal_gap = max(0.1, self.heatmap_horizontal_gap) def set_colorscale_mode(self, mode, colorscale_mode, colorscale_mode_64x64, show): if mode == '64x64': self.colorscale_mode = colorscale_mode_64x64 else: self.colorscale_mode = colorscale_mode #print('SHOW value', show) self.show_colorscale = show == [True] def cache_figure(self, fig): if not self.check_if_figure_is_cached() and fig != {}: key = self.get_figure_cache_key() cached_fig = deepcopy(fig) cached_fig.update_layout({'coloraxis1': None}, overwrite=True) #print('CACHING FIGURE:') self.figure_cache[key] = cached_fig def get_cached_figure(self): if self.check_if_figure_is_cached(): key = self.get_figure_cache_key() fig = deepcopy(self.figure_cache[key]) else: fig = None return fig def check_if_figure_is_cached(self): key = self.get_figure_cache_key() return key in self.figure_cache def get_figure_cache_key(self): return (self.subplot_rows, self.subplot_cols, self.visualization_mode_is_64x64, self.selected_head if not self.show_all_heads else -1, self.show_colorscale, self.colorscale_mode, self.board.turn) #return (self.subplot_rows, self.subplot_cols, self.visualization_mode_is_64x64, self.selected_head if not self.show_all_heads else -1, self.heatmap_horizontal_gap, self.heatmap_fig_h, self.heatmap_fig_w) #return (self.subplot_rows, self.subplot_cols, self.visualization_mode_is_64x64, self.show_all_heads) def get_side_to_move(self): return ['Black', 'White'][self.board.turn] def load_model(self): if self.model_path in self.model_cache: self.model, self.tfp = self.model_cache[self.model_path] elif self.model_path is not None: #net = '/home/jusufe/Projects/lc0/BT1024-3142c-swa-186000.pb.gz' #yaml_path = '/home/jusufe/Downloads/cfg.yaml' if not DEV_MODE: net = self.model_path yaml_path = self.model_yamls[self.model_path] with open(yaml_path) as f: cfg = f.read() cfg = yaml.safe_load(cfg) if 'dropout_rate' in cfg['model']: print('Setting dropout_rate to 0.0') cfg['model']['dropout_rate'] = 0.0 tfp = tfprocess.TFProcess(cfg) tfp.init_net() tfp.replace_weights(net, ignore_errors=True) self.model = tfp.model self.tfp = tfp else: self.model = DummyModel(SIMULATED_LAYERS, SIMULATED_HEADS) self.tfp = None else: self.model = None self.tfp = None def find_models(self): root = ROOT_DIR models_root_folder = os.path.join(root, 'models') model_folders = [f for f in os.listdir(models_root_folder) if isdir(join(models_root_folder, f))] model_paths = [os.path.relpath(join(models_root_folder, f)) for f in os.listdir(models_root_folder) if isdir(join(models_root_folder, f))] self.model_names = model_folders self.model_paths = model_paths #print('MODELS:') #print(self.model_names) #print(self.model_paths) def find_models2(self): import os from os.path import isdir, join root = ROOT_DIR models_root_folder = os.path.join(root, 'models') model_folders = [f for f in os.listdir(models_root_folder) if isdir(join(models_root_folder, f))] model_paths = [os.path.relpath(join(models_root_folder, f)) for f in os.listdir(models_root_folder) if isdir(join(models_root_folder, f))] models = [] paths = [] yamls = [] for path in model_paths: yaml_files = [file for file in os.listdir(path) if file.endswith(".yaml")] if len(yaml_files) != 1: continue model_files = [file for file in os.listdir(path) if file.endswith(".pb.gz")] if len(model_files) == 0: continue models += model_files paths += [os.path.relpath(join(path, f)) for f in model_files] yaml_file = os.path.relpath(join(path, yaml_files[0])) yamls += [yaml_file]*len(model_files) self.model_yamls = {path: yaml_file for path, yaml_file in zip(paths, yamls)} self.model_names = models self.model_paths = paths#model_paths def update_activations_data(self): if self.model is not None and self.selected_layer is None: self.selected_layer = 0 if not SIMULATE_TF: if self.selected_layer is not None and self.model is not None and self.selected_layer != 'Smolgen': if not DEV_MODE: inputs = board2planes(self.board) inputs = tf.reshape(tf.convert_to_tensor(inputs, dtype=tf.float32), [-1, 112, 8, 8]) else: inputs = None outputs = self.model(inputs) self.activations_data = outputs[-1] for i,x in enumerate(self.activations_data): print( 'LAYERS', i, x.shape) #smolgen = self.tfp.smol_weight_gen_dense.get_weights()[0].reshape((256, 64, 64)) #print('Smolgen') #print(type(smolgen)) #print(smolgen.shape) #print(type(smolgen[0])) #print(smolgen[0].shape) #_, _, _, self.activations_data = self.model(inputs) elif self.selected_layer == 'Smolgen' and self.tfp is not None and self.tfp.use_smolgen: weights = self.tfp.smol_weight_gen_dense.get_weights()[0] self.activations_data = weights.reshape((weights.shape[0], 64, 64)) print('TYPEEEEE', type(self.activations_data)) else: layers = SIMULATED_LAYERS heads = SIMULATED_HEADS self.activations_data = [np.random.rand(1, heads, 64, 64) for i in range(layers)] if self.model is not None: if self.model_path not in self.model_cache: self.model_cache[self.model_path] = [self.model, self.tfp] self.update_layers_in_body_count() #TODO: figure out better way to determine if we have policy attention weights #TODO: What happens if policy vis is selected and user switches to model without policy layer? Take care of this case. if self.activations_data is not None and self.activations_data[-2].shape == (1, 8, 24): self.has_attention_policy = True else: self.has_attention_policy = False # self.update_selected_activation_data() # self.activations = self.activations_data[self.selected_layer] def update_grid_shape(self): # TODO: add client side callback triggered by Interval component to save window or precise container dimensions to Div # TODO: Trigger server side figure update callback when dimensions are recorded and store in global_data # TODO: If needed, recalculate subplot rows and cols and container scaler based on the changed dimension def calc_cols(heads, rows): if heads % rows == 0: cols = int(heads / rows) else: cols = int(1 + heads / rows) return cols if FIXED_ROW and FIXED_COL: self.subplot_cols = FIXED_COL self.subplot_rows = FIXED_ROW return None heads = self.number_of_heads if self.subplot_mode == 'fit': max_rows_in_screen = 4 if heads <= 4: rows = 1 elif heads <= 8: rows = 2 else: rows = heads // 8 + int(heads % 8 != 0) elif self.subplot_mode == 'big': #print(heads) max_rows_in_screen = 2 rows = heads // 4 + int(heads % 4 != 0) #print(rows) if rows > max_rows_in_screen: container_height = f'{int((rows / max_rows_in_screen) * 100)}%' else: container_height = '100%' if rows != 0: cols = calc_cols(heads, rows) else: cols = 0 if self.subplot_rows != rows or self.subplot_cols != cols: self.grid_has_changed = True self.subplot_cols = cols self.subplot_rows = rows if self.show_all_heads: self.figure_container_height = container_height else: self.figure_container_height = '100%' def update_selected_activation_data(self): # import numpy as np # self.activations = activations_array + np.random.rand(8, 64, 64) if self.activations_data is not None: if self.selected_layer not in ('Policy', 'Smolgen'): if not DEV_MODE: activations = tf.squeeze(self.activations_data[self.selected_layer], axis=0).numpy() #self.activations = activations[:, ::-1, :] #Flip along y-axis else: activations = np.squeeze(self.activations_data[self.selected_layer], axis=0) elif self.selected_layer == 'Policy': print('RAW POLICY SHAPE', self.activations_data[-1].shape) activations = self.activations_data[-1].numpy() #print('POLICY SHAPE', activations.shape) #print('RAW POLICY SHAPE', self.activations_data[-1].shape) #activations = np.squeeze(self.activations_data[-1].numpy(), axis=0) #shape 64,64 #promo = np.squeeze(self.activations_data[-2].numpy(), axis=0) #shape 8,24 #print('promo shape:', promo.shape) #if self.board.turn: # pad_shape = (48, 8) #else: # pad_shape = (8, 48) #promo_padded = np.pad(promo, (pad_shape, (0, 0)), mode='constant', constant_values=None) #shape 64,24 #self.activations = np.expand_dims(np.concatenate((activations, promo_padded), axis=1), axis=0)#shape 1,64,88 #print('POLICY SHAPE', self.activations.shape) elif self.selected_layer == 'Smolgen': activations = self.tfp.smol_weight_gen_dense.get_weights()[0].reshape((256, 64, 64)) self.activations = activations[:, ::-1, :] # Flip along y-axis def set_visualization_mode(self, mode): self.visualization_mode = mode self.visualization_mode_is_64x64 = mode == '64x64' def set_layer(self, layer): self.selected_layer = layer self.update_selected_activation_data() if layer not in ('Policy', 'Smolgen'): self.number_of_heads = self.activations_data[self.selected_layer].shape[1] elif layer == 'Policy': self.number_of_heads = 1 elif layer == 'Smolgen': self.number_of_heads = self.activations.shape[0] self.set_head(0) self.update_grid_shape() def set_head(self, head): self.selected_head = head def set_model(self, model): if model != self.model_path: self.model_path = model self.load_model() self.update_activations_data() self.update_selected_activation_data() self.number_of_heads = self.activations_data[self.selected_layer].shape[1] if self.selected_head is None: self.selected_head = 0 else: self.selected_head = min(self.selected_head, self.number_of_heads - 1) self.update_grid_shape() if SIMULATE_TF: sleep(2) def update_layers_in_body_count(self): # TODO: figure out robust way to separate attention layers in body from the rest. UPDATE: Use yaml heads = self.activations_data[0].shape[1] for ind, layer in enumerate(self.activations_data): if layer.shape[1] != heads or len(layer.shape) != 4: ind = ind - 1 break self.nr_of_layers_in_body = ind + 1 if self.selected_layer not in ('Policy', 'Smolgen'): self.selected_layer = min(self.selected_layer, self.nr_of_layers_in_body - 1) def get_head_data(self, head): if self.activations.shape[0] <= head: return None if self.visualization_mode == '64x64': # print('64x64 selection') data = self.activations[head, :, :] elif self.visualization_mode == 'ROW': # print('ROW selection') if self.board.turn or self.selected_layer == 'Smolgen': #White turn to move row = 63 - self.focused_square_ind data = self.activations[head, row, :].reshape((8, 8)) else: #row = self.focused_square_ind multiples = self.focused_square_ind // 8 remainder = self.focused_square_ind % 8 a = 7 - remainder b = multiples * 8 row = a + b data = self.activations[head, row, :].reshape((8, 8))[::-1, :] else: # print('COL selection') if self.board.turn or self.selected_layer == 'Smolgen': #White turn to move col = self.focused_square_ind data = self.activations[head, :, col].reshape((8, 8))[::-1, ::-1] else: focused = 63 - self.focused_square_ind multiples = focused // 8 remainder = focused % 8 a = 7 - remainder b = multiples * 8 col = a + b #print('COL!!!!!!!!!!!!!!!!!', col, a, b, focused, self.focused_square_ind) data = self.activations[head, :, col].reshape((8, 8))[:, ::-1] return data def set_fen(self, fen): self.board.set_fen(fen) self.fen = fen self.update_activations_data() self.update_selected_activation_data() def set_board(self, board): self.board = deepcopy(board) self.update_activations_data() self.update_selected_activation_data() global_data = GlobalData() print('global data created')