import os from huggingface_hub import HfApi PROBLEM_TYPES = ["unconditional", "conditional"] TOKEN = os.environ.get("HF_TOKEN") CACHE_PATH=os.getenv("HF_HOME", ".") API = HfApi(token=TOKEN) organization="LeMaterial" submissions_repo = f'{organization}/lemat-gen-bench-submissions' results_repo = f'{organization}/lemat-genbench-results' # Column display names mapping COLUMN_DISPLAY_NAMES = { 'model_name': 'Model', 'n_structures': 'Total Structures', # Validity metrics 'overall_valid_count': 'Valid', 'charge_neutral_count': 'Charge Neutral', 'distance_valid_count': 'Distance Valid', 'plausibility_valid_count': 'Plausibility Valid', # Uniqueness and Novelty 'unique_count': 'Unique', 'novel_count': 'Novel', # Energy-based metrics 'mean_formation_energy': 'Formation Energy (eV)', 'formation_energy_std': 'Formation Energy Std', 'stability_mean_above_hull': 'E Above Hull (eV)', 'stability_std_e_above_hull': 'E Above Hull Std', 'mean_relaxation_RMSD': 'Relaxation RMSD (Å)', 'relaxation_RMSE_std': 'Relaxation RMSD Std', # Stability metrics 'stable_count': 'Stable', 'unique_in_stable_count': 'Unique in Stable', 'sun_count': 'SUN', # Metastability metrics 'metastable_count': 'Metastable', 'unique_in_metastable_count': 'Unique in Metastable', 'msun_count': 'MSUN', # Distribution metrics 'JSDistance': 'JS Distance', 'MMD': 'MMD', 'FrechetDistance': 'FID', # Diversity metrics 'element_diversity': 'Element Diversity', 'space_group_diversity': 'Space Group Diversity', 'site_diversity': 'Atomic Site Diversity', 'physical_size_diversity': 'Crystal Size Diversity', # HHI metrics 'hhi_production_mean': 'HHI Production', 'hhi_reserve_mean': 'HHI Reserve', 'hhi_combined_mean': 'HHI Combined', } # Metrics that can be shown as percentages (count-based metrics) COUNT_BASED_METRICS = [ 'overall_valid_count', 'charge_neutral_count', 'distance_valid_count', 'plausibility_valid_count', 'unique_count', 'novel_count', 'stable_count', 'unique_in_stable_count', 'sun_count', 'metastable_count', 'unique_in_metastable_count', 'msun_count', ] # Metric groups for organized display METRIC_GROUPS = { 'Validity ↑': [ 'overall_valid_count', 'charge_neutral_count', 'distance_valid_count', 'plausibility_valid_count', ], 'Uniqueness & Novelty ↑': [ 'unique_count', 'novel_count', ], 'Energy Metrics ↓': [ 'stability_mean_above_hull', 'stability_std_e_above_hull', 'mean_formation_energy', 'formation_energy_std', 'mean_relaxation_RMSD', 'relaxation_RMSE_std', ], 'Stability ↑': [ 'stable_count', 'unique_in_stable_count', 'sun_count', ], 'Metastability ↑': [ 'metastable_count', 'unique_in_metastable_count', 'msun_count', ], 'Distribution ↓': [ 'JSDistance', 'MMD', 'FrechetDistance', ], 'Diversity ↑': [ 'element_diversity', 'space_group_diversity', 'site_diversity', 'physical_size_diversity', ], 'HHI ↓': [ 'hhi_production_mean', 'hhi_reserve_mean', ], } # Color coding for metric families (background colors with transparency) METRIC_GROUP_COLORS = { 'Validity ↑': 'rgba(33, 150, 243, 0.15)', # Light blue with transparency 'Uniqueness & Novelty ↑': 'rgba(156, 39, 176, 0.15)', # Light purple with transparency 'Energy Metrics ↓': 'rgba(255, 152, 0, 0.15)', # Light orange with transparency 'Stability ↑': 'rgba(76, 175, 80, 0.15)', # Light green with transparency 'Metastability ↑': 'rgba(139, 195, 74, 0.15)', # Light lime with transparency 'Distribution ↓': 'rgba(233, 30, 99, 0.15)', # Light pink with transparency 'Diversity ↑': 'rgba(0, 188, 212, 0.15)', # Light cyan with transparency 'HHI ↓': 'rgba(255, 193, 7, 0.15)', # Light amber with transparency } # Map each column to its group for styling def get_column_to_group_mapping(): """Returns a dict mapping column names to their metric group.""" col_to_group = {} for group_name, cols in METRIC_GROUPS.items(): for col in cols: col_to_group[col] = group_name return col_to_group COLUMN_TO_GROUP = get_column_to_group_mapping() # Compact view columns (most important metrics visible without scrolling) COMPACT_VIEW_COLUMNS = [ 'model_name', 'overall_valid_count', 'unique_count', 'novel_count', 'stable_count', 'metastable_count', 'sun_count', 'msun_count', 'stability_mean_above_hull', 'mean_formation_energy', 'mean_relaxation_RMSD', ] # Full view columns (all metrics organized by groups) FULL_VIEW_COLUMNS = ['model_name', 'n_structures'] for group_name, cols in METRIC_GROUPS.items(): FULL_VIEW_COLUMNS.extend(cols) # Default columns for backward compatibility DEFAULT_COLUMNS = COMPACT_VIEW_COLUMNS