"""
Gradio UI Components Module
Contains all Gradio interface component definitions and layouts
"""
import gradio as gr
from typing import Callable, Optional
def create_gradio_interface(handler) -> gr.Blocks:
"""
Create Gradio interface
Args:
handler: Business logic handler instance
Returns:
Gradio Blocks instance
"""
with gr.Blocks(
title="ACE-Step V1.5 Demo",
theme=gr.themes.Soft(),
css="""
.main-header {
text-align: center;
margin-bottom: 2rem;
}
.section-header {
background: linear-gradient(90deg, #4CAF50, #45a049);
color: white;
padding: 10px;
border-radius: 5px;
margin: 10px 0;
}
"""
) as demo:
gr.HTML("""
♪ACE-Step V1.5 Demo
Generate music from text captions and lyrics using diffusion models
""")
# Dataset Explorer Section
dataset_section = create_dataset_section(handler)
# Generation Section
generation_section = create_generation_section(handler)
# Results Section
results_section = create_results_section(handler)
# Connect event handlers
setup_event_handlers(demo, handler, dataset_section, generation_section, results_section)
return demo
def create_dataset_section(handler) -> dict:
"""Create dataset explorer section"""
with gr.Group():
gr.HTML('')
with gr.Row(equal_height=True):
dataset_type = gr.Dropdown(
choices=["train", "test"],
value="train",
label="Dataset",
info="Choose dataset to explore",
scale=2
)
import_dataset_btn = gr.Button("📥 Import Dataset", variant="primary", scale=1)
search_type = gr.Dropdown(
choices=["keys", "idx", "random"],
value="random",
label="Search Type",
info="How to find items",
scale=1
)
search_value = gr.Textbox(
label="Search Value",
placeholder="Enter keys or index (leave empty for random)",
info="Keys: exact match, Index: 0 to dataset size-1",
scale=2
)
instruction_display = gr.Textbox(
label="📝 Instruction",
interactive=False,
placeholder="No instruction available",
lines=1
)
repaint_viz_plot = gr.Plot()
with gr.Accordion("📋 Item Metadata (JSON)", open=False):
item_info_json = gr.Code(
label="Complete Item Information",
language="json",
interactive=False,
lines=15
)
with gr.Row(equal_height=True):
item_src_audio = gr.Audio(
label="Source Audio",
type="filepath",
interactive=False,
scale=8
)
get_item_btn = gr.Button("🔍 Get Item", variant="secondary", interactive=False, scale=2)
with gr.Row(equal_height=True):
item_target_audio = gr.Audio(
label="Target Audio",
type="filepath",
interactive=False,
scale=8
)
item_refer_audio = gr.Audio(
label="Reference Audio",
type="filepath",
interactive=False,
scale=2
)
with gr.Row():
use_src_checkbox = gr.Checkbox(
label="Use Source Audio from Dataset",
value=True,
info="Check to use the source audio from dataset"
)
data_status = gr.Textbox(label="📊 Data Status", interactive=False, value="❌ No dataset imported")
auto_fill_btn = gr.Button("📋 Auto-fill Generation Form", variant="primary")
return {
"dataset_type": dataset_type,
"import_dataset_btn": import_dataset_btn,
"search_type": search_type,
"search_value": search_value,
"instruction_display": instruction_display,
"repaint_viz_plot": repaint_viz_plot,
"item_info_json": item_info_json,
"item_src_audio": item_src_audio,
"get_item_btn": get_item_btn,
"item_target_audio": item_target_audio,
"item_refer_audio": item_refer_audio,
"use_src_checkbox": use_src_checkbox,
"data_status": data_status,
"auto_fill_btn": auto_fill_btn,
}
def create_generation_section(handler) -> dict:
"""Create generation section"""
with gr.Group():
gr.HTML('')
# Service Configuration
with gr.Accordion("🔧 Service Configuration", open=True) as service_config_accordion:
with gr.Row():
with gr.Column(scale=2):
checkpoint_dropdown = gr.Dropdown(
label="Checkpoint File",
choices=handler.get_available_checkpoints(),
value=None,
info="Select a trained model checkpoint file (full path or filename)"
)
with gr.Column(scale=1):
refresh_btn = gr.Button("🔄 Refresh", size="sm")
with gr.Row():
# Get available acestep-v15- model list
available_models = handler.get_available_acestep_v15_models()
default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
config_path = gr.Dropdown(
label="Main Model Path",
choices=available_models,
value=default_model,
info="Select the model configuration directory (auto-scanned from checkpoints)"
)
device = gr.Dropdown(
choices=["auto", "cuda", "cpu"],
value="auto",
label="Device",
info="Processing device (auto-detect recommended)"
)
with gr.Row():
# Get available 5Hz LM model list
available_lm_models = handler.get_available_5hz_lm_models()
default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
lm_model_path = gr.Dropdown(
label="5Hz LM Model Path",
choices=available_lm_models,
value=default_lm_model,
info="Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)"
)
init_llm_checkbox = gr.Checkbox(
label="Initialize 5Hz LM",
value=False,
info="Check to initialize 5Hz LM during service initialization",
)
with gr.Row():
# Auto-detect flash attention availability
flash_attn_available = handler.is_flash_attention_available()
use_flash_attention_checkbox = gr.Checkbox(
label="Use Flash Attention",
value=flash_attn_available,
interactive=flash_attn_available,
info="Enable flash attention for faster inference (requires flash_attn package)" if flash_attn_available else "Flash attention not available (flash_attn package not installed)"
)
offload_to_cpu_checkbox = gr.Checkbox(
label="Offload to CPU",
value=False,
info="Offload models to CPU when not in use to save GPU memory"
)
offload_dit_to_cpu_checkbox = gr.Checkbox(
label="Offload DiT to CPU",
value=False,
info="Offload DiT model to CPU when not in use (only effective if Offload to CPU is checked)"
)
init_btn = gr.Button("Initialize Service", variant="primary", size="lg")
init_status = gr.Textbox(label="Status", interactive=False, lines=3)
# Inputs
with gr.Row():
with gr.Column(scale=2):
with gr.Accordion("📝 Required Inputs", open=True):
# Task type
# Determine initial task_type choices based on default model
default_model_lower = (default_model or "").lower()
if "turbo" in default_model_lower:
initial_task_choices = ["text2music", "repaint", "cover"]
else:
initial_task_choices = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
with gr.Row():
with gr.Column(scale=2):
task_type = gr.Dropdown(
choices=initial_task_choices,
value="text2music",
label="Task Type",
info="Select the task type for generation",
)
with gr.Column(scale=8):
instruction_display_gen = gr.Textbox(
label="Instruction",
value="Fill the audio semantic mask based on the given conditions:",
interactive=False,
lines=1,
info="Instruction is automatically generated based on task type",
)
track_name = gr.Dropdown(
choices=["woodwinds", "brass", "fx", "synth", "strings", "percussion",
"keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"],
value=None,
label="Track Name",
info="Select track name for lego/extract tasks",
visible=False
)
complete_track_classes = gr.CheckboxGroup(
choices=["woodwinds", "brass", "fx", "synth", "strings", "percussion",
"keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"],
label="Track Names",
info="Select multiple track classes for complete task",
visible=False
)
# Audio uploads
with gr.Accordion("🎵 Audio Uploads", open=False):
with gr.Row():
with gr.Column(scale=2):
reference_audio = gr.Audio(
label="Reference Audio (optional)",
type="filepath",
)
with gr.Column(scale=8):
src_audio = gr.Audio(
label="Source Audio (optional)",
type="filepath",
)
audio_code_string = gr.Textbox(
label="Audio Codes (optional)",
placeholder="<|audio_code_10695|><|audio_code_54246|>...",
lines=4,
visible=False,
info="Paste precomputed audio code tokens"
)
# Audio Codes for text2music
with gr.Accordion("🎼 Audio Codes (for text2music)", open=True, visible=True) as text2music_audio_codes_group:
text2music_audio_code_string = gr.Textbox(
label="Audio Codes",
placeholder="<|audio_code_10695|><|audio_code_54246|>...",
lines=6,
info="Paste precomputed audio code tokens for text2music generation"
)
# 5Hz LM
with gr.Row(visible=True) as use_5hz_lm_row:
use_5hz_lm_btn = gr.Button(
"Generate LM Hints",
variant="secondary",
size="lg",
)
lm_temperature = gr.Slider(
label="Temperature",
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.1,
scale=2,
info="Temperature for 5Hz LM sampling"
)
# Repainting controls
with gr.Group(visible=False) as repainting_group:
gr.HTML("