from pathlib import Path import json import pandas as pd import numpy as np import gradio as gr from datasets import load_dataset from gradio_leaderboard import Leaderboard from datetime import datetime import os from about import ( PROBLEM_TYPES, TOKEN, CACHE_PATH, API, submissions_repo, results_repo, COLUMN_DISPLAY_NAMES, COUNT_BASED_METRICS, METRIC_GROUPS, METRIC_GROUP_COLORS, COLUMN_TO_GROUP ) def get_leaderboard(): ds = load_dataset(results_repo, split='train', download_mode="force_redownload") full_df = pd.DataFrame(ds) print(full_df.columns) if len(full_df) == 0: return pd.DataFrame({'date':[], 'model':[], 'score':[], 'verified':[]}) return full_df def format_dataframe(df, show_percentage=False, selected_groups=None, compact_view=True): """Format the dataframe with proper column names and optional percentages.""" if len(df) == 0: return df # Build column list based on view mode selected_cols = ['model_name'] if compact_view: # Use predefined compact columns from about import COMPACT_VIEW_COLUMNS selected_cols = [col for col in COMPACT_VIEW_COLUMNS if col in df.columns] else: # Build from selected groups if 'n_structures' in df.columns: selected_cols.append('n_structures') # If no groups selected, show all if not selected_groups: selected_groups = list(METRIC_GROUPS.keys()) # Add columns from selected groups for group in selected_groups: if group in METRIC_GROUPS: for col in METRIC_GROUPS[group]: if col in df.columns and col not in selected_cols: selected_cols.append(col) # Create a copy with selected columns display_df = df[selected_cols].copy() # Add relaxed symbol to model name if relaxed column is True if 'relaxed' in df.columns and 'model_name' in display_df.columns: display_df['model_name'] = df.apply( lambda row: f"{row['model_name']} ⚡" if row.get('relaxed', False) else row['model_name'], axis=1 ) # Convert count-based metrics to percentages if requested if show_percentage and 'n_structures' in df.columns: n_structures = df['n_structures'] for col in COUNT_BASED_METRICS: if col in display_df.columns: # Calculate percentage and format as string with % display_df[col] = (df[col] / n_structures * 100).round(1).astype(str) + '%' # Round numeric columns for cleaner display for col in display_df.columns: if display_df[col].dtype in ['float64', 'float32']: display_df[col] = display_df[col].round(4) # Rename columns for display display_df = display_df.rename(columns=COLUMN_DISPLAY_NAMES) # Apply color coding based on metric groups styled_df = apply_color_styling(display_df, selected_cols) return styled_df def apply_color_styling(display_df, original_cols): """Apply background colors to dataframe based on metric groups using pandas Styler.""" def style_by_group(x): # Create a DataFrame with the same shape filled with empty strings styles = pd.DataFrame('', index=x.index, columns=x.columns) # Map display column names back to original column names for i, display_col in enumerate(x.columns): if i < len(original_cols): original_col = original_cols[i] # Check if this column belongs to a metric group if original_col in COLUMN_TO_GROUP: group = COLUMN_TO_GROUP[original_col] color = METRIC_GROUP_COLORS.get(group, '') if color: styles[display_col] = f'background-color: {color}' return styles # Apply the styling function return display_df.style.apply(style_by_group, axis=None) def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df, sort_by, sort_direction): """Update the leaderboard based on user selections. Uses cached dataframe to avoid re-downloading data on every change. """ # Use cached dataframe instead of re-downloading df_to_format = cached_df.copy() # Convert display name back to raw column name for sorting if sort_by and sort_by != "None": # Create reverse mapping from display names to raw column names display_to_raw = {v: k for k, v in COLUMN_DISPLAY_NAMES.items()} raw_column_name = display_to_raw.get(sort_by, sort_by) if raw_column_name in df_to_format.columns: ascending = (sort_direction == "Ascending") df_to_format = df_to_format.sort_values(by=raw_column_name, ascending=ascending) formatted_df = format_dataframe(df_to_format, show_percentage, selected_groups, compact_view) return formatted_df def show_output_box(message): return gr.update(value=message, visible=True) def submit_cif_files(problem_type, cif_files, relaxed, profile: gr.OAuthProfile | None): # TODO: Implement submission logic that includes the relaxed flag return def generate_metric_legend_html(): """Generate HTML table with color-coded metric group legend.""" metric_details = { 'Validity ↑': ('Valid, Charge Neutral, Distance Valid, Plausibility Valid', '↑ Higher is better'), 'Uniqueness & Novelty ↑': ('Unique, Novel', '↑ Higher is better'), 'Energy Metrics ↓': ('E Above Hull, Formation Energy, Relaxation RMSD (with std)', '↓ Lower is better'), 'Stability ↑': ('Stable, Unique in Stable, SUN', '↑ Higher is better'), 'Metastability ↑': ('Metastable, Unique in Metastable, MSUN', '↑ Higher is better'), 'Distribution ↓': ('JS Distance, MMD, FID', '↓ Lower is better'), 'Diversity ↑': ('Element, Space Group, Atomic Site, Crystal Size', '↑ Higher is better'), 'HHI ↓': ('HHI Production, HHI Reserve', '↓ Lower is better'), } html = '
| Color | ' html += 'Group | ' html += 'Metrics | ' html += 'Direction | ' html += '
|---|---|---|---|
| ' html += f' | {group_name} | ' html += f'{metrics} | ' html += f'{direction} | ' html += '