LeMat-GenBench / about.py
cgeorgiaw's picture
cgeorgiaw HF Staff
added marked for relaxed
99351a1
import os
from huggingface_hub import HfApi
PROBLEM_TYPES = ["unconditional", "conditional"]
TOKEN = os.environ.get("HF_TOKEN")
CACHE_PATH=os.getenv("HF_HOME", ".")
API = HfApi(token=TOKEN)
organization="LeMaterial"
submissions_repo = f'{organization}/lemat-gen-bench-submissions'
results_repo = f'{organization}/lemat-genbench-results'
# Column display names mapping
COLUMN_DISPLAY_NAMES = {
'model_name': 'Model',
'n_structures': 'Total Structures',
# Validity metrics
'overall_valid_count': 'Valid',
'charge_neutral_count': 'Charge Neutral',
'distance_valid_count': 'Distance Valid',
'plausibility_valid_count': 'Plausibility Valid',
# Uniqueness and Novelty
'unique_count': 'Unique',
'novel_count': 'Novel',
# Energy-based metrics
'mean_formation_energy': 'Formation Energy (eV)',
'formation_energy_std': 'Formation Energy Std',
'stability_mean_above_hull': 'E Above Hull (eV)',
'stability_std_e_above_hull': 'E Above Hull Std',
'mean_relaxation_RMSD': 'Relaxation RMSD (Γ…)',
'relaxation_RMSE_std': 'Relaxation RMSD Std',
# Stability metrics
'stable_count': 'Stable',
'unique_in_stable_count': 'Unique in Stable',
'sun_count': 'SUN',
# Metastability metrics
'metastable_count': 'Metastable',
'unique_in_metastable_count': 'Unique in Metastable',
'msun_count': 'MSUN',
# Distribution metrics
'JSDistance': 'JS Distance',
'MMD': 'MMD',
'FrechetDistance': 'FID',
# Diversity metrics
'element_diversity': 'Element Diversity',
'space_group_diversity': 'Space Group Diversity',
'site_diversity': 'Atomic Site Diversity',
'physical_size_diversity': 'Crystal Size Diversity',
# HHI metrics
'hhi_production_mean': 'HHI Production',
'hhi_reserve_mean': 'HHI Reserve',
'hhi_combined_mean': 'HHI Combined',
}
# Metrics that can be shown as percentages (count-based metrics)
COUNT_BASED_METRICS = [
'overall_valid_count',
'charge_neutral_count',
'distance_valid_count',
'plausibility_valid_count',
'unique_count',
'novel_count',
'stable_count',
'unique_in_stable_count',
'sun_count',
'metastable_count',
'unique_in_metastable_count',
'msun_count',
]
# Metric groups for organized display
METRIC_GROUPS = {
'Validity ↑': [
'overall_valid_count',
'charge_neutral_count',
'distance_valid_count',
'plausibility_valid_count',
],
'Uniqueness & Novelty ↑': [
'unique_count',
'novel_count',
],
'Energy Metrics ↓': [
'stability_mean_above_hull',
'stability_std_e_above_hull',
'mean_formation_energy',
'formation_energy_std',
'mean_relaxation_RMSD',
'relaxation_RMSE_std',
],
'Stability ↑': [
'stable_count',
'unique_in_stable_count',
'sun_count',
],
'Metastability ↑': [
'metastable_count',
'unique_in_metastable_count',
'msun_count',
],
'Distribution ↓': [
'JSDistance',
'MMD',
'FrechetDistance',
],
'Diversity ↑': [
'element_diversity',
'space_group_diversity',
'site_diversity',
'physical_size_diversity',
],
'HHI ↓': [
'hhi_production_mean',
'hhi_reserve_mean',
],
}
# Color coding for metric families (background colors with transparency)
METRIC_GROUP_COLORS = {
'Validity ↑': 'rgba(33, 150, 243, 0.15)', # Light blue with transparency
'Uniqueness & Novelty ↑': 'rgba(156, 39, 176, 0.15)', # Light purple with transparency
'Energy Metrics ↓': 'rgba(255, 152, 0, 0.15)', # Light orange with transparency
'Stability ↑': 'rgba(76, 175, 80, 0.15)', # Light green with transparency
'Metastability ↑': 'rgba(139, 195, 74, 0.15)', # Light lime with transparency
'Distribution ↓': 'rgba(233, 30, 99, 0.15)', # Light pink with transparency
'Diversity ↑': 'rgba(0, 188, 212, 0.15)', # Light cyan with transparency
'HHI ↓': 'rgba(255, 193, 7, 0.15)', # Light amber with transparency
}
# Map each column to its group for styling
def get_column_to_group_mapping():
"""Returns a dict mapping column names to their metric group."""
col_to_group = {}
for group_name, cols in METRIC_GROUPS.items():
for col in cols:
col_to_group[col] = group_name
return col_to_group
COLUMN_TO_GROUP = get_column_to_group_mapping()
# Compact view columns (most important metrics visible without scrolling)
COMPACT_VIEW_COLUMNS = [
'model_name',
'overall_valid_count',
'unique_count',
'novel_count',
'stable_count',
'metastable_count',
'sun_count',
'msun_count',
'stability_mean_above_hull',
'mean_formation_energy',
'mean_relaxation_RMSD',
]
# Full view columns (all metrics organized by groups)
FULL_VIEW_COLUMNS = ['model_name', 'n_structures']
for group_name, cols in METRIC_GROUPS.items():
FULL_VIEW_COLUMNS.extend(cols)
# Default columns for backward compatibility
DEFAULT_COLUMNS = COMPACT_VIEW_COLUMNS