| | import argparse |
| | import json |
| | from pathlib import Path |
| | from safetensors import safe_open |
| |
|
| |
|
| | def check_model_shape(model_path: str): |
| | """Inspects a model's config and weights to determine its MLP structure.""" |
| | model_path = Path(model_path) |
| | config_path = model_path / "config.json" |
| | weights_path = model_path / "model.safetensors" |
| |
|
| | if not config_path.exists(): |
| | print(f"Error: config.json not found in {model_path}") |
| | return |
| |
|
| | if not weights_path.exists(): |
| | print(f"Error: model.safetensors not found in {model_path}") |
| | return |
| |
|
| | print(f"--- Checking model shape in {model_path} ---") |
| |
|
| | |
| | with open(config_path, "r") as f: |
| | config = json.load(f) |
| |
|
| | has_dual_mlp_config = config.get("intermediate_size_mlp", 0) > 0 |
| | print(f"Config has 'intermediate_size_mlp': {has_dual_mlp_config}") |
| |
|
| | |
| | has_dual_mlp_weights = False |
| | try: |
| | with safe_open(weights_path, framework="mlx") as f: |
| | weight_keys = f.keys() |
| | |
| | |
| | for key in weight_keys: |
| | if ( |
| | "mlp" in key |
| | and "gate_proj" not in key |
| | and "up_proj" not in key |
| | and "down_proj" not in key |
| | ): |
| | print(f"Found potential dual-branch weight: {key}") |
| | has_dual_mlp_weights = True |
| | break |
| | except Exception as e: |
| | print(f"Could not read weights from model.safetensors: {e}") |
| | return |
| |
|
| | print(f"Found potential dual-branch MLP weights: {has_dual_mlp_weights}") |
| |
|
| | |
| | print("\n--- Conclusion ---") |
| | if has_dual_mlp_config and has_dual_mlp_weights: |
| | print("✅ The model appears to be a DUAL-BRANCH MLP variant.") |
| | elif has_dual_mlp_config and not has_dual_mlp_weights: |
| | print( |
| | "⚠️ The model configuration suggests a dual-branch MLP, but no corresponding weights were found." |
| | ) |
| | print(" It will likely run as a SINGLE-BRANCH model.") |
| | else: |
| | print("✅ The model appears to be a SINGLE-BRANCH MLP variant.") |
| | print("--------------------\n") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser( |
| | description="Check the MLP shape of a model variant." |
| | ) |
| | parser.add_argument( |
| | "model_path", |
| | type=str, |
| | nargs="?", |
| | default=".", |
| | help="Path to the model directory to check.", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | check_model_shape(args.model_path) |
| |
|