Spaces:
Running
on
A100
Running
on
A100
refact ui
Browse files- README.md +1 -1
- acestep/gradio_ui/events/__init__.py +192 -14
- acestep/gradio_ui/events/generation_handlers.py +99 -28
- acestep/gradio_ui/events/results_handlers.py +30 -89
- acestep/gradio_ui/i18n/en.json +2 -0
- acestep/gradio_ui/i18n/ja.json +2 -0
- acestep/gradio_ui/i18n/zh.json +2 -0
- acestep/gradio_ui/interfaces/__init__.py +10 -2
- acestep/gradio_ui/interfaces/generation.py +375 -448
- acestep/gradio_ui/interfaces/result.py +150 -105
- acestep/handler.py +67 -35
- app.py +121 -41
README.md
CHANGED
|
@@ -16,7 +16,7 @@ short_description: Music Generation Foundation Model v1.5
|
|
| 16 |
<a href="https://ace-step-v1.5.github.io">Project</a> |
|
| 17 |
<a href="https://huggingface.co/collections/ACE-Step/ace-step-15">Hugging Face</a> |
|
| 18 |
<a href="https://modelscope.cn/models/ACE-Step/ACE-Step-v1-5">ModelScope</a> |
|
| 19 |
-
<a href="https://huggingface.co/spaces/ACE-Step/
|
| 20 |
<a href="https://discord.gg/PeWDxrkdj7">Discord</a> |
|
| 21 |
<a href="https://arxiv.org/abs/2506.00045">Technical Report</a>
|
| 22 |
</p>
|
|
|
|
| 16 |
<a href="https://ace-step-v1.5.github.io">Project</a> |
|
| 17 |
<a href="https://huggingface.co/collections/ACE-Step/ace-step-15">Hugging Face</a> |
|
| 18 |
<a href="https://modelscope.cn/models/ACE-Step/ACE-Step-v1-5">ModelScope</a> |
|
| 19 |
+
<a href="https://huggingface.co/spaces/ACE-Step/Ace-Step-v1.5">Space Demo</a> |
|
| 20 |
<a href="https://discord.gg/PeWDxrkdj7">Discord</a> |
|
| 21 |
<a href="https://arxiv.org/abs/2506.00045">Technical Report</a>
|
| 22 |
</p>
|
acestep/gradio_ui/events/__init__.py
CHANGED
|
@@ -12,8 +12,20 @@ from . import training_handlers as train_h
|
|
| 12 |
from acestep.gradio_ui.i18n import t
|
| 13 |
|
| 14 |
|
| 15 |
-
def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section):
|
| 16 |
-
"""Setup event handlers connecting UI components and business logic
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# ========== Dataset Handlers ==========
|
| 19 |
dataset_section["import_dataset_btn"].click(
|
|
@@ -260,17 +272,42 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 260 |
]
|
| 261 |
)
|
| 262 |
|
| 263 |
-
# ==========
|
| 264 |
generation_section["generation_mode"].change(
|
| 265 |
fn=gen_h.handle_generation_mode_change,
|
| 266 |
inputs=[generation_section["generation_mode"]],
|
| 267 |
outputs=[
|
| 268 |
generation_section["simple_mode_group"],
|
| 269 |
-
generation_section["
|
| 270 |
-
generation_section["
|
|
|
|
|
|
|
| 271 |
generation_section["generate_btn"],
|
| 272 |
generation_section["simple_sample_created"],
|
| 273 |
-
generation_section["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
]
|
| 275 |
)
|
| 276 |
|
|
@@ -451,10 +488,28 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 451 |
],
|
| 452 |
js=download_existing_js # Run the above JS
|
| 453 |
)
|
| 454 |
-
# ========== Send to
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
for btn_idx in range(1, 9):
|
| 456 |
-
results_section[f"
|
| 457 |
-
fn=
|
| 458 |
inputs=[
|
| 459 |
results_section[f"generated_audio_{btn_idx}"],
|
| 460 |
results_section["lm_metadata_state"]
|
|
@@ -468,7 +523,50 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 468 |
generation_section["key_scale"],
|
| 469 |
generation_section["vocal_language"],
|
| 470 |
generation_section["time_signature"],
|
| 471 |
-
results_section["is_format_caption_state"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
]
|
| 473 |
)
|
| 474 |
|
|
@@ -519,12 +617,84 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 519 |
]
|
| 520 |
)
|
| 521 |
|
| 522 |
-
def generation_wrapper(*args):
|
| 523 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
# ========== Generation Handler ==========
|
| 525 |
generation_section["generate_btn"].click(
|
| 526 |
fn=generation_wrapper,
|
| 527 |
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
generation_section["captions"],
|
| 529 |
generation_section["lyrics"],
|
| 530 |
generation_section["bpm"],
|
|
@@ -634,8 +804,12 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 634 |
results_section["restore_params_btn"],
|
| 635 |
]
|
| 636 |
).then(
|
| 637 |
-
fn=lambda *args: res_h.generate_next_batch_background(
|
|
|
|
|
|
|
|
|
|
| 638 |
inputs=[
|
|
|
|
| 639 |
generation_section["autogen_checkbox"],
|
| 640 |
results_section["generation_params_state"],
|
| 641 |
results_section["current_batch_index"],
|
|
@@ -819,8 +993,12 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 819 |
results_section["restore_params_btn"],
|
| 820 |
]
|
| 821 |
).then(
|
| 822 |
-
fn=lambda *args: res_h.generate_next_batch_background(
|
|
|
|
|
|
|
|
|
|
| 823 |
inputs=[
|
|
|
|
| 824 |
generation_section["autogen_checkbox"],
|
| 825 |
results_section["generation_params_state"],
|
| 826 |
results_section["current_batch_index"],
|
|
|
|
| 12 |
from acestep.gradio_ui.i18n import t
|
| 13 |
|
| 14 |
|
| 15 |
+
def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section, init_params=None):
|
| 16 |
+
"""Setup event handlers connecting UI components and business logic
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
init_params: Dictionary containing initialization parameters including:
|
| 20 |
+
- dit_handler_2: Optional second DiT handler for multi-model setup
|
| 21 |
+
- available_dit_models: List of available DiT model names
|
| 22 |
+
- config_path: Primary model config path
|
| 23 |
+
- config_path_2: Secondary model config path (if available)
|
| 24 |
+
"""
|
| 25 |
+
# Get secondary DiT handler from init_params (for multi-model support)
|
| 26 |
+
dit_handler_2 = init_params.get('dit_handler_2') if init_params else None
|
| 27 |
+
config_path_1 = init_params.get('config_path', '') if init_params else ''
|
| 28 |
+
config_path_2 = init_params.get('config_path_2', '') if init_params else ''
|
| 29 |
|
| 30 |
# ========== Dataset Handlers ==========
|
| 31 |
dataset_section["import_dataset_btn"].click(
|
|
|
|
| 272 |
]
|
| 273 |
)
|
| 274 |
|
| 275 |
+
# ========== Generation Mode Toggle (Simple/Custom/Cover/Repaint) ==========
|
| 276 |
generation_section["generation_mode"].change(
|
| 277 |
fn=gen_h.handle_generation_mode_change,
|
| 278 |
inputs=[generation_section["generation_mode"]],
|
| 279 |
outputs=[
|
| 280 |
generation_section["simple_mode_group"],
|
| 281 |
+
generation_section["custom_mode_content"],
|
| 282 |
+
generation_section["cover_mode_group"],
|
| 283 |
+
generation_section["repainting_group"],
|
| 284 |
+
generation_section["task_type"],
|
| 285 |
generation_section["generate_btn"],
|
| 286 |
generation_section["simple_sample_created"],
|
| 287 |
+
generation_section["src_audio_group"],
|
| 288 |
+
generation_section["audio_cover_strength"],
|
| 289 |
+
]
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# ========== Process Source Audio Button ==========
|
| 293 |
+
# Combines Convert to Codes + Transcribe in one step
|
| 294 |
+
generation_section["process_src_btn"].click(
|
| 295 |
+
fn=lambda src, debug: gen_h.process_source_audio(dit_handler, llm_handler, src, debug),
|
| 296 |
+
inputs=[
|
| 297 |
+
generation_section["src_audio"],
|
| 298 |
+
generation_section["constrained_decoding_debug"]
|
| 299 |
+
],
|
| 300 |
+
outputs=[
|
| 301 |
+
generation_section["text2music_audio_code_string"],
|
| 302 |
+
results_section["status_output"],
|
| 303 |
+
generation_section["captions"],
|
| 304 |
+
generation_section["lyrics"],
|
| 305 |
+
generation_section["bpm"],
|
| 306 |
+
generation_section["audio_duration"],
|
| 307 |
+
generation_section["key_scale"],
|
| 308 |
+
generation_section["vocal_language"],
|
| 309 |
+
generation_section["time_signature"],
|
| 310 |
+
results_section["is_format_caption_state"],
|
| 311 |
]
|
| 312 |
)
|
| 313 |
|
|
|
|
| 488 |
],
|
| 489 |
js=download_existing_js # Run the above JS
|
| 490 |
)
|
| 491 |
+
# ========== Send to Cover Handlers ==========
|
| 492 |
+
def send_to_cover_handler(audio_file, lm_metadata):
|
| 493 |
+
"""Send audio to cover mode and switch to cover"""
|
| 494 |
+
if audio_file is None:
|
| 495 |
+
return (gr.skip(),) * 11
|
| 496 |
+
return (
|
| 497 |
+
audio_file, # src_audio
|
| 498 |
+
gr.skip(), # bpm
|
| 499 |
+
gr.skip(), # captions
|
| 500 |
+
gr.skip(), # lyrics
|
| 501 |
+
gr.skip(), # audio_duration
|
| 502 |
+
gr.skip(), # key_scale
|
| 503 |
+
gr.skip(), # vocal_language
|
| 504 |
+
gr.skip(), # time_signature
|
| 505 |
+
gr.skip(), # is_format_caption_state
|
| 506 |
+
"cover", # generation_mode - switch to cover
|
| 507 |
+
"cover", # task_type - set to cover
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
for btn_idx in range(1, 9):
|
| 511 |
+
results_section[f"send_to_cover_btn_{btn_idx}"].click(
|
| 512 |
+
fn=send_to_cover_handler,
|
| 513 |
inputs=[
|
| 514 |
results_section[f"generated_audio_{btn_idx}"],
|
| 515 |
results_section["lm_metadata_state"]
|
|
|
|
| 523 |
generation_section["key_scale"],
|
| 524 |
generation_section["vocal_language"],
|
| 525 |
generation_section["time_signature"],
|
| 526 |
+
results_section["is_format_caption_state"],
|
| 527 |
+
generation_section["generation_mode"],
|
| 528 |
+
generation_section["task_type"],
|
| 529 |
+
]
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
# ========== Send to Repaint Handlers ==========
|
| 533 |
+
def send_to_repaint_handler(audio_file, lm_metadata):
|
| 534 |
+
"""Send audio to repaint mode and switch to repaint"""
|
| 535 |
+
if audio_file is None:
|
| 536 |
+
return (gr.skip(),) * 11
|
| 537 |
+
return (
|
| 538 |
+
audio_file, # src_audio
|
| 539 |
+
gr.skip(), # bpm
|
| 540 |
+
gr.skip(), # captions
|
| 541 |
+
gr.skip(), # lyrics
|
| 542 |
+
gr.skip(), # audio_duration
|
| 543 |
+
gr.skip(), # key_scale
|
| 544 |
+
gr.skip(), # vocal_language
|
| 545 |
+
gr.skip(), # time_signature
|
| 546 |
+
gr.skip(), # is_format_caption_state
|
| 547 |
+
"repaint", # generation_mode - switch to repaint
|
| 548 |
+
"repaint", # task_type - set to repaint
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
for btn_idx in range(1, 9):
|
| 552 |
+
results_section[f"send_to_repaint_btn_{btn_idx}"].click(
|
| 553 |
+
fn=send_to_repaint_handler,
|
| 554 |
+
inputs=[
|
| 555 |
+
results_section[f"generated_audio_{btn_idx}"],
|
| 556 |
+
results_section["lm_metadata_state"]
|
| 557 |
+
],
|
| 558 |
+
outputs=[
|
| 559 |
+
generation_section["src_audio"],
|
| 560 |
+
generation_section["bpm"],
|
| 561 |
+
generation_section["captions"],
|
| 562 |
+
generation_section["lyrics"],
|
| 563 |
+
generation_section["audio_duration"],
|
| 564 |
+
generation_section["key_scale"],
|
| 565 |
+
generation_section["vocal_language"],
|
| 566 |
+
generation_section["time_signature"],
|
| 567 |
+
results_section["is_format_caption_state"],
|
| 568 |
+
generation_section["generation_mode"],
|
| 569 |
+
generation_section["task_type"],
|
| 570 |
]
|
| 571 |
)
|
| 572 |
|
|
|
|
| 617 |
]
|
| 618 |
)
|
| 619 |
|
| 620 |
+
def generation_wrapper(selected_model, generation_mode, simple_query_input, simple_vocal_language, *args):
|
| 621 |
+
"""Wrapper that selects the appropriate DiT handler based on model selection"""
|
| 622 |
+
# Convert args to list for modification
|
| 623 |
+
args_list = list(args)
|
| 624 |
+
|
| 625 |
+
# args order (after simple mode params):
|
| 626 |
+
# captions (0), lyrics (1), bpm (2), key_scale (3), time_signature (4), vocal_language (5),
|
| 627 |
+
# inference_steps (6), guidance_scale (7), random_seed_checkbox (8), seed (9),
|
| 628 |
+
# reference_audio (10), audio_duration (11), batch_size_input (12), src_audio (13),
|
| 629 |
+
# text2music_audio_code_string (14), repainting_start (15), repainting_end (16),
|
| 630 |
+
# instruction_display_gen (17), audio_cover_strength (18), task_type (19), ...
|
| 631 |
+
# ... lm_temperature (27), think_checkbox (28), ...
|
| 632 |
+
# ... instrumental_checkbox (at position after all regular params)
|
| 633 |
+
|
| 634 |
+
src_audio = args_list[13] if len(args_list) > 13 else None
|
| 635 |
+
task_type = args_list[19] if len(args_list) > 19 else "text2music"
|
| 636 |
+
|
| 637 |
+
# Validate: Cover and Repaint modes require source audio
|
| 638 |
+
if task_type in ["cover", "repaint"] and src_audio is None:
|
| 639 |
+
raise gr.Error(f"Source Audio is required for {task_type.capitalize()} mode. Please upload an audio file.")
|
| 640 |
+
|
| 641 |
+
# Handle Simple mode: first create sample, then generate
|
| 642 |
+
if generation_mode == "simple":
|
| 643 |
+
# Get instrumental from the main checkbox (args[-6] based on input order)
|
| 644 |
+
# The instrumental_checkbox is passed after all the regular generation params
|
| 645 |
+
instrumental = args_list[-6] if len(args_list) > 6 else False # instrumental_checkbox position
|
| 646 |
+
lm_temperature = args_list[27] if len(args_list) > 27 else 0.85
|
| 647 |
+
lm_top_k = args_list[30] if len(args_list) > 30 else 0
|
| 648 |
+
lm_top_p = args_list[31] if len(args_list) > 31 else 0.9
|
| 649 |
+
constrained_decoding_debug = args_list[38] if len(args_list) > 38 else False
|
| 650 |
+
|
| 651 |
+
# Call create_sample to generate caption/lyrics/metadata
|
| 652 |
+
from acestep.inference import create_sample
|
| 653 |
+
|
| 654 |
+
top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
|
| 655 |
+
top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
|
| 656 |
+
|
| 657 |
+
result = create_sample(
|
| 658 |
+
llm_handler=llm_handler,
|
| 659 |
+
query=simple_query_input,
|
| 660 |
+
instrumental=instrumental,
|
| 661 |
+
vocal_language=simple_vocal_language,
|
| 662 |
+
temperature=lm_temperature,
|
| 663 |
+
top_k=top_k_value,
|
| 664 |
+
top_p=top_p_value,
|
| 665 |
+
use_constrained_decoding=True,
|
| 666 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
if not result.success:
|
| 670 |
+
raise gr.Error(f"Failed to create sample: {result.status_message}")
|
| 671 |
+
|
| 672 |
+
# Update args with generated data
|
| 673 |
+
args_list[0] = result.caption # captions
|
| 674 |
+
args_list[1] = result.lyrics # lyrics
|
| 675 |
+
args_list[2] = result.bpm # bpm
|
| 676 |
+
args_list[3] = result.keyscale # key_scale
|
| 677 |
+
args_list[4] = result.timesignature # time_signature
|
| 678 |
+
args_list[5] = result.language # vocal_language
|
| 679 |
+
if result.duration and result.duration > 0:
|
| 680 |
+
args_list[11] = result.duration # audio_duration
|
| 681 |
+
# Enable thinking for Simple mode
|
| 682 |
+
args_list[28] = True # think_checkbox
|
| 683 |
+
|
| 684 |
+
# Determine which handler to use
|
| 685 |
+
active_handler = dit_handler # Default to primary handler
|
| 686 |
+
if dit_handler_2 is not None and selected_model == config_path_2:
|
| 687 |
+
active_handler = dit_handler_2
|
| 688 |
+
yield from res_h.generate_with_batch_management(active_handler, llm_handler, *args_list)
|
| 689 |
+
|
| 690 |
# ========== Generation Handler ==========
|
| 691 |
generation_section["generate_btn"].click(
|
| 692 |
fn=generation_wrapper,
|
| 693 |
inputs=[
|
| 694 |
+
generation_section["dit_model_selector"], # Model selection input
|
| 695 |
+
generation_section["generation_mode"], # For Simple mode detection
|
| 696 |
+
generation_section["simple_query_input"], # Simple mode query
|
| 697 |
+
generation_section["simple_vocal_language"], # Simple mode vocal language
|
| 698 |
generation_section["captions"],
|
| 699 |
generation_section["lyrics"],
|
| 700 |
generation_section["bpm"],
|
|
|
|
| 804 |
results_section["restore_params_btn"],
|
| 805 |
]
|
| 806 |
).then(
|
| 807 |
+
fn=lambda selected_model, *args: res_h.generate_next_batch_background(
|
| 808 |
+
dit_handler_2 if (dit_handler_2 is not None and selected_model == config_path_2) else dit_handler,
|
| 809 |
+
llm_handler, *args
|
| 810 |
+
),
|
| 811 |
inputs=[
|
| 812 |
+
generation_section["dit_model_selector"], # Model selection input
|
| 813 |
generation_section["autogen_checkbox"],
|
| 814 |
results_section["generation_params_state"],
|
| 815 |
results_section["current_batch_index"],
|
|
|
|
| 993 |
results_section["restore_params_btn"],
|
| 994 |
]
|
| 995 |
).then(
|
| 996 |
+
fn=lambda selected_model, *args: res_h.generate_next_batch_background(
|
| 997 |
+
dit_handler_2 if (dit_handler_2 is not None and selected_model == config_path_2) else dit_handler,
|
| 998 |
+
llm_handler, *args
|
| 999 |
+
),
|
| 1000 |
inputs=[
|
| 1001 |
+
generation_section["dit_model_selector"], # Model selection input
|
| 1002 |
generation_section["autogen_checkbox"],
|
| 1003 |
results_section["generation_params_state"],
|
| 1004 |
results_section["current_batch_index"],
|
acestep/gradio_ui/events/generation_handlers.py
CHANGED
|
@@ -480,10 +480,14 @@ def update_negative_prompt_visibility(init_llm_checked):
|
|
| 480 |
|
| 481 |
def update_audio_cover_strength_visibility(task_type_value, init_llm_checked):
|
| 482 |
"""Update audio_cover_strength visibility and label"""
|
| 483 |
-
# Show if task is cover OR if LM is initialized
|
| 484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
# Change label based on context
|
| 486 |
-
if init_llm_checked and
|
| 487 |
label = "LM codes strength"
|
| 488 |
info = "Control how many denoising steps use LM-generated codes"
|
| 489 |
else:
|
|
@@ -518,10 +522,12 @@ def update_instruction_ui(
|
|
| 518 |
track_name_visible = task_type_value in ["lego", "extract"]
|
| 519 |
# Show complete_track_classes for complete
|
| 520 |
complete_visible = task_type_value == "complete"
|
| 521 |
-
# Show audio_cover_strength for cover OR when LM is initialized
|
| 522 |
-
|
|
|
|
|
|
|
| 523 |
# Determine label and info based on context
|
| 524 |
-
if init_llm_checked and
|
| 525 |
audio_cover_strength_label = "LM codes strength"
|
| 526 |
audio_cover_strength_info = "Control how many denoising steps use LM-generated codes"
|
| 527 |
else:
|
|
@@ -605,9 +611,9 @@ def reset_format_caption_flag():
|
|
| 605 |
|
| 606 |
|
| 607 |
def update_audio_uploads_accordion(reference_audio, src_audio):
|
| 608 |
-
"""Update Audio Uploads
|
| 609 |
has_audio = (reference_audio is not None) or (src_audio is not None)
|
| 610 |
-
return gr.
|
| 611 |
|
| 612 |
|
| 613 |
def handle_instrumental_checkbox(instrumental_checked, current_lyrics):
|
|
@@ -682,41 +688,106 @@ def update_audio_components_visibility(batch_size):
|
|
| 682 |
|
| 683 |
def handle_generation_mode_change(mode: str):
|
| 684 |
"""
|
| 685 |
-
Handle generation mode change between Simple and
|
| 686 |
-
|
| 687 |
-
In Simple mode:
|
| 688 |
-
- Show simple mode group (query input, instrumental checkbox, create button)
|
| 689 |
-
- Collapse caption and lyrics accordions
|
| 690 |
-
- Hide optional parameters accordion
|
| 691 |
-
- Disable generate button until sample is created
|
| 692 |
|
| 693 |
-
|
| 694 |
-
-
|
| 695 |
-
-
|
| 696 |
-
- Show
|
| 697 |
-
-
|
| 698 |
|
| 699 |
Args:
|
| 700 |
-
mode: "simple" or "
|
| 701 |
|
| 702 |
Returns:
|
| 703 |
Tuple of updates for:
|
| 704 |
- simple_mode_group (visibility)
|
| 705 |
-
-
|
| 706 |
-
-
|
|
|
|
|
|
|
| 707 |
- generate_btn (interactive state)
|
| 708 |
- simple_sample_created (reset state)
|
| 709 |
-
-
|
|
|
|
| 710 |
"""
|
| 711 |
is_simple = mode == "simple"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 712 |
|
| 713 |
return (
|
| 714 |
gr.update(visible=is_simple), # simple_mode_group
|
| 715 |
-
gr.
|
| 716 |
-
gr.
|
| 717 |
-
gr.update(
|
|
|
|
|
|
|
| 718 |
False, # simple_sample_created - reset to False on mode change
|
| 719 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
)
|
| 721 |
|
| 722 |
|
|
|
|
| 480 |
|
| 481 |
def update_audio_cover_strength_visibility(task_type_value, init_llm_checked):
|
| 482 |
"""Update audio_cover_strength visibility and label"""
|
| 483 |
+
# Show if task is cover OR if LM is initialized (but NOT for repaint mode)
|
| 484 |
+
# Repaint mode never shows this control
|
| 485 |
+
is_repaint = task_type_value == "repaint"
|
| 486 |
+
is_cover = task_type_value == "cover"
|
| 487 |
+
is_visible = is_cover or (init_llm_checked and not is_repaint)
|
| 488 |
+
|
| 489 |
# Change label based on context
|
| 490 |
+
if init_llm_checked and not is_cover:
|
| 491 |
label = "LM codes strength"
|
| 492 |
info = "Control how many denoising steps use LM-generated codes"
|
| 493 |
else:
|
|
|
|
| 522 |
track_name_visible = task_type_value in ["lego", "extract"]
|
| 523 |
# Show complete_track_classes for complete
|
| 524 |
complete_visible = task_type_value == "complete"
|
| 525 |
+
# Show audio_cover_strength for cover OR when LM is initialized (but NOT for repaint)
|
| 526 |
+
is_repaint = task_type_value == "repaint"
|
| 527 |
+
is_cover = task_type_value == "cover"
|
| 528 |
+
audio_cover_strength_visible = is_cover or (init_llm_checked and not is_repaint)
|
| 529 |
# Determine label and info based on context
|
| 530 |
+
if init_llm_checked and not is_cover:
|
| 531 |
audio_cover_strength_label = "LM codes strength"
|
| 532 |
audio_cover_strength_info = "Control how many denoising steps use LM-generated codes"
|
| 533 |
else:
|
|
|
|
| 611 |
|
| 612 |
|
| 613 |
def update_audio_uploads_accordion(reference_audio, src_audio):
|
| 614 |
+
"""Update Audio Uploads visibility based on whether audio files are present"""
|
| 615 |
has_audio = (reference_audio is not None) or (src_audio is not None)
|
| 616 |
+
return gr.update(visible=has_audio)
|
| 617 |
|
| 618 |
|
| 619 |
def handle_instrumental_checkbox(instrumental_checked, current_lyrics):
|
|
|
|
| 688 |
|
| 689 |
def handle_generation_mode_change(mode: str):
|
| 690 |
"""
|
| 691 |
+
Handle generation mode change between Simple, Custom, Cover, and Repaint modes.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 692 |
|
| 693 |
+
Modes:
|
| 694 |
+
- Simple: Show simple mode group, hide others
|
| 695 |
+
- Custom: Show custom content (prompt), hide others
|
| 696 |
+
- Cover: Show src_audio_group + custom content + LM codes strength
|
| 697 |
+
- Repaint: Show src_audio_group + custom content + repaint time controls (hide LM codes strength)
|
| 698 |
|
| 699 |
Args:
|
| 700 |
+
mode: "simple", "custom", "cover", or "repaint"
|
| 701 |
|
| 702 |
Returns:
|
| 703 |
Tuple of updates for:
|
| 704 |
- simple_mode_group (visibility)
|
| 705 |
+
- custom_mode_content (visibility)
|
| 706 |
+
- cover_mode_group (visibility) - legacy, always hidden
|
| 707 |
+
- repainting_group (visibility)
|
| 708 |
+
- task_type (value)
|
| 709 |
- generate_btn (interactive state)
|
| 710 |
- simple_sample_created (reset state)
|
| 711 |
+
- src_audio_group (visibility) - shown for cover and repaint
|
| 712 |
+
- audio_cover_strength (visibility) - shown only for cover mode
|
| 713 |
"""
|
| 714 |
is_simple = mode == "simple"
|
| 715 |
+
is_custom = mode == "custom"
|
| 716 |
+
is_cover = mode == "cover"
|
| 717 |
+
is_repaint = mode == "repaint"
|
| 718 |
+
|
| 719 |
+
# Map mode to task_type
|
| 720 |
+
task_type_map = {
|
| 721 |
+
"simple": "text2music",
|
| 722 |
+
"custom": "text2music",
|
| 723 |
+
"cover": "cover",
|
| 724 |
+
"repaint": "repaint",
|
| 725 |
+
}
|
| 726 |
+
task_type_value = task_type_map.get(mode, "text2music")
|
| 727 |
|
| 728 |
return (
|
| 729 |
gr.update(visible=is_simple), # simple_mode_group
|
| 730 |
+
gr.update(visible=not is_simple), # custom_mode_content - visible for custom/cover/repaint
|
| 731 |
+
gr.update(visible=False), # cover_mode_group - legacy, always hidden
|
| 732 |
+
gr.update(visible=is_repaint), # repainting_group - time range controls
|
| 733 |
+
gr.update(value=task_type_value), # task_type
|
| 734 |
+
gr.update(interactive=True), # generate_btn - always enabled (Simple mode does create+generate in one step)
|
| 735 |
False, # simple_sample_created - reset to False on mode change
|
| 736 |
+
gr.update(visible=is_cover or is_repaint), # src_audio_group - shown for cover and repaint
|
| 737 |
+
gr.update(visible=is_cover), # audio_cover_strength - only shown for cover mode
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
def process_source_audio(dit_handler, llm_handler, src_audio, constrained_decoding_debug):
|
| 742 |
+
"""
|
| 743 |
+
Process source audio: convert to codes and then transcribe.
|
| 744 |
+
This combines convert_src_audio_to_codes_wrapper + transcribe_audio_codes.
|
| 745 |
+
|
| 746 |
+
Args:
|
| 747 |
+
dit_handler: DiT handler instance for audio code conversion
|
| 748 |
+
llm_handler: LLM handler instance for transcription
|
| 749 |
+
src_audio: Path to source audio file
|
| 750 |
+
constrained_decoding_debug: Whether to enable debug logging
|
| 751 |
+
|
| 752 |
+
Returns:
|
| 753 |
+
Tuple of (audio_codes, status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature, is_format_caption)
|
| 754 |
+
"""
|
| 755 |
+
if src_audio is None:
|
| 756 |
+
return ("", "No audio file provided", "", "", None, None, "", "", "", False)
|
| 757 |
+
|
| 758 |
+
# Step 1: Convert audio to codes
|
| 759 |
+
try:
|
| 760 |
+
codes_string = dit_handler.convert_src_audio_to_codes(src_audio)
|
| 761 |
+
if not codes_string:
|
| 762 |
+
return ("", "Failed to convert audio to codes", "", "", None, None, "", "", "", False)
|
| 763 |
+
except Exception as e:
|
| 764 |
+
return ("", f"Error converting audio: {str(e)}", "", "", None, None, "", "", "", False)
|
| 765 |
+
|
| 766 |
+
# Step 2: Transcribe the codes
|
| 767 |
+
result = understand_music(
|
| 768 |
+
llm_handler=llm_handler,
|
| 769 |
+
audio_codes=codes_string,
|
| 770 |
+
use_constrained_decoding=True,
|
| 771 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
# Handle error case
|
| 775 |
+
if not result.success:
|
| 776 |
+
if result.error == "LLM not initialized":
|
| 777 |
+
return (codes_string, t("messages.lm_not_initialized"), "", "", None, None, "", "", "", False)
|
| 778 |
+
return (codes_string, result.status_message, "", "", None, None, "", "", "", False)
|
| 779 |
+
|
| 780 |
+
return (
|
| 781 |
+
codes_string,
|
| 782 |
+
result.status_message,
|
| 783 |
+
result.caption,
|
| 784 |
+
result.lyrics,
|
| 785 |
+
result.bpm,
|
| 786 |
+
result.duration,
|
| 787 |
+
result.keyscale,
|
| 788 |
+
result.language,
|
| 789 |
+
result.timesignature,
|
| 790 |
+
True # Set is_format_caption to True
|
| 791 |
)
|
| 792 |
|
| 793 |
|
acestep/gradio_ui/events/results_handlers.py
CHANGED
|
@@ -265,106 +265,45 @@ def _build_generation_info(
|
|
| 265 |
Formatted generation info string
|
| 266 |
"""
|
| 267 |
info_parts = []
|
|
|
|
| 268 |
|
| 269 |
-
# Part 1:
|
| 270 |
-
# Only count model time (LM + DiT), not post-processing like audio conversion
|
| 271 |
-
if time_costs and num_audios > 0:
|
| 272 |
-
lm_total = time_costs.get('lm_total_time', 0.0)
|
| 273 |
-
dit_total = time_costs.get('dit_total_time_cost', 0.0)
|
| 274 |
-
model_total = lm_total + dit_total
|
| 275 |
-
if model_total > 0:
|
| 276 |
-
avg_time_per_track = model_total / num_audios
|
| 277 |
-
avg_section = f"**🎯 Average Time per Track: {avg_time_per_track:.2f}s** ({num_audios} track(s))"
|
| 278 |
-
info_parts.append(avg_section)
|
| 279 |
-
|
| 280 |
-
# Part 2: LM-generated metadata (if available)
|
| 281 |
-
if lm_metadata:
|
| 282 |
-
metadata_lines = []
|
| 283 |
-
if lm_metadata.get('bpm'):
|
| 284 |
-
metadata_lines.append(f"- **BPM:** {lm_metadata['bpm']}")
|
| 285 |
-
if lm_metadata.get('caption'):
|
| 286 |
-
metadata_lines.append(f"- **Refined Caption:** {lm_metadata['caption']}")
|
| 287 |
-
if lm_metadata.get('lyrics'):
|
| 288 |
-
metadata_lines.append(f"- **Refined Lyrics:** {lm_metadata['lyrics']}")
|
| 289 |
-
if lm_metadata.get('duration'):
|
| 290 |
-
metadata_lines.append(f"- **Duration:** {lm_metadata['duration']} seconds")
|
| 291 |
-
if lm_metadata.get('keyscale'):
|
| 292 |
-
metadata_lines.append(f"- **Key Scale:** {lm_metadata['keyscale']}")
|
| 293 |
-
if lm_metadata.get('language'):
|
| 294 |
-
metadata_lines.append(f"- **Language:** {lm_metadata['language']}")
|
| 295 |
-
if lm_metadata.get('timesignature'):
|
| 296 |
-
metadata_lines.append(f"- **Time Signature:** {lm_metadata['timesignature']}")
|
| 297 |
-
|
| 298 |
-
if metadata_lines:
|
| 299 |
-
metadata_section = "**🤖 LM-Generated Metadata:**\n" + "\n".join(metadata_lines)
|
| 300 |
-
info_parts.append(metadata_section)
|
| 301 |
-
|
| 302 |
-
# Part 3: Time costs breakdown (formatted and beautified)
|
| 303 |
if time_costs:
|
| 304 |
-
time_lines = []
|
| 305 |
-
|
| 306 |
-
# LM time costs
|
| 307 |
-
lm_phase1 = time_costs.get('lm_phase1_time', 0.0)
|
| 308 |
-
lm_phase2 = time_costs.get('lm_phase2_time', 0.0)
|
| 309 |
lm_total = time_costs.get('lm_total_time', 0.0)
|
| 310 |
-
|
| 311 |
-
if lm_total > 0:
|
| 312 |
-
time_lines.append("**🧠 LM Time:**")
|
| 313 |
-
if lm_phase1 > 0:
|
| 314 |
-
time_lines.append(f" - Phase 1 (CoT): {lm_phase1:.2f}s")
|
| 315 |
-
if lm_phase2 > 0:
|
| 316 |
-
time_lines.append(f" - Phase 2 (Codes): {lm_phase2:.2f}s")
|
| 317 |
-
time_lines.append(f" - Total: {lm_total:.2f}s")
|
| 318 |
-
|
| 319 |
-
# DiT time costs
|
| 320 |
-
dit_encoder = time_costs.get('dit_encoder_time_cost', 0.0)
|
| 321 |
-
dit_model = time_costs.get('dit_model_time_cost', 0.0)
|
| 322 |
-
dit_vae_decode = time_costs.get('dit_vae_decode_time_cost', 0.0)
|
| 323 |
-
dit_offload = time_costs.get('dit_offload_time_cost', 0.0)
|
| 324 |
dit_total = time_costs.get('dit_total_time_cost', 0.0)
|
| 325 |
-
|
| 326 |
-
time_lines.append("\n**🎵 DiT Time:**")
|
| 327 |
-
if dit_encoder > 0:
|
| 328 |
-
time_lines.append(f" - Encoder: {dit_encoder:.2f}s")
|
| 329 |
-
if dit_model > 0:
|
| 330 |
-
time_lines.append(f" - Model: {dit_model:.2f}s")
|
| 331 |
-
if dit_vae_decode > 0:
|
| 332 |
-
time_lines.append(f" - VAE Decode: {dit_vae_decode:.2f}s")
|
| 333 |
-
if dit_offload > 0:
|
| 334 |
-
time_lines.append(f" - Offload: {dit_offload:.2f}s")
|
| 335 |
-
time_lines.append(f" - Total: {dit_total:.2f}s")
|
| 336 |
|
| 337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
audio_conversion_time = time_costs.get('audio_conversion_time', 0.0)
|
| 339 |
auto_score_time = time_costs.get('auto_score_time', 0.0)
|
| 340 |
auto_lrc_time = time_costs.get('auto_lrc_time', 0.0)
|
|
|
|
| 341 |
|
| 342 |
-
if
|
| 343 |
-
|
|
|
|
|
|
|
| 344 |
if audio_conversion_time > 0:
|
| 345 |
-
|
|
|
|
| 346 |
if auto_score_time > 0:
|
| 347 |
-
|
| 348 |
if auto_lrc_time > 0:
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
if time_lines:
|
| 352 |
-
time_section = "\n".join(time_lines)
|
| 353 |
-
info_parts.append(time_section)
|
| 354 |
-
|
| 355 |
-
# Part 4: Generation summary
|
| 356 |
-
summary_lines = [
|
| 357 |
-
"**🎵 Generation Complete**",
|
| 358 |
-
f" - **Seeds:** {seed_value}",
|
| 359 |
-
f" - **Steps:** {inference_steps}",
|
| 360 |
-
f" - **Audio Count:** {num_audios} audio(s)",
|
| 361 |
-
]
|
| 362 |
-
info_parts.append("\n".join(summary_lines))
|
| 363 |
-
|
| 364 |
-
# Part 5: Pipeline total time (at the end)
|
| 365 |
-
pipeline_total = time_costs.get('pipeline_total_time', 0.0) if time_costs else 0.0
|
| 366 |
-
if pipeline_total > 0:
|
| 367 |
-
info_parts.append(f"**⏱️ Total Time: {pipeline_total:.2f}s**")
|
| 368 |
|
| 369 |
# Combine all parts
|
| 370 |
return "\n\n".join(info_parts)
|
|
@@ -775,7 +714,9 @@ def generate_with_progress(
|
|
| 775 |
codes_display_updates[i] = gr.update(value=code_str, visible=True) # Keep visible=True
|
| 776 |
|
| 777 |
details_accordion_updates = [gr.skip() for _ in range(8)]
|
| 778 |
-
#
|
|
|
|
|
|
|
| 779 |
|
| 780 |
# Clear LRC first (this triggers .change() to clear subtitles)
|
| 781 |
# Keep visible=True to ensure .change() event is properly triggered
|
|
|
|
| 265 |
Formatted generation info string
|
| 266 |
"""
|
| 267 |
info_parts = []
|
| 268 |
+
songs_label = f"({num_audios} songs)"
|
| 269 |
|
| 270 |
+
# Part 1: Total generation time (LM + DiT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
if time_costs:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
lm_total = time_costs.get('lm_total_time', 0.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
dit_total = time_costs.get('dit_total_time_cost', 0.0)
|
| 274 |
+
generation_total = lm_total + dit_total
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
+
if generation_total > 0:
|
| 277 |
+
avg_per_song = generation_total / num_audios if num_audios > 0 else 0
|
| 278 |
+
gen_lines = [
|
| 279 |
+
f"**🎵 Total generation time {songs_label}: {generation_total:.2f}s**",
|
| 280 |
+
f"**{avg_per_song:.2f}s per song**",
|
| 281 |
+
]
|
| 282 |
+
if lm_total > 0:
|
| 283 |
+
gen_lines.append(f"- LM phase {songs_label}: {lm_total:.2f}s")
|
| 284 |
+
if dit_total > 0:
|
| 285 |
+
gen_lines.append(f"- DiT phase {songs_label}: {dit_total:.2f}s")
|
| 286 |
+
info_parts.append("\n".join(gen_lines))
|
| 287 |
+
|
| 288 |
+
# Part 2: Total processing time (post-processing)
|
| 289 |
+
if time_costs:
|
| 290 |
audio_conversion_time = time_costs.get('audio_conversion_time', 0.0)
|
| 291 |
auto_score_time = time_costs.get('auto_score_time', 0.0)
|
| 292 |
auto_lrc_time = time_costs.get('auto_lrc_time', 0.0)
|
| 293 |
+
processing_total = audio_conversion_time + auto_score_time + auto_lrc_time
|
| 294 |
|
| 295 |
+
if processing_total > 0:
|
| 296 |
+
proc_lines = [
|
| 297 |
+
f"**🔧 Total processing time {songs_label}: {processing_total:.2f}s**",
|
| 298 |
+
]
|
| 299 |
if audio_conversion_time > 0:
|
| 300 |
+
info_format = time_costs.get('audio_format', 'mp3')
|
| 301 |
+
proc_lines.append(f"- to {info_format} {songs_label}: {audio_conversion_time:.2f}s")
|
| 302 |
if auto_score_time > 0:
|
| 303 |
+
proc_lines.append(f"- scoring {songs_label}: {auto_score_time:.2f}s")
|
| 304 |
if auto_lrc_time > 0:
|
| 305 |
+
proc_lines.append(f"- LRC detection {songs_label}: {auto_lrc_time:.2f}s")
|
| 306 |
+
info_parts.append("\n".join(proc_lines))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
|
| 308 |
# Combine all parts
|
| 309 |
return "\n\n".join(info_parts)
|
|
|
|
| 714 |
codes_display_updates[i] = gr.update(value=code_str, visible=True) # Keep visible=True
|
| 715 |
|
| 716 |
details_accordion_updates = [gr.skip() for _ in range(8)]
|
| 717 |
+
# Auto-expand accordion if auto_score or auto_lrc is enabled
|
| 718 |
+
if auto_score or auto_lrc:
|
| 719 |
+
details_accordion_updates[i] = gr.Accordion(open=True)
|
| 720 |
|
| 721 |
# Clear LRC first (this triggers .change() to clear subtitles)
|
| 722 |
# Keep visible=True to ensure .change() event is properly triggered
|
acestep/gradio_ui/i18n/en.json
CHANGED
|
@@ -174,6 +174,8 @@
|
|
| 174 |
"title": "🎵 Results",
|
| 175 |
"generated_music": "🎵 Generated Music (Sample {n})",
|
| 176 |
"send_to_src_btn": "🔗 Send To Src Audio",
|
|
|
|
|
|
|
| 177 |
"save_btn": "💾 Save",
|
| 178 |
"score_btn": "📊 Score",
|
| 179 |
"lrc_btn": "🎵 LRC",
|
|
|
|
| 174 |
"title": "🎵 Results",
|
| 175 |
"generated_music": "🎵 Generated Music (Sample {n})",
|
| 176 |
"send_to_src_btn": "🔗 Send To Src Audio",
|
| 177 |
+
"send_to_cover_btn": "🔗 Send To Cover",
|
| 178 |
+
"send_to_repaint_btn": "🔗 Send To Repaint",
|
| 179 |
"save_btn": "💾 Save",
|
| 180 |
"score_btn": "📊 Score",
|
| 181 |
"lrc_btn": "🎵 LRC",
|
acestep/gradio_ui/i18n/ja.json
CHANGED
|
@@ -174,6 +174,8 @@
|
|
| 174 |
"title": "🎵 結果",
|
| 175 |
"generated_music": "🎵 生成された音楽(サンプル {n})",
|
| 176 |
"send_to_src_btn": "🔗 ソースオーディオに送信",
|
|
|
|
|
|
|
| 177 |
"save_btn": "💾 保存",
|
| 178 |
"score_btn": "📊 スコア",
|
| 179 |
"lrc_btn": "🎵 LRC",
|
|
|
|
| 174 |
"title": "🎵 結果",
|
| 175 |
"generated_music": "🎵 生成された音楽(サンプル {n})",
|
| 176 |
"send_to_src_btn": "🔗 ソースオーディオに送信",
|
| 177 |
+
"send_to_cover_btn": "🔗 Send To Cover",
|
| 178 |
+
"send_to_repaint_btn": "🔗 Send To Repaint",
|
| 179 |
"save_btn": "💾 保存",
|
| 180 |
"score_btn": "📊 スコア",
|
| 181 |
"lrc_btn": "🎵 LRC",
|
acestep/gradio_ui/i18n/zh.json
CHANGED
|
@@ -174,6 +174,8 @@
|
|
| 174 |
"title": "🎵 结果",
|
| 175 |
"generated_music": "🎵 生成的音乐(样本 {n})",
|
| 176 |
"send_to_src_btn": "🔗 发送到源音频",
|
|
|
|
|
|
|
| 177 |
"save_btn": "💾 保存",
|
| 178 |
"score_btn": "📊 评分",
|
| 179 |
"lrc_btn": "🎵 LRC",
|
|
|
|
| 174 |
"title": "🎵 结果",
|
| 175 |
"generated_music": "🎵 生成的音乐(样本 {n})",
|
| 176 |
"send_to_src_btn": "🔗 发送到源音频",
|
| 177 |
+
"send_to_cover_btn": "🔗 Send To Cover",
|
| 178 |
+
"send_to_repaint_btn": "🔗 Send To Repaint",
|
| 179 |
"save_btn": "💾 保存",
|
| 180 |
"score_btn": "📊 评分",
|
| 181 |
"lrc_btn": "🎵 LRC",
|
acestep/gradio_ui/interfaces/__init__.py
CHANGED
|
@@ -65,6 +65,14 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_para
|
|
| 65 |
<div class="main-header">
|
| 66 |
<h1>{t("app.title")}</h1>
|
| 67 |
<p>{t("app.subtitle")}</p>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
</div>
|
| 69 |
""")
|
| 70 |
|
|
@@ -81,8 +89,8 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_para
|
|
| 81 |
# Pass init_params to support hiding in service mode
|
| 82 |
training_section = create_training_section(dit_handler, llm_handler, init_params=init_params)
|
| 83 |
|
| 84 |
-
# Connect event handlers
|
| 85 |
-
setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section)
|
| 86 |
|
| 87 |
# Connect training event handlers
|
| 88 |
setup_training_event_handlers(demo, dit_handler, llm_handler, training_section)
|
|
|
|
| 65 |
<div class="main-header">
|
| 66 |
<h1>{t("app.title")}</h1>
|
| 67 |
<p>{t("app.subtitle")}</p>
|
| 68 |
+
<p style="margin-top: 0.5rem;">
|
| 69 |
+
<a href="https://ace-step-v1.5.github.io" target="_blank">Project</a> |
|
| 70 |
+
<a href="https://huggingface.co/collections/ACE-Step/ace-step-15" target="_blank">Hugging Face</a> |
|
| 71 |
+
<a href="https://modelscope.cn/models/ACE-Step/ACE-Step-v1-5" target="_blank">ModelScope</a> |
|
| 72 |
+
<a href="https://huggingface.co/spaces/ACE-Step/Ace-Step-v1.5" target="_blank">Space Demo</a> |
|
| 73 |
+
<a href="https://discord.gg/PeWDxrkdj7" target="_blank">Discord</a> |
|
| 74 |
+
<a href="https://arxiv.org/abs/2506.00045" target="_blank">Technical Report</a>
|
| 75 |
+
</p>
|
| 76 |
</div>
|
| 77 |
""")
|
| 78 |
|
|
|
|
| 89 |
# Pass init_params to support hiding in service mode
|
| 90 |
training_section = create_training_section(dit_handler, llm_handler, init_params=init_params)
|
| 91 |
|
| 92 |
+
# Connect event handlers (pass init_params for multi-model support)
|
| 93 |
+
setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section, init_params=init_params)
|
| 94 |
|
| 95 |
# Connect training event handlers
|
| 96 |
setup_training_event_handlers(demo, dit_handler, llm_handler, training_section)
|
acestep/gradio_ui/interfaces/generation.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
Gradio UI Generation Section Module
|
| 3 |
-
Contains generation section component definitions
|
| 4 |
"""
|
| 5 |
import gradio as gr
|
| 6 |
from acestep.constants import (
|
|
@@ -14,7 +14,7 @@ from acestep.gradio_ui.i18n import t
|
|
| 14 |
|
| 15 |
|
| 16 |
def create_generation_section(dit_handler, llm_handler, init_params=None, language='en') -> dict:
|
| 17 |
-
"""Create generation section
|
| 18 |
|
| 19 |
Args:
|
| 20 |
dit_handler: DiT handler instance
|
|
@@ -32,10 +32,15 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 32 |
# Get current language from init_params if available
|
| 33 |
current_language = init_params.get('language', language) if init_params else language
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
with gr.Group():
|
| 36 |
-
# Service Configuration
|
| 37 |
accordion_open = not service_pre_initialized
|
| 38 |
-
accordion_visible = not service_pre_initialized
|
| 39 |
with gr.Accordion(t("service.title"), open=accordion_open, visible=accordion_visible) as service_config_accordion:
|
| 40 |
# Language selector at the top
|
| 41 |
with gr.Row():
|
|
@@ -51,10 +56,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 51 |
scale=1,
|
| 52 |
)
|
| 53 |
|
| 54 |
-
# Dropdown options section - all dropdowns grouped together
|
| 55 |
with gr.Row(equal_height=True):
|
| 56 |
with gr.Column(scale=4):
|
| 57 |
-
# Set checkpoint value from init_params if pre-initialized
|
| 58 |
checkpoint_value = init_params.get('checkpoint') if service_pre_initialized else None
|
| 59 |
checkpoint_dropdown = gr.Dropdown(
|
| 60 |
label=t("service.checkpoint_label"),
|
|
@@ -66,11 +69,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 66 |
refresh_btn = gr.Button(t("service.refresh_btn"), size="sm")
|
| 67 |
|
| 68 |
with gr.Row():
|
| 69 |
-
# Get available acestep-v15- model list
|
| 70 |
available_models = dit_handler.get_available_acestep_v15_models()
|
| 71 |
default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
|
| 72 |
-
|
| 73 |
-
# Set config_path value from init_params if pre-initialized
|
| 74 |
config_path_value = init_params.get('config_path', default_model) if service_pre_initialized else default_model
|
| 75 |
config_path = gr.Dropdown(
|
| 76 |
label=t("service.model_path_label"),
|
|
@@ -78,7 +78,6 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 78 |
value=config_path_value,
|
| 79 |
info=t("service.model_path_info")
|
| 80 |
)
|
| 81 |
-
# Set device value from init_params if pre-initialized
|
| 82 |
device_value = init_params.get('device', 'auto') if service_pre_initialized else 'auto'
|
| 83 |
device = gr.Dropdown(
|
| 84 |
choices=["auto", "cuda", "cpu"],
|
|
@@ -88,11 +87,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 88 |
)
|
| 89 |
|
| 90 |
with gr.Row():
|
| 91 |
-
# Get available 5Hz LM model list
|
| 92 |
available_lm_models = llm_handler.get_available_5hz_lm_models()
|
| 93 |
default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
|
| 94 |
-
|
| 95 |
-
# Set lm_model_path value from init_params if pre-initialized
|
| 96 |
lm_model_path_value = init_params.get('lm_model_path', default_lm_model) if service_pre_initialized else default_lm_model
|
| 97 |
lm_model_path = gr.Dropdown(
|
| 98 |
label=t("service.lm_model_path_label"),
|
|
@@ -100,7 +96,6 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 100 |
value=lm_model_path_value,
|
| 101 |
info=t("service.lm_model_path_info")
|
| 102 |
)
|
| 103 |
-
# Set backend value from init_params if pre-initialized
|
| 104 |
backend_value = init_params.get('backend', 'vllm') if service_pre_initialized else 'vllm'
|
| 105 |
backend_dropdown = gr.Dropdown(
|
| 106 |
choices=["vllm", "pt"],
|
|
@@ -109,18 +104,14 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 109 |
info=t("service.backend_info")
|
| 110 |
)
|
| 111 |
|
| 112 |
-
# Checkbox options section - all checkboxes grouped together
|
| 113 |
with gr.Row():
|
| 114 |
-
# Set init_llm value from init_params if pre-initialized
|
| 115 |
init_llm_value = init_params.get('init_llm', True) if service_pre_initialized else True
|
| 116 |
init_llm_checkbox = gr.Checkbox(
|
| 117 |
label=t("service.init_llm_label"),
|
| 118 |
value=init_llm_value,
|
| 119 |
info=t("service.init_llm_info"),
|
| 120 |
)
|
| 121 |
-
# Auto-detect flash attention availability
|
| 122 |
flash_attn_available = dit_handler.is_flash_attention_available()
|
| 123 |
-
# Set use_flash_attention value from init_params if pre-initialized
|
| 124 |
use_flash_attention_value = init_params.get('use_flash_attention', flash_attn_available) if service_pre_initialized else flash_attn_available
|
| 125 |
use_flash_attention_checkbox = gr.Checkbox(
|
| 126 |
label=t("service.flash_attention_label"),
|
|
@@ -128,14 +119,12 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 128 |
interactive=flash_attn_available,
|
| 129 |
info=t("service.flash_attention_info_enabled") if flash_attn_available else t("service.flash_attention_info_disabled")
|
| 130 |
)
|
| 131 |
-
# Set offload_to_cpu value from init_params if pre-initialized
|
| 132 |
offload_to_cpu_value = init_params.get('offload_to_cpu', False) if service_pre_initialized else False
|
| 133 |
offload_to_cpu_checkbox = gr.Checkbox(
|
| 134 |
label=t("service.offload_cpu_label"),
|
| 135 |
value=offload_to_cpu_value,
|
| 136 |
info=t("service.offload_cpu_info")
|
| 137 |
)
|
| 138 |
-
# Set offload_dit_to_cpu value from init_params if pre-initialized
|
| 139 |
offload_dit_to_cpu_value = init_params.get('offload_dit_to_cpu', False) if service_pre_initialized else False
|
| 140 |
offload_dit_to_cpu_checkbox = gr.Checkbox(
|
| 141 |
label=t("service.offload_dit_cpu_label"),
|
|
@@ -144,7 +133,6 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 144 |
)
|
| 145 |
|
| 146 |
init_btn = gr.Button(t("service.init_btn"), variant="primary", size="lg")
|
| 147 |
-
# Set init_status value from init_params if pre-initialized
|
| 148 |
init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
|
| 149 |
init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value)
|
| 150 |
|
|
@@ -173,505 +161,436 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 173 |
scale=2,
|
| 174 |
)
|
| 175 |
|
| 176 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
with gr.Row():
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
track_name = gr.Dropdown(
|
| 216 |
-
choices=TRACK_NAMES,
|
| 217 |
-
value=None,
|
| 218 |
-
label=t("generation.track_name_label"),
|
| 219 |
-
info=t("generation.track_name_info"),
|
| 220 |
-
visible=False
|
| 221 |
)
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
)
|
| 229 |
-
|
| 230 |
-
# Audio uploads
|
| 231 |
-
audio_uploads_accordion = gr.Accordion(t("generation.audio_uploads"), open=False)
|
| 232 |
-
with audio_uploads_accordion:
|
| 233 |
-
with gr.Row(equal_height=True):
|
| 234 |
-
with gr.Column(scale=2):
|
| 235 |
-
reference_audio = gr.Audio(
|
| 236 |
-
label=t("generation.reference_audio"),
|
| 237 |
-
type="filepath",
|
| 238 |
-
)
|
| 239 |
-
with gr.Column(scale=7):
|
| 240 |
-
src_audio = gr.Audio(
|
| 241 |
-
label=t("generation.source_audio"),
|
| 242 |
-
type="filepath",
|
| 243 |
-
)
|
| 244 |
-
with gr.Column(scale=1, min_width=80):
|
| 245 |
-
convert_src_to_codes_btn = gr.Button(
|
| 246 |
-
t("generation.convert_codes_btn"),
|
| 247 |
-
variant="secondary",
|
| 248 |
-
size="sm"
|
| 249 |
-
)
|
| 250 |
-
|
| 251 |
-
# Audio Codes for text2music - single input for transcription or cover task
|
| 252 |
-
with gr.Accordion(t("generation.lm_codes_hints"), open=False, visible=True) as text2music_audio_codes_group:
|
| 253 |
-
with gr.Row(equal_height=True):
|
| 254 |
-
text2music_audio_code_string = gr.Textbox(
|
| 255 |
-
label=t("generation.lm_codes_label"),
|
| 256 |
-
placeholder=t("generation.lm_codes_placeholder"),
|
| 257 |
-
lines=6,
|
| 258 |
-
info=t("generation.lm_codes_info"),
|
| 259 |
-
scale=9,
|
| 260 |
-
)
|
| 261 |
-
transcribe_btn = gr.Button(
|
| 262 |
-
t("generation.transcribe_btn"),
|
| 263 |
-
variant="secondary",
|
| 264 |
-
size="sm",
|
| 265 |
-
scale=1,
|
| 266 |
-
)
|
| 267 |
-
|
| 268 |
-
# Repainting controls
|
| 269 |
-
with gr.Group(visible=False) as repainting_group:
|
| 270 |
-
gr.HTML(f"<h5>{t('generation.repainting_controls')}</h5>")
|
| 271 |
-
with gr.Row():
|
| 272 |
-
repainting_start = gr.Number(
|
| 273 |
-
label=t("generation.repainting_start"),
|
| 274 |
-
value=0.0,
|
| 275 |
-
step=0.1,
|
| 276 |
-
)
|
| 277 |
-
repainting_end = gr.Number(
|
| 278 |
-
label=t("generation.repainting_end"),
|
| 279 |
-
value=-1,
|
| 280 |
-
minimum=-1,
|
| 281 |
-
step=0.1,
|
| 282 |
-
)
|
| 283 |
-
|
| 284 |
-
# Simple/Custom Mode Toggle
|
| 285 |
-
# In service mode: only Custom mode, hide the toggle
|
| 286 |
-
with gr.Row(visible=not service_mode):
|
| 287 |
-
generation_mode = gr.Radio(
|
| 288 |
-
choices=[
|
| 289 |
-
(t("generation.mode_simple"), "simple"),
|
| 290 |
-
(t("generation.mode_custom"), "custom"),
|
| 291 |
-
],
|
| 292 |
-
value="custom" if service_mode else "simple",
|
| 293 |
-
label=t("generation.mode_label"),
|
| 294 |
-
info=t("generation.mode_info"),
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
# Simple Mode Components - hidden in service mode
|
| 298 |
-
with gr.Group(visible=not service_mode) as simple_mode_group:
|
| 299 |
-
with gr.Row(equal_height=True):
|
| 300 |
-
simple_query_input = gr.Textbox(
|
| 301 |
-
label=t("generation.simple_query_label"),
|
| 302 |
-
placeholder=t("generation.simple_query_placeholder"),
|
| 303 |
-
lines=2,
|
| 304 |
-
info=t("generation.simple_query_info"),
|
| 305 |
-
scale=12,
|
| 306 |
-
)
|
| 307 |
-
|
| 308 |
-
with gr.Column(scale=1, min_width=100):
|
| 309 |
-
random_desc_btn = gr.Button(
|
| 310 |
-
"🎲",
|
| 311 |
-
variant="secondary",
|
| 312 |
-
size="sm",
|
| 313 |
-
scale=2
|
| 314 |
-
)
|
| 315 |
-
|
| 316 |
-
with gr.Row(equal_height=True):
|
| 317 |
-
with gr.Column(scale=1, variant="compact"):
|
| 318 |
-
simple_instrumental_checkbox = gr.Checkbox(
|
| 319 |
-
label=t("generation.instrumental_label"),
|
| 320 |
-
value=False,
|
| 321 |
-
)
|
| 322 |
-
with gr.Column(scale=18):
|
| 323 |
-
create_sample_btn = gr.Button(
|
| 324 |
-
t("generation.create_sample_btn"),
|
| 325 |
-
variant="primary",
|
| 326 |
-
size="lg",
|
| 327 |
-
)
|
| 328 |
-
with gr.Column(scale=1, variant="compact"):
|
| 329 |
-
simple_vocal_language = gr.Dropdown(
|
| 330 |
-
choices=VALID_LANGUAGES,
|
| 331 |
-
value="unknown",
|
| 332 |
-
allow_custom_value=True,
|
| 333 |
-
label=t("generation.simple_vocal_language_label"),
|
| 334 |
-
interactive=True,
|
| 335 |
-
)
|
| 336 |
-
|
| 337 |
-
# State to track if sample has been created in Simple mode
|
| 338 |
-
simple_sample_created = gr.State(value=False)
|
| 339 |
|
| 340 |
-
#
|
| 341 |
-
|
| 342 |
-
|
| 343 |
with gr.Row(equal_height=True):
|
| 344 |
captions = gr.Textbox(
|
| 345 |
-
label=
|
| 346 |
-
placeholder=
|
| 347 |
-
lines=
|
| 348 |
-
|
| 349 |
-
scale=12,
|
| 350 |
-
)
|
| 351 |
-
with gr.Column(scale=1, min_width=100):
|
| 352 |
-
sample_btn = gr.Button(
|
| 353 |
-
"🎲",
|
| 354 |
-
variant="secondary",
|
| 355 |
-
size="sm",
|
| 356 |
-
scale=2,
|
| 357 |
-
)
|
| 358 |
-
# Lyrics - wrapped in accordion that can be collapsed in Simple mode
|
| 359 |
-
# In service mode: auto-expand
|
| 360 |
-
with gr.Accordion(t("generation.lyrics_title"), open=service_mode) as lyrics_accordion:
|
| 361 |
-
lyrics = gr.Textbox(
|
| 362 |
-
label=t("generation.lyrics_label"),
|
| 363 |
-
placeholder=t("generation.lyrics_placeholder"),
|
| 364 |
-
lines=8,
|
| 365 |
-
info=t("generation.lyrics_info")
|
| 366 |
-
)
|
| 367 |
-
|
| 368 |
-
with gr.Row(variant="compact", equal_height=True):
|
| 369 |
-
instrumental_checkbox = gr.Checkbox(
|
| 370 |
-
label=t("generation.instrumental_label"),
|
| 371 |
-
value=False,
|
| 372 |
scale=1,
|
| 373 |
-
min_width=120,
|
| 374 |
-
container=True,
|
| 375 |
)
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
value="unknown",
|
| 382 |
-
label=t("generation.vocal_language_label"),
|
| 383 |
-
show_label=False,
|
| 384 |
-
container=True,
|
| 385 |
-
allow_custom_value=True,
|
| 386 |
-
scale=3,
|
| 387 |
-
)
|
| 388 |
-
|
| 389 |
-
# 右侧:格式化按钮 (Button)
|
| 390 |
-
# 放在同一行最右侧,操作更顺手
|
| 391 |
-
format_btn = gr.Button(
|
| 392 |
-
t("generation.format_btn"),
|
| 393 |
-
variant="secondary",
|
| 394 |
scale=1,
|
| 395 |
-
min_width=80,
|
| 396 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
|
| 398 |
-
#
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
key_scale = gr.Textbox(
|
| 409 |
-
label=t("generation.keyscale_label"),
|
| 410 |
-
placeholder=t("generation.keyscale_placeholder"),
|
| 411 |
-
value="",
|
| 412 |
-
info=t("generation.keyscale_info")
|
| 413 |
-
)
|
| 414 |
-
time_signature = gr.Dropdown(
|
| 415 |
-
choices=["2", "3", "4", "N/A", ""],
|
| 416 |
-
value="",
|
| 417 |
-
label=t("generation.timesig_label"),
|
| 418 |
-
allow_custom_value=True,
|
| 419 |
-
info=t("generation.timesig_info")
|
| 420 |
-
)
|
| 421 |
-
audio_duration = gr.Number(
|
| 422 |
-
label=t("generation.duration_label"),
|
| 423 |
-
value=-1,
|
| 424 |
-
minimum=-1,
|
| 425 |
-
maximum=600.0,
|
| 426 |
-
step=0.1,
|
| 427 |
-
info=t("generation.duration_info")
|
| 428 |
-
)
|
| 429 |
-
batch_size_input = gr.Number(
|
| 430 |
-
label=t("generation.batch_size_label"),
|
| 431 |
-
value=2,
|
| 432 |
-
minimum=1,
|
| 433 |
-
maximum=8,
|
| 434 |
-
step=1,
|
| 435 |
-
info=t("generation.batch_size_info"),
|
| 436 |
-
interactive=not service_mode # Fixed in service mode
|
| 437 |
-
)
|
| 438 |
|
| 439 |
-
#
|
| 440 |
-
|
| 441 |
-
#
|
| 442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
with gr.Row():
|
| 444 |
inference_steps = gr.Slider(
|
| 445 |
minimum=1,
|
| 446 |
maximum=20,
|
| 447 |
value=8,
|
| 448 |
step=1,
|
| 449 |
-
label=
|
| 450 |
-
info=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
)
|
| 452 |
-
guidance_scale = gr.Slider(
|
| 453 |
-
minimum=1.0,
|
| 454 |
-
maximum=15.0,
|
| 455 |
-
value=7.0,
|
| 456 |
-
step=0.1,
|
| 457 |
-
label=t("generation.guidance_scale_label"),
|
| 458 |
-
info=t("generation.guidance_scale_info"),
|
| 459 |
-
visible=False
|
| 460 |
-
)
|
| 461 |
-
with gr.Column():
|
| 462 |
-
seed = gr.Textbox(
|
| 463 |
-
label=t("generation.seed_label"),
|
| 464 |
-
value="-1",
|
| 465 |
-
info=t("generation.seed_info")
|
| 466 |
-
)
|
| 467 |
-
random_seed_checkbox = gr.Checkbox(
|
| 468 |
-
label=t("generation.random_seed_label"),
|
| 469 |
-
value=True,
|
| 470 |
-
info=t("generation.random_seed_info")
|
| 471 |
-
)
|
| 472 |
audio_format = gr.Dropdown(
|
| 473 |
choices=["mp3", "flac"],
|
| 474 |
value="mp3",
|
| 475 |
-
label=
|
| 476 |
-
info=
|
| 477 |
-
interactive=not service_mode # Fixed in service mode
|
| 478 |
)
|
| 479 |
|
|
|
|
| 480 |
with gr.Row():
|
| 481 |
-
use_adg = gr.Checkbox(
|
| 482 |
-
label=t("generation.use_adg_label"),
|
| 483 |
-
value=False,
|
| 484 |
-
info=t("generation.use_adg_info"),
|
| 485 |
-
visible=False
|
| 486 |
-
)
|
| 487 |
shift = gr.Slider(
|
| 488 |
minimum=1.0,
|
| 489 |
maximum=5.0,
|
| 490 |
value=3.0,
|
| 491 |
step=0.1,
|
| 492 |
-
label=
|
| 493 |
-
info=
|
| 494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
)
|
| 496 |
infer_method = gr.Dropdown(
|
| 497 |
choices=["ode", "sde"],
|
| 498 |
value="ode",
|
| 499 |
-
label=
|
| 500 |
-
info=
|
| 501 |
)
|
| 502 |
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
value=0.0,
|
| 516 |
-
step=0.01,
|
| 517 |
-
label=t("generation.cfg_interval_start"),
|
| 518 |
-
visible=False
|
| 519 |
-
)
|
| 520 |
-
cfg_interval_end = gr.Slider(
|
| 521 |
-
minimum=0.0,
|
| 522 |
-
maximum=1.0,
|
| 523 |
-
value=1.0,
|
| 524 |
-
step=0.01,
|
| 525 |
-
label=t("generation.cfg_interval_end"),
|
| 526 |
-
visible=False
|
| 527 |
-
)
|
| 528 |
-
|
| 529 |
-
# LM (Language Model) Parameters
|
| 530 |
-
gr.HTML(f"<h4>{t('generation.lm_params_title')}</h4>")
|
| 531 |
with gr.Row():
|
| 532 |
lm_temperature = gr.Slider(
|
| 533 |
-
label=t("generation.lm_temperature_label"),
|
| 534 |
minimum=0.0,
|
| 535 |
maximum=2.0,
|
| 536 |
value=0.85,
|
| 537 |
-
step=0.
|
| 538 |
-
|
| 539 |
-
info=
|
| 540 |
)
|
| 541 |
lm_cfg_scale = gr.Slider(
|
| 542 |
-
label=t("generation.lm_cfg_scale_label"),
|
| 543 |
minimum=1.0,
|
| 544 |
maximum=3.0,
|
| 545 |
value=2.0,
|
| 546 |
step=0.1,
|
| 547 |
-
|
| 548 |
-
info=
|
| 549 |
)
|
| 550 |
lm_top_k = gr.Slider(
|
| 551 |
-
label=t("generation.lm_top_k_label"),
|
| 552 |
minimum=0,
|
| 553 |
maximum=100,
|
| 554 |
value=0,
|
| 555 |
step=1,
|
| 556 |
-
|
| 557 |
-
info=
|
| 558 |
)
|
| 559 |
lm_top_p = gr.Slider(
|
| 560 |
-
label=t("generation.lm_top_p_label"),
|
| 561 |
minimum=0.0,
|
| 562 |
maximum=1.0,
|
| 563 |
value=0.9,
|
| 564 |
step=0.01,
|
| 565 |
-
|
| 566 |
-
info=
|
| 567 |
-
)
|
| 568 |
-
|
| 569 |
-
with gr.Row():
|
| 570 |
-
lm_negative_prompt = gr.Textbox(
|
| 571 |
-
label=t("generation.lm_negative_prompt_label"),
|
| 572 |
-
value="NO USER INPUT",
|
| 573 |
-
placeholder=t("generation.lm_negative_prompt_placeholder"),
|
| 574 |
-
info=t("generation.lm_negative_prompt_info"),
|
| 575 |
-
lines=2,
|
| 576 |
-
scale=2,
|
| 577 |
)
|
| 578 |
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
)
|
| 586 |
-
|
| 587 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 588 |
value=True,
|
| 589 |
-
info=t("generation.cot_language_info"),
|
| 590 |
-
scale=1,
|
| 591 |
)
|
| 592 |
-
|
| 593 |
-
label=
|
| 594 |
value=False,
|
| 595 |
-
info=t("generation.constrained_debug_info"),
|
| 596 |
-
scale=1,
|
| 597 |
-
interactive=not service_mode # Fixed in service mode
|
| 598 |
)
|
| 599 |
|
| 600 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
auto_score = gr.Checkbox(
|
| 602 |
-
label=
|
| 603 |
value=False,
|
| 604 |
-
info=t("generation.auto_score_info"),
|
| 605 |
-
scale=1,
|
| 606 |
-
interactive=not service_mode # Fixed in service mode
|
| 607 |
)
|
| 608 |
auto_lrc = gr.Checkbox(
|
| 609 |
-
label=
|
| 610 |
value=False,
|
| 611 |
-
info=t("generation.auto_lrc_info"),
|
| 612 |
-
scale=1,
|
| 613 |
-
interactive=not service_mode # Fixed in service mode
|
| 614 |
-
)
|
| 615 |
-
lm_batch_chunk_size = gr.Number(
|
| 616 |
-
label=t("generation.lm_batch_chunk_label"),
|
| 617 |
-
value=8,
|
| 618 |
-
minimum=1,
|
| 619 |
-
maximum=32,
|
| 620 |
-
step=1,
|
| 621 |
-
info=t("generation.lm_batch_chunk_info"),
|
| 622 |
-
scale=1,
|
| 623 |
-
interactive=not service_mode # Fixed in service mode
|
| 624 |
-
)
|
| 625 |
-
|
| 626 |
-
with gr.Row():
|
| 627 |
-
audio_cover_strength = gr.Slider(
|
| 628 |
-
minimum=0.0,
|
| 629 |
-
maximum=1.0,
|
| 630 |
-
value=1.0,
|
| 631 |
-
step=0.01,
|
| 632 |
-
label=t("generation.codes_strength_label"),
|
| 633 |
-
info=t("generation.codes_strength_info"),
|
| 634 |
-
scale=1,
|
| 635 |
-
)
|
| 636 |
-
score_scale = gr.Slider(
|
| 637 |
-
minimum=0.01,
|
| 638 |
-
maximum=1.0,
|
| 639 |
-
value=0.5,
|
| 640 |
-
step=0.01,
|
| 641 |
-
label=t("generation.score_sensitivity_label"),
|
| 642 |
-
info=t("generation.score_sensitivity_info"),
|
| 643 |
-
scale=1,
|
| 644 |
-
visible=not service_mode # Hidden in service mode
|
| 645 |
)
|
| 646 |
|
| 647 |
-
#
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
|
| 676 |
return {
|
| 677 |
"service_config_accordion": service_config_accordion,
|
|
@@ -694,6 +613,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 694 |
"unload_lora_btn": unload_lora_btn,
|
| 695 |
"use_lora_checkbox": use_lora_checkbox,
|
| 696 |
"lora_status": lora_status,
|
|
|
|
|
|
|
| 697 |
"task_type": task_type,
|
| 698 |
"instruction_display_gen": instruction_display_gen,
|
| 699 |
"track_name": track_name,
|
|
@@ -717,7 +638,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 717 |
"repainting_start": repainting_start,
|
| 718 |
"repainting_end": repainting_end,
|
| 719 |
"audio_cover_strength": audio_cover_strength,
|
| 720 |
-
#
|
| 721 |
"generation_mode": generation_mode,
|
| 722 |
"simple_mode_group": simple_mode_group,
|
| 723 |
"simple_query_input": simple_query_input,
|
|
@@ -729,6 +650,13 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 729 |
"caption_accordion": caption_accordion,
|
| 730 |
"lyrics_accordion": lyrics_accordion,
|
| 731 |
"optional_params_accordion": optional_params_accordion,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 732 |
# Existing components
|
| 733 |
"captions": captions,
|
| 734 |
"sample_btn": sample_btn,
|
|
@@ -763,4 +691,3 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 763 |
"auto_lrc": auto_lrc,
|
| 764 |
"lm_batch_chunk_size": lm_batch_chunk_size,
|
| 765 |
}
|
| 766 |
-
|
|
|
|
| 1 |
"""
|
| 2 |
Gradio UI Generation Section Module
|
| 3 |
+
Contains generation section component definitions - Simplified UI
|
| 4 |
"""
|
| 5 |
import gradio as gr
|
| 6 |
from acestep.constants import (
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def create_generation_section(dit_handler, llm_handler, init_params=None, language='en') -> dict:
|
| 17 |
+
"""Create generation section with simplified UI
|
| 18 |
|
| 19 |
Args:
|
| 20 |
dit_handler: DiT handler instance
|
|
|
|
| 32 |
# Get current language from init_params if available
|
| 33 |
current_language = init_params.get('language', language) if init_params else language
|
| 34 |
|
| 35 |
+
# Get available models
|
| 36 |
+
available_dit_models = init_params.get('available_dit_models', []) if init_params else []
|
| 37 |
+
current_model_value = init_params.get('config_path', '') if init_params else ''
|
| 38 |
+
show_model_selector = len(available_dit_models) > 1
|
| 39 |
+
|
| 40 |
with gr.Group():
|
| 41 |
+
# ==================== Service Configuration (Hidden in service mode) ====================
|
| 42 |
accordion_open = not service_pre_initialized
|
| 43 |
+
accordion_visible = not service_pre_initialized
|
| 44 |
with gr.Accordion(t("service.title"), open=accordion_open, visible=accordion_visible) as service_config_accordion:
|
| 45 |
# Language selector at the top
|
| 46 |
with gr.Row():
|
|
|
|
| 56 |
scale=1,
|
| 57 |
)
|
| 58 |
|
|
|
|
| 59 |
with gr.Row(equal_height=True):
|
| 60 |
with gr.Column(scale=4):
|
|
|
|
| 61 |
checkpoint_value = init_params.get('checkpoint') if service_pre_initialized else None
|
| 62 |
checkpoint_dropdown = gr.Dropdown(
|
| 63 |
label=t("service.checkpoint_label"),
|
|
|
|
| 69 |
refresh_btn = gr.Button(t("service.refresh_btn"), size="sm")
|
| 70 |
|
| 71 |
with gr.Row():
|
|
|
|
| 72 |
available_models = dit_handler.get_available_acestep_v15_models()
|
| 73 |
default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
|
|
|
|
|
|
|
| 74 |
config_path_value = init_params.get('config_path', default_model) if service_pre_initialized else default_model
|
| 75 |
config_path = gr.Dropdown(
|
| 76 |
label=t("service.model_path_label"),
|
|
|
|
| 78 |
value=config_path_value,
|
| 79 |
info=t("service.model_path_info")
|
| 80 |
)
|
|
|
|
| 81 |
device_value = init_params.get('device', 'auto') if service_pre_initialized else 'auto'
|
| 82 |
device = gr.Dropdown(
|
| 83 |
choices=["auto", "cuda", "cpu"],
|
|
|
|
| 87 |
)
|
| 88 |
|
| 89 |
with gr.Row():
|
|
|
|
| 90 |
available_lm_models = llm_handler.get_available_5hz_lm_models()
|
| 91 |
default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
|
|
|
|
|
|
|
| 92 |
lm_model_path_value = init_params.get('lm_model_path', default_lm_model) if service_pre_initialized else default_lm_model
|
| 93 |
lm_model_path = gr.Dropdown(
|
| 94 |
label=t("service.lm_model_path_label"),
|
|
|
|
| 96 |
value=lm_model_path_value,
|
| 97 |
info=t("service.lm_model_path_info")
|
| 98 |
)
|
|
|
|
| 99 |
backend_value = init_params.get('backend', 'vllm') if service_pre_initialized else 'vllm'
|
| 100 |
backend_dropdown = gr.Dropdown(
|
| 101 |
choices=["vllm", "pt"],
|
|
|
|
| 104 |
info=t("service.backend_info")
|
| 105 |
)
|
| 106 |
|
|
|
|
| 107 |
with gr.Row():
|
|
|
|
| 108 |
init_llm_value = init_params.get('init_llm', True) if service_pre_initialized else True
|
| 109 |
init_llm_checkbox = gr.Checkbox(
|
| 110 |
label=t("service.init_llm_label"),
|
| 111 |
value=init_llm_value,
|
| 112 |
info=t("service.init_llm_info"),
|
| 113 |
)
|
|
|
|
| 114 |
flash_attn_available = dit_handler.is_flash_attention_available()
|
|
|
|
| 115 |
use_flash_attention_value = init_params.get('use_flash_attention', flash_attn_available) if service_pre_initialized else flash_attn_available
|
| 116 |
use_flash_attention_checkbox = gr.Checkbox(
|
| 117 |
label=t("service.flash_attention_label"),
|
|
|
|
| 119 |
interactive=flash_attn_available,
|
| 120 |
info=t("service.flash_attention_info_enabled") if flash_attn_available else t("service.flash_attention_info_disabled")
|
| 121 |
)
|
|
|
|
| 122 |
offload_to_cpu_value = init_params.get('offload_to_cpu', False) if service_pre_initialized else False
|
| 123 |
offload_to_cpu_checkbox = gr.Checkbox(
|
| 124 |
label=t("service.offload_cpu_label"),
|
| 125 |
value=offload_to_cpu_value,
|
| 126 |
info=t("service.offload_cpu_info")
|
| 127 |
)
|
|
|
|
| 128 |
offload_dit_to_cpu_value = init_params.get('offload_dit_to_cpu', False) if service_pre_initialized else False
|
| 129 |
offload_dit_to_cpu_checkbox = gr.Checkbox(
|
| 130 |
label=t("service.offload_dit_cpu_label"),
|
|
|
|
| 133 |
)
|
| 134 |
|
| 135 |
init_btn = gr.Button(t("service.init_btn"), variant="primary", size="lg")
|
|
|
|
| 136 |
init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
|
| 137 |
init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value)
|
| 138 |
|
|
|
|
| 161 |
scale=2,
|
| 162 |
)
|
| 163 |
|
| 164 |
+
# ==================== Model Selector (Top, only when multiple models) ====================
|
| 165 |
+
with gr.Row(visible=show_model_selector):
|
| 166 |
+
dit_model_selector = gr.Dropdown(
|
| 167 |
+
choices=available_dit_models,
|
| 168 |
+
value=current_model_value,
|
| 169 |
+
label="models",
|
| 170 |
+
scale=1,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Hidden dropdown when only one model (for event handler compatibility)
|
| 174 |
+
if not show_model_selector:
|
| 175 |
+
dit_model_selector = gr.Dropdown(
|
| 176 |
+
choices=available_dit_models if available_dit_models else [current_model_value],
|
| 177 |
+
value=current_model_value,
|
| 178 |
+
visible=False,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# ==================== Generation Mode (4 modes) ====================
|
| 182 |
+
gr.HTML("<div style='background: #4a5568; color: white; padding: 8px 16px; border-radius: 4px; font-weight: bold;'>Generation Mode</div>")
|
| 183 |
with gr.Row():
|
| 184 |
+
generation_mode = gr.Radio(
|
| 185 |
+
choices=[
|
| 186 |
+
("Simple", "simple"),
|
| 187 |
+
("Custom", "custom"),
|
| 188 |
+
("Cover", "cover"),
|
| 189 |
+
("Repaint", "repaint"),
|
| 190 |
+
],
|
| 191 |
+
value="custom",
|
| 192 |
+
label="",
|
| 193 |
+
show_label=False,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# ==================== Simple Mode Group ====================
|
| 197 |
+
with gr.Column(visible=False) as simple_mode_group:
|
| 198 |
+
# Row: Song Description + Vocal Language + Random button
|
| 199 |
+
with gr.Row(equal_height=True):
|
| 200 |
+
simple_query_input = gr.Textbox(
|
| 201 |
+
label=t("generation.simple_query_label"),
|
| 202 |
+
placeholder=t("generation.simple_query_placeholder"),
|
| 203 |
+
lines=2,
|
| 204 |
+
info=t("generation.simple_query_info"),
|
| 205 |
+
scale=10,
|
| 206 |
+
)
|
| 207 |
+
simple_vocal_language = gr.Dropdown(
|
| 208 |
+
choices=VALID_LANGUAGES,
|
| 209 |
+
value="unknown",
|
| 210 |
+
allow_custom_value=True,
|
| 211 |
+
label=t("generation.simple_vocal_language_label"),
|
| 212 |
+
interactive=True,
|
| 213 |
+
info="use unknown for instrumental",
|
| 214 |
+
scale=2,
|
| 215 |
+
)
|
| 216 |
+
with gr.Column(scale=1, min_width=60):
|
| 217 |
+
random_desc_btn = gr.Button(
|
| 218 |
+
"🎲",
|
| 219 |
+
variant="secondary",
|
| 220 |
+
size="lg",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
)
|
| 222 |
+
|
| 223 |
+
# Hidden components (kept for compatibility but not shown)
|
| 224 |
+
simple_instrumental_checkbox = gr.Checkbox(
|
| 225 |
+
label=t("generation.instrumental_label"),
|
| 226 |
+
value=False,
|
| 227 |
+
visible=False,
|
| 228 |
+
)
|
| 229 |
+
create_sample_btn = gr.Button(
|
| 230 |
+
t("generation.create_sample_btn"),
|
| 231 |
+
variant="primary",
|
| 232 |
+
size="lg",
|
| 233 |
+
visible=False,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# State to track if sample has been created in Simple mode
|
| 237 |
+
simple_sample_created = gr.State(value=False)
|
| 238 |
+
|
| 239 |
+
# ==================== Source Audio (for Cover/Repaint) ====================
|
| 240 |
+
# This is shown above the main content for Cover and Repaint modes
|
| 241 |
+
with gr.Column(visible=False) as src_audio_group:
|
| 242 |
+
with gr.Row(equal_height=True):
|
| 243 |
+
# Source Audio - scale=10 to match (refer_audio=2 + prompt/lyrics=8)
|
| 244 |
+
src_audio = gr.Audio(
|
| 245 |
+
label="Source Audio",
|
| 246 |
+
type="filepath",
|
| 247 |
+
scale=10,
|
| 248 |
+
)
|
| 249 |
+
# Process button - scale=1 to align with random button
|
| 250 |
+
with gr.Column(scale=1, min_width=80):
|
| 251 |
+
process_src_btn = gr.Button(
|
| 252 |
+
"Analyze",
|
| 253 |
+
variant="secondary",
|
| 254 |
+
size="lg",
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Hidden Audio Codes storage (needed internally but not displayed)
|
| 258 |
+
text2music_audio_code_string = gr.Textbox(
|
| 259 |
+
label="Audio Codes",
|
| 260 |
+
visible=False,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# ==================== Custom/Cover/Repaint Mode Content ====================
|
| 264 |
+
with gr.Column() as custom_mode_content:
|
| 265 |
+
with gr.Row(equal_height=True):
|
| 266 |
+
# Left: Reference Audio
|
| 267 |
+
with gr.Column(scale=2, min_width=200):
|
| 268 |
+
reference_audio = gr.Audio(
|
| 269 |
+
label="Reference Audio (optional)",
|
| 270 |
+
type="filepath",
|
| 271 |
+
show_label=True,
|
| 272 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
+
# Middle: Prompt + Lyrics + Format button
|
| 275 |
+
with gr.Column(scale=8):
|
| 276 |
+
# Row 1: Prompt and Lyrics
|
| 277 |
with gr.Row(equal_height=True):
|
| 278 |
captions = gr.Textbox(
|
| 279 |
+
label="Prompt",
|
| 280 |
+
placeholder="Describe the music style, mood, instruments...",
|
| 281 |
+
lines=12,
|
| 282 |
+
max_lines=12,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
scale=1,
|
|
|
|
|
|
|
| 284 |
)
|
| 285 |
+
lyrics = gr.Textbox(
|
| 286 |
+
label="Lyrics",
|
| 287 |
+
placeholder="Enter lyrics here... Use [Verse], [Chorus] etc. for structure",
|
| 288 |
+
lines=12,
|
| 289 |
+
max_lines=12,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
scale=1,
|
|
|
|
| 291 |
)
|
| 292 |
+
|
| 293 |
+
# Row 2: Format button (only below Prompt and Lyrics)
|
| 294 |
+
format_btn = gr.Button(
|
| 295 |
+
"Format",
|
| 296 |
+
variant="secondary",
|
| 297 |
+
)
|
| 298 |
|
| 299 |
+
# Right: Random button
|
| 300 |
+
with gr.Column(scale=1, min_width=60):
|
| 301 |
+
sample_btn = gr.Button(
|
| 302 |
+
"🎲",
|
| 303 |
+
variant="secondary",
|
| 304 |
+
size="lg",
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Placeholder for removed audio_uploads_accordion (for compatibility)
|
| 308 |
+
audio_uploads_accordion = gr.Column(visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
|
| 310 |
+
# Legacy cover_mode_group (hidden, for backward compatibility)
|
| 311 |
+
cover_mode_group = gr.Column(visible=False)
|
| 312 |
+
# Legacy convert button (hidden, for backward compatibility)
|
| 313 |
+
convert_src_to_codes_btn = gr.Button("Convert to Codes", visible=False)
|
| 314 |
+
|
| 315 |
+
# ==================== Repaint Mode: Source + Time Range ====================
|
| 316 |
+
with gr.Column(visible=False) as repainting_group:
|
| 317 |
+
with gr.Row():
|
| 318 |
+
repainting_start = gr.Number(
|
| 319 |
+
label="Start (seconds)",
|
| 320 |
+
value=0.0,
|
| 321 |
+
step=0.1,
|
| 322 |
+
scale=1,
|
| 323 |
+
)
|
| 324 |
+
repainting_end = gr.Number(
|
| 325 |
+
label="End (seconds, -1 for end)",
|
| 326 |
+
value=-1,
|
| 327 |
+
minimum=-1,
|
| 328 |
+
step=0.1,
|
| 329 |
+
scale=1,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# ==================== Optional Parameters ====================
|
| 333 |
+
with gr.Accordion("⚙️ Optional Parameters", open=False, visible=False) as optional_params_accordion:
|
| 334 |
+
pass
|
| 335 |
+
|
| 336 |
+
# ==================== Advanced Settings ====================
|
| 337 |
+
with gr.Accordion("🔧 Advanced Settings", open=False) as advanced_options_accordion:
|
| 338 |
+
with gr.Row():
|
| 339 |
+
bpm = gr.Number(
|
| 340 |
+
label="BPM (optional)",
|
| 341 |
+
value=0,
|
| 342 |
+
step=1,
|
| 343 |
+
info="leave empty for N/A",
|
| 344 |
+
scale=1,
|
| 345 |
+
)
|
| 346 |
+
key_scale = gr.Textbox(
|
| 347 |
+
label="Key Signature (optional)",
|
| 348 |
+
placeholder="Leave empty for N/A",
|
| 349 |
+
value="",
|
| 350 |
+
info="A-G, #/♭, major/minor",
|
| 351 |
+
scale=1,
|
| 352 |
+
)
|
| 353 |
+
time_signature = gr.Dropdown(
|
| 354 |
+
choices=["", "2", "3", "4"],
|
| 355 |
+
value="",
|
| 356 |
+
label="Time Signature (optional)",
|
| 357 |
+
allow_custom_value=True,
|
| 358 |
+
info="2/4, 3/4, 4/4...",
|
| 359 |
+
scale=1,
|
| 360 |
+
)
|
| 361 |
+
audio_duration = gr.Number(
|
| 362 |
+
label="Audio Duration (seconds)",
|
| 363 |
+
value=-1,
|
| 364 |
+
minimum=-1,
|
| 365 |
+
maximum=600.0,
|
| 366 |
+
step=1,
|
| 367 |
+
info="Use -1 for random",
|
| 368 |
+
scale=1,
|
| 369 |
+
)
|
| 370 |
+
vocal_language = gr.Dropdown(
|
| 371 |
+
choices=VALID_LANGUAGES,
|
| 372 |
+
value="unknown",
|
| 373 |
+
label="Vocal Language",
|
| 374 |
+
allow_custom_value=True,
|
| 375 |
+
info="use `unknown` for instrumental",
|
| 376 |
+
scale=1,
|
| 377 |
+
)
|
| 378 |
+
batch_size_input = gr.Number(
|
| 379 |
+
label="batch size",
|
| 380 |
+
info="max 8",
|
| 381 |
+
value=2,
|
| 382 |
+
minimum=1,
|
| 383 |
+
maximum=8,
|
| 384 |
+
step=1,
|
| 385 |
+
scale=1,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# Row 1: DiT Inference Steps, Seed, Audio Format
|
| 389 |
with gr.Row():
|
| 390 |
inference_steps = gr.Slider(
|
| 391 |
minimum=1,
|
| 392 |
maximum=20,
|
| 393 |
value=8,
|
| 394 |
step=1,
|
| 395 |
+
label="DiT Inference Steps",
|
| 396 |
+
info="Turbo: max 8, Base: max 200",
|
| 397 |
+
)
|
| 398 |
+
seed = gr.Textbox(
|
| 399 |
+
label="Seed",
|
| 400 |
+
value="-1",
|
| 401 |
+
info="Use comma-separated values for batches",
|
| 402 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
audio_format = gr.Dropdown(
|
| 404 |
choices=["mp3", "flac"],
|
| 405 |
value="mp3",
|
| 406 |
+
label="Audio Format",
|
| 407 |
+
info="Audio format for saved files",
|
|
|
|
| 408 |
)
|
| 409 |
|
| 410 |
+
# Row 2: Shift, Random Seed, Inference Method
|
| 411 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
shift = gr.Slider(
|
| 413 |
minimum=1.0,
|
| 414 |
maximum=5.0,
|
| 415 |
value=3.0,
|
| 416 |
step=0.1,
|
| 417 |
+
label="Shift",
|
| 418 |
+
info="Timestep shift factor for base models (range 1.0-5.0, default 3.0). Not effective for turbo models.",
|
| 419 |
+
)
|
| 420 |
+
random_seed_checkbox = gr.Checkbox(
|
| 421 |
+
label="Random Seed",
|
| 422 |
+
value=True,
|
| 423 |
+
info="Enable to auto-generate seeds",
|
| 424 |
)
|
| 425 |
infer_method = gr.Dropdown(
|
| 426 |
choices=["ode", "sde"],
|
| 427 |
value="ode",
|
| 428 |
+
label="Inference Method",
|
| 429 |
+
info="Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
|
| 430 |
)
|
| 431 |
|
| 432 |
+
# Row 3: Custom Timesteps (full width)
|
| 433 |
+
custom_timesteps = gr.Textbox(
|
| 434 |
+
label="Custom Timesteps",
|
| 435 |
+
placeholder="0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0",
|
| 436 |
+
value="",
|
| 437 |
+
info="Optional: comma-separated values from 1.0 to 0.0 (e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference steps and shift.",
|
| 438 |
+
)
|
| 439 |
|
| 440 |
+
# Section: LM Generation Parameters
|
| 441 |
+
gr.HTML("<h4>🎵 LM Generation Parameters</h4>")
|
| 442 |
+
|
| 443 |
+
# Row 4: LM Temperature, LM CFG Scale, LM Top-K, LM Top-P
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
with gr.Row():
|
| 445 |
lm_temperature = gr.Slider(
|
|
|
|
| 446 |
minimum=0.0,
|
| 447 |
maximum=2.0,
|
| 448 |
value=0.85,
|
| 449 |
+
step=0.05,
|
| 450 |
+
label="LM Temperature",
|
| 451 |
+
info="5Hz LM temperature (higher = more random)",
|
| 452 |
)
|
| 453 |
lm_cfg_scale = gr.Slider(
|
|
|
|
| 454 |
minimum=1.0,
|
| 455 |
maximum=3.0,
|
| 456 |
value=2.0,
|
| 457 |
step=0.1,
|
| 458 |
+
label="LM CFG Scale",
|
| 459 |
+
info="5Hz LM CFG (1.0 = no CFG)",
|
| 460 |
)
|
| 461 |
lm_top_k = gr.Slider(
|
|
|
|
| 462 |
minimum=0,
|
| 463 |
maximum=100,
|
| 464 |
value=0,
|
| 465 |
step=1,
|
| 466 |
+
label="LM Top-K",
|
| 467 |
+
info="Top-k (0 = disabled)",
|
| 468 |
)
|
| 469 |
lm_top_p = gr.Slider(
|
|
|
|
| 470 |
minimum=0.0,
|
| 471 |
maximum=1.0,
|
| 472 |
value=0.9,
|
| 473 |
step=0.01,
|
| 474 |
+
label="LM Top-P",
|
| 475 |
+
info="Top-p (1.0 = disabled)",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
)
|
| 477 |
|
| 478 |
+
# Row 5: LM Negative Prompt (full width)
|
| 479 |
+
lm_negative_prompt = gr.Textbox(
|
| 480 |
+
label="LM Negative Prompt",
|
| 481 |
+
value="NO USER INPUT",
|
| 482 |
+
placeholder="Things to avoid in generation...",
|
| 483 |
+
lines=2,
|
| 484 |
+
info="Negative prompt (use when LM CFG Scale > 1.0)",
|
| 485 |
+
)
|
| 486 |
+
# audio_cover_strength remains hidden for now
|
| 487 |
+
audio_cover_strength = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, visible=False)
|
| 488 |
+
|
| 489 |
+
# Note: audio_duration, bpm, key_scale, time_signature are now visible in Optional Parameters
|
| 490 |
+
# ==================== Generate Button Row ====================
|
| 491 |
+
generate_btn_interactive = init_params.get('enable_generate', False) if service_pre_initialized else False
|
| 492 |
+
with gr.Row(equal_height=True):
|
| 493 |
+
# Left: Thinking and Instrumental checkboxes
|
| 494 |
+
with gr.Column(scale=1, min_width=120):
|
| 495 |
+
think_checkbox = gr.Checkbox(
|
| 496 |
+
label="Thinking",
|
| 497 |
value=True,
|
|
|
|
|
|
|
| 498 |
)
|
| 499 |
+
instrumental_checkbox = gr.Checkbox(
|
| 500 |
+
label="Instrumental",
|
| 501 |
value=False,
|
|
|
|
|
|
|
|
|
|
| 502 |
)
|
| 503 |
|
| 504 |
+
# Center: Generate button
|
| 505 |
+
with gr.Column(scale=4):
|
| 506 |
+
generate_btn = gr.Button(
|
| 507 |
+
"🎵 Generate Music",
|
| 508 |
+
variant="primary",
|
| 509 |
+
size="lg",
|
| 510 |
+
interactive=generate_btn_interactive,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# Right: auto_score, auto_lrc
|
| 514 |
+
with gr.Column(scale=1, min_width=120):
|
| 515 |
auto_score = gr.Checkbox(
|
| 516 |
+
label="Get Scores",
|
| 517 |
value=False,
|
|
|
|
|
|
|
|
|
|
| 518 |
)
|
| 519 |
auto_lrc = gr.Checkbox(
|
| 520 |
+
label="Get LRC",
|
| 521 |
value=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
)
|
| 523 |
|
| 524 |
+
# ==================== Hidden Components (for internal use) ====================
|
| 525 |
+
# These are needed for event handlers but not shown in UI
|
| 526 |
+
|
| 527 |
+
# Task type (set automatically based on generation_mode)
|
| 528 |
+
actual_model = init_params.get('config_path', 'acestep-v15-turbo') if service_pre_initialized else 'acestep-v15-turbo'
|
| 529 |
+
actual_model_lower = (actual_model or "").lower()
|
| 530 |
+
if "turbo" in actual_model_lower:
|
| 531 |
+
initial_task_choices = TASK_TYPES_TURBO
|
| 532 |
+
else:
|
| 533 |
+
initial_task_choices = TASK_TYPES_BASE
|
| 534 |
+
|
| 535 |
+
task_type = gr.Dropdown(
|
| 536 |
+
choices=initial_task_choices,
|
| 537 |
+
value="text2music",
|
| 538 |
+
visible=False,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
instruction_display_gen = gr.Textbox(
|
| 542 |
+
value=DEFAULT_DIT_INSTRUCTION,
|
| 543 |
+
visible=False,
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
track_name = gr.Dropdown(
|
| 547 |
+
choices=TRACK_NAMES,
|
| 548 |
+
value=None,
|
| 549 |
+
visible=False,
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
complete_track_classes = gr.CheckboxGroup(
|
| 553 |
+
choices=TRACK_NAMES,
|
| 554 |
+
visible=False,
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
# Note: lyrics, vocal_language, instrumental_checkbox, format_btn are now visible in custom_mode_content
|
| 558 |
+
|
| 559 |
+
# Hidden advanced settings (keep defaults)
|
| 560 |
+
# Note: Most parameters are now visible in Advanced Settings section above
|
| 561 |
+
guidance_scale = gr.Slider(value=7.0, visible=False)
|
| 562 |
+
use_adg = gr.Checkbox(value=False, visible=False)
|
| 563 |
+
cfg_interval_start = gr.Slider(value=0.0, visible=False)
|
| 564 |
+
cfg_interval_end = gr.Slider(value=1.0, visible=False)
|
| 565 |
+
|
| 566 |
+
# LM parameters (remaining hidden ones)
|
| 567 |
+
use_cot_metas = gr.Checkbox(value=True, visible=False)
|
| 568 |
+
use_cot_caption = gr.Checkbox(value=True, visible=False)
|
| 569 |
+
use_cot_language = gr.Checkbox(value=True, visible=False)
|
| 570 |
+
constrained_decoding_debug = gr.Checkbox(value=False, visible=False)
|
| 571 |
+
allow_lm_batch = gr.Checkbox(value=True, visible=False)
|
| 572 |
+
lm_batch_chunk_size = gr.Number(value=8, visible=False)
|
| 573 |
+
score_scale = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, visible=False)
|
| 574 |
+
autogen_checkbox = gr.Checkbox(value=False, visible=False)
|
| 575 |
+
|
| 576 |
+
# Transcribe button (hidden)
|
| 577 |
+
transcribe_btn = gr.Button(value="Transcribe", visible=False)
|
| 578 |
+
text2music_audio_codes_group = gr.Group(visible=False)
|
| 579 |
+
|
| 580 |
+
# Note: format_btn is now visible in custom_mode_content
|
| 581 |
+
|
| 582 |
+
# Load file button (hidden for now)
|
| 583 |
+
load_file = gr.UploadButton(
|
| 584 |
+
label="Load",
|
| 585 |
+
file_types=[".json"],
|
| 586 |
+
file_count="single",
|
| 587 |
+
visible=False,
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
# Caption/Lyrics accordions (not used in new UI but needed for compatibility)
|
| 591 |
+
caption_accordion = gr.Accordion("Caption", visible=False)
|
| 592 |
+
lyrics_accordion = gr.Accordion("Lyrics", visible=False)
|
| 593 |
+
# Note: optional_params_accordion is now visible above
|
| 594 |
|
| 595 |
return {
|
| 596 |
"service_config_accordion": service_config_accordion,
|
|
|
|
| 613 |
"unload_lora_btn": unload_lora_btn,
|
| 614 |
"use_lora_checkbox": use_lora_checkbox,
|
| 615 |
"lora_status": lora_status,
|
| 616 |
+
# DiT model selector
|
| 617 |
+
"dit_model_selector": dit_model_selector,
|
| 618 |
"task_type": task_type,
|
| 619 |
"instruction_display_gen": instruction_display_gen,
|
| 620 |
"track_name": track_name,
|
|
|
|
| 638 |
"repainting_start": repainting_start,
|
| 639 |
"repainting_end": repainting_end,
|
| 640 |
"audio_cover_strength": audio_cover_strength,
|
| 641 |
+
# Generation mode components
|
| 642 |
"generation_mode": generation_mode,
|
| 643 |
"simple_mode_group": simple_mode_group,
|
| 644 |
"simple_query_input": simple_query_input,
|
|
|
|
| 650 |
"caption_accordion": caption_accordion,
|
| 651 |
"lyrics_accordion": lyrics_accordion,
|
| 652 |
"optional_params_accordion": optional_params_accordion,
|
| 653 |
+
# Custom mode components
|
| 654 |
+
"custom_mode_content": custom_mode_content,
|
| 655 |
+
"cover_mode_group": cover_mode_group,
|
| 656 |
+
# Source audio group for Cover/Repaint
|
| 657 |
+
"src_audio_group": src_audio_group,
|
| 658 |
+
"process_src_btn": process_src_btn,
|
| 659 |
+
"advanced_options_accordion": advanced_options_accordion,
|
| 660 |
# Existing components
|
| 661 |
"captions": captions,
|
| 662 |
"sample_btn": sample_btn,
|
|
|
|
| 691 |
"auto_lrc": auto_lrc,
|
| 692 |
"lm_batch_chunk_size": lm_batch_chunk_size,
|
| 693 |
}
|
|
|
acestep/gradio_ui/interfaces/result.py
CHANGED
|
@@ -32,8 +32,14 @@ def create_results_section(dit_handler) -> dict:
|
|
| 32 |
buttons=[]
|
| 33 |
)
|
| 34 |
with gr.Row(equal_height=True):
|
| 35 |
-
|
| 36 |
-
t("results.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
variant="secondary",
|
| 38 |
size="sm",
|
| 39 |
scale=1
|
|
@@ -48,23 +54,17 @@ def create_results_section(dit_handler) -> dict:
|
|
| 48 |
t("results.score_btn"),
|
| 49 |
variant="secondary",
|
| 50 |
size="sm",
|
| 51 |
-
scale=1
|
|
|
|
| 52 |
)
|
| 53 |
lrc_btn_1 = gr.Button(
|
| 54 |
t("results.lrc_btn"),
|
| 55 |
variant="secondary",
|
| 56 |
size="sm",
|
| 57 |
-
scale=1
|
|
|
|
| 58 |
)
|
| 59 |
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_1:
|
| 60 |
-
codes_display_1 = gr.Textbox(
|
| 61 |
-
label=t("results.codes_label", n=1),
|
| 62 |
-
interactive=False,
|
| 63 |
-
buttons=["copy"],
|
| 64 |
-
lines=4,
|
| 65 |
-
max_lines=4,
|
| 66 |
-
visible=True
|
| 67 |
-
)
|
| 68 |
score_display_1 = gr.Textbox(
|
| 69 |
label=t("results.quality_score_label", n=1),
|
| 70 |
interactive=False,
|
|
@@ -81,6 +81,14 @@ def create_results_section(dit_handler) -> dict:
|
|
| 81 |
max_lines=8,
|
| 82 |
visible=True
|
| 83 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
with gr.Column(visible=True) as audio_col_2:
|
| 85 |
generated_audio_2 = gr.Audio(
|
| 86 |
label=t("results.generated_music", n=2),
|
|
@@ -89,8 +97,14 @@ def create_results_section(dit_handler) -> dict:
|
|
| 89 |
buttons=[]
|
| 90 |
)
|
| 91 |
with gr.Row(equal_height=True):
|
| 92 |
-
|
| 93 |
-
t("results.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
variant="secondary",
|
| 95 |
size="sm",
|
| 96 |
scale=1
|
|
@@ -105,23 +119,17 @@ def create_results_section(dit_handler) -> dict:
|
|
| 105 |
t("results.score_btn"),
|
| 106 |
variant="secondary",
|
| 107 |
size="sm",
|
| 108 |
-
scale=1
|
|
|
|
| 109 |
)
|
| 110 |
lrc_btn_2 = gr.Button(
|
| 111 |
t("results.lrc_btn"),
|
| 112 |
variant="secondary",
|
| 113 |
size="sm",
|
| 114 |
-
scale=1
|
|
|
|
| 115 |
)
|
| 116 |
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_2:
|
| 117 |
-
codes_display_2 = gr.Textbox(
|
| 118 |
-
label=t("results.codes_label", n=2),
|
| 119 |
-
interactive=False,
|
| 120 |
-
buttons=["copy"],
|
| 121 |
-
lines=4,
|
| 122 |
-
max_lines=4,
|
| 123 |
-
visible=True
|
| 124 |
-
)
|
| 125 |
score_display_2 = gr.Textbox(
|
| 126 |
label=t("results.quality_score_label", n=2),
|
| 127 |
interactive=False,
|
|
@@ -138,6 +146,14 @@ def create_results_section(dit_handler) -> dict:
|
|
| 138 |
max_lines=8,
|
| 139 |
visible=True
|
| 140 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
with gr.Column(visible=False) as audio_col_3:
|
| 142 |
generated_audio_3 = gr.Audio(
|
| 143 |
label=t("results.generated_music", n=3),
|
|
@@ -146,8 +162,14 @@ def create_results_section(dit_handler) -> dict:
|
|
| 146 |
buttons=[]
|
| 147 |
)
|
| 148 |
with gr.Row(equal_height=True):
|
| 149 |
-
|
| 150 |
-
t("results.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
variant="secondary",
|
| 152 |
size="sm",
|
| 153 |
scale=1
|
|
@@ -162,23 +184,17 @@ def create_results_section(dit_handler) -> dict:
|
|
| 162 |
t("results.score_btn"),
|
| 163 |
variant="secondary",
|
| 164 |
size="sm",
|
| 165 |
-
scale=1
|
|
|
|
| 166 |
)
|
| 167 |
lrc_btn_3 = gr.Button(
|
| 168 |
t("results.lrc_btn"),
|
| 169 |
variant="secondary",
|
| 170 |
size="sm",
|
| 171 |
-
scale=1
|
|
|
|
| 172 |
)
|
| 173 |
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_3:
|
| 174 |
-
codes_display_3 = gr.Textbox(
|
| 175 |
-
label=t("results.codes_label", n=3),
|
| 176 |
-
interactive=False,
|
| 177 |
-
buttons=["copy"],
|
| 178 |
-
lines=4,
|
| 179 |
-
max_lines=4,
|
| 180 |
-
visible=True
|
| 181 |
-
)
|
| 182 |
score_display_3 = gr.Textbox(
|
| 183 |
label=t("results.quality_score_label", n=3),
|
| 184 |
interactive=False,
|
|
@@ -195,6 +211,14 @@ def create_results_section(dit_handler) -> dict:
|
|
| 195 |
max_lines=8,
|
| 196 |
visible=True
|
| 197 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
with gr.Column(visible=False) as audio_col_4:
|
| 199 |
generated_audio_4 = gr.Audio(
|
| 200 |
label=t("results.generated_music", n=4),
|
|
@@ -203,8 +227,14 @@ def create_results_section(dit_handler) -> dict:
|
|
| 203 |
buttons=[]
|
| 204 |
)
|
| 205 |
with gr.Row(equal_height=True):
|
| 206 |
-
|
| 207 |
-
t("results.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
variant="secondary",
|
| 209 |
size="sm",
|
| 210 |
scale=1
|
|
@@ -219,23 +249,17 @@ def create_results_section(dit_handler) -> dict:
|
|
| 219 |
t("results.score_btn"),
|
| 220 |
variant="secondary",
|
| 221 |
size="sm",
|
| 222 |
-
scale=1
|
|
|
|
| 223 |
)
|
| 224 |
lrc_btn_4 = gr.Button(
|
| 225 |
t("results.lrc_btn"),
|
| 226 |
variant="secondary",
|
| 227 |
size="sm",
|
| 228 |
-
scale=1
|
|
|
|
| 229 |
)
|
| 230 |
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_4:
|
| 231 |
-
codes_display_4 = gr.Textbox(
|
| 232 |
-
label=t("results.codes_label", n=4),
|
| 233 |
-
interactive=False,
|
| 234 |
-
buttons=["copy"],
|
| 235 |
-
lines=4,
|
| 236 |
-
max_lines=4,
|
| 237 |
-
visible=True
|
| 238 |
-
)
|
| 239 |
score_display_4 = gr.Textbox(
|
| 240 |
label=t("results.quality_score_label", n=4),
|
| 241 |
interactive=False,
|
|
@@ -252,6 +276,14 @@ def create_results_section(dit_handler) -> dict:
|
|
| 252 |
max_lines=8,
|
| 253 |
visible=True
|
| 254 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
# Second row for batch size 5-8 (initially hidden)
|
| 257 |
with gr.Row(visible=False) as audio_row_5_8:
|
|
@@ -263,19 +295,12 @@ def create_results_section(dit_handler) -> dict:
|
|
| 263 |
buttons=[]
|
| 264 |
)
|
| 265 |
with gr.Row(equal_height=True):
|
| 266 |
-
|
|
|
|
| 267 |
save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 268 |
-
score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
| 269 |
-
lrc_btn_5 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
|
| 270 |
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_5:
|
| 271 |
-
codes_display_5 = gr.Textbox(
|
| 272 |
-
label=t("results.codes_label", n=5),
|
| 273 |
-
interactive=False,
|
| 274 |
-
buttons=["copy"],
|
| 275 |
-
lines=4,
|
| 276 |
-
max_lines=4,
|
| 277 |
-
visible=True
|
| 278 |
-
)
|
| 279 |
score_display_5 = gr.Textbox(
|
| 280 |
label=t("results.quality_score_label", n=5),
|
| 281 |
interactive=False,
|
|
@@ -292,6 +317,14 @@ def create_results_section(dit_handler) -> dict:
|
|
| 292 |
max_lines=8,
|
| 293 |
visible=True
|
| 294 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
with gr.Column() as audio_col_6:
|
| 296 |
generated_audio_6 = gr.Audio(
|
| 297 |
label=t("results.generated_music", n=6),
|
|
@@ -300,19 +333,12 @@ def create_results_section(dit_handler) -> dict:
|
|
| 300 |
buttons=[]
|
| 301 |
)
|
| 302 |
with gr.Row(equal_height=True):
|
| 303 |
-
|
|
|
|
| 304 |
save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 305 |
-
score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
| 306 |
-
lrc_btn_6 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
|
| 307 |
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_6:
|
| 308 |
-
codes_display_6 = gr.Textbox(
|
| 309 |
-
label=t("results.codes_label", n=6),
|
| 310 |
-
interactive=False,
|
| 311 |
-
buttons=["copy"],
|
| 312 |
-
lines=4,
|
| 313 |
-
max_lines=4,
|
| 314 |
-
visible=True
|
| 315 |
-
)
|
| 316 |
score_display_6 = gr.Textbox(
|
| 317 |
label=t("results.quality_score_label", n=6),
|
| 318 |
interactive=False,
|
|
@@ -329,6 +355,14 @@ def create_results_section(dit_handler) -> dict:
|
|
| 329 |
max_lines=8,
|
| 330 |
visible=True
|
| 331 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
with gr.Column() as audio_col_7:
|
| 333 |
generated_audio_7 = gr.Audio(
|
| 334 |
label=t("results.generated_music", n=7),
|
|
@@ -337,19 +371,12 @@ def create_results_section(dit_handler) -> dict:
|
|
| 337 |
buttons=[]
|
| 338 |
)
|
| 339 |
with gr.Row(equal_height=True):
|
| 340 |
-
|
|
|
|
| 341 |
save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 342 |
-
score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
| 343 |
-
lrc_btn_7 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
|
| 344 |
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_7:
|
| 345 |
-
codes_display_7 = gr.Textbox(
|
| 346 |
-
label=t("results.codes_label", n=7),
|
| 347 |
-
interactive=False,
|
| 348 |
-
buttons=["copy"],
|
| 349 |
-
lines=4,
|
| 350 |
-
max_lines=4,
|
| 351 |
-
visible=True
|
| 352 |
-
)
|
| 353 |
score_display_7 = gr.Textbox(
|
| 354 |
label=t("results.quality_score_label", n=7),
|
| 355 |
interactive=False,
|
|
@@ -366,6 +393,14 @@ def create_results_section(dit_handler) -> dict:
|
|
| 366 |
max_lines=8,
|
| 367 |
visible=True
|
| 368 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
with gr.Column() as audio_col_8:
|
| 370 |
generated_audio_8 = gr.Audio(
|
| 371 |
label=t("results.generated_music", n=8),
|
|
@@ -374,19 +409,12 @@ def create_results_section(dit_handler) -> dict:
|
|
| 374 |
buttons=[]
|
| 375 |
)
|
| 376 |
with gr.Row(equal_height=True):
|
| 377 |
-
|
|
|
|
| 378 |
save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 379 |
-
score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
| 380 |
-
lrc_btn_8 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
|
| 381 |
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_8:
|
| 382 |
-
codes_display_8 = gr.Textbox(
|
| 383 |
-
label=t("results.codes_label", n=8),
|
| 384 |
-
interactive=False,
|
| 385 |
-
buttons=["copy"],
|
| 386 |
-
lines=4,
|
| 387 |
-
max_lines=4,
|
| 388 |
-
visible=True
|
| 389 |
-
)
|
| 390 |
score_display_8 = gr.Textbox(
|
| 391 |
label=t("results.quality_score_label", n=8),
|
| 392 |
interactive=False,
|
|
@@ -403,11 +431,19 @@ def create_results_section(dit_handler) -> dict:
|
|
| 403 |
max_lines=8,
|
| 404 |
visible=True
|
| 405 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
|
| 408 |
|
| 409 |
-
# Batch navigation controls
|
| 410 |
-
with gr.Row(equal_height=True):
|
| 411 |
prev_batch_btn = gr.Button(
|
| 412 |
t("results.prev_btn"),
|
| 413 |
variant="secondary",
|
|
@@ -435,12 +471,13 @@ def create_results_section(dit_handler) -> dict:
|
|
| 435 |
size="sm"
|
| 436 |
)
|
| 437 |
|
| 438 |
-
# One-click restore parameters button
|
| 439 |
restore_params_btn = gr.Button(
|
| 440 |
t("results.restore_params_btn"),
|
| 441 |
variant="secondary",
|
| 442 |
-
interactive=False,
|
| 443 |
-
size="sm"
|
|
|
|
| 444 |
)
|
| 445 |
|
| 446 |
with gr.Accordion(t("results.batch_results_title"), open=False):
|
|
@@ -482,14 +519,22 @@ def create_results_section(dit_handler) -> dict:
|
|
| 482 |
"audio_col_6": audio_col_6,
|
| 483 |
"audio_col_7": audio_col_7,
|
| 484 |
"audio_col_8": audio_col_8,
|
| 485 |
-
"
|
| 486 |
-
"
|
| 487 |
-
"
|
| 488 |
-
"
|
| 489 |
-
"
|
| 490 |
-
"
|
| 491 |
-
"
|
| 492 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
"save_btn_1": save_btn_1,
|
| 494 |
"save_btn_2": save_btn_2,
|
| 495 |
"save_btn_3": save_btn_3,
|
|
|
|
| 32 |
buttons=[]
|
| 33 |
)
|
| 34 |
with gr.Row(equal_height=True):
|
| 35 |
+
send_to_cover_btn_1 = gr.Button(
|
| 36 |
+
t("results.send_to_cover_btn"),
|
| 37 |
+
variant="secondary",
|
| 38 |
+
size="sm",
|
| 39 |
+
scale=1
|
| 40 |
+
)
|
| 41 |
+
send_to_repaint_btn_1 = gr.Button(
|
| 42 |
+
t("results.send_to_repaint_btn"),
|
| 43 |
variant="secondary",
|
| 44 |
size="sm",
|
| 45 |
scale=1
|
|
|
|
| 54 |
t("results.score_btn"),
|
| 55 |
variant="secondary",
|
| 56 |
size="sm",
|
| 57 |
+
scale=1,
|
| 58 |
+
visible=False
|
| 59 |
)
|
| 60 |
lrc_btn_1 = gr.Button(
|
| 61 |
t("results.lrc_btn"),
|
| 62 |
variant="secondary",
|
| 63 |
size="sm",
|
| 64 |
+
scale=1,
|
| 65 |
+
visible=False
|
| 66 |
)
|
| 67 |
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
score_display_1 = gr.Textbox(
|
| 69 |
label=t("results.quality_score_label", n=1),
|
| 70 |
interactive=False,
|
|
|
|
| 81 |
max_lines=8,
|
| 82 |
visible=True
|
| 83 |
)
|
| 84 |
+
codes_display_1 = gr.Textbox(
|
| 85 |
+
label=t("results.codes_label", n=1),
|
| 86 |
+
interactive=False,
|
| 87 |
+
buttons=["copy"],
|
| 88 |
+
lines=4,
|
| 89 |
+
max_lines=4,
|
| 90 |
+
visible=True
|
| 91 |
+
)
|
| 92 |
with gr.Column(visible=True) as audio_col_2:
|
| 93 |
generated_audio_2 = gr.Audio(
|
| 94 |
label=t("results.generated_music", n=2),
|
|
|
|
| 97 |
buttons=[]
|
| 98 |
)
|
| 99 |
with gr.Row(equal_height=True):
|
| 100 |
+
send_to_cover_btn_2 = gr.Button(
|
| 101 |
+
t("results.send_to_cover_btn"),
|
| 102 |
+
variant="secondary",
|
| 103 |
+
size="sm",
|
| 104 |
+
scale=1
|
| 105 |
+
)
|
| 106 |
+
send_to_repaint_btn_2 = gr.Button(
|
| 107 |
+
t("results.send_to_repaint_btn"),
|
| 108 |
variant="secondary",
|
| 109 |
size="sm",
|
| 110 |
scale=1
|
|
|
|
| 119 |
t("results.score_btn"),
|
| 120 |
variant="secondary",
|
| 121 |
size="sm",
|
| 122 |
+
scale=1,
|
| 123 |
+
visible=False
|
| 124 |
)
|
| 125 |
lrc_btn_2 = gr.Button(
|
| 126 |
t("results.lrc_btn"),
|
| 127 |
variant="secondary",
|
| 128 |
size="sm",
|
| 129 |
+
scale=1,
|
| 130 |
+
visible=False
|
| 131 |
)
|
| 132 |
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_2:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
score_display_2 = gr.Textbox(
|
| 134 |
label=t("results.quality_score_label", n=2),
|
| 135 |
interactive=False,
|
|
|
|
| 146 |
max_lines=8,
|
| 147 |
visible=True
|
| 148 |
)
|
| 149 |
+
codes_display_2 = gr.Textbox(
|
| 150 |
+
label=t("results.codes_label", n=2),
|
| 151 |
+
interactive=False,
|
| 152 |
+
buttons=["copy"],
|
| 153 |
+
lines=4,
|
| 154 |
+
max_lines=4,
|
| 155 |
+
visible=True
|
| 156 |
+
)
|
| 157 |
with gr.Column(visible=False) as audio_col_3:
|
| 158 |
generated_audio_3 = gr.Audio(
|
| 159 |
label=t("results.generated_music", n=3),
|
|
|
|
| 162 |
buttons=[]
|
| 163 |
)
|
| 164 |
with gr.Row(equal_height=True):
|
| 165 |
+
send_to_cover_btn_3 = gr.Button(
|
| 166 |
+
t("results.send_to_cover_btn"),
|
| 167 |
+
variant="secondary",
|
| 168 |
+
size="sm",
|
| 169 |
+
scale=1
|
| 170 |
+
)
|
| 171 |
+
send_to_repaint_btn_3 = gr.Button(
|
| 172 |
+
t("results.send_to_repaint_btn"),
|
| 173 |
variant="secondary",
|
| 174 |
size="sm",
|
| 175 |
scale=1
|
|
|
|
| 184 |
t("results.score_btn"),
|
| 185 |
variant="secondary",
|
| 186 |
size="sm",
|
| 187 |
+
scale=1,
|
| 188 |
+
visible=False
|
| 189 |
)
|
| 190 |
lrc_btn_3 = gr.Button(
|
| 191 |
t("results.lrc_btn"),
|
| 192 |
variant="secondary",
|
| 193 |
size="sm",
|
| 194 |
+
scale=1,
|
| 195 |
+
visible=False
|
| 196 |
)
|
| 197 |
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_3:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
score_display_3 = gr.Textbox(
|
| 199 |
label=t("results.quality_score_label", n=3),
|
| 200 |
interactive=False,
|
|
|
|
| 211 |
max_lines=8,
|
| 212 |
visible=True
|
| 213 |
)
|
| 214 |
+
codes_display_3 = gr.Textbox(
|
| 215 |
+
label=t("results.codes_label", n=3),
|
| 216 |
+
interactive=False,
|
| 217 |
+
buttons=["copy"],
|
| 218 |
+
lines=4,
|
| 219 |
+
max_lines=4,
|
| 220 |
+
visible=True
|
| 221 |
+
)
|
| 222 |
with gr.Column(visible=False) as audio_col_4:
|
| 223 |
generated_audio_4 = gr.Audio(
|
| 224 |
label=t("results.generated_music", n=4),
|
|
|
|
| 227 |
buttons=[]
|
| 228 |
)
|
| 229 |
with gr.Row(equal_height=True):
|
| 230 |
+
send_to_cover_btn_4 = gr.Button(
|
| 231 |
+
t("results.send_to_cover_btn"),
|
| 232 |
+
variant="secondary",
|
| 233 |
+
size="sm",
|
| 234 |
+
scale=1
|
| 235 |
+
)
|
| 236 |
+
send_to_repaint_btn_4 = gr.Button(
|
| 237 |
+
t("results.send_to_repaint_btn"),
|
| 238 |
variant="secondary",
|
| 239 |
size="sm",
|
| 240 |
scale=1
|
|
|
|
| 249 |
t("results.score_btn"),
|
| 250 |
variant="secondary",
|
| 251 |
size="sm",
|
| 252 |
+
scale=1,
|
| 253 |
+
visible=False
|
| 254 |
)
|
| 255 |
lrc_btn_4 = gr.Button(
|
| 256 |
t("results.lrc_btn"),
|
| 257 |
variant="secondary",
|
| 258 |
size="sm",
|
| 259 |
+
scale=1,
|
| 260 |
+
visible=False
|
| 261 |
)
|
| 262 |
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_4:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
score_display_4 = gr.Textbox(
|
| 264 |
label=t("results.quality_score_label", n=4),
|
| 265 |
interactive=False,
|
|
|
|
| 276 |
max_lines=8,
|
| 277 |
visible=True
|
| 278 |
)
|
| 279 |
+
codes_display_4 = gr.Textbox(
|
| 280 |
+
label=t("results.codes_label", n=4),
|
| 281 |
+
interactive=False,
|
| 282 |
+
buttons=["copy"],
|
| 283 |
+
lines=4,
|
| 284 |
+
max_lines=4,
|
| 285 |
+
visible=True
|
| 286 |
+
)
|
| 287 |
|
| 288 |
# Second row for batch size 5-8 (initially hidden)
|
| 289 |
with gr.Row(visible=False) as audio_row_5_8:
|
|
|
|
| 295 |
buttons=[]
|
| 296 |
)
|
| 297 |
with gr.Row(equal_height=True):
|
| 298 |
+
send_to_cover_btn_5 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
|
| 299 |
+
send_to_repaint_btn_5 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
|
| 300 |
save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 301 |
+
score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 302 |
+
lrc_btn_5 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 303 |
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_5:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
score_display_5 = gr.Textbox(
|
| 305 |
label=t("results.quality_score_label", n=5),
|
| 306 |
interactive=False,
|
|
|
|
| 317 |
max_lines=8,
|
| 318 |
visible=True
|
| 319 |
)
|
| 320 |
+
codes_display_5 = gr.Textbox(
|
| 321 |
+
label=t("results.codes_label", n=5),
|
| 322 |
+
interactive=False,
|
| 323 |
+
buttons=["copy"],
|
| 324 |
+
lines=4,
|
| 325 |
+
max_lines=4,
|
| 326 |
+
visible=True
|
| 327 |
+
)
|
| 328 |
with gr.Column() as audio_col_6:
|
| 329 |
generated_audio_6 = gr.Audio(
|
| 330 |
label=t("results.generated_music", n=6),
|
|
|
|
| 333 |
buttons=[]
|
| 334 |
)
|
| 335 |
with gr.Row(equal_height=True):
|
| 336 |
+
send_to_cover_btn_6 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
|
| 337 |
+
send_to_repaint_btn_6 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
|
| 338 |
save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 339 |
+
score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 340 |
+
lrc_btn_6 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 341 |
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_6:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
score_display_6 = gr.Textbox(
|
| 343 |
label=t("results.quality_score_label", n=6),
|
| 344 |
interactive=False,
|
|
|
|
| 355 |
max_lines=8,
|
| 356 |
visible=True
|
| 357 |
)
|
| 358 |
+
codes_display_6 = gr.Textbox(
|
| 359 |
+
label=t("results.codes_label", n=6),
|
| 360 |
+
interactive=False,
|
| 361 |
+
buttons=["copy"],
|
| 362 |
+
lines=4,
|
| 363 |
+
max_lines=4,
|
| 364 |
+
visible=True
|
| 365 |
+
)
|
| 366 |
with gr.Column() as audio_col_7:
|
| 367 |
generated_audio_7 = gr.Audio(
|
| 368 |
label=t("results.generated_music", n=7),
|
|
|
|
| 371 |
buttons=[]
|
| 372 |
)
|
| 373 |
with gr.Row(equal_height=True):
|
| 374 |
+
send_to_cover_btn_7 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
|
| 375 |
+
send_to_repaint_btn_7 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
|
| 376 |
save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 377 |
+
score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 378 |
+
lrc_btn_7 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 379 |
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_7:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
score_display_7 = gr.Textbox(
|
| 381 |
label=t("results.quality_score_label", n=7),
|
| 382 |
interactive=False,
|
|
|
|
| 393 |
max_lines=8,
|
| 394 |
visible=True
|
| 395 |
)
|
| 396 |
+
codes_display_7 = gr.Textbox(
|
| 397 |
+
label=t("results.codes_label", n=7),
|
| 398 |
+
interactive=False,
|
| 399 |
+
buttons=["copy"],
|
| 400 |
+
lines=4,
|
| 401 |
+
max_lines=4,
|
| 402 |
+
visible=True
|
| 403 |
+
)
|
| 404 |
with gr.Column() as audio_col_8:
|
| 405 |
generated_audio_8 = gr.Audio(
|
| 406 |
label=t("results.generated_music", n=8),
|
|
|
|
| 409 |
buttons=[]
|
| 410 |
)
|
| 411 |
with gr.Row(equal_height=True):
|
| 412 |
+
send_to_cover_btn_8 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
|
| 413 |
+
send_to_repaint_btn_8 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
|
| 414 |
save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 415 |
+
score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 416 |
+
lrc_btn_8 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 417 |
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_8:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
score_display_8 = gr.Textbox(
|
| 419 |
label=t("results.quality_score_label", n=8),
|
| 420 |
interactive=False,
|
|
|
|
| 431 |
max_lines=8,
|
| 432 |
visible=True
|
| 433 |
)
|
| 434 |
+
codes_display_8 = gr.Textbox(
|
| 435 |
+
label=t("results.codes_label", n=8),
|
| 436 |
+
interactive=False,
|
| 437 |
+
buttons=["copy"],
|
| 438 |
+
lines=4,
|
| 439 |
+
max_lines=4,
|
| 440 |
+
visible=True
|
| 441 |
+
)
|
| 442 |
|
| 443 |
status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
|
| 444 |
|
| 445 |
+
# Batch navigation controls (hidden for simplified UI)
|
| 446 |
+
with gr.Row(equal_height=True, visible=False):
|
| 447 |
prev_batch_btn = gr.Button(
|
| 448 |
t("results.prev_btn"),
|
| 449 |
variant="secondary",
|
|
|
|
| 471 |
size="sm"
|
| 472 |
)
|
| 473 |
|
| 474 |
+
# One-click restore parameters button (hidden for simplified UI)
|
| 475 |
restore_params_btn = gr.Button(
|
| 476 |
t("results.restore_params_btn"),
|
| 477 |
variant="secondary",
|
| 478 |
+
interactive=False,
|
| 479 |
+
size="sm",
|
| 480 |
+
visible=False
|
| 481 |
)
|
| 482 |
|
| 483 |
with gr.Accordion(t("results.batch_results_title"), open=False):
|
|
|
|
| 519 |
"audio_col_6": audio_col_6,
|
| 520 |
"audio_col_7": audio_col_7,
|
| 521 |
"audio_col_8": audio_col_8,
|
| 522 |
+
"send_to_cover_btn_1": send_to_cover_btn_1,
|
| 523 |
+
"send_to_cover_btn_2": send_to_cover_btn_2,
|
| 524 |
+
"send_to_cover_btn_3": send_to_cover_btn_3,
|
| 525 |
+
"send_to_cover_btn_4": send_to_cover_btn_4,
|
| 526 |
+
"send_to_cover_btn_5": send_to_cover_btn_5,
|
| 527 |
+
"send_to_cover_btn_6": send_to_cover_btn_6,
|
| 528 |
+
"send_to_cover_btn_7": send_to_cover_btn_7,
|
| 529 |
+
"send_to_cover_btn_8": send_to_cover_btn_8,
|
| 530 |
+
"send_to_repaint_btn_1": send_to_repaint_btn_1,
|
| 531 |
+
"send_to_repaint_btn_2": send_to_repaint_btn_2,
|
| 532 |
+
"send_to_repaint_btn_3": send_to_repaint_btn_3,
|
| 533 |
+
"send_to_repaint_btn_4": send_to_repaint_btn_4,
|
| 534 |
+
"send_to_repaint_btn_5": send_to_repaint_btn_5,
|
| 535 |
+
"send_to_repaint_btn_6": send_to_repaint_btn_6,
|
| 536 |
+
"send_to_repaint_btn_7": send_to_repaint_btn_7,
|
| 537 |
+
"send_to_repaint_btn_8": send_to_repaint_btn_8,
|
| 538 |
"save_btn_1": save_btn_1,
|
| 539 |
"save_btn_2": save_btn_2,
|
| 540 |
"save_btn_3": save_btn_3,
|
acestep/handler.py
CHANGED
|
@@ -315,6 +315,11 @@ class AceStepHandler:
|
|
| 315 |
offload_to_cpu: bool = False,
|
| 316 |
offload_dit_to_cpu: bool = False,
|
| 317 |
quantization: Optional[str] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
) -> Tuple[str, bool]:
|
| 319 |
"""
|
| 320 |
Initialize DiT model service
|
|
@@ -327,6 +332,10 @@ class AceStepHandler:
|
|
| 327 |
compile_model: Whether to use torch.compile to optimize the model
|
| 328 |
offload_to_cpu: Whether to offload models to CPU when not in use
|
| 329 |
offload_dit_to_cpu: Whether to offload DiT model to CPU when not in use (only effective if offload_to_cpu is True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
Returns:
|
| 332 |
(status_message, enable_generate_button)
|
|
@@ -440,54 +449,77 @@ class AceStepHandler:
|
|
| 440 |
logger.info(f"[initialize_service] DiT quantized with: {self.quantization}")
|
| 441 |
|
| 442 |
|
| 443 |
-
|
| 444 |
-
if
|
| 445 |
-
self.silence_latent =
|
| 446 |
-
|
| 447 |
-
# and is small enough that it won't significantly impact VRAM
|
| 448 |
-
self.silence_latent = self.silence_latent.to(device).to(self.dtype)
|
| 449 |
else:
|
| 450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
else:
|
| 452 |
raise FileNotFoundError(f"ACE-Step V1.5 checkpoint not found at {acestep_v15_checkpoint_path}")
|
| 453 |
|
| 454 |
-
# 2. Load VAE
|
| 455 |
-
vae_checkpoint_path = os.path.join(checkpoint_dir, "vae")
|
| 456 |
-
if
|
| 457 |
-
self.vae =
|
| 458 |
-
|
| 459 |
-
vae_dtype = self._get_vae_dtype(device)
|
| 460 |
-
if not self.offload_to_cpu:
|
| 461 |
-
self.vae = self.vae.to(device).to(vae_dtype)
|
| 462 |
-
else:
|
| 463 |
-
self.vae = self.vae.to("cpu").to(vae_dtype)
|
| 464 |
-
self.vae.eval()
|
| 465 |
else:
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
self.text_encoder = AutoModel.from_pretrained(text_encoder_path)
|
| 476 |
-
if not self.offload_to_cpu:
|
| 477 |
-
self.text_encoder = self.text_encoder.to(device).to(self.dtype)
|
| 478 |
else:
|
| 479 |
-
|
| 480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
else:
|
| 482 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
|
| 484 |
# Determine actual attention implementation used
|
| 485 |
actual_attn = getattr(self.config, "_attn_implementation", "eager")
|
| 486 |
|
|
|
|
|
|
|
|
|
|
| 487 |
status_msg = f"✅ Model initialized successfully on {device}\n"
|
| 488 |
status_msg += f"Main model: {acestep_v15_checkpoint_path}\n"
|
| 489 |
-
|
| 490 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
status_msg += f"Dtype: {self.dtype}\n"
|
| 492 |
status_msg += f"Attention: {actual_attn}\n"
|
| 493 |
status_msg += f"Compiled: {compile_model}\n"
|
|
|
|
| 315 |
offload_to_cpu: bool = False,
|
| 316 |
offload_dit_to_cpu: bool = False,
|
| 317 |
quantization: Optional[str] = None,
|
| 318 |
+
# Shared components (for multi-model setup to save memory)
|
| 319 |
+
shared_vae = None,
|
| 320 |
+
shared_text_encoder = None,
|
| 321 |
+
shared_text_tokenizer = None,
|
| 322 |
+
shared_silence_latent = None,
|
| 323 |
) -> Tuple[str, bool]:
|
| 324 |
"""
|
| 325 |
Initialize DiT model service
|
|
|
|
| 332 |
compile_model: Whether to use torch.compile to optimize the model
|
| 333 |
offload_to_cpu: Whether to offload models to CPU when not in use
|
| 334 |
offload_dit_to_cpu: Whether to offload DiT model to CPU when not in use (only effective if offload_to_cpu is True)
|
| 335 |
+
shared_vae: Optional shared VAE instance (for multi-model setup)
|
| 336 |
+
shared_text_encoder: Optional shared text encoder instance (for multi-model setup)
|
| 337 |
+
shared_text_tokenizer: Optional shared text tokenizer instance (for multi-model setup)
|
| 338 |
+
shared_silence_latent: Optional shared silence latent tensor (for multi-model setup)
|
| 339 |
|
| 340 |
Returns:
|
| 341 |
(status_message, enable_generate_button)
|
|
|
|
| 449 |
logger.info(f"[initialize_service] DiT quantized with: {self.quantization}")
|
| 450 |
|
| 451 |
|
| 452 |
+
# Load or use shared silence_latent
|
| 453 |
+
if shared_silence_latent is not None:
|
| 454 |
+
self.silence_latent = shared_silence_latent
|
| 455 |
+
logger.info("[initialize_service] Using shared silence_latent")
|
|
|
|
|
|
|
| 456 |
else:
|
| 457 |
+
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
|
| 458 |
+
if os.path.exists(silence_latent_path):
|
| 459 |
+
self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
|
| 460 |
+
# Always keep silence_latent on GPU - it's used in many places outside model context
|
| 461 |
+
# and is small enough that it won't significantly impact VRAM
|
| 462 |
+
self.silence_latent = self.silence_latent.to(device).to(self.dtype)
|
| 463 |
+
else:
|
| 464 |
+
raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
|
| 465 |
else:
|
| 466 |
raise FileNotFoundError(f"ACE-Step V1.5 checkpoint not found at {acestep_v15_checkpoint_path}")
|
| 467 |
|
| 468 |
+
# 2. Load or use shared VAE
|
| 469 |
+
vae_checkpoint_path = os.path.join(checkpoint_dir, "vae") # Define for status message
|
| 470 |
+
if shared_vae is not None:
|
| 471 |
+
self.vae = shared_vae
|
| 472 |
+
logger.info("[initialize_service] Using shared VAE")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
else:
|
| 474 |
+
if os.path.exists(vae_checkpoint_path):
|
| 475 |
+
self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
|
| 476 |
+
# Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
|
| 477 |
+
vae_dtype = self._get_vae_dtype(device)
|
| 478 |
+
if not self.offload_to_cpu:
|
| 479 |
+
self.vae = self.vae.to(device).to(vae_dtype)
|
| 480 |
+
else:
|
| 481 |
+
self.vae = self.vae.to("cpu").to(vae_dtype)
|
| 482 |
+
self.vae.eval()
|
|
|
|
|
|
|
|
|
|
| 483 |
else:
|
| 484 |
+
raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
|
| 485 |
+
|
| 486 |
+
if compile_model:
|
| 487 |
+
self.vae = torch.compile(self.vae)
|
| 488 |
+
|
| 489 |
+
# 3. Load or use shared text encoder and tokenizer
|
| 490 |
+
text_encoder_path = os.path.join(checkpoint_dir, "Qwen3-Embedding-0.6B") # Define for status message
|
| 491 |
+
if shared_text_encoder is not None and shared_text_tokenizer is not None:
|
| 492 |
+
self.text_encoder = shared_text_encoder
|
| 493 |
+
self.text_tokenizer = shared_text_tokenizer
|
| 494 |
+
logger.info("[initialize_service] Using shared text encoder and tokenizer")
|
| 495 |
else:
|
| 496 |
+
if os.path.exists(text_encoder_path):
|
| 497 |
+
self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)
|
| 498 |
+
self.text_encoder = AutoModel.from_pretrained(text_encoder_path)
|
| 499 |
+
if not self.offload_to_cpu:
|
| 500 |
+
self.text_encoder = self.text_encoder.to(device).to(self.dtype)
|
| 501 |
+
else:
|
| 502 |
+
self.text_encoder = self.text_encoder.to("cpu").to(self.dtype)
|
| 503 |
+
self.text_encoder.eval()
|
| 504 |
+
else:
|
| 505 |
+
raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}")
|
| 506 |
|
| 507 |
# Determine actual attention implementation used
|
| 508 |
actual_attn = getattr(self.config, "_attn_implementation", "eager")
|
| 509 |
|
| 510 |
+
# Determine if using shared components
|
| 511 |
+
using_shared = shared_vae is not None or shared_text_encoder is not None
|
| 512 |
+
|
| 513 |
status_msg = f"✅ Model initialized successfully on {device}\n"
|
| 514 |
status_msg += f"Main model: {acestep_v15_checkpoint_path}\n"
|
| 515 |
+
if shared_vae is None:
|
| 516 |
+
status_msg += f"VAE: {vae_checkpoint_path}\n"
|
| 517 |
+
else:
|
| 518 |
+
status_msg += f"VAE: shared\n"
|
| 519 |
+
if shared_text_encoder is None:
|
| 520 |
+
status_msg += f"Text encoder: {text_encoder_path}\n"
|
| 521 |
+
else:
|
| 522 |
+
status_msg += f"Text encoder: shared\n"
|
| 523 |
status_msg += f"Dtype: {self.dtype}\n"
|
| 524 |
status_msg += f"Attention: {actual_attn}\n"
|
| 525 |
status_msg += f"Compiled: {compile_model}\n"
|
app.py
CHANGED
|
@@ -53,7 +53,24 @@ def get_persistent_storage_path():
|
|
| 53 |
1. Must be enabled in Space settings
|
| 54 |
2. Path is typically /data for Docker SDK
|
| 55 |
3. Falls back to app directory if /data is not writable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# Try HuggingFace Space persistent storage first
|
| 58 |
hf_data_path = "/data"
|
| 59 |
|
|
@@ -80,6 +97,14 @@ def get_persistent_storage_path():
|
|
| 80 |
def main():
|
| 81 |
"""Main entry point for HuggingFace Space"""
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
# Get persistent storage path (auto-detect)
|
| 84 |
persistent_storage_path = get_persistent_storage_path()
|
| 85 |
|
|
@@ -87,14 +112,15 @@ def main():
|
|
| 87 |
gpu_memory_gb = get_gpu_memory_gb()
|
| 88 |
auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
|
| 89 |
|
| 90 |
-
if
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
| 98 |
|
| 99 |
# Create handler instances
|
| 100 |
print("Creating handlers...")
|
|
@@ -107,6 +133,9 @@ def main():
|
|
| 107 |
"SERVICE_MODE_DIT_MODEL",
|
| 108 |
"acestep-v15-turbo"
|
| 109 |
)
|
|
|
|
|
|
|
|
|
|
| 110 |
lm_model_path = os.environ.get(
|
| 111 |
"SERVICE_MODE_LM_MODEL",
|
| 112 |
"acestep-5Hz-lm-1.7B"
|
|
@@ -115,50 +144,97 @@ def main():
|
|
| 115 |
device = "auto"
|
| 116 |
|
| 117 |
print(f"Service mode configuration:")
|
| 118 |
-
print(f" DiT model: {config_path}")
|
|
|
|
|
|
|
| 119 |
print(f" LM model: {lm_model_path}")
|
| 120 |
print(f" Backend: {backend}")
|
| 121 |
print(f" Offload to CPU: {auto_offload}")
|
|
|
|
| 122 |
|
| 123 |
# Determine flash attention availability
|
| 124 |
use_flash_attention = dit_handler.is_flash_attention_available()
|
| 125 |
print(f" Flash Attention: {use_flash_attention}")
|
| 126 |
|
| 127 |
-
# Initialize
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
config_path=config_path,
|
| 132 |
-
device=device,
|
| 133 |
-
use_flash_attention=use_flash_attention,
|
| 134 |
-
compile_model=False,
|
| 135 |
-
offload_to_cpu=auto_offload,
|
| 136 |
-
offload_dit_to_cpu=False
|
| 137 |
-
)
|
| 138 |
|
| 139 |
-
if
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
| 141 |
else:
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
|
| 161 |
-
init_status += f"\n{lm_status}"
|
| 162 |
|
| 163 |
# Prepare initialization parameters for UI
|
| 164 |
init_params = {
|
|
@@ -166,6 +242,7 @@ def main():
|
|
| 166 |
'service_mode': True,
|
| 167 |
'checkpoint': None,
|
| 168 |
'config_path': config_path,
|
|
|
|
| 169 |
'device': device,
|
| 170 |
'init_llm': True,
|
| 171 |
'lm_model_path': lm_model_path,
|
|
@@ -176,9 +253,12 @@ def main():
|
|
| 176 |
'init_status': init_status,
|
| 177 |
'enable_generate': enable_generate,
|
| 178 |
'dit_handler': dit_handler,
|
|
|
|
|
|
|
| 179 |
'llm_handler': llm_handler,
|
| 180 |
'language': 'en',
|
| 181 |
'persistent_storage_path': persistent_storage_path,
|
|
|
|
| 182 |
}
|
| 183 |
|
| 184 |
print("Service initialization completed!")
|
|
|
|
| 53 |
1. Must be enabled in Space settings
|
| 54 |
2. Path is typically /data for Docker SDK
|
| 55 |
3. Falls back to app directory if /data is not writable
|
| 56 |
+
|
| 57 |
+
Local development:
|
| 58 |
+
- Set CHECKPOINT_DIR environment variable to use local checkpoints
|
| 59 |
+
Example: CHECKPOINT_DIR=/path/to/checkpoints python app.py
|
| 60 |
+
The path should be the parent directory of 'checkpoints' folder
|
| 61 |
"""
|
| 62 |
+
# Check for local checkpoint directory override (for development)
|
| 63 |
+
checkpoint_dir_override = os.environ.get("CHECKPOINT_DIR")
|
| 64 |
+
if checkpoint_dir_override:
|
| 65 |
+
# If user specifies the checkpoints folder directly, use its parent
|
| 66 |
+
if checkpoint_dir_override.endswith("/checkpoints") or checkpoint_dir_override.endswith("\\checkpoints"):
|
| 67 |
+
checkpoint_dir_override = os.path.dirname(checkpoint_dir_override)
|
| 68 |
+
if os.path.exists(checkpoint_dir_override):
|
| 69 |
+
print(f"Using local checkpoint directory (CHECKPOINT_DIR): {checkpoint_dir_override}")
|
| 70 |
+
return checkpoint_dir_override
|
| 71 |
+
else:
|
| 72 |
+
print(f"Warning: CHECKPOINT_DIR path does not exist: {checkpoint_dir_override}")
|
| 73 |
+
|
| 74 |
# Try HuggingFace Space persistent storage first
|
| 75 |
hf_data_path = "/data"
|
| 76 |
|
|
|
|
| 97 |
def main():
|
| 98 |
"""Main entry point for HuggingFace Space"""
|
| 99 |
|
| 100 |
+
# Check for DEBUG_UI mode (skip model initialization for UI development)
|
| 101 |
+
debug_ui = os.environ.get("DEBUG_UI", "").lower() in ("1", "true", "yes")
|
| 102 |
+
if debug_ui:
|
| 103 |
+
print("=" * 60)
|
| 104 |
+
print("DEBUG_UI mode enabled - skipping model initialization")
|
| 105 |
+
print("UI will be fully functional but generation is disabled")
|
| 106 |
+
print("=" * 60)
|
| 107 |
+
|
| 108 |
# Get persistent storage path (auto-detect)
|
| 109 |
persistent_storage_path = get_persistent_storage_path()
|
| 110 |
|
|
|
|
| 112 |
gpu_memory_gb = get_gpu_memory_gb()
|
| 113 |
auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
|
| 114 |
|
| 115 |
+
if not debug_ui:
|
| 116 |
+
if auto_offload:
|
| 117 |
+
print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)")
|
| 118 |
+
print("Auto-enabling CPU offload to reduce GPU memory usage")
|
| 119 |
+
elif gpu_memory_gb > 0:
|
| 120 |
+
print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (>= 16GB)")
|
| 121 |
+
print("CPU offload disabled by default")
|
| 122 |
+
else:
|
| 123 |
+
print("No GPU detected, running on CPU")
|
| 124 |
|
| 125 |
# Create handler instances
|
| 126 |
print("Creating handlers...")
|
|
|
|
| 133 |
"SERVICE_MODE_DIT_MODEL",
|
| 134 |
"acestep-v15-turbo"
|
| 135 |
)
|
| 136 |
+
# Second DiT model - default to turbo-shift3 for two-model setup
|
| 137 |
+
config_path_2 = os.environ.get("SERVICE_MODE_DIT_MODEL_2", "acestep-v15-turbo-shift3").strip()
|
| 138 |
+
|
| 139 |
lm_model_path = os.environ.get(
|
| 140 |
"SERVICE_MODE_LM_MODEL",
|
| 141 |
"acestep-5Hz-lm-1.7B"
|
|
|
|
| 144 |
device = "auto"
|
| 145 |
|
| 146 |
print(f"Service mode configuration:")
|
| 147 |
+
print(f" DiT model 1: {config_path}")
|
| 148 |
+
if config_path_2:
|
| 149 |
+
print(f" DiT model 2: {config_path_2}")
|
| 150 |
print(f" LM model: {lm_model_path}")
|
| 151 |
print(f" Backend: {backend}")
|
| 152 |
print(f" Offload to CPU: {auto_offload}")
|
| 153 |
+
print(f" DEBUG_UI: {debug_ui}")
|
| 154 |
|
| 155 |
# Determine flash attention availability
|
| 156 |
use_flash_attention = dit_handler.is_flash_attention_available()
|
| 157 |
print(f" Flash Attention: {use_flash_attention}")
|
| 158 |
|
| 159 |
+
# Initialize models (skip in DEBUG_UI mode)
|
| 160 |
+
init_status = ""
|
| 161 |
+
enable_generate = False
|
| 162 |
+
dit_handler_2 = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
+
if debug_ui:
|
| 165 |
+
# In DEBUG_UI mode, skip all model initialization
|
| 166 |
+
init_status = "⚠️ DEBUG_UI mode - models not loaded\nUI is functional but generation is disabled"
|
| 167 |
+
enable_generate = False
|
| 168 |
+
print("Skipping model initialization (DEBUG_UI mode)")
|
| 169 |
else:
|
| 170 |
+
# Initialize primary DiT model
|
| 171 |
+
print(f"Initializing DiT model 1: {config_path}...")
|
| 172 |
+
init_status, enable_generate = dit_handler.initialize_service(
|
| 173 |
+
project_root=current_dir,
|
| 174 |
+
config_path=config_path,
|
| 175 |
+
device=device,
|
| 176 |
+
use_flash_attention=use_flash_attention,
|
| 177 |
+
compile_model=False,
|
| 178 |
+
offload_to_cpu=auto_offload,
|
| 179 |
+
offload_dit_to_cpu=False
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
if not enable_generate:
|
| 183 |
+
print(f"Warning: DiT model 1 initialization issue: {init_status}", file=sys.stderr)
|
| 184 |
+
else:
|
| 185 |
+
print("DiT model 1 initialized successfully")
|
| 186 |
+
|
| 187 |
+
# Initialize second DiT model if configured
|
| 188 |
+
if config_path_2:
|
| 189 |
+
print(f"Initializing DiT model 2: {config_path_2}...")
|
| 190 |
+
dit_handler_2 = AceStepHandler(persistent_storage_path=persistent_storage_path)
|
| 191 |
+
|
| 192 |
+
# Share VAE, text_encoder, and silence_latent from the first handler to save memory
|
| 193 |
+
init_status_2, enable_generate_2 = dit_handler_2.initialize_service(
|
| 194 |
+
project_root=current_dir,
|
| 195 |
+
config_path=config_path_2,
|
| 196 |
+
device=device,
|
| 197 |
+
use_flash_attention=use_flash_attention,
|
| 198 |
+
compile_model=False,
|
| 199 |
+
offload_to_cpu=auto_offload,
|
| 200 |
+
offload_dit_to_cpu=False,
|
| 201 |
+
# Share components from first handler
|
| 202 |
+
shared_vae=dit_handler.vae,
|
| 203 |
+
shared_text_encoder=dit_handler.text_encoder,
|
| 204 |
+
shared_text_tokenizer=dit_handler.text_tokenizer,
|
| 205 |
+
shared_silence_latent=dit_handler.silence_latent,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
if not enable_generate_2:
|
| 209 |
+
print(f"Warning: DiT model 2 initialization issue: {init_status_2}", file=sys.stderr)
|
| 210 |
+
init_status += f"\n⚠️ DiT model 2 failed: {init_status_2}"
|
| 211 |
+
else:
|
| 212 |
+
print("DiT model 2 initialized successfully")
|
| 213 |
+
init_status += f"\n✅ DiT model 2: {config_path_2}"
|
| 214 |
+
|
| 215 |
+
# Initialize LM model
|
| 216 |
+
checkpoint_dir = dit_handler._get_checkpoint_dir()
|
| 217 |
+
print(f"Initializing 5Hz LM: {lm_model_path}...")
|
| 218 |
+
lm_status, lm_success = llm_handler.initialize(
|
| 219 |
+
checkpoint_dir=checkpoint_dir,
|
| 220 |
+
lm_model_path=lm_model_path,
|
| 221 |
+
backend=backend,
|
| 222 |
+
device=device,
|
| 223 |
+
offload_to_cpu=auto_offload,
|
| 224 |
+
dtype=dit_handler.dtype
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if lm_success:
|
| 228 |
+
print("5Hz LM initialized successfully")
|
| 229 |
+
init_status += f"\n{lm_status}"
|
| 230 |
+
else:
|
| 231 |
+
print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
|
| 232 |
+
init_status += f"\n{lm_status}"
|
| 233 |
|
| 234 |
+
# Build available models list for UI
|
| 235 |
+
available_dit_models = [config_path]
|
| 236 |
+
if config_path_2 and dit_handler_2 is not None:
|
| 237 |
+
available_dit_models.append(config_path_2)
|
|
|
|
|
|
|
| 238 |
|
| 239 |
# Prepare initialization parameters for UI
|
| 240 |
init_params = {
|
|
|
|
| 242 |
'service_mode': True,
|
| 243 |
'checkpoint': None,
|
| 244 |
'config_path': config_path,
|
| 245 |
+
'config_path_2': config_path_2 if config_path_2 else None,
|
| 246 |
'device': device,
|
| 247 |
'init_llm': True,
|
| 248 |
'lm_model_path': lm_model_path,
|
|
|
|
| 253 |
'init_status': init_status,
|
| 254 |
'enable_generate': enable_generate,
|
| 255 |
'dit_handler': dit_handler,
|
| 256 |
+
'dit_handler_2': dit_handler_2,
|
| 257 |
+
'available_dit_models': available_dit_models,
|
| 258 |
'llm_handler': llm_handler,
|
| 259 |
'language': 'en',
|
| 260 |
'persistent_storage_path': persistent_storage_path,
|
| 261 |
+
'debug_ui': debug_ui,
|
| 262 |
}
|
| 263 |
|
| 264 |
print("Service initialization completed!")
|