| | import sys |
| | from pathlib import Path |
| |
|
| | |
| | sys.path.append(str(Path.cwd())) |
| |
|
| | from model import load_model |
| | from mlx.utils import tree_flatten |
| |
|
| |
|
| | def run_diagnostic_checks(): |
| | """ |
| | Performs the verification checks outlined in the review. |
| | """ |
| | print("--- Running Diagnostic Checks ---") |
| |
|
| | |
| | try: |
| | model = load_model(".") |
| | print("Successfully loaded model definition.") |
| | except Exception as e: |
| | print(f"Error loading model: {e}") |
| | return |
| |
|
| | |
| | try: |
| | params = model.parameters() |
| | num_params = sum(p.size for _, p in tree_flatten(params)) |
| | print(f"Total number of parameters: {num_params / 1e6:.2f}M") |
| | except Exception as e: |
| | print(f"Error calculating parameters: {e}") |
| |
|
| | |
| | print("--- Verifying MLP Weight Shapes ---") |
| | try: |
| | first_block = model.layers[0] |
| | args = model.args |
| | print(f"use_dual_mlp detected: {args.use_dual_mlp}") |
| |
|
| | if args.use_dual_mlp: |
| | g_up_shape = first_block.feed_forward.g_up.weight.shape |
| | p_up_shape = first_block.feed_forward.p_up.weight.shape |
| | print(f"Gated MLP branch (g_up) weight shape: {g_up_shape}") |
| | print(f"Plain MLP branch (p_up) weight shape: {p_up_shape}") |
| | assert g_up_shape == (args.intermediate_size, args.hidden_size) |
| | assert p_up_shape == (args.intermediate_size_mlp, args.hidden_size) |
| | print("DualMLP weight shapes are correct.") |
| | else: |
| | gate_proj_shape = first_block.feed_forward.gate_proj.weight.shape |
| | up_proj_shape = first_block.feed_forward.up_proj.weight.shape |
| | print(f"SwiGLUMLP gate_proj weight shape: {gate_proj_shape}") |
| | print(f"SwiGLUMLP up_proj weight shape: {up_proj_shape}") |
| | assert gate_proj_shape == (args.intermediate_size_mlp, args.hidden_size) |
| | assert up_proj_shape == (args.intermediate_size_mlp, args.hidden_size) |
| | print("SwiGLUMLP weight shapes are correct.") |
| |
|
| | except AttributeError as e: |
| | print( |
| | f"Error accessing MLP weights. It seems the structure is not as expected: {e}" |
| | ) |
| | except AssertionError: |
| | print("Error: MLP weight shapes do not match the configuration.") |
| | except Exception as e: |
| | print(f"An unexpected error occurred while verifying shapes: {e}") |
| |
|
| | |
| | print("--- Verifying Embedding Shape ---") |
| | try: |
| | embedding_shape = model.tok_embeddings.weight.shape |
| | print(f"Embedding weight shape: {embedding_shape}") |
| |
|
| | args = model.args |
| | print(f"Expected embedding shape: ({args.vocab_size}, {args.hidden_size})") |
| |
|
| | assert embedding_shape == (args.vocab_size, args.hidden_size) |
| | print("Embedding shape is correct.") |
| | except Exception as e: |
| | print(f"An unexpected error occurred while verifying embedding shape: {e}") |
| |
|
| | print("--- Sanity Checking Loaded Weights ---") |
| | try: |
| | |
| | if model.args.use_dual_mlp: |
| | _ = model.layers[0].feed_forward.g_gate.weight |
| | _ = model.layers[0].feed_forward.g_up.weight |
| | _ = model.layers[0].feed_forward.g_down.weight |
| | _ = model.layers[0].feed_forward.p_up.weight |
| | _ = model.layers[0].feed_forward.p_down.weight |
| | print("Found dual-branch MLP weights in the model.") |
| | else: |
| | _ = model.layers[0].feed_forward.gate_proj.weight |
| | _ = model.layers[0].feed_forward.up_proj.weight |
| | _ = model.layers[0].feed_forward.down_proj.weight |
| | print("Found SwiGLU MLP weights in the model.") |
| | print("Weight presence sanity check passed.") |
| | except Exception as e: |
| | print(f"An error occurred during sanity check: {e}") |
| |
|
| | print("--- Diagnostic Checks Complete ---") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | run_diagnostic_checks() |
| |
|