import gradio as gr from datasets import load_dataset as load_dataset_hf, get_dataset_config_names import os # Predefined dataset names (configs will be fetched dynamically) PREDEFINED_DATASETS = [ "abraranwar/agibotworld_alpha_rfm", "abraranwar/libero_rfm", "abraranwar/usc_koch_rewind_rfm", "aliangdw/metaworld", "anqil/rh20t_rfm", "anqil/rh20t_subset_rfm", "jesbu1/auto_eval_rfm", "jesbu1/egodex_rfm", "jesbu1/epic_rfm", "jesbu1/fino_net_rfm", "jesbu1/failsafe_rfm", "jesbu1/hand_paired_rfm", "jesbu1/galaxea_rfm", "jesbu1/h2r_rfm", "jesbu1/humanoid_everyday_rfm", "jesbu1/molmoact_rfm", "jesbu1/motif_rfm", "jesbu1/oxe_rfm", "jesbu1/oxe_rfm_eval", "jesbu1/ph2d_rfm", "jesbu1/racer_rfm", "jesbu1/roboarena_0825_rfm", "jesbu1/soar_rfm", "ykorkmaz/libero_failure_rfm", "aliangdw/usc_xarm_policy_ranking", "aliangdw/usc_franka_policy_ranking", "aliangdw/utd_so101_policy_ranking", "aliangdw/utd_so101_human" ] def load_rfm_dataset(dataset_name, config_name): """Load the RFM dataset from HuggingFace Hub.""" try: # Validate inputs if not dataset_name or not config_name: return None, "❌ Please provide both dataset name and configuration" # Try to load the dataset dataset = load_dataset_hf(dataset_name, name=config_name, split="train") # Check if dataset has the expected structure expected_features = ["task", "frames", "quality_label", "is_robot", "data_source"] missing_features = [f for f in expected_features if f not in dataset.features] if missing_features: return None, f"⚠️ Dataset loaded but missing expected features: {missing_features}" # Check if dataset has any samples if len(dataset) == 0: return None, f"⚠️ Dataset {dataset_name}/{config_name} is empty" return dataset, f"✅ Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}" except Exception as e: error_msg = str(e) if "not found" in error_msg.lower(): return None, f"❌ Dataset or configuration not found: {dataset_name}/{config_name}" elif "authentication" in error_msg.lower(): return None, f"❌ Authentication required for {dataset_name}" else: return None, f"❌ Error loading dataset: {error_msg}" def get_available_configs(dataset_name): """Get available configurations for a dataset.""" try: # Use the dedicated function to get config names configs = get_dataset_config_names(dataset_name) return configs except Exception as e: print(f"Error getting configs for {dataset_name}: {e}") return [] def update_config_choices_with_custom(dataset_name): """Update config choices by fetching from the dataset.""" if not dataset_name: return gr.update(choices=[], value="") try: # Always try to fetch configs from the dataset configs = get_available_configs(dataset_name) if configs: current_value = configs[0] return gr.update(choices=configs, value=current_value) else: return gr.update(choices=[], value="") except Exception as e: # If fetching fails, allow custom input print(f"Warning: Could not fetch configs for {dataset_name}: {e}") return gr.update(choices=[], value="") def visualize_trajectory(dataset, index, dataset_name=None): """ Function to retrieve a trajectory and its metadata from the dataset. """ if dataset is None: return None, "Error: Could not load dataset", "Error: Could not load dataset", None try: item = dataset[int(index)] # Get metadata task = item["task"] quality_label = item["quality_label"] is_robot = item["is_robot"] data_source = item["data_source"] # Get the frames data (video file path) frames_data = item["frames"] # Handle video file path (could be local path or HuggingFace Hub URL) if isinstance(frames_data, str): # Use dynamic dataset name if provided, otherwise fallback to default if dataset_name: video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}" else: video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}" frames_info = f"Video file: {video_path}" else: return None, "Error: Invalid video path", "Error: Invalid video path", None # Create metadata metadata = f""" ## Trajectory Information **Video path:** {video_path} **Language Task:** {task} **Quality Label:** {quality_label} **Data Type:** {'Robot' if is_robot else 'Human'} **Source:** {data_source} **Trajectory ID:** {item.get('id', 'N/A')} """ # Return video path for Gradio to display return video_path, metadata, f"Trajectory {index}", None except Exception as e: return None, f"Error: {str(e)}", f"Error: {str(e)}", None # Create the Gradio interface with gr.Blocks(title="RFM Dataset Visualizer") as demo: gr.Markdown("# RFM Dataset Visualizer") gr.Markdown("Browse through trajectory videos and their metadata from the Reward Foundation Model dataset.") # Dataset selection with gr.Row(): with gr.Column(scale=2): dataset_name_input = gr.Dropdown( choices=PREDEFINED_DATASETS, value="jesbu1/oxe_rfm", label="Dataset Name", allow_custom_value=True ) with gr.Column(scale=2): config_name_input = gr.Dropdown( choices=[], value="", label="Configuration Name", allow_custom_value=True ) with gr.Column(scale=1): refresh_btn = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm") with gr.Column(scale=1): load_btn = gr.Button("Load Dataset", variant="primary") # Status message status_output = gr.Markdown("Ready to load dataset...") # Dataset info dataset_info = gr.Markdown("") # Visualization section with gr.Row(): with gr.Column(scale=2): # Video/Image display video_output = gr.Video(label="Trajectory Video", height=400, autoplay=True) image_output = gr.Image(label="Frame Preview", height=400, visible=False) with gr.Column(scale=1): # Metadata display metadata_output = gr.Markdown(label="Metadata") # Navigation controls with gr.Row(): with gr.Column(scale=1): prev_btn = gr.Button("⬅️ Previous", variant="secondary") with gr.Column(scale=2): # Slider for navigation with dynamic max slider = gr.Slider( minimum=0, maximum=0, step=1, value=0, label="Select a dataset first", interactive=False ) with gr.Column(scale=1): next_btn = gr.Button("Next ➡️", variant="secondary") # Current trajectory title title_output = gr.Textbox(label="Current Trajectory", interactive=False) # State variables current_dataset = gr.State(None) current_index = gr.State(0) def load_dataset(dataset_name, config_name): """Load the dataset and update the interface.""" dataset, status = load_rfm_dataset(dataset_name, config_name) if dataset is not None: max_index = len(dataset) - 1 info = f"**Dataset Info:**\n- **Total Trajectories:** {len(dataset)}\n- **Features:** {list(dataset.features.keys())}" # Return the dataset length for number input configuration return dataset, status, info, 0, max_index else: return None, status, "", 0, 0 def update_trajectory(dataset, index, dataset_name=None): """Update the displayed trajectory.""" if dataset is None: return None, "No dataset loaded", "No dataset loaded", None # Ensure index is within bounds and is a valid number if index is None or not isinstance(index, (int, float)): index = 0 elif index >= len(dataset): index = len(dataset) - 1 elif index < 0: index = 0 return visualize_trajectory(dataset, int(index), dataset_name) def next_trajectory(dataset, current_idx, dataset_name=None): """Go to next trajectory.""" if dataset is None: return current_idx, None, "No dataset loaded", "No dataset loaded", None next_idx = min(current_idx + 1, len(dataset) - 1) video, metadata, title, image = visualize_trajectory(dataset, next_idx, dataset_name) return next_idx, video, metadata, title, image def prev_trajectory(dataset, current_idx, dataset_name=None): """Go to previous trajectory.""" if dataset is None: return current_idx, None, "No dataset loaded", "No dataset loaded", None prev_idx = max(current_idx - 1, 0) video, metadata, title, image = visualize_trajectory(dataset, prev_idx, dataset_name) return prev_idx, video, metadata, title, image def update_slider_range(dataset): """Update the slider with new maximum value based on dataset length.""" if dataset is not None: max_value = len(dataset) - 1 return gr.update( maximum=max_value, value=0, # Reset to beginning label=f"Trajectory Index (0 to {max_value})", interactive=True ) else: return gr.update( maximum=0, value=0, label="Select a dataset first", interactive=False ) def update_config_choices(dataset_name): """Update the config dropdown choices based on selected dataset.""" return update_config_choices_with_custom(dataset_name) # Connect the components # Update config choices when dataset changes dataset_name_input.change( fn=update_config_choices, inputs=[dataset_name_input], outputs=[config_name_input] ) # Refresh configs button for custom datasets refresh_btn.click( fn=update_config_choices_with_custom, inputs=[dataset_name_input], outputs=[config_name_input] ) load_btn.click( fn=load_dataset, inputs=[dataset_name_input, config_name_input], outputs=[current_dataset, status_output, dataset_info, current_index] ).then( fn=update_slider_range, inputs=current_dataset, outputs=slider ) slider.change( fn=lambda dataset, idx, dataset_name: update_trajectory(dataset, idx, dataset_name), inputs=[current_dataset, slider, dataset_name_input], outputs=[video_output, metadata_output, title_output, image_output] ).then( fn=lambda dataset, idx: idx if dataset is None or idx < len(dataset) else len(dataset) - 1, inputs=[current_dataset, slider], outputs=[current_index] ) next_btn.click( fn=lambda dataset, idx, dataset_name: next_trajectory(dataset, idx, dataset_name), inputs=[current_dataset, current_index, dataset_name_input], outputs=[current_index, video_output, metadata_output, title_output, image_output] ).then( fn=lambda idx: idx, inputs=current_index, outputs=slider ) prev_btn.click( fn=lambda dataset, idx, dataset_name: prev_trajectory(dataset, idx, dataset_name), inputs=[current_dataset, current_index, dataset_name_input], outputs=[current_index, video_output, metadata_output, title_output, image_output] ).then( fn=lambda idx: idx, inputs=current_index, outputs=slider ) # Load initial dataset and configs demo.load( fn=lambda: ("jesbu1/oxe_rfm", "oxe_jaco_play"), # Set initial values outputs=[dataset_name_input, config_name_input] ).then( fn=update_config_choices_with_custom, inputs=[dataset_name_input], outputs=[config_name_input] ).then( fn=lambda dataset_name, config_name: load_dataset(dataset_name, config_name), inputs=[dataset_name_input, config_name_input], outputs=[current_dataset, status_output, dataset_info, current_index] ).then( fn=update_slider_range, inputs=current_dataset, outputs=slider ) def main(): """Main function to launch the RFM visualizer.""" demo.launch() # Launch the app if __name__ == "__main__": main()