visualizer / app.py
jesbu1's picture
utd new datasets
b62dad4
raw
history blame
13.2 kB
import gradio as gr
from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
import os
# Predefined dataset names (configs will be fetched dynamically)
PREDEFINED_DATASETS = [
"abraranwar/agibotworld_alpha_rfm",
"abraranwar/libero_rfm",
"abraranwar/usc_koch_rewind_rfm",
"aliangdw/metaworld",
"anqil/rh20t_rfm",
"anqil/rh20t_subset_rfm",
"jesbu1/auto_eval_rfm",
"jesbu1/egodex_rfm",
"jesbu1/epic_rfm",
"jesbu1/fino_net_rfm",
"jesbu1/failsafe_rfm",
"jesbu1/galaxea_rfm",
"jesbu1/h2r_rfm",
"jesbu1/humanoid_everyday_rfm",
"jesbu1/molmoact_rfm",
"jesbu1/motif_rfm",
"jesbu1/oxe_rfm",
"jesbu1/oxe_rfm_eval",
"jesbu1/ph2d_rfm",
"jesbu1/racer_rfm",
"jesbu1/roboarena_0825_rfm",
"jesbu1/soar_rfm",
"ykorkmaz/libero_failure_rfm",
"aliangdw/usc_xarm_policy_ranking",
"aliangdw/usc_franka_policy_ranking",
"aliangdw/utd_so101_policy_ranking",
"aliangdw/utd_so101_human"
]
def load_rfm_dataset(dataset_name, config_name):
"""Load the RFM dataset from HuggingFace Hub."""
try:
# Validate inputs
if not dataset_name or not config_name:
return None, "❌ Please provide both dataset name and configuration"
# Try to load the dataset
dataset = load_dataset_hf(dataset_name, name=config_name, split="train")
# Check if dataset has the expected structure
expected_features = ["task", "frames", "quality_label", "is_robot", "data_source"]
missing_features = [f for f in expected_features if f not in dataset.features]
if missing_features:
return None, f"⚠️ Dataset loaded but missing expected features: {missing_features}"
# Check if dataset has any samples
if len(dataset) == 0:
return None, f"⚠️ Dataset {dataset_name}/{config_name} is empty"
return dataset, f"✅ Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}"
except Exception as e:
error_msg = str(e)
if "not found" in error_msg.lower():
return None, f"❌ Dataset or configuration not found: {dataset_name}/{config_name}"
elif "authentication" in error_msg.lower():
return None, f"❌ Authentication required for {dataset_name}"
else:
return None, f"❌ Error loading dataset: {error_msg}"
def get_available_configs(dataset_name):
"""Get available configurations for a dataset."""
try:
# Use the dedicated function to get config names
configs = get_dataset_config_names(dataset_name)
return configs
except Exception as e:
print(f"Error getting configs for {dataset_name}: {e}")
return []
def update_config_choices_with_custom(dataset_name):
"""Update config choices by fetching from the dataset."""
if not dataset_name:
return gr.update(choices=[], value="")
try:
# Always try to fetch configs from the dataset
configs = get_available_configs(dataset_name)
if configs:
current_value = configs[0]
return gr.update(choices=configs, value=current_value)
else:
return gr.update(choices=[], value="")
except Exception as e:
# If fetching fails, allow custom input
print(f"Warning: Could not fetch configs for {dataset_name}: {e}")
return gr.update(choices=[], value="")
def visualize_trajectory(dataset, index, dataset_name=None):
"""
Function to retrieve a trajectory and its metadata from the dataset.
"""
if dataset is None:
return None, "Error: Could not load dataset", "Error: Could not load dataset", None
try:
item = dataset[int(index)]
# Get metadata
task = item["task"]
quality_label = item["quality_label"]
is_robot = item["is_robot"]
data_source = item["data_source"]
# Get the frames data (video file path)
frames_data = item["frames"]
# Handle video file path (could be local path or HuggingFace Hub URL)
if isinstance(frames_data, str):
# Use dynamic dataset name if provided, otherwise fallback to default
if dataset_name:
video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}"
else:
video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}"
frames_info = f"Video file: {video_path}"
else:
return None, "Error: Invalid video path", "Error: Invalid video path", None
# Create metadata
metadata = f"""
## Trajectory Information
**Video path:** {video_path}
**Language Task:** {task}
**Quality Label:** {quality_label}
**Data Type:** {'Robot' if is_robot else 'Human'}
**Source:** {data_source}
**Trajectory ID:** {item.get('id', 'N/A')}
"""
# Return video path for Gradio to display
return video_path, metadata, f"Trajectory {index}", None
except Exception as e:
return None, f"Error: {str(e)}", f"Error: {str(e)}", None
# Create the Gradio interface
with gr.Blocks(title="RFM Dataset Visualizer") as demo:
gr.Markdown("# RFM Dataset Visualizer")
gr.Markdown("Browse through trajectory videos and their metadata from the Reward Foundation Model dataset.")
# Dataset selection
with gr.Row():
with gr.Column(scale=2):
dataset_name_input = gr.Dropdown(
choices=PREDEFINED_DATASETS,
value="jesbu1/oxe_rfm",
label="Dataset Name",
allow_custom_value=True
)
with gr.Column(scale=2):
config_name_input = gr.Dropdown(
choices=[],
value="",
label="Configuration Name",
allow_custom_value=True
)
with gr.Column(scale=1):
refresh_btn = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm")
with gr.Column(scale=1):
load_btn = gr.Button("Load Dataset", variant="primary")
# Status message
status_output = gr.Markdown("Ready to load dataset...")
# Dataset info
dataset_info = gr.Markdown("")
# Visualization section
with gr.Row():
with gr.Column(scale=2):
# Video/Image display
video_output = gr.Video(label="Trajectory Video", height=400, autoplay=True)
image_output = gr.Image(label="Frame Preview", height=400, visible=False)
with gr.Column(scale=1):
# Metadata display
metadata_output = gr.Markdown(label="Metadata")
# Navigation controls
with gr.Row():
with gr.Column(scale=1):
prev_btn = gr.Button("⬅️ Previous", variant="secondary")
with gr.Column(scale=2):
# Slider for navigation with dynamic max
slider = gr.Slider(
minimum=0,
maximum=0,
step=1,
value=0,
label="Select a dataset first",
interactive=False
)
with gr.Column(scale=1):
next_btn = gr.Button("Next ➡️", variant="secondary")
# Current trajectory title
title_output = gr.Textbox(label="Current Trajectory", interactive=False)
# State variables
current_dataset = gr.State(None)
current_index = gr.State(0)
def load_dataset(dataset_name, config_name):
"""Load the dataset and update the interface."""
dataset, status = load_rfm_dataset(dataset_name, config_name)
if dataset is not None:
max_index = len(dataset) - 1
info = f"**Dataset Info:**\n- **Total Trajectories:** {len(dataset)}\n- **Features:** {list(dataset.features.keys())}"
# Return the dataset length for number input configuration
return dataset, status, info, 0, max_index
else:
return None, status, "", 0, 0
def update_trajectory(dataset, index, dataset_name=None):
"""Update the displayed trajectory."""
if dataset is None:
return None, "No dataset loaded", "No dataset loaded", None
# Ensure index is within bounds and is a valid number
if index is None or not isinstance(index, (int, float)):
index = 0
elif index >= len(dataset):
index = len(dataset) - 1
elif index < 0:
index = 0
return visualize_trajectory(dataset, int(index), dataset_name)
def next_trajectory(dataset, current_idx, dataset_name=None):
"""Go to next trajectory."""
if dataset is None:
return current_idx, None, "No dataset loaded", "No dataset loaded", None
next_idx = min(current_idx + 1, len(dataset) - 1)
video, metadata, title, image = visualize_trajectory(dataset, next_idx, dataset_name)
return next_idx, video, metadata, title, image
def prev_trajectory(dataset, current_idx, dataset_name=None):
"""Go to previous trajectory."""
if dataset is None:
return current_idx, None, "No dataset loaded", "No dataset loaded", None
prev_idx = max(current_idx - 1, 0)
video, metadata, title, image = visualize_trajectory(dataset, prev_idx, dataset_name)
return prev_idx, video, metadata, title, image
def update_slider_range(dataset):
"""Update the slider with new maximum value based on dataset length."""
if dataset is not None:
max_value = len(dataset) - 1
return gr.update(
maximum=max_value,
value=0, # Reset to beginning
label=f"Trajectory Index (0 to {max_value})",
interactive=True
)
else:
return gr.update(
maximum=0,
value=0,
label="Select a dataset first",
interactive=False
)
def update_config_choices(dataset_name):
"""Update the config dropdown choices based on selected dataset."""
return update_config_choices_with_custom(dataset_name)
# Connect the components
# Update config choices when dataset changes
dataset_name_input.change(
fn=update_config_choices,
inputs=[dataset_name_input],
outputs=[config_name_input]
)
# Refresh configs button for custom datasets
refresh_btn.click(
fn=update_config_choices_with_custom,
inputs=[dataset_name_input],
outputs=[config_name_input]
)
load_btn.click(
fn=load_dataset,
inputs=[dataset_name_input, config_name_input],
outputs=[current_dataset, status_output, dataset_info, current_index]
).then(
fn=update_slider_range,
inputs=current_dataset,
outputs=slider
)
slider.change(
fn=lambda dataset, idx, dataset_name: update_trajectory(dataset, idx, dataset_name),
inputs=[current_dataset, slider, dataset_name_input],
outputs=[video_output, metadata_output, title_output, image_output]
).then(
fn=lambda dataset, idx: idx if dataset is None or idx < len(dataset) else len(dataset) - 1,
inputs=[current_dataset, slider],
outputs=[current_index]
)
next_btn.click(
fn=lambda dataset, idx, dataset_name: next_trajectory(dataset, idx, dataset_name),
inputs=[current_dataset, current_index, dataset_name_input],
outputs=[current_index, video_output, metadata_output, title_output, image_output]
).then(
fn=lambda idx: idx,
inputs=current_index,
outputs=slider
)
prev_btn.click(
fn=lambda dataset, idx, dataset_name: prev_trajectory(dataset, idx, dataset_name),
inputs=[current_dataset, current_index, dataset_name_input],
outputs=[current_index, video_output, metadata_output, title_output, image_output]
).then(
fn=lambda idx: idx,
inputs=current_index,
outputs=slider
)
# Load initial dataset and configs
demo.load(
fn=lambda: ("jesbu1/oxe_rfm", "oxe_jaco_play"), # Set initial values
outputs=[dataset_name_input, config_name_input]
).then(
fn=update_config_choices_with_custom,
inputs=[dataset_name_input],
outputs=[config_name_input]
).then(
fn=lambda dataset_name, config_name: load_dataset(dataset_name, config_name),
inputs=[dataset_name_input, config_name_input],
outputs=[current_dataset, status_output, dataset_info, current_index]
).then(
fn=update_slider_range,
inputs=current_dataset,
outputs=slider
)
def main():
"""Main function to launch the RFM visualizer."""
demo.launch()
# Launch the app
if __name__ == "__main__":
main()