Spaces:
Running
Running
| from pathlib import Path | |
| import json | |
| import pandas as pd | |
| import numpy as np | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from gradio_leaderboard import Leaderboard | |
| from datetime import datetime | |
| import os | |
| from about import ( | |
| PROBLEM_TYPES, TOKEN, CACHE_PATH, API, submissions_repo, results_repo, | |
| COLUMN_DISPLAY_NAMES, COUNT_BASED_METRICS, METRIC_GROUPS, | |
| METRIC_GROUP_COLORS, COLUMN_TO_GROUP | |
| ) | |
| def get_leaderboard(): | |
| ds = load_dataset(results_repo, split='train', download_mode="force_redownload") | |
| full_df = pd.DataFrame(ds) | |
| print(full_df.columns) | |
| if len(full_df) == 0: | |
| return pd.DataFrame({'date':[], 'model':[], 'score':[], 'verified':[]}) | |
| return full_df | |
| def format_dataframe(df, show_percentage=False, selected_groups=None, compact_view=True): | |
| """Format the dataframe with proper column names and optional percentages.""" | |
| if len(df) == 0: | |
| return df | |
| # Build column list based on view mode | |
| selected_cols = ['model_name'] | |
| if compact_view: | |
| # Use predefined compact columns | |
| from about import COMPACT_VIEW_COLUMNS | |
| selected_cols = [col for col in COMPACT_VIEW_COLUMNS if col in df.columns] | |
| else: | |
| # Build from selected groups | |
| if 'n_structures' in df.columns: | |
| selected_cols.append('n_structures') | |
| # If no groups selected, show all | |
| if not selected_groups: | |
| selected_groups = list(METRIC_GROUPS.keys()) | |
| # Add columns from selected groups | |
| for group in selected_groups: | |
| if group in METRIC_GROUPS: | |
| for col in METRIC_GROUPS[group]: | |
| if col in df.columns and col not in selected_cols: | |
| selected_cols.append(col) | |
| # Create a copy with selected columns | |
| display_df = df[selected_cols].copy() | |
| # Add relaxed symbol to model name if relaxed column is True | |
| if 'relaxed' in df.columns and 'model_name' in display_df.columns: | |
| display_df['model_name'] = df.apply( | |
| lambda row: f"{row['model_name']} ⚡" if row.get('relaxed', False) else row['model_name'], | |
| axis=1 | |
| ) | |
| # Convert count-based metrics to percentages if requested | |
| if show_percentage and 'n_structures' in df.columns: | |
| n_structures = df['n_structures'] | |
| for col in COUNT_BASED_METRICS: | |
| if col in display_df.columns: | |
| # Calculate percentage and format as string with % | |
| display_df[col] = (df[col] / n_structures * 100).round(1).astype(str) + '%' | |
| # Round numeric columns for cleaner display | |
| for col in display_df.columns: | |
| if display_df[col].dtype in ['float64', 'float32']: | |
| display_df[col] = display_df[col].round(4) | |
| # Rename columns for display | |
| display_df = display_df.rename(columns=COLUMN_DISPLAY_NAMES) | |
| # Apply color coding based on metric groups | |
| styled_df = apply_color_styling(display_df, selected_cols) | |
| return styled_df | |
| def apply_color_styling(display_df, original_cols): | |
| """Apply background colors to dataframe based on metric groups using pandas Styler.""" | |
| def style_by_group(x): | |
| # Create a DataFrame with the same shape filled with empty strings | |
| styles = pd.DataFrame('', index=x.index, columns=x.columns) | |
| # Map display column names back to original column names | |
| for i, display_col in enumerate(x.columns): | |
| if i < len(original_cols): | |
| original_col = original_cols[i] | |
| # Check if this column belongs to a metric group | |
| if original_col in COLUMN_TO_GROUP: | |
| group = COLUMN_TO_GROUP[original_col] | |
| color = METRIC_GROUP_COLORS.get(group, '') | |
| if color: | |
| styles[display_col] = f'background-color: {color}' | |
| return styles | |
| # Apply the styling function | |
| return display_df.style.apply(style_by_group, axis=None) | |
| def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df, sort_by, sort_direction): | |
| """Update the leaderboard based on user selections. | |
| Uses cached dataframe to avoid re-downloading data on every change. | |
| """ | |
| # Use cached dataframe instead of re-downloading | |
| df_to_format = cached_df.copy() | |
| # Convert display name back to raw column name for sorting | |
| if sort_by and sort_by != "None": | |
| # Create reverse mapping from display names to raw column names | |
| display_to_raw = {v: k for k, v in COLUMN_DISPLAY_NAMES.items()} | |
| raw_column_name = display_to_raw.get(sort_by, sort_by) | |
| if raw_column_name in df_to_format.columns: | |
| ascending = (sort_direction == "Ascending") | |
| df_to_format = df_to_format.sort_values(by=raw_column_name, ascending=ascending) | |
| formatted_df = format_dataframe(df_to_format, show_percentage, selected_groups, compact_view) | |
| return formatted_df | |
| def show_output_box(message): | |
| return gr.update(value=message, visible=True) | |
| def submit_cif_files(model_name, problem_type, cif_files, relaxed, profile: gr.OAuthProfile | None): | |
| """Submit structures to the leaderboard.""" | |
| from huggingface_hub import upload_file | |
| # Validate inputs | |
| if not model_name or not model_name.strip(): | |
| return "Error: Please provide a model name.", None | |
| if not problem_type: | |
| return "Error: Please select a problem type.", None | |
| if not cif_files: | |
| return "Error: Please upload a file.", None | |
| if not profile: | |
| return "Error: Please log in to submit.", None | |
| try: | |
| username = profile.username | |
| timestamp = datetime.now().isoformat() | |
| # Create submission metadata | |
| submission_data = { | |
| "username": username, | |
| "model_name": model_name.strip(), | |
| "problem_type": problem_type, | |
| "relaxed": relaxed, | |
| "timestamp": timestamp, | |
| "file_name": Path(cif_files).name | |
| } | |
| # Create a unique submission ID | |
| submission_id = f"{username}_{model_name.strip().replace(' ', '_')}_{timestamp.replace(':', '-')}" | |
| # Upload the submission file | |
| file_path = Path(cif_files) | |
| uploaded_file_path = f"submissions/{submission_id}/{file_path.name}" | |
| upload_file( | |
| path_or_fileobj=str(file_path), | |
| path_in_repo=uploaded_file_path, | |
| repo_id=submissions_repo, | |
| token=TOKEN, | |
| repo_type="dataset" | |
| ) | |
| # Upload metadata as JSON | |
| metadata_path = f"submissions/{submission_id}/metadata.json" | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: | |
| json.dump(submission_data, f, indent=2) | |
| temp_metadata_path = f.name | |
| upload_file( | |
| path_or_fileobj=temp_metadata_path, | |
| path_in_repo=metadata_path, | |
| repo_id=submissions_repo, | |
| token=TOKEN, | |
| repo_type="dataset" | |
| ) | |
| # Clean up temp file | |
| os.unlink(temp_metadata_path) | |
| return f"Success! Submitted {model_name} for {problem_type} evaluation. Submission ID: {submission_id}", submission_id | |
| except Exception as e: | |
| return f"Error during submission: {str(e)}", None | |
| def generate_metric_legend_html(): | |
| """Generate HTML table with color-coded metric group legend.""" | |
| metric_details = { | |
| 'Validity ↑': ('Valid, Charge Neutral, Distance Valid, Plausibility Valid', '↑ Higher is better'), | |
| 'Uniqueness & Novelty ↑': ('Unique, Novel', '↑ Higher is better'), | |
| 'Energy Metrics ↓': ('E Above Hull, Formation Energy, Relaxation RMSD (with std)', '↓ Lower is better'), | |
| 'Stability ↑': ('Stable, Unique in Stable, SUN', '↑ Higher is better'), | |
| 'Metastability ↑': ('Metastable, Unique in Metastable, MSUN', '↑ Higher is better'), | |
| 'Distribution ↓': ('JS Distance, MMD, FID', '↓ Lower is better'), | |
| 'Diversity ↑': ('Element, Space Group, Atomic Site, Crystal Size', '↑ Higher is better'), | |
| 'HHI ↓': ('HHI Production, HHI Reserve', '↓ Lower is better'), | |
| } | |
| html = '<table style="width: 100%; border-collapse: collapse;">' | |
| html += '<thead><tr>' | |
| html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Color</th>' | |
| html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Group</th>' | |
| html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Metrics</th>' | |
| html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Direction</th>' | |
| html += '</tr></thead><tbody>' | |
| for group, color in METRIC_GROUP_COLORS.items(): | |
| metrics, direction = metric_details.get(group, ('', '')) | |
| group_name = group.replace('↑', '').replace('↓', '').strip() | |
| html += '<tr>' | |
| html += f'<td style="border: 1px solid #ddd; padding: 8px;"><div style="width: 30px; height: 20px; background-color: {color}; border: 1px solid #999;"></div></td>' | |
| html += f'<td style="border: 1px solid #ddd; padding: 8px;"><strong>{group_name}</strong></td>' | |
| html += f'<td style="border: 1px solid #ddd; padding: 8px;">{metrics}</td>' | |
| html += f'<td style="border: 1px solid #ddd; padding: 8px;">{direction}</td>' | |
| html += '</tr>' | |
| html += '</tbody></table>' | |
| return html | |
| def gradio_interface() -> gr.Blocks: | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Welcome to the LeMaterial Generative Benchmark Leaderboard!") | |
| with gr.Tabs(elem_classes="tab-buttons"): | |
| with gr.TabItem("🚀 Leaderboard", elem_id="boundary-benchmark-tab-table"): | |
| gr.Markdown("# LeMat-GenBench") | |
| # Display options | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| compact_view = gr.Checkbox( | |
| value=True, | |
| label="Compact View", | |
| info="Show only key metrics" | |
| ) | |
| show_percentage = gr.Checkbox( | |
| value=True, | |
| label="Show as Percentages", | |
| info="Display count-based metrics as percentages of total structures" | |
| ) | |
| with gr.Column(scale=1): | |
| # Create choices with display names, but values are the raw column names | |
| sort_choices = ["None"] + [COLUMN_DISPLAY_NAMES.get(col, col) for col in COLUMN_DISPLAY_NAMES.keys()] | |
| sort_by = gr.Dropdown( | |
| choices=sort_choices, | |
| value="None", | |
| label="Sort By", | |
| info="Select column to sort by" | |
| ) | |
| sort_direction = gr.Radio( | |
| choices=["Ascending", "Descending"], | |
| value="Descending", | |
| label="Sort Direction" | |
| ) | |
| with gr.Column(scale=2): | |
| selected_groups = gr.CheckboxGroup( | |
| choices=list(METRIC_GROUPS.keys()), | |
| value=list(METRIC_GROUPS.keys()), | |
| label="Metric Families (only active when Compact View is off)", | |
| info="Select which metric groups to display" | |
| ) | |
| # Metric legend with color coding | |
| with gr.Accordion("Metric Groups Legend", open=False): | |
| gr.HTML(generate_metric_legend_html()) | |
| try: | |
| # Initial dataframe - load once and cache | |
| initial_df = get_leaderboard() | |
| cached_df_state = gr.State(initial_df) | |
| formatted_df = format_dataframe(initial_df, show_percentage=True, selected_groups=list(METRIC_GROUPS.keys()), compact_view=True) | |
| leaderboard_table = gr.Dataframe( | |
| label="GenBench Leaderboard", | |
| value=formatted_df, | |
| interactive=False, | |
| wrap=True, | |
| column_widths=["180px"] + ["160px"] * (len(formatted_df.columns) - 1) if len(formatted_df.columns) > 0 else None, | |
| show_fullscreen_button=True | |
| ) | |
| # Update dataframe when options change (using cached data) | |
| show_percentage.change( | |
| fn=update_leaderboard, | |
| inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], | |
| outputs=leaderboard_table | |
| ) | |
| selected_groups.change( | |
| fn=update_leaderboard, | |
| inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], | |
| outputs=leaderboard_table | |
| ) | |
| compact_view.change( | |
| fn=update_leaderboard, | |
| inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], | |
| outputs=leaderboard_table | |
| ) | |
| sort_by.change( | |
| fn=update_leaderboard, | |
| inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], | |
| outputs=leaderboard_table | |
| ) | |
| sort_direction.change( | |
| fn=update_leaderboard, | |
| inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], | |
| outputs=leaderboard_table | |
| ) | |
| except Exception as e: | |
| gr.Markdown(f"Leaderboard is empty or error loading: {str(e)}") | |
| gr.Markdown("Verified submissions mean the results came from a model submission rather than a CIF submission. The lightening bolt (⚡) next to the model name indicates that the submitted structures were already relaxed.") | |
| with gr.TabItem("❔About", elem_id="boundary-benchmark-tab-table"): | |
| gr.Markdown( | |
| """ | |
| ## About LeMat-Gen-Bench | |
| **Welcome to the LeMat-Bench Leaderboard!** This leaderboard showcases generative models for materials discovery evaluated on the LeMat-Bench benchmark. Read more in our pre-print. | |
| """) | |
| with gr.TabItem("✉️ Submit", elem_id="boundary-benchmark-tab-table"): | |
| gr.Markdown( | |
| """ | |
| # Materials Submission | |
| Upload a CSV, pkl, or a ZIP of CIFs with your structures. | |
| """ | |
| ) | |
| filename = gr.State(value=None) | |
| gr.LoginButton() | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_name_input = gr.Textbox( | |
| label="Model Name", | |
| placeholder="Enter your model name", | |
| info="Provide a name for your model/method" | |
| ) | |
| problem_type = gr.Dropdown(PROBLEM_TYPES, label="Problem Type") | |
| with gr.Column(): | |
| cif_file = gr.File(label="Upload a CSV, a pkl, or a ZIP of CIF files.") | |
| relaxed = gr.Checkbox( | |
| value=False, | |
| label="Structures are already relaxed", | |
| info="Check this box if your submitted structures have already been relaxed" | |
| ) | |
| submit_btn = gr.Button("Submission") | |
| message = gr.Textbox(label="Status", lines=1, visible=False) | |
| # help message | |
| gr.Markdown("If you have issues with submission or using the leaderboard, please start a discussion in the Community tab of this Space.") | |
| submit_btn.click( | |
| submit_cif_files, | |
| inputs=[model_name_input, problem_type, cif_file, relaxed], | |
| outputs=[message, filename], | |
| ).then( | |
| fn=show_output_box, | |
| inputs=[message], | |
| outputs=[message], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| gradio_interface().launch() | |