from pathlib import Path import json import pandas as pd import numpy as np import gradio as gr from datasets import load_dataset from gradio_leaderboard import Leaderboard from datetime import datetime import os from about import ( PROBLEM_TYPES, TOKEN, CACHE_PATH, API, submissions_repo, results_repo, COLUMN_DISPLAY_NAMES, COUNT_BASED_METRICS, METRIC_GROUPS, METRIC_GROUP_COLORS, COLUMN_TO_GROUP ) def get_leaderboard(): ds = load_dataset(results_repo, split='train', download_mode="force_redownload") full_df = pd.DataFrame(ds) print(full_df.columns) if len(full_df) == 0: return pd.DataFrame({'date':[], 'model':[], 'score':[], 'verified':[]}) return full_df def format_dataframe(df, show_percentage=False, selected_groups=None, compact_view=True): """Format the dataframe with proper column names and optional percentages.""" if len(df) == 0: return df # Build column list based on view mode selected_cols = ['model_name'] if compact_view: # Use predefined compact columns from about import COMPACT_VIEW_COLUMNS selected_cols = [col for col in COMPACT_VIEW_COLUMNS if col in df.columns] else: # Build from selected groups if 'n_structures' in df.columns: selected_cols.append('n_structures') # If no groups selected, show all if not selected_groups: selected_groups = list(METRIC_GROUPS.keys()) # Add columns from selected groups for group in selected_groups: if group in METRIC_GROUPS: for col in METRIC_GROUPS[group]: if col in df.columns and col not in selected_cols: selected_cols.append(col) # Create a copy with selected columns display_df = df[selected_cols].copy() # Add relaxed symbol to model name if relaxed column is True if 'relaxed' in df.columns and 'model_name' in display_df.columns: display_df['model_name'] = df.apply( lambda row: f"{row['model_name']} ⚡" if row.get('relaxed', False) else row['model_name'], axis=1 ) # Convert count-based metrics to percentages if requested if show_percentage and 'n_structures' in df.columns: n_structures = df['n_structures'] for col in COUNT_BASED_METRICS: if col in display_df.columns: # Calculate percentage and format as string with % display_df[col] = (df[col] / n_structures * 100).round(1).astype(str) + '%' # Round numeric columns for cleaner display for col in display_df.columns: if display_df[col].dtype in ['float64', 'float32']: display_df[col] = display_df[col].round(4) # Rename columns for display display_df = display_df.rename(columns=COLUMN_DISPLAY_NAMES) # Apply color coding based on metric groups styled_df = apply_color_styling(display_df, selected_cols) return styled_df def apply_color_styling(display_df, original_cols): """Apply background colors to dataframe based on metric groups using pandas Styler.""" def style_by_group(x): # Create a DataFrame with the same shape filled with empty strings styles = pd.DataFrame('', index=x.index, columns=x.columns) # Map display column names back to original column names for i, display_col in enumerate(x.columns): if i < len(original_cols): original_col = original_cols[i] # Check if this column belongs to a metric group if original_col in COLUMN_TO_GROUP: group = COLUMN_TO_GROUP[original_col] color = METRIC_GROUP_COLORS.get(group, '') if color: styles[display_col] = f'background-color: {color}' return styles # Apply the styling function return display_df.style.apply(style_by_group, axis=None) def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df, sort_by, sort_direction): """Update the leaderboard based on user selections. Uses cached dataframe to avoid re-downloading data on every change. """ # Use cached dataframe instead of re-downloading df_to_format = cached_df.copy() # Convert display name back to raw column name for sorting if sort_by and sort_by != "None": # Create reverse mapping from display names to raw column names display_to_raw = {v: k for k, v in COLUMN_DISPLAY_NAMES.items()} raw_column_name = display_to_raw.get(sort_by, sort_by) if raw_column_name in df_to_format.columns: ascending = (sort_direction == "Ascending") df_to_format = df_to_format.sort_values(by=raw_column_name, ascending=ascending) formatted_df = format_dataframe(df_to_format, show_percentage, selected_groups, compact_view) return formatted_df def show_output_box(message): return gr.update(value=message, visible=True) def submit_cif_files(problem_type, cif_files, relaxed, profile: gr.OAuthProfile | None): # TODO: Implement submission logic that includes the relaxed flag return def generate_metric_legend_html(): """Generate HTML table with color-coded metric group legend.""" metric_details = { 'Validity ↑': ('Valid, Charge Neutral, Distance Valid, Plausibility Valid', '↑ Higher is better'), 'Uniqueness & Novelty ↑': ('Unique, Novel', '↑ Higher is better'), 'Energy Metrics ↓': ('E Above Hull, Formation Energy, Relaxation RMSD (with std)', '↓ Lower is better'), 'Stability ↑': ('Stable, Unique in Stable, SUN', '↑ Higher is better'), 'Metastability ↑': ('Metastable, Unique in Metastable, MSUN', '↑ Higher is better'), 'Distribution ↓': ('JS Distance, MMD, FID', '↓ Lower is better'), 'Diversity ↑': ('Element, Space Group, Atomic Site, Crystal Size', '↑ Higher is better'), 'HHI ↓': ('HHI Production, HHI Reserve', '↓ Lower is better'), } html = '' html += '' html += '' html += '' html += '' html += '' html += '' for group, color in METRIC_GROUP_COLORS.items(): metrics, direction = metric_details.get(group, ('', '')) group_name = group.replace('↑', '').replace('↓', '').strip() html += '' html += f'' html += f'' html += f'' html += f'' html += '' html += '
ColorGroupMetricsDirection
{group_name}{metrics}{direction}
' return html def gradio_interface() -> gr.Blocks: with gr.Blocks() as demo: gr.Markdown("## Welcome to the LeMaterial Generative Benchmark Leaderboard!") with gr.Tabs(elem_classes="tab-buttons"): with gr.TabItem("🚀 Leaderboard", elem_id="boundary-benchmark-tab-table"): gr.Markdown("# LeMat-GenBench") # Display options with gr.Row(): with gr.Column(scale=1): compact_view = gr.Checkbox( value=True, label="Compact View", info="Show only key metrics" ) show_percentage = gr.Checkbox( value=True, label="Show as Percentages", info="Display count-based metrics as percentages of total structures" ) with gr.Column(scale=1): # Create choices with display names, but values are the raw column names sort_choices = ["None"] + [COLUMN_DISPLAY_NAMES.get(col, col) for col in COLUMN_DISPLAY_NAMES.keys()] sort_by = gr.Dropdown( choices=sort_choices, value="None", label="Sort By", info="Select column to sort by" ) sort_direction = gr.Radio( choices=["Ascending", "Descending"], value="Descending", label="Sort Direction" ) with gr.Column(scale=2): selected_groups = gr.CheckboxGroup( choices=list(METRIC_GROUPS.keys()), value=list(METRIC_GROUPS.keys()), label="Metric Families (only active when Compact View is off)", info="Select which metric groups to display" ) # Metric legend with color coding with gr.Accordion("Metric Groups Legend", open=False): gr.HTML(generate_metric_legend_html()) try: # Initial dataframe - load once and cache initial_df = get_leaderboard() cached_df_state = gr.State(initial_df) formatted_df = format_dataframe(initial_df, show_percentage=True, selected_groups=list(METRIC_GROUPS.keys()), compact_view=True) leaderboard_table = gr.Dataframe( label="GenBench Leaderboard", value=formatted_df, interactive=False, wrap=True, column_widths=["180px"] + ["160px"] * (len(formatted_df.columns) - 1) if len(formatted_df.columns) > 0 else None, show_fullscreen_button=True ) # Update dataframe when options change (using cached data) show_percentage.change( fn=update_leaderboard, inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], outputs=leaderboard_table ) selected_groups.change( fn=update_leaderboard, inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], outputs=leaderboard_table ) compact_view.change( fn=update_leaderboard, inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], outputs=leaderboard_table ) sort_by.change( fn=update_leaderboard, inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], outputs=leaderboard_table ) sort_direction.change( fn=update_leaderboard, inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], outputs=leaderboard_table ) except Exception as e: gr.Markdown(f"Leaderboard is empty or error loading: {str(e)}") gr.Markdown("Verified submissions mean the results came from a model submission rather than a CIF submission.") with gr.TabItem("❔About", elem_id="boundary-benchmark-tab-table"): gr.Markdown( """ ## About LeMat-Gen-Bench **Welcome to the LeMat-Bench Leaderboard**, There are unconditional generation and conditional generation components of this leaderboard. """) with gr.TabItem("✉️ Submit", elem_id="boundary-benchmark-tab-table"): gr.Markdown( """ # Materials Submission Upload a CSV, pkl, or a ZIP of CIFs with your structures. """ ) filename = gr.State(value=None) gr.LoginButton() with gr.Row(): with gr.Column(): problem_type = gr.Dropdown(PROBLEM_TYPES, label="Problem Type") with gr.Column(): cif_file = gr.File(label="Upload a CSV, a pkl, or a ZIP of CIF files.") relaxed = gr.Checkbox( value=False, label="Structures are already relaxed", info="Check this box if your submitted structures have already been relaxed" ) submit_btn = gr.Button("Submission") message = gr.Textbox(label="Status", lines=1, visible=False) # help message gr.Markdown("If you have issues with submission or using the leaderboard, please start a discussion in the Community tab of this Space.") submit_btn.click( submit_cif_files, inputs=[problem_type, cif_file, relaxed], outputs=[message, filename], ).then( fn=show_output_box, inputs=[message], outputs=[message], ) return demo if __name__ == "__main__": gradio_interface().launch()