Spaces:
Running
Running
| import gradio as gr | |
| from datasets import load_dataset as load_dataset_hf, get_dataset_config_names | |
| import os | |
| # Predefined dataset names (configs will be fetched dynamically) | |
| PREDEFINED_DATASETS = [ | |
| "abraranwar/agibotworld_alpha_rfm", | |
| "abraranwar/libero_rfm", | |
| "abraranwar/usc_koch_rewind_rfm", | |
| "aliangdw/metaworld", | |
| "anqil/rh20t_rfm", | |
| "anqil/rh20t_subset_rfm", | |
| "jesbu1/auto_eval_rfm", | |
| "jesbu1/egodex_rfm", | |
| "jesbu1/epic_rfm", | |
| "jesbu1/fino_net_rfm", | |
| "jesbu1/failsafe_rfm", | |
| "jesbu1/galaxea_rfm", | |
| "jesbu1/h2r_rfm", | |
| "jesbu1/humanoid_everyday_rfm", | |
| "jesbu1/molmoact_rfm", | |
| "jesbu1/motif_rfm", | |
| "jesbu1/oxe_rfm", | |
| "jesbu1/oxe_rfm_eval", | |
| "jesbu1/ph2d_rfm", | |
| "jesbu1/racer_rfm", | |
| "jesbu1/roboarena_0825_rfm", | |
| "jesbu1/soar_rfm", | |
| "ykorkmaz/libero_failure_rfm", | |
| "aliangdw/usc_xarm_policy_ranking", | |
| "aliangdw/usc_franka_policy_ranking", | |
| "aliangdw/utd_so101_policy_ranking", | |
| "aliangdw/utd_so101_human" | |
| ] | |
| def load_rfm_dataset(dataset_name, config_name): | |
| """Load the RFM dataset from HuggingFace Hub.""" | |
| try: | |
| # Validate inputs | |
| if not dataset_name or not config_name: | |
| return None, "❌ Please provide both dataset name and configuration" | |
| # Try to load the dataset | |
| dataset = load_dataset_hf(dataset_name, name=config_name, split="train") | |
| # Check if dataset has the expected structure | |
| expected_features = ["task", "frames", "quality_label", "is_robot", "data_source"] | |
| missing_features = [f for f in expected_features if f not in dataset.features] | |
| if missing_features: | |
| return None, f"⚠️ Dataset loaded but missing expected features: {missing_features}" | |
| # Check if dataset has any samples | |
| if len(dataset) == 0: | |
| return None, f"⚠️ Dataset {dataset_name}/{config_name} is empty" | |
| return dataset, f"✅ Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}" | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "not found" in error_msg.lower(): | |
| return None, f"❌ Dataset or configuration not found: {dataset_name}/{config_name}" | |
| elif "authentication" in error_msg.lower(): | |
| return None, f"❌ Authentication required for {dataset_name}" | |
| else: | |
| return None, f"❌ Error loading dataset: {error_msg}" | |
| def get_available_configs(dataset_name): | |
| """Get available configurations for a dataset.""" | |
| try: | |
| # Use the dedicated function to get config names | |
| configs = get_dataset_config_names(dataset_name) | |
| return configs | |
| except Exception as e: | |
| print(f"Error getting configs for {dataset_name}: {e}") | |
| return [] | |
| def update_config_choices_with_custom(dataset_name): | |
| """Update config choices by fetching from the dataset.""" | |
| if not dataset_name: | |
| return gr.update(choices=[], value="") | |
| try: | |
| # Always try to fetch configs from the dataset | |
| configs = get_available_configs(dataset_name) | |
| if configs: | |
| current_value = configs[0] | |
| return gr.update(choices=configs, value=current_value) | |
| else: | |
| return gr.update(choices=[], value="") | |
| except Exception as e: | |
| # If fetching fails, allow custom input | |
| print(f"Warning: Could not fetch configs for {dataset_name}: {e}") | |
| return gr.update(choices=[], value="") | |
| def visualize_trajectory(dataset, index, dataset_name=None): | |
| """ | |
| Function to retrieve a trajectory and its metadata from the dataset. | |
| """ | |
| if dataset is None: | |
| return None, "Error: Could not load dataset", "Error: Could not load dataset", None | |
| try: | |
| item = dataset[int(index)] | |
| # Get metadata | |
| task = item["task"] | |
| quality_label = item["quality_label"] | |
| is_robot = item["is_robot"] | |
| data_source = item["data_source"] | |
| # Get the frames data (video file path) | |
| frames_data = item["frames"] | |
| # Handle video file path (could be local path or HuggingFace Hub URL) | |
| if isinstance(frames_data, str): | |
| # Use dynamic dataset name if provided, otherwise fallback to default | |
| if dataset_name: | |
| video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}" | |
| else: | |
| video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}" | |
| frames_info = f"Video file: {video_path}" | |
| else: | |
| return None, "Error: Invalid video path", "Error: Invalid video path", None | |
| # Create metadata | |
| metadata = f""" | |
| ## Trajectory Information | |
| **Video path:** {video_path} | |
| **Language Task:** {task} | |
| **Quality Label:** {quality_label} | |
| **Data Type:** {'Robot' if is_robot else 'Human'} | |
| **Source:** {data_source} | |
| **Trajectory ID:** {item.get('id', 'N/A')} | |
| """ | |
| # Return video path for Gradio to display | |
| return video_path, metadata, f"Trajectory {index}", None | |
| except Exception as e: | |
| return None, f"Error: {str(e)}", f"Error: {str(e)}", None | |
| # Create the Gradio interface | |
| with gr.Blocks(title="RFM Dataset Visualizer") as demo: | |
| gr.Markdown("# RFM Dataset Visualizer") | |
| gr.Markdown("Browse through trajectory videos and their metadata from the Reward Foundation Model dataset.") | |
| # Dataset selection | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| dataset_name_input = gr.Dropdown( | |
| choices=PREDEFINED_DATASETS, | |
| value="jesbu1/oxe_rfm", | |
| label="Dataset Name", | |
| allow_custom_value=True | |
| ) | |
| with gr.Column(scale=2): | |
| config_name_input = gr.Dropdown( | |
| choices=[], | |
| value="", | |
| label="Configuration Name", | |
| allow_custom_value=True | |
| ) | |
| with gr.Column(scale=1): | |
| refresh_btn = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm") | |
| with gr.Column(scale=1): | |
| load_btn = gr.Button("Load Dataset", variant="primary") | |
| # Status message | |
| status_output = gr.Markdown("Ready to load dataset...") | |
| # Dataset info | |
| dataset_info = gr.Markdown("") | |
| # Visualization section | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Video/Image display | |
| video_output = gr.Video(label="Trajectory Video", height=400, autoplay=True) | |
| image_output = gr.Image(label="Frame Preview", height=400, visible=False) | |
| with gr.Column(scale=1): | |
| # Metadata display | |
| metadata_output = gr.Markdown(label="Metadata") | |
| # Navigation controls | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prev_btn = gr.Button("⬅️ Previous", variant="secondary") | |
| with gr.Column(scale=2): | |
| # Slider for navigation with dynamic max | |
| slider = gr.Slider( | |
| minimum=0, | |
| maximum=0, | |
| step=1, | |
| value=0, | |
| label="Select a dataset first", | |
| interactive=False | |
| ) | |
| with gr.Column(scale=1): | |
| next_btn = gr.Button("Next ➡️", variant="secondary") | |
| # Current trajectory title | |
| title_output = gr.Textbox(label="Current Trajectory", interactive=False) | |
| # State variables | |
| current_dataset = gr.State(None) | |
| current_index = gr.State(0) | |
| def load_dataset(dataset_name, config_name): | |
| """Load the dataset and update the interface.""" | |
| dataset, status = load_rfm_dataset(dataset_name, config_name) | |
| if dataset is not None: | |
| max_index = len(dataset) - 1 | |
| info = f"**Dataset Info:**\n- **Total Trajectories:** {len(dataset)}\n- **Features:** {list(dataset.features.keys())}" | |
| # Return the dataset length for number input configuration | |
| return dataset, status, info, 0, max_index | |
| else: | |
| return None, status, "", 0, 0 | |
| def update_trajectory(dataset, index, dataset_name=None): | |
| """Update the displayed trajectory.""" | |
| if dataset is None: | |
| return None, "No dataset loaded", "No dataset loaded", None | |
| # Ensure index is within bounds and is a valid number | |
| if index is None or not isinstance(index, (int, float)): | |
| index = 0 | |
| elif index >= len(dataset): | |
| index = len(dataset) - 1 | |
| elif index < 0: | |
| index = 0 | |
| return visualize_trajectory(dataset, int(index), dataset_name) | |
| def next_trajectory(dataset, current_idx, dataset_name=None): | |
| """Go to next trajectory.""" | |
| if dataset is None: | |
| return current_idx, None, "No dataset loaded", "No dataset loaded", None | |
| next_idx = min(current_idx + 1, len(dataset) - 1) | |
| video, metadata, title, image = visualize_trajectory(dataset, next_idx, dataset_name) | |
| return next_idx, video, metadata, title, image | |
| def prev_trajectory(dataset, current_idx, dataset_name=None): | |
| """Go to previous trajectory.""" | |
| if dataset is None: | |
| return current_idx, None, "No dataset loaded", "No dataset loaded", None | |
| prev_idx = max(current_idx - 1, 0) | |
| video, metadata, title, image = visualize_trajectory(dataset, prev_idx, dataset_name) | |
| return prev_idx, video, metadata, title, image | |
| def update_slider_range(dataset): | |
| """Update the slider with new maximum value based on dataset length.""" | |
| if dataset is not None: | |
| max_value = len(dataset) - 1 | |
| return gr.update( | |
| maximum=max_value, | |
| value=0, # Reset to beginning | |
| label=f"Trajectory Index (0 to {max_value})", | |
| interactive=True | |
| ) | |
| else: | |
| return gr.update( | |
| maximum=0, | |
| value=0, | |
| label="Select a dataset first", | |
| interactive=False | |
| ) | |
| def update_config_choices(dataset_name): | |
| """Update the config dropdown choices based on selected dataset.""" | |
| return update_config_choices_with_custom(dataset_name) | |
| # Connect the components | |
| # Update config choices when dataset changes | |
| dataset_name_input.change( | |
| fn=update_config_choices, | |
| inputs=[dataset_name_input], | |
| outputs=[config_name_input] | |
| ) | |
| # Refresh configs button for custom datasets | |
| refresh_btn.click( | |
| fn=update_config_choices_with_custom, | |
| inputs=[dataset_name_input], | |
| outputs=[config_name_input] | |
| ) | |
| load_btn.click( | |
| fn=load_dataset, | |
| inputs=[dataset_name_input, config_name_input], | |
| outputs=[current_dataset, status_output, dataset_info, current_index] | |
| ).then( | |
| fn=update_slider_range, | |
| inputs=current_dataset, | |
| outputs=slider | |
| ) | |
| slider.change( | |
| fn=lambda dataset, idx, dataset_name: update_trajectory(dataset, idx, dataset_name), | |
| inputs=[current_dataset, slider, dataset_name_input], | |
| outputs=[video_output, metadata_output, title_output, image_output] | |
| ).then( | |
| fn=lambda dataset, idx: idx if dataset is None or idx < len(dataset) else len(dataset) - 1, | |
| inputs=[current_dataset, slider], | |
| outputs=[current_index] | |
| ) | |
| next_btn.click( | |
| fn=lambda dataset, idx, dataset_name: next_trajectory(dataset, idx, dataset_name), | |
| inputs=[current_dataset, current_index, dataset_name_input], | |
| outputs=[current_index, video_output, metadata_output, title_output, image_output] | |
| ).then( | |
| fn=lambda idx: idx, | |
| inputs=current_index, | |
| outputs=slider | |
| ) | |
| prev_btn.click( | |
| fn=lambda dataset, idx, dataset_name: prev_trajectory(dataset, idx, dataset_name), | |
| inputs=[current_dataset, current_index, dataset_name_input], | |
| outputs=[current_index, video_output, metadata_output, title_output, image_output] | |
| ).then( | |
| fn=lambda idx: idx, | |
| inputs=current_index, | |
| outputs=slider | |
| ) | |
| # Load initial dataset and configs | |
| demo.load( | |
| fn=lambda: ("jesbu1/oxe_rfm", "oxe_jaco_play"), # Set initial values | |
| outputs=[dataset_name_input, config_name_input] | |
| ).then( | |
| fn=update_config_choices_with_custom, | |
| inputs=[dataset_name_input], | |
| outputs=[config_name_input] | |
| ).then( | |
| fn=lambda dataset_name, config_name: load_dataset(dataset_name, config_name), | |
| inputs=[dataset_name_input, config_name_input], | |
| outputs=[current_dataset, status_output, dataset_info, current_index] | |
| ).then( | |
| fn=update_slider_range, | |
| inputs=current_dataset, | |
| outputs=slider | |
| ) | |
| def main(): | |
| """Main function to launch the RFM visualizer.""" | |
| demo.launch() | |
| # Launch the app | |
| if __name__ == "__main__": | |
| main() | |