ChuxiJ commited on
Commit
0b990cd
·
1 Parent(s): fa05c34
README.md CHANGED
@@ -16,7 +16,7 @@ short_description: Music Generation Foundation Model v1.5
16
  <a href="https://ace-step-v1.5.github.io">Project</a> |
17
  <a href="https://huggingface.co/collections/ACE-Step/ace-step-15">Hugging Face</a> |
18
  <a href="https://modelscope.cn/models/ACE-Step/ACE-Step-v1-5">ModelScope</a> |
19
- <a href="https://huggingface.co/spaces/ACE-Step/ACE-Step-1.5">Space Demo</a> |
20
  <a href="https://discord.gg/PeWDxrkdj7">Discord</a> |
21
  <a href="https://arxiv.org/abs/2506.00045">Technical Report</a>
22
  </p>
 
16
  <a href="https://ace-step-v1.5.github.io">Project</a> |
17
  <a href="https://huggingface.co/collections/ACE-Step/ace-step-15">Hugging Face</a> |
18
  <a href="https://modelscope.cn/models/ACE-Step/ACE-Step-v1-5">ModelScope</a> |
19
+ <a href="https://huggingface.co/spaces/ACE-Step/Ace-Step-v1.5">Space Demo</a> |
20
  <a href="https://discord.gg/PeWDxrkdj7">Discord</a> |
21
  <a href="https://arxiv.org/abs/2506.00045">Technical Report</a>
22
  </p>
acestep/gradio_ui/events/__init__.py CHANGED
@@ -12,8 +12,20 @@ from . import training_handlers as train_h
12
  from acestep.gradio_ui.i18n import t
13
 
14
 
15
- def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section):
16
- """Setup event handlers connecting UI components and business logic"""
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # ========== Dataset Handlers ==========
19
  dataset_section["import_dataset_btn"].click(
@@ -260,17 +272,42 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
260
  ]
261
  )
262
 
263
- # ========== Simple/Custom Mode Toggle ==========
264
  generation_section["generation_mode"].change(
265
  fn=gen_h.handle_generation_mode_change,
266
  inputs=[generation_section["generation_mode"]],
267
  outputs=[
268
  generation_section["simple_mode_group"],
269
- generation_section["caption_accordion"],
270
- generation_section["lyrics_accordion"],
 
 
271
  generation_section["generate_btn"],
272
  generation_section["simple_sample_created"],
273
- generation_section["optional_params_accordion"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  ]
275
  )
276
 
@@ -451,10 +488,28 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
451
  ],
452
  js=download_existing_js # Run the above JS
453
  )
454
- # ========== Send to SRC Handlers ==========
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  for btn_idx in range(1, 9):
456
- results_section[f"send_to_src_btn_{btn_idx}"].click(
457
- fn=res_h.send_audio_to_src_with_metadata,
458
  inputs=[
459
  results_section[f"generated_audio_{btn_idx}"],
460
  results_section["lm_metadata_state"]
@@ -468,7 +523,50 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
468
  generation_section["key_scale"],
469
  generation_section["vocal_language"],
470
  generation_section["time_signature"],
471
- results_section["is_format_caption_state"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  ]
473
  )
474
 
@@ -519,12 +617,84 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
519
  ]
520
  )
521
 
522
- def generation_wrapper(*args):
523
- yield from res_h.generate_with_batch_management(dit_handler, llm_handler, *args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  # ========== Generation Handler ==========
525
  generation_section["generate_btn"].click(
526
  fn=generation_wrapper,
527
  inputs=[
 
 
 
 
528
  generation_section["captions"],
529
  generation_section["lyrics"],
530
  generation_section["bpm"],
@@ -634,8 +804,12 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
634
  results_section["restore_params_btn"],
635
  ]
636
  ).then(
637
- fn=lambda *args: res_h.generate_next_batch_background(dit_handler, llm_handler, *args),
 
 
 
638
  inputs=[
 
639
  generation_section["autogen_checkbox"],
640
  results_section["generation_params_state"],
641
  results_section["current_batch_index"],
@@ -819,8 +993,12 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
819
  results_section["restore_params_btn"],
820
  ]
821
  ).then(
822
- fn=lambda *args: res_h.generate_next_batch_background(dit_handler, llm_handler, *args),
 
 
 
823
  inputs=[
 
824
  generation_section["autogen_checkbox"],
825
  results_section["generation_params_state"],
826
  results_section["current_batch_index"],
 
12
  from acestep.gradio_ui.i18n import t
13
 
14
 
15
+ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section, init_params=None):
16
+ """Setup event handlers connecting UI components and business logic
17
+
18
+ Args:
19
+ init_params: Dictionary containing initialization parameters including:
20
+ - dit_handler_2: Optional second DiT handler for multi-model setup
21
+ - available_dit_models: List of available DiT model names
22
+ - config_path: Primary model config path
23
+ - config_path_2: Secondary model config path (if available)
24
+ """
25
+ # Get secondary DiT handler from init_params (for multi-model support)
26
+ dit_handler_2 = init_params.get('dit_handler_2') if init_params else None
27
+ config_path_1 = init_params.get('config_path', '') if init_params else ''
28
+ config_path_2 = init_params.get('config_path_2', '') if init_params else ''
29
 
30
  # ========== Dataset Handlers ==========
31
  dataset_section["import_dataset_btn"].click(
 
272
  ]
273
  )
274
 
275
+ # ========== Generation Mode Toggle (Simple/Custom/Cover/Repaint) ==========
276
  generation_section["generation_mode"].change(
277
  fn=gen_h.handle_generation_mode_change,
278
  inputs=[generation_section["generation_mode"]],
279
  outputs=[
280
  generation_section["simple_mode_group"],
281
+ generation_section["custom_mode_content"],
282
+ generation_section["cover_mode_group"],
283
+ generation_section["repainting_group"],
284
+ generation_section["task_type"],
285
  generation_section["generate_btn"],
286
  generation_section["simple_sample_created"],
287
+ generation_section["src_audio_group"],
288
+ generation_section["audio_cover_strength"],
289
+ ]
290
+ )
291
+
292
+ # ========== Process Source Audio Button ==========
293
+ # Combines Convert to Codes + Transcribe in one step
294
+ generation_section["process_src_btn"].click(
295
+ fn=lambda src, debug: gen_h.process_source_audio(dit_handler, llm_handler, src, debug),
296
+ inputs=[
297
+ generation_section["src_audio"],
298
+ generation_section["constrained_decoding_debug"]
299
+ ],
300
+ outputs=[
301
+ generation_section["text2music_audio_code_string"],
302
+ results_section["status_output"],
303
+ generation_section["captions"],
304
+ generation_section["lyrics"],
305
+ generation_section["bpm"],
306
+ generation_section["audio_duration"],
307
+ generation_section["key_scale"],
308
+ generation_section["vocal_language"],
309
+ generation_section["time_signature"],
310
+ results_section["is_format_caption_state"],
311
  ]
312
  )
313
 
 
488
  ],
489
  js=download_existing_js # Run the above JS
490
  )
491
+ # ========== Send to Cover Handlers ==========
492
+ def send_to_cover_handler(audio_file, lm_metadata):
493
+ """Send audio to cover mode and switch to cover"""
494
+ if audio_file is None:
495
+ return (gr.skip(),) * 11
496
+ return (
497
+ audio_file, # src_audio
498
+ gr.skip(), # bpm
499
+ gr.skip(), # captions
500
+ gr.skip(), # lyrics
501
+ gr.skip(), # audio_duration
502
+ gr.skip(), # key_scale
503
+ gr.skip(), # vocal_language
504
+ gr.skip(), # time_signature
505
+ gr.skip(), # is_format_caption_state
506
+ "cover", # generation_mode - switch to cover
507
+ "cover", # task_type - set to cover
508
+ )
509
+
510
  for btn_idx in range(1, 9):
511
+ results_section[f"send_to_cover_btn_{btn_idx}"].click(
512
+ fn=send_to_cover_handler,
513
  inputs=[
514
  results_section[f"generated_audio_{btn_idx}"],
515
  results_section["lm_metadata_state"]
 
523
  generation_section["key_scale"],
524
  generation_section["vocal_language"],
525
  generation_section["time_signature"],
526
+ results_section["is_format_caption_state"],
527
+ generation_section["generation_mode"],
528
+ generation_section["task_type"],
529
+ ]
530
+ )
531
+
532
+ # ========== Send to Repaint Handlers ==========
533
+ def send_to_repaint_handler(audio_file, lm_metadata):
534
+ """Send audio to repaint mode and switch to repaint"""
535
+ if audio_file is None:
536
+ return (gr.skip(),) * 11
537
+ return (
538
+ audio_file, # src_audio
539
+ gr.skip(), # bpm
540
+ gr.skip(), # captions
541
+ gr.skip(), # lyrics
542
+ gr.skip(), # audio_duration
543
+ gr.skip(), # key_scale
544
+ gr.skip(), # vocal_language
545
+ gr.skip(), # time_signature
546
+ gr.skip(), # is_format_caption_state
547
+ "repaint", # generation_mode - switch to repaint
548
+ "repaint", # task_type - set to repaint
549
+ )
550
+
551
+ for btn_idx in range(1, 9):
552
+ results_section[f"send_to_repaint_btn_{btn_idx}"].click(
553
+ fn=send_to_repaint_handler,
554
+ inputs=[
555
+ results_section[f"generated_audio_{btn_idx}"],
556
+ results_section["lm_metadata_state"]
557
+ ],
558
+ outputs=[
559
+ generation_section["src_audio"],
560
+ generation_section["bpm"],
561
+ generation_section["captions"],
562
+ generation_section["lyrics"],
563
+ generation_section["audio_duration"],
564
+ generation_section["key_scale"],
565
+ generation_section["vocal_language"],
566
+ generation_section["time_signature"],
567
+ results_section["is_format_caption_state"],
568
+ generation_section["generation_mode"],
569
+ generation_section["task_type"],
570
  ]
571
  )
572
 
 
617
  ]
618
  )
619
 
620
+ def generation_wrapper(selected_model, generation_mode, simple_query_input, simple_vocal_language, *args):
621
+ """Wrapper that selects the appropriate DiT handler based on model selection"""
622
+ # Convert args to list for modification
623
+ args_list = list(args)
624
+
625
+ # args order (after simple mode params):
626
+ # captions (0), lyrics (1), bpm (2), key_scale (3), time_signature (4), vocal_language (5),
627
+ # inference_steps (6), guidance_scale (7), random_seed_checkbox (8), seed (9),
628
+ # reference_audio (10), audio_duration (11), batch_size_input (12), src_audio (13),
629
+ # text2music_audio_code_string (14), repainting_start (15), repainting_end (16),
630
+ # instruction_display_gen (17), audio_cover_strength (18), task_type (19), ...
631
+ # ... lm_temperature (27), think_checkbox (28), ...
632
+ # ... instrumental_checkbox (at position after all regular params)
633
+
634
+ src_audio = args_list[13] if len(args_list) > 13 else None
635
+ task_type = args_list[19] if len(args_list) > 19 else "text2music"
636
+
637
+ # Validate: Cover and Repaint modes require source audio
638
+ if task_type in ["cover", "repaint"] and src_audio is None:
639
+ raise gr.Error(f"Source Audio is required for {task_type.capitalize()} mode. Please upload an audio file.")
640
+
641
+ # Handle Simple mode: first create sample, then generate
642
+ if generation_mode == "simple":
643
+ # Get instrumental from the main checkbox (args[-6] based on input order)
644
+ # The instrumental_checkbox is passed after all the regular generation params
645
+ instrumental = args_list[-6] if len(args_list) > 6 else False # instrumental_checkbox position
646
+ lm_temperature = args_list[27] if len(args_list) > 27 else 0.85
647
+ lm_top_k = args_list[30] if len(args_list) > 30 else 0
648
+ lm_top_p = args_list[31] if len(args_list) > 31 else 0.9
649
+ constrained_decoding_debug = args_list[38] if len(args_list) > 38 else False
650
+
651
+ # Call create_sample to generate caption/lyrics/metadata
652
+ from acestep.inference import create_sample
653
+
654
+ top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
655
+ top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
656
+
657
+ result = create_sample(
658
+ llm_handler=llm_handler,
659
+ query=simple_query_input,
660
+ instrumental=instrumental,
661
+ vocal_language=simple_vocal_language,
662
+ temperature=lm_temperature,
663
+ top_k=top_k_value,
664
+ top_p=top_p_value,
665
+ use_constrained_decoding=True,
666
+ constrained_decoding_debug=constrained_decoding_debug,
667
+ )
668
+
669
+ if not result.success:
670
+ raise gr.Error(f"Failed to create sample: {result.status_message}")
671
+
672
+ # Update args with generated data
673
+ args_list[0] = result.caption # captions
674
+ args_list[1] = result.lyrics # lyrics
675
+ args_list[2] = result.bpm # bpm
676
+ args_list[3] = result.keyscale # key_scale
677
+ args_list[4] = result.timesignature # time_signature
678
+ args_list[5] = result.language # vocal_language
679
+ if result.duration and result.duration > 0:
680
+ args_list[11] = result.duration # audio_duration
681
+ # Enable thinking for Simple mode
682
+ args_list[28] = True # think_checkbox
683
+
684
+ # Determine which handler to use
685
+ active_handler = dit_handler # Default to primary handler
686
+ if dit_handler_2 is not None and selected_model == config_path_2:
687
+ active_handler = dit_handler_2
688
+ yield from res_h.generate_with_batch_management(active_handler, llm_handler, *args_list)
689
+
690
  # ========== Generation Handler ==========
691
  generation_section["generate_btn"].click(
692
  fn=generation_wrapper,
693
  inputs=[
694
+ generation_section["dit_model_selector"], # Model selection input
695
+ generation_section["generation_mode"], # For Simple mode detection
696
+ generation_section["simple_query_input"], # Simple mode query
697
+ generation_section["simple_vocal_language"], # Simple mode vocal language
698
  generation_section["captions"],
699
  generation_section["lyrics"],
700
  generation_section["bpm"],
 
804
  results_section["restore_params_btn"],
805
  ]
806
  ).then(
807
+ fn=lambda selected_model, *args: res_h.generate_next_batch_background(
808
+ dit_handler_2 if (dit_handler_2 is not None and selected_model == config_path_2) else dit_handler,
809
+ llm_handler, *args
810
+ ),
811
  inputs=[
812
+ generation_section["dit_model_selector"], # Model selection input
813
  generation_section["autogen_checkbox"],
814
  results_section["generation_params_state"],
815
  results_section["current_batch_index"],
 
993
  results_section["restore_params_btn"],
994
  ]
995
  ).then(
996
+ fn=lambda selected_model, *args: res_h.generate_next_batch_background(
997
+ dit_handler_2 if (dit_handler_2 is not None and selected_model == config_path_2) else dit_handler,
998
+ llm_handler, *args
999
+ ),
1000
  inputs=[
1001
+ generation_section["dit_model_selector"], # Model selection input
1002
  generation_section["autogen_checkbox"],
1003
  results_section["generation_params_state"],
1004
  results_section["current_batch_index"],
acestep/gradio_ui/events/generation_handlers.py CHANGED
@@ -480,10 +480,14 @@ def update_negative_prompt_visibility(init_llm_checked):
480
 
481
  def update_audio_cover_strength_visibility(task_type_value, init_llm_checked):
482
  """Update audio_cover_strength visibility and label"""
483
- # Show if task is cover OR if LM is initialized
484
- is_visible = (task_type_value == "cover") or init_llm_checked
 
 
 
 
485
  # Change label based on context
486
- if init_llm_checked and task_type_value != "cover":
487
  label = "LM codes strength"
488
  info = "Control how many denoising steps use LM-generated codes"
489
  else:
@@ -518,10 +522,12 @@ def update_instruction_ui(
518
  track_name_visible = task_type_value in ["lego", "extract"]
519
  # Show complete_track_classes for complete
520
  complete_visible = task_type_value == "complete"
521
- # Show audio_cover_strength for cover OR when LM is initialized
522
- audio_cover_strength_visible = (task_type_value == "cover") or init_llm_checked
 
 
523
  # Determine label and info based on context
524
- if init_llm_checked and task_type_value != "cover":
525
  audio_cover_strength_label = "LM codes strength"
526
  audio_cover_strength_info = "Control how many denoising steps use LM-generated codes"
527
  else:
@@ -605,9 +611,9 @@ def reset_format_caption_flag():
605
 
606
 
607
  def update_audio_uploads_accordion(reference_audio, src_audio):
608
- """Update Audio Uploads accordion open state based on whether audio files are present"""
609
  has_audio = (reference_audio is not None) or (src_audio is not None)
610
- return gr.Accordion(open=has_audio)
611
 
612
 
613
  def handle_instrumental_checkbox(instrumental_checked, current_lyrics):
@@ -682,41 +688,106 @@ def update_audio_components_visibility(batch_size):
682
 
683
  def handle_generation_mode_change(mode: str):
684
  """
685
- Handle generation mode change between Simple and Custom modes.
686
-
687
- In Simple mode:
688
- - Show simple mode group (query input, instrumental checkbox, create button)
689
- - Collapse caption and lyrics accordions
690
- - Hide optional parameters accordion
691
- - Disable generate button until sample is created
692
 
693
- In Custom mode:
694
- - Hide simple mode group
695
- - Expand caption and lyrics accordions
696
- - Show optional parameters accordion
697
- - Enable generate button
698
 
699
  Args:
700
- mode: "simple" or "custom"
701
 
702
  Returns:
703
  Tuple of updates for:
704
  - simple_mode_group (visibility)
705
- - caption_accordion (open state)
706
- - lyrics_accordion (open state)
 
 
707
  - generate_btn (interactive state)
708
  - simple_sample_created (reset state)
709
- - optional_params_accordion (visibility)
 
710
  """
711
  is_simple = mode == "simple"
 
 
 
 
 
 
 
 
 
 
 
 
712
 
713
  return (
714
  gr.update(visible=is_simple), # simple_mode_group
715
- gr.Accordion(open=not is_simple), # caption_accordion - collapsed in simple, open in custom
716
- gr.Accordion(open=not is_simple), # lyrics_accordion - collapsed in simple, open in custom
717
- gr.update(interactive=not is_simple), # generate_btn - disabled in simple until sample created
 
 
718
  False, # simple_sample_created - reset to False on mode change
719
- gr.Accordion(open=not is_simple), # optional_params_accordion - hidden in simple mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
720
  )
721
 
722
 
 
480
 
481
  def update_audio_cover_strength_visibility(task_type_value, init_llm_checked):
482
  """Update audio_cover_strength visibility and label"""
483
+ # Show if task is cover OR if LM is initialized (but NOT for repaint mode)
484
+ # Repaint mode never shows this control
485
+ is_repaint = task_type_value == "repaint"
486
+ is_cover = task_type_value == "cover"
487
+ is_visible = is_cover or (init_llm_checked and not is_repaint)
488
+
489
  # Change label based on context
490
+ if init_llm_checked and not is_cover:
491
  label = "LM codes strength"
492
  info = "Control how many denoising steps use LM-generated codes"
493
  else:
 
522
  track_name_visible = task_type_value in ["lego", "extract"]
523
  # Show complete_track_classes for complete
524
  complete_visible = task_type_value == "complete"
525
+ # Show audio_cover_strength for cover OR when LM is initialized (but NOT for repaint)
526
+ is_repaint = task_type_value == "repaint"
527
+ is_cover = task_type_value == "cover"
528
+ audio_cover_strength_visible = is_cover or (init_llm_checked and not is_repaint)
529
  # Determine label and info based on context
530
+ if init_llm_checked and not is_cover:
531
  audio_cover_strength_label = "LM codes strength"
532
  audio_cover_strength_info = "Control how many denoising steps use LM-generated codes"
533
  else:
 
611
 
612
 
613
  def update_audio_uploads_accordion(reference_audio, src_audio):
614
+ """Update Audio Uploads visibility based on whether audio files are present"""
615
  has_audio = (reference_audio is not None) or (src_audio is not None)
616
+ return gr.update(visible=has_audio)
617
 
618
 
619
  def handle_instrumental_checkbox(instrumental_checked, current_lyrics):
 
688
 
689
  def handle_generation_mode_change(mode: str):
690
  """
691
+ Handle generation mode change between Simple, Custom, Cover, and Repaint modes.
 
 
 
 
 
 
692
 
693
+ Modes:
694
+ - Simple: Show simple mode group, hide others
695
+ - Custom: Show custom content (prompt), hide others
696
+ - Cover: Show src_audio_group + custom content + LM codes strength
697
+ - Repaint: Show src_audio_group + custom content + repaint time controls (hide LM codes strength)
698
 
699
  Args:
700
+ mode: "simple", "custom", "cover", or "repaint"
701
 
702
  Returns:
703
  Tuple of updates for:
704
  - simple_mode_group (visibility)
705
+ - custom_mode_content (visibility)
706
+ - cover_mode_group (visibility) - legacy, always hidden
707
+ - repainting_group (visibility)
708
+ - task_type (value)
709
  - generate_btn (interactive state)
710
  - simple_sample_created (reset state)
711
+ - src_audio_group (visibility) - shown for cover and repaint
712
+ - audio_cover_strength (visibility) - shown only for cover mode
713
  """
714
  is_simple = mode == "simple"
715
+ is_custom = mode == "custom"
716
+ is_cover = mode == "cover"
717
+ is_repaint = mode == "repaint"
718
+
719
+ # Map mode to task_type
720
+ task_type_map = {
721
+ "simple": "text2music",
722
+ "custom": "text2music",
723
+ "cover": "cover",
724
+ "repaint": "repaint",
725
+ }
726
+ task_type_value = task_type_map.get(mode, "text2music")
727
 
728
  return (
729
  gr.update(visible=is_simple), # simple_mode_group
730
+ gr.update(visible=not is_simple), # custom_mode_content - visible for custom/cover/repaint
731
+ gr.update(visible=False), # cover_mode_group - legacy, always hidden
732
+ gr.update(visible=is_repaint), # repainting_group - time range controls
733
+ gr.update(value=task_type_value), # task_type
734
+ gr.update(interactive=True), # generate_btn - always enabled (Simple mode does create+generate in one step)
735
  False, # simple_sample_created - reset to False on mode change
736
+ gr.update(visible=is_cover or is_repaint), # src_audio_group - shown for cover and repaint
737
+ gr.update(visible=is_cover), # audio_cover_strength - only shown for cover mode
738
+ )
739
+
740
+
741
+ def process_source_audio(dit_handler, llm_handler, src_audio, constrained_decoding_debug):
742
+ """
743
+ Process source audio: convert to codes and then transcribe.
744
+ This combines convert_src_audio_to_codes_wrapper + transcribe_audio_codes.
745
+
746
+ Args:
747
+ dit_handler: DiT handler instance for audio code conversion
748
+ llm_handler: LLM handler instance for transcription
749
+ src_audio: Path to source audio file
750
+ constrained_decoding_debug: Whether to enable debug logging
751
+
752
+ Returns:
753
+ Tuple of (audio_codes, status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature, is_format_caption)
754
+ """
755
+ if src_audio is None:
756
+ return ("", "No audio file provided", "", "", None, None, "", "", "", False)
757
+
758
+ # Step 1: Convert audio to codes
759
+ try:
760
+ codes_string = dit_handler.convert_src_audio_to_codes(src_audio)
761
+ if not codes_string:
762
+ return ("", "Failed to convert audio to codes", "", "", None, None, "", "", "", False)
763
+ except Exception as e:
764
+ return ("", f"Error converting audio: {str(e)}", "", "", None, None, "", "", "", False)
765
+
766
+ # Step 2: Transcribe the codes
767
+ result = understand_music(
768
+ llm_handler=llm_handler,
769
+ audio_codes=codes_string,
770
+ use_constrained_decoding=True,
771
+ constrained_decoding_debug=constrained_decoding_debug,
772
+ )
773
+
774
+ # Handle error case
775
+ if not result.success:
776
+ if result.error == "LLM not initialized":
777
+ return (codes_string, t("messages.lm_not_initialized"), "", "", None, None, "", "", "", False)
778
+ return (codes_string, result.status_message, "", "", None, None, "", "", "", False)
779
+
780
+ return (
781
+ codes_string,
782
+ result.status_message,
783
+ result.caption,
784
+ result.lyrics,
785
+ result.bpm,
786
+ result.duration,
787
+ result.keyscale,
788
+ result.language,
789
+ result.timesignature,
790
+ True # Set is_format_caption to True
791
  )
792
 
793
 
acestep/gradio_ui/events/results_handlers.py CHANGED
@@ -265,106 +265,45 @@ def _build_generation_info(
265
  Formatted generation info string
266
  """
267
  info_parts = []
 
268
 
269
- # Part 1: Per-track average time (prominently displayed at the top)
270
- # Only count model time (LM + DiT), not post-processing like audio conversion
271
- if time_costs and num_audios > 0:
272
- lm_total = time_costs.get('lm_total_time', 0.0)
273
- dit_total = time_costs.get('dit_total_time_cost', 0.0)
274
- model_total = lm_total + dit_total
275
- if model_total > 0:
276
- avg_time_per_track = model_total / num_audios
277
- avg_section = f"**🎯 Average Time per Track: {avg_time_per_track:.2f}s** ({num_audios} track(s))"
278
- info_parts.append(avg_section)
279
-
280
- # Part 2: LM-generated metadata (if available)
281
- if lm_metadata:
282
- metadata_lines = []
283
- if lm_metadata.get('bpm'):
284
- metadata_lines.append(f"- **BPM:** {lm_metadata['bpm']}")
285
- if lm_metadata.get('caption'):
286
- metadata_lines.append(f"- **Refined Caption:** {lm_metadata['caption']}")
287
- if lm_metadata.get('lyrics'):
288
- metadata_lines.append(f"- **Refined Lyrics:** {lm_metadata['lyrics']}")
289
- if lm_metadata.get('duration'):
290
- metadata_lines.append(f"- **Duration:** {lm_metadata['duration']} seconds")
291
- if lm_metadata.get('keyscale'):
292
- metadata_lines.append(f"- **Key Scale:** {lm_metadata['keyscale']}")
293
- if lm_metadata.get('language'):
294
- metadata_lines.append(f"- **Language:** {lm_metadata['language']}")
295
- if lm_metadata.get('timesignature'):
296
- metadata_lines.append(f"- **Time Signature:** {lm_metadata['timesignature']}")
297
-
298
- if metadata_lines:
299
- metadata_section = "**🤖 LM-Generated Metadata:**\n" + "\n".join(metadata_lines)
300
- info_parts.append(metadata_section)
301
-
302
- # Part 3: Time costs breakdown (formatted and beautified)
303
  if time_costs:
304
- time_lines = []
305
-
306
- # LM time costs
307
- lm_phase1 = time_costs.get('lm_phase1_time', 0.0)
308
- lm_phase2 = time_costs.get('lm_phase2_time', 0.0)
309
  lm_total = time_costs.get('lm_total_time', 0.0)
310
-
311
- if lm_total > 0:
312
- time_lines.append("**🧠 LM Time:**")
313
- if lm_phase1 > 0:
314
- time_lines.append(f" - Phase 1 (CoT): {lm_phase1:.2f}s")
315
- if lm_phase2 > 0:
316
- time_lines.append(f" - Phase 2 (Codes): {lm_phase2:.2f}s")
317
- time_lines.append(f" - Total: {lm_total:.2f}s")
318
-
319
- # DiT time costs
320
- dit_encoder = time_costs.get('dit_encoder_time_cost', 0.0)
321
- dit_model = time_costs.get('dit_model_time_cost', 0.0)
322
- dit_vae_decode = time_costs.get('dit_vae_decode_time_cost', 0.0)
323
- dit_offload = time_costs.get('dit_offload_time_cost', 0.0)
324
  dit_total = time_costs.get('dit_total_time_cost', 0.0)
325
- if dit_total > 0:
326
- time_lines.append("\n**🎵 DiT Time:**")
327
- if dit_encoder > 0:
328
- time_lines.append(f" - Encoder: {dit_encoder:.2f}s")
329
- if dit_model > 0:
330
- time_lines.append(f" - Model: {dit_model:.2f}s")
331
- if dit_vae_decode > 0:
332
- time_lines.append(f" - VAE Decode: {dit_vae_decode:.2f}s")
333
- if dit_offload > 0:
334
- time_lines.append(f" - Offload: {dit_offload:.2f}s")
335
- time_lines.append(f" - Total: {dit_total:.2f}s")
336
 
337
- # Post-processing time costs
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  audio_conversion_time = time_costs.get('audio_conversion_time', 0.0)
339
  auto_score_time = time_costs.get('auto_score_time', 0.0)
340
  auto_lrc_time = time_costs.get('auto_lrc_time', 0.0)
 
341
 
342
- if audio_conversion_time > 0 or auto_score_time > 0 or auto_lrc_time > 0:
343
- time_lines.append("\n**🔧 Post-processing Time:**")
 
 
344
  if audio_conversion_time > 0:
345
- time_lines.append(f" - Audio Conversion: {audio_conversion_time:.2f}s")
 
346
  if auto_score_time > 0:
347
- time_lines.append(f" - Auto Score: {auto_score_time:.2f}s")
348
  if auto_lrc_time > 0:
349
- time_lines.append(f" - Auto LRC: {auto_lrc_time:.2f}s")
350
-
351
- if time_lines:
352
- time_section = "\n".join(time_lines)
353
- info_parts.append(time_section)
354
-
355
- # Part 4: Generation summary
356
- summary_lines = [
357
- "**🎵 Generation Complete**",
358
- f" - **Seeds:** {seed_value}",
359
- f" - **Steps:** {inference_steps}",
360
- f" - **Audio Count:** {num_audios} audio(s)",
361
- ]
362
- info_parts.append("\n".join(summary_lines))
363
-
364
- # Part 5: Pipeline total time (at the end)
365
- pipeline_total = time_costs.get('pipeline_total_time', 0.0) if time_costs else 0.0
366
- if pipeline_total > 0:
367
- info_parts.append(f"**⏱️ Total Time: {pipeline_total:.2f}s**")
368
 
369
  # Combine all parts
370
  return "\n\n".join(info_parts)
@@ -775,7 +714,9 @@ def generate_with_progress(
775
  codes_display_updates[i] = gr.update(value=code_str, visible=True) # Keep visible=True
776
 
777
  details_accordion_updates = [gr.skip() for _ in range(8)]
778
- # Don't change accordion visibility - keep it always expandable
 
 
779
 
780
  # Clear LRC first (this triggers .change() to clear subtitles)
781
  # Keep visible=True to ensure .change() event is properly triggered
 
265
  Formatted generation info string
266
  """
267
  info_parts = []
268
+ songs_label = f"({num_audios} songs)"
269
 
270
+ # Part 1: Total generation time (LM + DiT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  if time_costs:
 
 
 
 
 
272
  lm_total = time_costs.get('lm_total_time', 0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  dit_total = time_costs.get('dit_total_time_cost', 0.0)
274
+ generation_total = lm_total + dit_total
 
 
 
 
 
 
 
 
 
 
275
 
276
+ if generation_total > 0:
277
+ avg_per_song = generation_total / num_audios if num_audios > 0 else 0
278
+ gen_lines = [
279
+ f"**🎵 Total generation time {songs_label}: {generation_total:.2f}s**",
280
+ f"**{avg_per_song:.2f}s per song**",
281
+ ]
282
+ if lm_total > 0:
283
+ gen_lines.append(f"- LM phase {songs_label}: {lm_total:.2f}s")
284
+ if dit_total > 0:
285
+ gen_lines.append(f"- DiT phase {songs_label}: {dit_total:.2f}s")
286
+ info_parts.append("\n".join(gen_lines))
287
+
288
+ # Part 2: Total processing time (post-processing)
289
+ if time_costs:
290
  audio_conversion_time = time_costs.get('audio_conversion_time', 0.0)
291
  auto_score_time = time_costs.get('auto_score_time', 0.0)
292
  auto_lrc_time = time_costs.get('auto_lrc_time', 0.0)
293
+ processing_total = audio_conversion_time + auto_score_time + auto_lrc_time
294
 
295
+ if processing_total > 0:
296
+ proc_lines = [
297
+ f"**🔧 Total processing time {songs_label}: {processing_total:.2f}s**",
298
+ ]
299
  if audio_conversion_time > 0:
300
+ info_format = time_costs.get('audio_format', 'mp3')
301
+ proc_lines.append(f"- to {info_format} {songs_label}: {audio_conversion_time:.2f}s")
302
  if auto_score_time > 0:
303
+ proc_lines.append(f"- scoring {songs_label}: {auto_score_time:.2f}s")
304
  if auto_lrc_time > 0:
305
+ proc_lines.append(f"- LRC detection {songs_label}: {auto_lrc_time:.2f}s")
306
+ info_parts.append("\n".join(proc_lines))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
  # Combine all parts
309
  return "\n\n".join(info_parts)
 
714
  codes_display_updates[i] = gr.update(value=code_str, visible=True) # Keep visible=True
715
 
716
  details_accordion_updates = [gr.skip() for _ in range(8)]
717
+ # Auto-expand accordion if auto_score or auto_lrc is enabled
718
+ if auto_score or auto_lrc:
719
+ details_accordion_updates[i] = gr.Accordion(open=True)
720
 
721
  # Clear LRC first (this triggers .change() to clear subtitles)
722
  # Keep visible=True to ensure .change() event is properly triggered
acestep/gradio_ui/i18n/en.json CHANGED
@@ -174,6 +174,8 @@
174
  "title": "🎵 Results",
175
  "generated_music": "🎵 Generated Music (Sample {n})",
176
  "send_to_src_btn": "🔗 Send To Src Audio",
 
 
177
  "save_btn": "💾 Save",
178
  "score_btn": "📊 Score",
179
  "lrc_btn": "🎵 LRC",
 
174
  "title": "🎵 Results",
175
  "generated_music": "🎵 Generated Music (Sample {n})",
176
  "send_to_src_btn": "🔗 Send To Src Audio",
177
+ "send_to_cover_btn": "🔗 Send To Cover",
178
+ "send_to_repaint_btn": "🔗 Send To Repaint",
179
  "save_btn": "💾 Save",
180
  "score_btn": "📊 Score",
181
  "lrc_btn": "🎵 LRC",
acestep/gradio_ui/i18n/ja.json CHANGED
@@ -174,6 +174,8 @@
174
  "title": "🎵 結果",
175
  "generated_music": "🎵 生成された音楽(サンプル {n})",
176
  "send_to_src_btn": "🔗 ソースオーディオに送信",
 
 
177
  "save_btn": "💾 保存",
178
  "score_btn": "📊 スコア",
179
  "lrc_btn": "🎵 LRC",
 
174
  "title": "🎵 結果",
175
  "generated_music": "🎵 生成された音楽(サンプル {n})",
176
  "send_to_src_btn": "🔗 ソースオーディオに送信",
177
+ "send_to_cover_btn": "🔗 Send To Cover",
178
+ "send_to_repaint_btn": "🔗 Send To Repaint",
179
  "save_btn": "💾 保存",
180
  "score_btn": "📊 スコア",
181
  "lrc_btn": "🎵 LRC",
acestep/gradio_ui/i18n/zh.json CHANGED
@@ -174,6 +174,8 @@
174
  "title": "🎵 结果",
175
  "generated_music": "🎵 生成的音乐(样本 {n})",
176
  "send_to_src_btn": "🔗 发送到源音频",
 
 
177
  "save_btn": "💾 保存",
178
  "score_btn": "📊 评分",
179
  "lrc_btn": "🎵 LRC",
 
174
  "title": "🎵 结果",
175
  "generated_music": "🎵 生成的音乐(样本 {n})",
176
  "send_to_src_btn": "🔗 发送到源音频",
177
+ "send_to_cover_btn": "🔗 Send To Cover",
178
+ "send_to_repaint_btn": "🔗 Send To Repaint",
179
  "save_btn": "💾 保存",
180
  "score_btn": "📊 评分",
181
  "lrc_btn": "🎵 LRC",
acestep/gradio_ui/interfaces/__init__.py CHANGED
@@ -65,6 +65,14 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_para
65
  <div class="main-header">
66
  <h1>{t("app.title")}</h1>
67
  <p>{t("app.subtitle")}</p>
 
 
 
 
 
 
 
 
68
  </div>
69
  """)
70
 
@@ -81,8 +89,8 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_para
81
  # Pass init_params to support hiding in service mode
82
  training_section = create_training_section(dit_handler, llm_handler, init_params=init_params)
83
 
84
- # Connect event handlers
85
- setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section)
86
 
87
  # Connect training event handlers
88
  setup_training_event_handlers(demo, dit_handler, llm_handler, training_section)
 
65
  <div class="main-header">
66
  <h1>{t("app.title")}</h1>
67
  <p>{t("app.subtitle")}</p>
68
+ <p style="margin-top: 0.5rem;">
69
+ <a href="https://ace-step-v1.5.github.io" target="_blank">Project</a> |
70
+ <a href="https://huggingface.co/collections/ACE-Step/ace-step-15" target="_blank">Hugging Face</a> |
71
+ <a href="https://modelscope.cn/models/ACE-Step/ACE-Step-v1-5" target="_blank">ModelScope</a> |
72
+ <a href="https://huggingface.co/spaces/ACE-Step/Ace-Step-v1.5" target="_blank">Space Demo</a> |
73
+ <a href="https://discord.gg/PeWDxrkdj7" target="_blank">Discord</a> |
74
+ <a href="https://arxiv.org/abs/2506.00045" target="_blank">Technical Report</a>
75
+ </p>
76
  </div>
77
  """)
78
 
 
89
  # Pass init_params to support hiding in service mode
90
  training_section = create_training_section(dit_handler, llm_handler, init_params=init_params)
91
 
92
+ # Connect event handlers (pass init_params for multi-model support)
93
+ setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section, init_params=init_params)
94
 
95
  # Connect training event handlers
96
  setup_training_event_handlers(demo, dit_handler, llm_handler, training_section)
acestep/gradio_ui/interfaces/generation.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  Gradio UI Generation Section Module
3
- Contains generation section component definitions
4
  """
5
  import gradio as gr
6
  from acestep.constants import (
@@ -14,7 +14,7 @@ from acestep.gradio_ui.i18n import t
14
 
15
 
16
  def create_generation_section(dit_handler, llm_handler, init_params=None, language='en') -> dict:
17
- """Create generation section
18
 
19
  Args:
20
  dit_handler: DiT handler instance
@@ -32,10 +32,15 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
32
  # Get current language from init_params if available
33
  current_language = init_params.get('language', language) if init_params else language
34
 
 
 
 
 
 
35
  with gr.Group():
36
- # Service Configuration - collapse if pre-initialized, hide if in service mode
37
  accordion_open = not service_pre_initialized
38
- accordion_visible = not service_pre_initialized # Hide when running in service mode
39
  with gr.Accordion(t("service.title"), open=accordion_open, visible=accordion_visible) as service_config_accordion:
40
  # Language selector at the top
41
  with gr.Row():
@@ -51,10 +56,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
51
  scale=1,
52
  )
53
 
54
- # Dropdown options section - all dropdowns grouped together
55
  with gr.Row(equal_height=True):
56
  with gr.Column(scale=4):
57
- # Set checkpoint value from init_params if pre-initialized
58
  checkpoint_value = init_params.get('checkpoint') if service_pre_initialized else None
59
  checkpoint_dropdown = gr.Dropdown(
60
  label=t("service.checkpoint_label"),
@@ -66,11 +69,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
66
  refresh_btn = gr.Button(t("service.refresh_btn"), size="sm")
67
 
68
  with gr.Row():
69
- # Get available acestep-v15- model list
70
  available_models = dit_handler.get_available_acestep_v15_models()
71
  default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
72
-
73
- # Set config_path value from init_params if pre-initialized
74
  config_path_value = init_params.get('config_path', default_model) if service_pre_initialized else default_model
75
  config_path = gr.Dropdown(
76
  label=t("service.model_path_label"),
@@ -78,7 +78,6 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
78
  value=config_path_value,
79
  info=t("service.model_path_info")
80
  )
81
- # Set device value from init_params if pre-initialized
82
  device_value = init_params.get('device', 'auto') if service_pre_initialized else 'auto'
83
  device = gr.Dropdown(
84
  choices=["auto", "cuda", "cpu"],
@@ -88,11 +87,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
88
  )
89
 
90
  with gr.Row():
91
- # Get available 5Hz LM model list
92
  available_lm_models = llm_handler.get_available_5hz_lm_models()
93
  default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
94
-
95
- # Set lm_model_path value from init_params if pre-initialized
96
  lm_model_path_value = init_params.get('lm_model_path', default_lm_model) if service_pre_initialized else default_lm_model
97
  lm_model_path = gr.Dropdown(
98
  label=t("service.lm_model_path_label"),
@@ -100,7 +96,6 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
100
  value=lm_model_path_value,
101
  info=t("service.lm_model_path_info")
102
  )
103
- # Set backend value from init_params if pre-initialized
104
  backend_value = init_params.get('backend', 'vllm') if service_pre_initialized else 'vllm'
105
  backend_dropdown = gr.Dropdown(
106
  choices=["vllm", "pt"],
@@ -109,18 +104,14 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
109
  info=t("service.backend_info")
110
  )
111
 
112
- # Checkbox options section - all checkboxes grouped together
113
  with gr.Row():
114
- # Set init_llm value from init_params if pre-initialized
115
  init_llm_value = init_params.get('init_llm', True) if service_pre_initialized else True
116
  init_llm_checkbox = gr.Checkbox(
117
  label=t("service.init_llm_label"),
118
  value=init_llm_value,
119
  info=t("service.init_llm_info"),
120
  )
121
- # Auto-detect flash attention availability
122
  flash_attn_available = dit_handler.is_flash_attention_available()
123
- # Set use_flash_attention value from init_params if pre-initialized
124
  use_flash_attention_value = init_params.get('use_flash_attention', flash_attn_available) if service_pre_initialized else flash_attn_available
125
  use_flash_attention_checkbox = gr.Checkbox(
126
  label=t("service.flash_attention_label"),
@@ -128,14 +119,12 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
128
  interactive=flash_attn_available,
129
  info=t("service.flash_attention_info_enabled") if flash_attn_available else t("service.flash_attention_info_disabled")
130
  )
131
- # Set offload_to_cpu value from init_params if pre-initialized
132
  offload_to_cpu_value = init_params.get('offload_to_cpu', False) if service_pre_initialized else False
133
  offload_to_cpu_checkbox = gr.Checkbox(
134
  label=t("service.offload_cpu_label"),
135
  value=offload_to_cpu_value,
136
  info=t("service.offload_cpu_info")
137
  )
138
- # Set offload_dit_to_cpu value from init_params if pre-initialized
139
  offload_dit_to_cpu_value = init_params.get('offload_dit_to_cpu', False) if service_pre_initialized else False
140
  offload_dit_to_cpu_checkbox = gr.Checkbox(
141
  label=t("service.offload_dit_cpu_label"),
@@ -144,7 +133,6 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
144
  )
145
 
146
  init_btn = gr.Button(t("service.init_btn"), variant="primary", size="lg")
147
- # Set init_status value from init_params if pre-initialized
148
  init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
149
  init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value)
150
 
@@ -173,505 +161,436 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
173
  scale=2,
174
  )
175
 
176
- # Inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  with gr.Row():
178
- with gr.Column(scale=2):
179
- with gr.Accordion(t("generation.required_inputs"), open=True):
180
- # Task type
181
- # Determine initial task_type choices based on actual model in use
182
- # When service is pre-initialized, use config_path from init_params
183
- actual_model = init_params.get('config_path', default_model) if service_pre_initialized else default_model
184
- actual_model_lower = (actual_model or "").lower()
185
- if "turbo" in actual_model_lower:
186
- initial_task_choices = TASK_TYPES_TURBO
187
- else:
188
- initial_task_choices = TASK_TYPES_BASE
189
-
190
- with gr.Row(equal_height=True):
191
- with gr.Column(scale=2):
192
- task_type = gr.Dropdown(
193
- choices=initial_task_choices,
194
- value="text2music",
195
- label=t("generation.task_type_label"),
196
- info=t("generation.task_type_info"),
197
- )
198
- with gr.Column(scale=7):
199
- instruction_display_gen = gr.Textbox(
200
- label=t("generation.instruction_label"),
201
- value=DEFAULT_DIT_INSTRUCTION,
202
- interactive=False,
203
- lines=1,
204
- info=t("generation.instruction_info"),
205
- )
206
- with gr.Column(scale=1, min_width=100):
207
- load_file = gr.UploadButton(
208
- t("generation.load_btn"),
209
- file_types=[".json"],
210
- file_count="single",
211
- variant="secondary",
212
- size="sm",
213
- )
214
-
215
- track_name = gr.Dropdown(
216
- choices=TRACK_NAMES,
217
- value=None,
218
- label=t("generation.track_name_label"),
219
- info=t("generation.track_name_info"),
220
- visible=False
221
  )
222
-
223
- complete_track_classes = gr.CheckboxGroup(
224
- choices=TRACK_NAMES,
225
- label=t("generation.track_classes_label"),
226
- info=t("generation.track_classes_info"),
227
- visible=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  )
229
-
230
- # Audio uploads
231
- audio_uploads_accordion = gr.Accordion(t("generation.audio_uploads"), open=False)
232
- with audio_uploads_accordion:
233
- with gr.Row(equal_height=True):
234
- with gr.Column(scale=2):
235
- reference_audio = gr.Audio(
236
- label=t("generation.reference_audio"),
237
- type="filepath",
238
- )
239
- with gr.Column(scale=7):
240
- src_audio = gr.Audio(
241
- label=t("generation.source_audio"),
242
- type="filepath",
243
- )
244
- with gr.Column(scale=1, min_width=80):
245
- convert_src_to_codes_btn = gr.Button(
246
- t("generation.convert_codes_btn"),
247
- variant="secondary",
248
- size="sm"
249
- )
250
-
251
- # Audio Codes for text2music - single input for transcription or cover task
252
- with gr.Accordion(t("generation.lm_codes_hints"), open=False, visible=True) as text2music_audio_codes_group:
253
- with gr.Row(equal_height=True):
254
- text2music_audio_code_string = gr.Textbox(
255
- label=t("generation.lm_codes_label"),
256
- placeholder=t("generation.lm_codes_placeholder"),
257
- lines=6,
258
- info=t("generation.lm_codes_info"),
259
- scale=9,
260
- )
261
- transcribe_btn = gr.Button(
262
- t("generation.transcribe_btn"),
263
- variant="secondary",
264
- size="sm",
265
- scale=1,
266
- )
267
-
268
- # Repainting controls
269
- with gr.Group(visible=False) as repainting_group:
270
- gr.HTML(f"<h5>{t('generation.repainting_controls')}</h5>")
271
- with gr.Row():
272
- repainting_start = gr.Number(
273
- label=t("generation.repainting_start"),
274
- value=0.0,
275
- step=0.1,
276
- )
277
- repainting_end = gr.Number(
278
- label=t("generation.repainting_end"),
279
- value=-1,
280
- minimum=-1,
281
- step=0.1,
282
- )
283
-
284
- # Simple/Custom Mode Toggle
285
- # In service mode: only Custom mode, hide the toggle
286
- with gr.Row(visible=not service_mode):
287
- generation_mode = gr.Radio(
288
- choices=[
289
- (t("generation.mode_simple"), "simple"),
290
- (t("generation.mode_custom"), "custom"),
291
- ],
292
- value="custom" if service_mode else "simple",
293
- label=t("generation.mode_label"),
294
- info=t("generation.mode_info"),
295
- )
296
-
297
- # Simple Mode Components - hidden in service mode
298
- with gr.Group(visible=not service_mode) as simple_mode_group:
299
- with gr.Row(equal_height=True):
300
- simple_query_input = gr.Textbox(
301
- label=t("generation.simple_query_label"),
302
- placeholder=t("generation.simple_query_placeholder"),
303
- lines=2,
304
- info=t("generation.simple_query_info"),
305
- scale=12,
306
- )
307
-
308
- with gr.Column(scale=1, min_width=100):
309
- random_desc_btn = gr.Button(
310
- "🎲",
311
- variant="secondary",
312
- size="sm",
313
- scale=2
314
- )
315
-
316
- with gr.Row(equal_height=True):
317
- with gr.Column(scale=1, variant="compact"):
318
- simple_instrumental_checkbox = gr.Checkbox(
319
- label=t("generation.instrumental_label"),
320
- value=False,
321
- )
322
- with gr.Column(scale=18):
323
- create_sample_btn = gr.Button(
324
- t("generation.create_sample_btn"),
325
- variant="primary",
326
- size="lg",
327
- )
328
- with gr.Column(scale=1, variant="compact"):
329
- simple_vocal_language = gr.Dropdown(
330
- choices=VALID_LANGUAGES,
331
- value="unknown",
332
- allow_custom_value=True,
333
- label=t("generation.simple_vocal_language_label"),
334
- interactive=True,
335
- )
336
-
337
- # State to track if sample has been created in Simple mode
338
- simple_sample_created = gr.State(value=False)
339
 
340
- # Music Caption - wrapped in accordion that can be collapsed in Simple mode
341
- # In service mode: auto-expand
342
- with gr.Accordion(t("generation.caption_title"), open=service_mode) as caption_accordion:
343
  with gr.Row(equal_height=True):
344
  captions = gr.Textbox(
345
- label=t("generation.caption_label"),
346
- placeholder=t("generation.caption_placeholder"),
347
- lines=3,
348
- info=t("generation.caption_info"),
349
- scale=12,
350
- )
351
- with gr.Column(scale=1, min_width=100):
352
- sample_btn = gr.Button(
353
- "🎲",
354
- variant="secondary",
355
- size="sm",
356
- scale=2,
357
- )
358
- # Lyrics - wrapped in accordion that can be collapsed in Simple mode
359
- # In service mode: auto-expand
360
- with gr.Accordion(t("generation.lyrics_title"), open=service_mode) as lyrics_accordion:
361
- lyrics = gr.Textbox(
362
- label=t("generation.lyrics_label"),
363
- placeholder=t("generation.lyrics_placeholder"),
364
- lines=8,
365
- info=t("generation.lyrics_info")
366
- )
367
-
368
- with gr.Row(variant="compact", equal_height=True):
369
- instrumental_checkbox = gr.Checkbox(
370
- label=t("generation.instrumental_label"),
371
- value=False,
372
  scale=1,
373
- min_width=120,
374
- container=True,
375
  )
376
-
377
- # 中间:语言选择 (Dropdown)
378
- # 移除 gr.HTML hack,直接使用 label 参数,Gradio 会自动处理对齐
379
- vocal_language = gr.Dropdown(
380
- choices=VALID_LANGUAGES,
381
- value="unknown",
382
- label=t("generation.vocal_language_label"),
383
- show_label=False,
384
- container=True,
385
- allow_custom_value=True,
386
- scale=3,
387
- )
388
-
389
- # 右侧:格式化按钮 (Button)
390
- # 放在同一行最右侧,操作更顺手
391
- format_btn = gr.Button(
392
- t("generation.format_btn"),
393
- variant="secondary",
394
  scale=1,
395
- min_width=80,
396
  )
 
 
 
 
 
 
397
 
398
- # Optional Parameters
399
- # In service mode: auto-expand
400
- with gr.Accordion(t("generation.optional_params"), open=service_mode) as optional_params_accordion:
401
- with gr.Row():
402
- bpm = gr.Number(
403
- label=t("generation.bpm_label"),
404
- value=None,
405
- step=1,
406
- info=t("generation.bpm_info")
407
- )
408
- key_scale = gr.Textbox(
409
- label=t("generation.keyscale_label"),
410
- placeholder=t("generation.keyscale_placeholder"),
411
- value="",
412
- info=t("generation.keyscale_info")
413
- )
414
- time_signature = gr.Dropdown(
415
- choices=["2", "3", "4", "N/A", ""],
416
- value="",
417
- label=t("generation.timesig_label"),
418
- allow_custom_value=True,
419
- info=t("generation.timesig_info")
420
- )
421
- audio_duration = gr.Number(
422
- label=t("generation.duration_label"),
423
- value=-1,
424
- minimum=-1,
425
- maximum=600.0,
426
- step=0.1,
427
- info=t("generation.duration_info")
428
- )
429
- batch_size_input = gr.Number(
430
- label=t("generation.batch_size_label"),
431
- value=2,
432
- minimum=1,
433
- maximum=8,
434
- step=1,
435
- info=t("generation.batch_size_info"),
436
- interactive=not service_mode # Fixed in service mode
437
- )
438
 
439
- # Advanced Settings
440
- # Default UI settings use turbo mode (max 20 steps, default 8, show shift with default 3)
441
- # These will be updated after model initialization based on handler.is_turbo_model()
442
- with gr.Accordion(t("generation.advanced_settings"), open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  with gr.Row():
444
  inference_steps = gr.Slider(
445
  minimum=1,
446
  maximum=20,
447
  value=8,
448
  step=1,
449
- label=t("generation.inference_steps_label"),
450
- info=t("generation.inference_steps_info")
 
 
 
 
 
451
  )
452
- guidance_scale = gr.Slider(
453
- minimum=1.0,
454
- maximum=15.0,
455
- value=7.0,
456
- step=0.1,
457
- label=t("generation.guidance_scale_label"),
458
- info=t("generation.guidance_scale_info"),
459
- visible=False
460
- )
461
- with gr.Column():
462
- seed = gr.Textbox(
463
- label=t("generation.seed_label"),
464
- value="-1",
465
- info=t("generation.seed_info")
466
- )
467
- random_seed_checkbox = gr.Checkbox(
468
- label=t("generation.random_seed_label"),
469
- value=True,
470
- info=t("generation.random_seed_info")
471
- )
472
  audio_format = gr.Dropdown(
473
  choices=["mp3", "flac"],
474
  value="mp3",
475
- label=t("generation.audio_format_label"),
476
- info=t("generation.audio_format_info"),
477
- interactive=not service_mode # Fixed in service mode
478
  )
479
 
 
480
  with gr.Row():
481
- use_adg = gr.Checkbox(
482
- label=t("generation.use_adg_label"),
483
- value=False,
484
- info=t("generation.use_adg_info"),
485
- visible=False
486
- )
487
  shift = gr.Slider(
488
  minimum=1.0,
489
  maximum=5.0,
490
  value=3.0,
491
  step=0.1,
492
- label=t("generation.shift_label"),
493
- info=t("generation.shift_info"),
494
- visible=True
 
 
 
 
495
  )
496
  infer_method = gr.Dropdown(
497
  choices=["ode", "sde"],
498
  value="ode",
499
- label=t("generation.infer_method_label"),
500
- info=t("generation.infer_method_info"),
501
  )
502
 
503
- with gr.Row():
504
- custom_timesteps = gr.Textbox(
505
- label=t("generation.custom_timesteps_label"),
506
- placeholder="0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0",
507
- value="",
508
- info=t("generation.custom_timesteps_info"),
509
- )
510
 
511
- with gr.Row():
512
- cfg_interval_start = gr.Slider(
513
- minimum=0.0,
514
- maximum=1.0,
515
- value=0.0,
516
- step=0.01,
517
- label=t("generation.cfg_interval_start"),
518
- visible=False
519
- )
520
- cfg_interval_end = gr.Slider(
521
- minimum=0.0,
522
- maximum=1.0,
523
- value=1.0,
524
- step=0.01,
525
- label=t("generation.cfg_interval_end"),
526
- visible=False
527
- )
528
-
529
- # LM (Language Model) Parameters
530
- gr.HTML(f"<h4>{t('generation.lm_params_title')}</h4>")
531
  with gr.Row():
532
  lm_temperature = gr.Slider(
533
- label=t("generation.lm_temperature_label"),
534
  minimum=0.0,
535
  maximum=2.0,
536
  value=0.85,
537
- step=0.1,
538
- scale=1,
539
- info=t("generation.lm_temperature_info")
540
  )
541
  lm_cfg_scale = gr.Slider(
542
- label=t("generation.lm_cfg_scale_label"),
543
  minimum=1.0,
544
  maximum=3.0,
545
  value=2.0,
546
  step=0.1,
547
- scale=1,
548
- info=t("generation.lm_cfg_scale_info")
549
  )
550
  lm_top_k = gr.Slider(
551
- label=t("generation.lm_top_k_label"),
552
  minimum=0,
553
  maximum=100,
554
  value=0,
555
  step=1,
556
- scale=1,
557
- info=t("generation.lm_top_k_info")
558
  )
559
  lm_top_p = gr.Slider(
560
- label=t("generation.lm_top_p_label"),
561
  minimum=0.0,
562
  maximum=1.0,
563
  value=0.9,
564
  step=0.01,
565
- scale=1,
566
- info=t("generation.lm_top_p_info")
567
- )
568
-
569
- with gr.Row():
570
- lm_negative_prompt = gr.Textbox(
571
- label=t("generation.lm_negative_prompt_label"),
572
- value="NO USER INPUT",
573
- placeholder=t("generation.lm_negative_prompt_placeholder"),
574
- info=t("generation.lm_negative_prompt_info"),
575
- lines=2,
576
- scale=2,
577
  )
578
 
579
- with gr.Row():
580
- use_cot_metas = gr.Checkbox(
581
- label=t("generation.cot_metas_label"),
582
- value=True,
583
- info=t("generation.cot_metas_info"),
584
- scale=1,
585
- )
586
- use_cot_language = gr.Checkbox(
587
- label=t("generation.cot_language_label"),
 
 
 
 
 
 
 
 
 
 
588
  value=True,
589
- info=t("generation.cot_language_info"),
590
- scale=1,
591
  )
592
- constrained_decoding_debug = gr.Checkbox(
593
- label=t("generation.constrained_debug_label"),
594
  value=False,
595
- info=t("generation.constrained_debug_info"),
596
- scale=1,
597
- interactive=not service_mode # Fixed in service mode
598
  )
599
 
600
- with gr.Row():
 
 
 
 
 
 
 
 
 
 
601
  auto_score = gr.Checkbox(
602
- label=t("generation.auto_score_label"),
603
  value=False,
604
- info=t("generation.auto_score_info"),
605
- scale=1,
606
- interactive=not service_mode # Fixed in service mode
607
  )
608
  auto_lrc = gr.Checkbox(
609
- label=t("generation.auto_lrc_label"),
610
  value=False,
611
- info=t("generation.auto_lrc_info"),
612
- scale=1,
613
- interactive=not service_mode # Fixed in service mode
614
- )
615
- lm_batch_chunk_size = gr.Number(
616
- label=t("generation.lm_batch_chunk_label"),
617
- value=8,
618
- minimum=1,
619
- maximum=32,
620
- step=1,
621
- info=t("generation.lm_batch_chunk_info"),
622
- scale=1,
623
- interactive=not service_mode # Fixed in service mode
624
- )
625
-
626
- with gr.Row():
627
- audio_cover_strength = gr.Slider(
628
- minimum=0.0,
629
- maximum=1.0,
630
- value=1.0,
631
- step=0.01,
632
- label=t("generation.codes_strength_label"),
633
- info=t("generation.codes_strength_info"),
634
- scale=1,
635
- )
636
- score_scale = gr.Slider(
637
- minimum=0.01,
638
- maximum=1.0,
639
- value=0.5,
640
- step=0.01,
641
- label=t("generation.score_sensitivity_label"),
642
- info=t("generation.score_sensitivity_info"),
643
- scale=1,
644
- visible=not service_mode # Hidden in service mode
645
  )
646
 
647
- # Set generate_btn to interactive if service is pre-initialized
648
- generate_btn_interactive = init_params.get('enable_generate', False) if service_pre_initialized else False
649
- with gr.Row(equal_height=True):
650
- with gr.Column(scale=1, variant="compact"):
651
- think_checkbox = gr.Checkbox(
652
- label=t("generation.think_label"),
653
- value=True,
654
- scale=1,
655
- )
656
- allow_lm_batch = gr.Checkbox(
657
- label=t("generation.parallel_thinking_label"),
658
- value=True,
659
- scale=1,
660
- )
661
- with gr.Column(scale=18):
662
- generate_btn = gr.Button(t("generation.generate_btn"), variant="primary", size="lg", interactive=generate_btn_interactive)
663
- with gr.Column(scale=1, variant="compact"):
664
- autogen_checkbox = gr.Checkbox(
665
- label=t("generation.autogen_label"),
666
- value=False, # Default to False for both service and local modes
667
- scale=1,
668
- interactive=not service_mode # Not selectable in service mode
669
- )
670
- use_cot_caption = gr.Checkbox(
671
- label=t("generation.caption_rewrite_label"),
672
- value=True,
673
- scale=1,
674
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
 
676
  return {
677
  "service_config_accordion": service_config_accordion,
@@ -694,6 +613,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
694
  "unload_lora_btn": unload_lora_btn,
695
  "use_lora_checkbox": use_lora_checkbox,
696
  "lora_status": lora_status,
 
 
697
  "task_type": task_type,
698
  "instruction_display_gen": instruction_display_gen,
699
  "track_name": track_name,
@@ -717,7 +638,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
717
  "repainting_start": repainting_start,
718
  "repainting_end": repainting_end,
719
  "audio_cover_strength": audio_cover_strength,
720
- # Simple/Custom Mode Components
721
  "generation_mode": generation_mode,
722
  "simple_mode_group": simple_mode_group,
723
  "simple_query_input": simple_query_input,
@@ -729,6 +650,13 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
729
  "caption_accordion": caption_accordion,
730
  "lyrics_accordion": lyrics_accordion,
731
  "optional_params_accordion": optional_params_accordion,
 
 
 
 
 
 
 
732
  # Existing components
733
  "captions": captions,
734
  "sample_btn": sample_btn,
@@ -763,4 +691,3 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
763
  "auto_lrc": auto_lrc,
764
  "lm_batch_chunk_size": lm_batch_chunk_size,
765
  }
766
-
 
1
  """
2
  Gradio UI Generation Section Module
3
+ Contains generation section component definitions - Simplified UI
4
  """
5
  import gradio as gr
6
  from acestep.constants import (
 
14
 
15
 
16
  def create_generation_section(dit_handler, llm_handler, init_params=None, language='en') -> dict:
17
+ """Create generation section with simplified UI
18
 
19
  Args:
20
  dit_handler: DiT handler instance
 
32
  # Get current language from init_params if available
33
  current_language = init_params.get('language', language) if init_params else language
34
 
35
+ # Get available models
36
+ available_dit_models = init_params.get('available_dit_models', []) if init_params else []
37
+ current_model_value = init_params.get('config_path', '') if init_params else ''
38
+ show_model_selector = len(available_dit_models) > 1
39
+
40
  with gr.Group():
41
+ # ==================== Service Configuration (Hidden in service mode) ====================
42
  accordion_open = not service_pre_initialized
43
+ accordion_visible = not service_pre_initialized
44
  with gr.Accordion(t("service.title"), open=accordion_open, visible=accordion_visible) as service_config_accordion:
45
  # Language selector at the top
46
  with gr.Row():
 
56
  scale=1,
57
  )
58
 
 
59
  with gr.Row(equal_height=True):
60
  with gr.Column(scale=4):
 
61
  checkpoint_value = init_params.get('checkpoint') if service_pre_initialized else None
62
  checkpoint_dropdown = gr.Dropdown(
63
  label=t("service.checkpoint_label"),
 
69
  refresh_btn = gr.Button(t("service.refresh_btn"), size="sm")
70
 
71
  with gr.Row():
 
72
  available_models = dit_handler.get_available_acestep_v15_models()
73
  default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
 
 
74
  config_path_value = init_params.get('config_path', default_model) if service_pre_initialized else default_model
75
  config_path = gr.Dropdown(
76
  label=t("service.model_path_label"),
 
78
  value=config_path_value,
79
  info=t("service.model_path_info")
80
  )
 
81
  device_value = init_params.get('device', 'auto') if service_pre_initialized else 'auto'
82
  device = gr.Dropdown(
83
  choices=["auto", "cuda", "cpu"],
 
87
  )
88
 
89
  with gr.Row():
 
90
  available_lm_models = llm_handler.get_available_5hz_lm_models()
91
  default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
 
 
92
  lm_model_path_value = init_params.get('lm_model_path', default_lm_model) if service_pre_initialized else default_lm_model
93
  lm_model_path = gr.Dropdown(
94
  label=t("service.lm_model_path_label"),
 
96
  value=lm_model_path_value,
97
  info=t("service.lm_model_path_info")
98
  )
 
99
  backend_value = init_params.get('backend', 'vllm') if service_pre_initialized else 'vllm'
100
  backend_dropdown = gr.Dropdown(
101
  choices=["vllm", "pt"],
 
104
  info=t("service.backend_info")
105
  )
106
 
 
107
  with gr.Row():
 
108
  init_llm_value = init_params.get('init_llm', True) if service_pre_initialized else True
109
  init_llm_checkbox = gr.Checkbox(
110
  label=t("service.init_llm_label"),
111
  value=init_llm_value,
112
  info=t("service.init_llm_info"),
113
  )
 
114
  flash_attn_available = dit_handler.is_flash_attention_available()
 
115
  use_flash_attention_value = init_params.get('use_flash_attention', flash_attn_available) if service_pre_initialized else flash_attn_available
116
  use_flash_attention_checkbox = gr.Checkbox(
117
  label=t("service.flash_attention_label"),
 
119
  interactive=flash_attn_available,
120
  info=t("service.flash_attention_info_enabled") if flash_attn_available else t("service.flash_attention_info_disabled")
121
  )
 
122
  offload_to_cpu_value = init_params.get('offload_to_cpu', False) if service_pre_initialized else False
123
  offload_to_cpu_checkbox = gr.Checkbox(
124
  label=t("service.offload_cpu_label"),
125
  value=offload_to_cpu_value,
126
  info=t("service.offload_cpu_info")
127
  )
 
128
  offload_dit_to_cpu_value = init_params.get('offload_dit_to_cpu', False) if service_pre_initialized else False
129
  offload_dit_to_cpu_checkbox = gr.Checkbox(
130
  label=t("service.offload_dit_cpu_label"),
 
133
  )
134
 
135
  init_btn = gr.Button(t("service.init_btn"), variant="primary", size="lg")
 
136
  init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
137
  init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value)
138
 
 
161
  scale=2,
162
  )
163
 
164
+ # ==================== Model Selector (Top, only when multiple models) ====================
165
+ with gr.Row(visible=show_model_selector):
166
+ dit_model_selector = gr.Dropdown(
167
+ choices=available_dit_models,
168
+ value=current_model_value,
169
+ label="models",
170
+ scale=1,
171
+ )
172
+
173
+ # Hidden dropdown when only one model (for event handler compatibility)
174
+ if not show_model_selector:
175
+ dit_model_selector = gr.Dropdown(
176
+ choices=available_dit_models if available_dit_models else [current_model_value],
177
+ value=current_model_value,
178
+ visible=False,
179
+ )
180
+
181
+ # ==================== Generation Mode (4 modes) ====================
182
+ gr.HTML("<div style='background: #4a5568; color: white; padding: 8px 16px; border-radius: 4px; font-weight: bold;'>Generation Mode</div>")
183
  with gr.Row():
184
+ generation_mode = gr.Radio(
185
+ choices=[
186
+ ("Simple", "simple"),
187
+ ("Custom", "custom"),
188
+ ("Cover", "cover"),
189
+ ("Repaint", "repaint"),
190
+ ],
191
+ value="custom",
192
+ label="",
193
+ show_label=False,
194
+ )
195
+
196
+ # ==================== Simple Mode Group ====================
197
+ with gr.Column(visible=False) as simple_mode_group:
198
+ # Row: Song Description + Vocal Language + Random button
199
+ with gr.Row(equal_height=True):
200
+ simple_query_input = gr.Textbox(
201
+ label=t("generation.simple_query_label"),
202
+ placeholder=t("generation.simple_query_placeholder"),
203
+ lines=2,
204
+ info=t("generation.simple_query_info"),
205
+ scale=10,
206
+ )
207
+ simple_vocal_language = gr.Dropdown(
208
+ choices=VALID_LANGUAGES,
209
+ value="unknown",
210
+ allow_custom_value=True,
211
+ label=t("generation.simple_vocal_language_label"),
212
+ interactive=True,
213
+ info="use unknown for instrumental",
214
+ scale=2,
215
+ )
216
+ with gr.Column(scale=1, min_width=60):
217
+ random_desc_btn = gr.Button(
218
+ "🎲",
219
+ variant="secondary",
220
+ size="lg",
 
 
 
 
 
 
221
  )
222
+
223
+ # Hidden components (kept for compatibility but not shown)
224
+ simple_instrumental_checkbox = gr.Checkbox(
225
+ label=t("generation.instrumental_label"),
226
+ value=False,
227
+ visible=False,
228
+ )
229
+ create_sample_btn = gr.Button(
230
+ t("generation.create_sample_btn"),
231
+ variant="primary",
232
+ size="lg",
233
+ visible=False,
234
+ )
235
+
236
+ # State to track if sample has been created in Simple mode
237
+ simple_sample_created = gr.State(value=False)
238
+
239
+ # ==================== Source Audio (for Cover/Repaint) ====================
240
+ # This is shown above the main content for Cover and Repaint modes
241
+ with gr.Column(visible=False) as src_audio_group:
242
+ with gr.Row(equal_height=True):
243
+ # Source Audio - scale=10 to match (refer_audio=2 + prompt/lyrics=8)
244
+ src_audio = gr.Audio(
245
+ label="Source Audio",
246
+ type="filepath",
247
+ scale=10,
248
+ )
249
+ # Process button - scale=1 to align with random button
250
+ with gr.Column(scale=1, min_width=80):
251
+ process_src_btn = gr.Button(
252
+ "Analyze",
253
+ variant="secondary",
254
+ size="lg",
255
+ )
256
+
257
+ # Hidden Audio Codes storage (needed internally but not displayed)
258
+ text2music_audio_code_string = gr.Textbox(
259
+ label="Audio Codes",
260
+ visible=False,
261
+ )
262
+
263
+ # ==================== Custom/Cover/Repaint Mode Content ====================
264
+ with gr.Column() as custom_mode_content:
265
+ with gr.Row(equal_height=True):
266
+ # Left: Reference Audio
267
+ with gr.Column(scale=2, min_width=200):
268
+ reference_audio = gr.Audio(
269
+ label="Reference Audio (optional)",
270
+ type="filepath",
271
+ show_label=True,
272
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
+ # Middle: Prompt + Lyrics + Format button
275
+ with gr.Column(scale=8):
276
+ # Row 1: Prompt and Lyrics
277
  with gr.Row(equal_height=True):
278
  captions = gr.Textbox(
279
+ label="Prompt",
280
+ placeholder="Describe the music style, mood, instruments...",
281
+ lines=12,
282
+ max_lines=12,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  scale=1,
 
 
284
  )
285
+ lyrics = gr.Textbox(
286
+ label="Lyrics",
287
+ placeholder="Enter lyrics here... Use [Verse], [Chorus] etc. for structure",
288
+ lines=12,
289
+ max_lines=12,
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  scale=1,
 
291
  )
292
+
293
+ # Row 2: Format button (only below Prompt and Lyrics)
294
+ format_btn = gr.Button(
295
+ "Format",
296
+ variant="secondary",
297
+ )
298
 
299
+ # Right: Random button
300
+ with gr.Column(scale=1, min_width=60):
301
+ sample_btn = gr.Button(
302
+ "🎲",
303
+ variant="secondary",
304
+ size="lg",
305
+ )
306
+
307
+ # Placeholder for removed audio_uploads_accordion (for compatibility)
308
+ audio_uploads_accordion = gr.Column(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
+ # Legacy cover_mode_group (hidden, for backward compatibility)
311
+ cover_mode_group = gr.Column(visible=False)
312
+ # Legacy convert button (hidden, for backward compatibility)
313
+ convert_src_to_codes_btn = gr.Button("Convert to Codes", visible=False)
314
+
315
+ # ==================== Repaint Mode: Source + Time Range ====================
316
+ with gr.Column(visible=False) as repainting_group:
317
+ with gr.Row():
318
+ repainting_start = gr.Number(
319
+ label="Start (seconds)",
320
+ value=0.0,
321
+ step=0.1,
322
+ scale=1,
323
+ )
324
+ repainting_end = gr.Number(
325
+ label="End (seconds, -1 for end)",
326
+ value=-1,
327
+ minimum=-1,
328
+ step=0.1,
329
+ scale=1,
330
+ )
331
+
332
+ # ==================== Optional Parameters ====================
333
+ with gr.Accordion("⚙️ Optional Parameters", open=False, visible=False) as optional_params_accordion:
334
+ pass
335
+
336
+ # ==================== Advanced Settings ====================
337
+ with gr.Accordion("🔧 Advanced Settings", open=False) as advanced_options_accordion:
338
+ with gr.Row():
339
+ bpm = gr.Number(
340
+ label="BPM (optional)",
341
+ value=0,
342
+ step=1,
343
+ info="leave empty for N/A",
344
+ scale=1,
345
+ )
346
+ key_scale = gr.Textbox(
347
+ label="Key Signature (optional)",
348
+ placeholder="Leave empty for N/A",
349
+ value="",
350
+ info="A-G, #/♭, major/minor",
351
+ scale=1,
352
+ )
353
+ time_signature = gr.Dropdown(
354
+ choices=["", "2", "3", "4"],
355
+ value="",
356
+ label="Time Signature (optional)",
357
+ allow_custom_value=True,
358
+ info="2/4, 3/4, 4/4...",
359
+ scale=1,
360
+ )
361
+ audio_duration = gr.Number(
362
+ label="Audio Duration (seconds)",
363
+ value=-1,
364
+ minimum=-1,
365
+ maximum=600.0,
366
+ step=1,
367
+ info="Use -1 for random",
368
+ scale=1,
369
+ )
370
+ vocal_language = gr.Dropdown(
371
+ choices=VALID_LANGUAGES,
372
+ value="unknown",
373
+ label="Vocal Language",
374
+ allow_custom_value=True,
375
+ info="use `unknown` for instrumental",
376
+ scale=1,
377
+ )
378
+ batch_size_input = gr.Number(
379
+ label="batch size",
380
+ info="max 8",
381
+ value=2,
382
+ minimum=1,
383
+ maximum=8,
384
+ step=1,
385
+ scale=1,
386
+ )
387
+
388
+ # Row 1: DiT Inference Steps, Seed, Audio Format
389
  with gr.Row():
390
  inference_steps = gr.Slider(
391
  minimum=1,
392
  maximum=20,
393
  value=8,
394
  step=1,
395
+ label="DiT Inference Steps",
396
+ info="Turbo: max 8, Base: max 200",
397
+ )
398
+ seed = gr.Textbox(
399
+ label="Seed",
400
+ value="-1",
401
+ info="Use comma-separated values for batches",
402
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  audio_format = gr.Dropdown(
404
  choices=["mp3", "flac"],
405
  value="mp3",
406
+ label="Audio Format",
407
+ info="Audio format for saved files",
 
408
  )
409
 
410
+ # Row 2: Shift, Random Seed, Inference Method
411
  with gr.Row():
 
 
 
 
 
 
412
  shift = gr.Slider(
413
  minimum=1.0,
414
  maximum=5.0,
415
  value=3.0,
416
  step=0.1,
417
+ label="Shift",
418
+ info="Timestep shift factor for base models (range 1.0-5.0, default 3.0). Not effective for turbo models.",
419
+ )
420
+ random_seed_checkbox = gr.Checkbox(
421
+ label="Random Seed",
422
+ value=True,
423
+ info="Enable to auto-generate seeds",
424
  )
425
  infer_method = gr.Dropdown(
426
  choices=["ode", "sde"],
427
  value="ode",
428
+ label="Inference Method",
429
+ info="Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
430
  )
431
 
432
+ # Row 3: Custom Timesteps (full width)
433
+ custom_timesteps = gr.Textbox(
434
+ label="Custom Timesteps",
435
+ placeholder="0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0",
436
+ value="",
437
+ info="Optional: comma-separated values from 1.0 to 0.0 (e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference steps and shift.",
438
+ )
439
 
440
+ # Section: LM Generation Parameters
441
+ gr.HTML("<h4>🎵 LM Generation Parameters</h4>")
442
+
443
+ # Row 4: LM Temperature, LM CFG Scale, LM Top-K, LM Top-P
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
  with gr.Row():
445
  lm_temperature = gr.Slider(
 
446
  minimum=0.0,
447
  maximum=2.0,
448
  value=0.85,
449
+ step=0.05,
450
+ label="LM Temperature",
451
+ info="5Hz LM temperature (higher = more random)",
452
  )
453
  lm_cfg_scale = gr.Slider(
 
454
  minimum=1.0,
455
  maximum=3.0,
456
  value=2.0,
457
  step=0.1,
458
+ label="LM CFG Scale",
459
+ info="5Hz LM CFG (1.0 = no CFG)",
460
  )
461
  lm_top_k = gr.Slider(
 
462
  minimum=0,
463
  maximum=100,
464
  value=0,
465
  step=1,
466
+ label="LM Top-K",
467
+ info="Top-k (0 = disabled)",
468
  )
469
  lm_top_p = gr.Slider(
 
470
  minimum=0.0,
471
  maximum=1.0,
472
  value=0.9,
473
  step=0.01,
474
+ label="LM Top-P",
475
+ info="Top-p (1.0 = disabled)",
 
 
 
 
 
 
 
 
 
 
476
  )
477
 
478
+ # Row 5: LM Negative Prompt (full width)
479
+ lm_negative_prompt = gr.Textbox(
480
+ label="LM Negative Prompt",
481
+ value="NO USER INPUT",
482
+ placeholder="Things to avoid in generation...",
483
+ lines=2,
484
+ info="Negative prompt (use when LM CFG Scale > 1.0)",
485
+ )
486
+ # audio_cover_strength remains hidden for now
487
+ audio_cover_strength = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, visible=False)
488
+
489
+ # Note: audio_duration, bpm, key_scale, time_signature are now visible in Optional Parameters
490
+ # ==================== Generate Button Row ====================
491
+ generate_btn_interactive = init_params.get('enable_generate', False) if service_pre_initialized else False
492
+ with gr.Row(equal_height=True):
493
+ # Left: Thinking and Instrumental checkboxes
494
+ with gr.Column(scale=1, min_width=120):
495
+ think_checkbox = gr.Checkbox(
496
+ label="Thinking",
497
  value=True,
 
 
498
  )
499
+ instrumental_checkbox = gr.Checkbox(
500
+ label="Instrumental",
501
  value=False,
 
 
 
502
  )
503
 
504
+ # Center: Generate button
505
+ with gr.Column(scale=4):
506
+ generate_btn = gr.Button(
507
+ "🎵 Generate Music",
508
+ variant="primary",
509
+ size="lg",
510
+ interactive=generate_btn_interactive,
511
+ )
512
+
513
+ # Right: auto_score, auto_lrc
514
+ with gr.Column(scale=1, min_width=120):
515
  auto_score = gr.Checkbox(
516
+ label="Get Scores",
517
  value=False,
 
 
 
518
  )
519
  auto_lrc = gr.Checkbox(
520
+ label="Get LRC",
521
  value=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
  )
523
 
524
+ # ==================== Hidden Components (for internal use) ====================
525
+ # These are needed for event handlers but not shown in UI
526
+
527
+ # Task type (set automatically based on generation_mode)
528
+ actual_model = init_params.get('config_path', 'acestep-v15-turbo') if service_pre_initialized else 'acestep-v15-turbo'
529
+ actual_model_lower = (actual_model or "").lower()
530
+ if "turbo" in actual_model_lower:
531
+ initial_task_choices = TASK_TYPES_TURBO
532
+ else:
533
+ initial_task_choices = TASK_TYPES_BASE
534
+
535
+ task_type = gr.Dropdown(
536
+ choices=initial_task_choices,
537
+ value="text2music",
538
+ visible=False,
539
+ )
540
+
541
+ instruction_display_gen = gr.Textbox(
542
+ value=DEFAULT_DIT_INSTRUCTION,
543
+ visible=False,
544
+ )
545
+
546
+ track_name = gr.Dropdown(
547
+ choices=TRACK_NAMES,
548
+ value=None,
549
+ visible=False,
550
+ )
551
+
552
+ complete_track_classes = gr.CheckboxGroup(
553
+ choices=TRACK_NAMES,
554
+ visible=False,
555
+ )
556
+
557
+ # Note: lyrics, vocal_language, instrumental_checkbox, format_btn are now visible in custom_mode_content
558
+
559
+ # Hidden advanced settings (keep defaults)
560
+ # Note: Most parameters are now visible in Advanced Settings section above
561
+ guidance_scale = gr.Slider(value=7.0, visible=False)
562
+ use_adg = gr.Checkbox(value=False, visible=False)
563
+ cfg_interval_start = gr.Slider(value=0.0, visible=False)
564
+ cfg_interval_end = gr.Slider(value=1.0, visible=False)
565
+
566
+ # LM parameters (remaining hidden ones)
567
+ use_cot_metas = gr.Checkbox(value=True, visible=False)
568
+ use_cot_caption = gr.Checkbox(value=True, visible=False)
569
+ use_cot_language = gr.Checkbox(value=True, visible=False)
570
+ constrained_decoding_debug = gr.Checkbox(value=False, visible=False)
571
+ allow_lm_batch = gr.Checkbox(value=True, visible=False)
572
+ lm_batch_chunk_size = gr.Number(value=8, visible=False)
573
+ score_scale = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, visible=False)
574
+ autogen_checkbox = gr.Checkbox(value=False, visible=False)
575
+
576
+ # Transcribe button (hidden)
577
+ transcribe_btn = gr.Button(value="Transcribe", visible=False)
578
+ text2music_audio_codes_group = gr.Group(visible=False)
579
+
580
+ # Note: format_btn is now visible in custom_mode_content
581
+
582
+ # Load file button (hidden for now)
583
+ load_file = gr.UploadButton(
584
+ label="Load",
585
+ file_types=[".json"],
586
+ file_count="single",
587
+ visible=False,
588
+ )
589
+
590
+ # Caption/Lyrics accordions (not used in new UI but needed for compatibility)
591
+ caption_accordion = gr.Accordion("Caption", visible=False)
592
+ lyrics_accordion = gr.Accordion("Lyrics", visible=False)
593
+ # Note: optional_params_accordion is now visible above
594
 
595
  return {
596
  "service_config_accordion": service_config_accordion,
 
613
  "unload_lora_btn": unload_lora_btn,
614
  "use_lora_checkbox": use_lora_checkbox,
615
  "lora_status": lora_status,
616
+ # DiT model selector
617
+ "dit_model_selector": dit_model_selector,
618
  "task_type": task_type,
619
  "instruction_display_gen": instruction_display_gen,
620
  "track_name": track_name,
 
638
  "repainting_start": repainting_start,
639
  "repainting_end": repainting_end,
640
  "audio_cover_strength": audio_cover_strength,
641
+ # Generation mode components
642
  "generation_mode": generation_mode,
643
  "simple_mode_group": simple_mode_group,
644
  "simple_query_input": simple_query_input,
 
650
  "caption_accordion": caption_accordion,
651
  "lyrics_accordion": lyrics_accordion,
652
  "optional_params_accordion": optional_params_accordion,
653
+ # Custom mode components
654
+ "custom_mode_content": custom_mode_content,
655
+ "cover_mode_group": cover_mode_group,
656
+ # Source audio group for Cover/Repaint
657
+ "src_audio_group": src_audio_group,
658
+ "process_src_btn": process_src_btn,
659
+ "advanced_options_accordion": advanced_options_accordion,
660
  # Existing components
661
  "captions": captions,
662
  "sample_btn": sample_btn,
 
691
  "auto_lrc": auto_lrc,
692
  "lm_batch_chunk_size": lm_batch_chunk_size,
693
  }
 
acestep/gradio_ui/interfaces/result.py CHANGED
@@ -32,8 +32,14 @@ def create_results_section(dit_handler) -> dict:
32
  buttons=[]
33
  )
34
  with gr.Row(equal_height=True):
35
- send_to_src_btn_1 = gr.Button(
36
- t("results.send_to_src_btn"),
 
 
 
 
 
 
37
  variant="secondary",
38
  size="sm",
39
  scale=1
@@ -48,23 +54,17 @@ def create_results_section(dit_handler) -> dict:
48
  t("results.score_btn"),
49
  variant="secondary",
50
  size="sm",
51
- scale=1
 
52
  )
53
  lrc_btn_1 = gr.Button(
54
  t("results.lrc_btn"),
55
  variant="secondary",
56
  size="sm",
57
- scale=1
 
58
  )
59
  with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_1:
60
- codes_display_1 = gr.Textbox(
61
- label=t("results.codes_label", n=1),
62
- interactive=False,
63
- buttons=["copy"],
64
- lines=4,
65
- max_lines=4,
66
- visible=True
67
- )
68
  score_display_1 = gr.Textbox(
69
  label=t("results.quality_score_label", n=1),
70
  interactive=False,
@@ -81,6 +81,14 @@ def create_results_section(dit_handler) -> dict:
81
  max_lines=8,
82
  visible=True
83
  )
 
 
 
 
 
 
 
 
84
  with gr.Column(visible=True) as audio_col_2:
85
  generated_audio_2 = gr.Audio(
86
  label=t("results.generated_music", n=2),
@@ -89,8 +97,14 @@ def create_results_section(dit_handler) -> dict:
89
  buttons=[]
90
  )
91
  with gr.Row(equal_height=True):
92
- send_to_src_btn_2 = gr.Button(
93
- t("results.send_to_src_btn"),
 
 
 
 
 
 
94
  variant="secondary",
95
  size="sm",
96
  scale=1
@@ -105,23 +119,17 @@ def create_results_section(dit_handler) -> dict:
105
  t("results.score_btn"),
106
  variant="secondary",
107
  size="sm",
108
- scale=1
 
109
  )
110
  lrc_btn_2 = gr.Button(
111
  t("results.lrc_btn"),
112
  variant="secondary",
113
  size="sm",
114
- scale=1
 
115
  )
116
  with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_2:
117
- codes_display_2 = gr.Textbox(
118
- label=t("results.codes_label", n=2),
119
- interactive=False,
120
- buttons=["copy"],
121
- lines=4,
122
- max_lines=4,
123
- visible=True
124
- )
125
  score_display_2 = gr.Textbox(
126
  label=t("results.quality_score_label", n=2),
127
  interactive=False,
@@ -138,6 +146,14 @@ def create_results_section(dit_handler) -> dict:
138
  max_lines=8,
139
  visible=True
140
  )
 
 
 
 
 
 
 
 
141
  with gr.Column(visible=False) as audio_col_3:
142
  generated_audio_3 = gr.Audio(
143
  label=t("results.generated_music", n=3),
@@ -146,8 +162,14 @@ def create_results_section(dit_handler) -> dict:
146
  buttons=[]
147
  )
148
  with gr.Row(equal_height=True):
149
- send_to_src_btn_3 = gr.Button(
150
- t("results.send_to_src_btn"),
 
 
 
 
 
 
151
  variant="secondary",
152
  size="sm",
153
  scale=1
@@ -162,23 +184,17 @@ def create_results_section(dit_handler) -> dict:
162
  t("results.score_btn"),
163
  variant="secondary",
164
  size="sm",
165
- scale=1
 
166
  )
167
  lrc_btn_3 = gr.Button(
168
  t("results.lrc_btn"),
169
  variant="secondary",
170
  size="sm",
171
- scale=1
 
172
  )
173
  with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_3:
174
- codes_display_3 = gr.Textbox(
175
- label=t("results.codes_label", n=3),
176
- interactive=False,
177
- buttons=["copy"],
178
- lines=4,
179
- max_lines=4,
180
- visible=True
181
- )
182
  score_display_3 = gr.Textbox(
183
  label=t("results.quality_score_label", n=3),
184
  interactive=False,
@@ -195,6 +211,14 @@ def create_results_section(dit_handler) -> dict:
195
  max_lines=8,
196
  visible=True
197
  )
 
 
 
 
 
 
 
 
198
  with gr.Column(visible=False) as audio_col_4:
199
  generated_audio_4 = gr.Audio(
200
  label=t("results.generated_music", n=4),
@@ -203,8 +227,14 @@ def create_results_section(dit_handler) -> dict:
203
  buttons=[]
204
  )
205
  with gr.Row(equal_height=True):
206
- send_to_src_btn_4 = gr.Button(
207
- t("results.send_to_src_btn"),
 
 
 
 
 
 
208
  variant="secondary",
209
  size="sm",
210
  scale=1
@@ -219,23 +249,17 @@ def create_results_section(dit_handler) -> dict:
219
  t("results.score_btn"),
220
  variant="secondary",
221
  size="sm",
222
- scale=1
 
223
  )
224
  lrc_btn_4 = gr.Button(
225
  t("results.lrc_btn"),
226
  variant="secondary",
227
  size="sm",
228
- scale=1
 
229
  )
230
  with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_4:
231
- codes_display_4 = gr.Textbox(
232
- label=t("results.codes_label", n=4),
233
- interactive=False,
234
- buttons=["copy"],
235
- lines=4,
236
- max_lines=4,
237
- visible=True
238
- )
239
  score_display_4 = gr.Textbox(
240
  label=t("results.quality_score_label", n=4),
241
  interactive=False,
@@ -252,6 +276,14 @@ def create_results_section(dit_handler) -> dict:
252
  max_lines=8,
253
  visible=True
254
  )
 
 
 
 
 
 
 
 
255
 
256
  # Second row for batch size 5-8 (initially hidden)
257
  with gr.Row(visible=False) as audio_row_5_8:
@@ -263,19 +295,12 @@ def create_results_section(dit_handler) -> dict:
263
  buttons=[]
264
  )
265
  with gr.Row(equal_height=True):
266
- send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
 
267
  save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
268
- score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
269
- lrc_btn_5 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
270
  with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_5:
271
- codes_display_5 = gr.Textbox(
272
- label=t("results.codes_label", n=5),
273
- interactive=False,
274
- buttons=["copy"],
275
- lines=4,
276
- max_lines=4,
277
- visible=True
278
- )
279
  score_display_5 = gr.Textbox(
280
  label=t("results.quality_score_label", n=5),
281
  interactive=False,
@@ -292,6 +317,14 @@ def create_results_section(dit_handler) -> dict:
292
  max_lines=8,
293
  visible=True
294
  )
 
 
 
 
 
 
 
 
295
  with gr.Column() as audio_col_6:
296
  generated_audio_6 = gr.Audio(
297
  label=t("results.generated_music", n=6),
@@ -300,19 +333,12 @@ def create_results_section(dit_handler) -> dict:
300
  buttons=[]
301
  )
302
  with gr.Row(equal_height=True):
303
- send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
 
304
  save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
305
- score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
306
- lrc_btn_6 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
307
  with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_6:
308
- codes_display_6 = gr.Textbox(
309
- label=t("results.codes_label", n=6),
310
- interactive=False,
311
- buttons=["copy"],
312
- lines=4,
313
- max_lines=4,
314
- visible=True
315
- )
316
  score_display_6 = gr.Textbox(
317
  label=t("results.quality_score_label", n=6),
318
  interactive=False,
@@ -329,6 +355,14 @@ def create_results_section(dit_handler) -> dict:
329
  max_lines=8,
330
  visible=True
331
  )
 
 
 
 
 
 
 
 
332
  with gr.Column() as audio_col_7:
333
  generated_audio_7 = gr.Audio(
334
  label=t("results.generated_music", n=7),
@@ -337,19 +371,12 @@ def create_results_section(dit_handler) -> dict:
337
  buttons=[]
338
  )
339
  with gr.Row(equal_height=True):
340
- send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
 
341
  save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
342
- score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
343
- lrc_btn_7 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
344
  with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_7:
345
- codes_display_7 = gr.Textbox(
346
- label=t("results.codes_label", n=7),
347
- interactive=False,
348
- buttons=["copy"],
349
- lines=4,
350
- max_lines=4,
351
- visible=True
352
- )
353
  score_display_7 = gr.Textbox(
354
  label=t("results.quality_score_label", n=7),
355
  interactive=False,
@@ -366,6 +393,14 @@ def create_results_section(dit_handler) -> dict:
366
  max_lines=8,
367
  visible=True
368
  )
 
 
 
 
 
 
 
 
369
  with gr.Column() as audio_col_8:
370
  generated_audio_8 = gr.Audio(
371
  label=t("results.generated_music", n=8),
@@ -374,19 +409,12 @@ def create_results_section(dit_handler) -> dict:
374
  buttons=[]
375
  )
376
  with gr.Row(equal_height=True):
377
- send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
 
378
  save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
379
- score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
380
- lrc_btn_8 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
381
  with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_8:
382
- codes_display_8 = gr.Textbox(
383
- label=t("results.codes_label", n=8),
384
- interactive=False,
385
- buttons=["copy"],
386
- lines=4,
387
- max_lines=4,
388
- visible=True
389
- )
390
  score_display_8 = gr.Textbox(
391
  label=t("results.quality_score_label", n=8),
392
  interactive=False,
@@ -403,11 +431,19 @@ def create_results_section(dit_handler) -> dict:
403
  max_lines=8,
404
  visible=True
405
  )
 
 
 
 
 
 
 
 
406
 
407
  status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
408
 
409
- # Batch navigation controls
410
- with gr.Row(equal_height=True):
411
  prev_batch_btn = gr.Button(
412
  t("results.prev_btn"),
413
  variant="secondary",
@@ -435,12 +471,13 @@ def create_results_section(dit_handler) -> dict:
435
  size="sm"
436
  )
437
 
438
- # One-click restore parameters button
439
  restore_params_btn = gr.Button(
440
  t("results.restore_params_btn"),
441
  variant="secondary",
442
- interactive=False, # Initially disabled, enabled after generation
443
- size="sm"
 
444
  )
445
 
446
  with gr.Accordion(t("results.batch_results_title"), open=False):
@@ -482,14 +519,22 @@ def create_results_section(dit_handler) -> dict:
482
  "audio_col_6": audio_col_6,
483
  "audio_col_7": audio_col_7,
484
  "audio_col_8": audio_col_8,
485
- "send_to_src_btn_1": send_to_src_btn_1,
486
- "send_to_src_btn_2": send_to_src_btn_2,
487
- "send_to_src_btn_3": send_to_src_btn_3,
488
- "send_to_src_btn_4": send_to_src_btn_4,
489
- "send_to_src_btn_5": send_to_src_btn_5,
490
- "send_to_src_btn_6": send_to_src_btn_6,
491
- "send_to_src_btn_7": send_to_src_btn_7,
492
- "send_to_src_btn_8": send_to_src_btn_8,
 
 
 
 
 
 
 
 
493
  "save_btn_1": save_btn_1,
494
  "save_btn_2": save_btn_2,
495
  "save_btn_3": save_btn_3,
 
32
  buttons=[]
33
  )
34
  with gr.Row(equal_height=True):
35
+ send_to_cover_btn_1 = gr.Button(
36
+ t("results.send_to_cover_btn"),
37
+ variant="secondary",
38
+ size="sm",
39
+ scale=1
40
+ )
41
+ send_to_repaint_btn_1 = gr.Button(
42
+ t("results.send_to_repaint_btn"),
43
  variant="secondary",
44
  size="sm",
45
  scale=1
 
54
  t("results.score_btn"),
55
  variant="secondary",
56
  size="sm",
57
+ scale=1,
58
+ visible=False
59
  )
60
  lrc_btn_1 = gr.Button(
61
  t("results.lrc_btn"),
62
  variant="secondary",
63
  size="sm",
64
+ scale=1,
65
+ visible=False
66
  )
67
  with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_1:
 
 
 
 
 
 
 
 
68
  score_display_1 = gr.Textbox(
69
  label=t("results.quality_score_label", n=1),
70
  interactive=False,
 
81
  max_lines=8,
82
  visible=True
83
  )
84
+ codes_display_1 = gr.Textbox(
85
+ label=t("results.codes_label", n=1),
86
+ interactive=False,
87
+ buttons=["copy"],
88
+ lines=4,
89
+ max_lines=4,
90
+ visible=True
91
+ )
92
  with gr.Column(visible=True) as audio_col_2:
93
  generated_audio_2 = gr.Audio(
94
  label=t("results.generated_music", n=2),
 
97
  buttons=[]
98
  )
99
  with gr.Row(equal_height=True):
100
+ send_to_cover_btn_2 = gr.Button(
101
+ t("results.send_to_cover_btn"),
102
+ variant="secondary",
103
+ size="sm",
104
+ scale=1
105
+ )
106
+ send_to_repaint_btn_2 = gr.Button(
107
+ t("results.send_to_repaint_btn"),
108
  variant="secondary",
109
  size="sm",
110
  scale=1
 
119
  t("results.score_btn"),
120
  variant="secondary",
121
  size="sm",
122
+ scale=1,
123
+ visible=False
124
  )
125
  lrc_btn_2 = gr.Button(
126
  t("results.lrc_btn"),
127
  variant="secondary",
128
  size="sm",
129
+ scale=1,
130
+ visible=False
131
  )
132
  with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_2:
 
 
 
 
 
 
 
 
133
  score_display_2 = gr.Textbox(
134
  label=t("results.quality_score_label", n=2),
135
  interactive=False,
 
146
  max_lines=8,
147
  visible=True
148
  )
149
+ codes_display_2 = gr.Textbox(
150
+ label=t("results.codes_label", n=2),
151
+ interactive=False,
152
+ buttons=["copy"],
153
+ lines=4,
154
+ max_lines=4,
155
+ visible=True
156
+ )
157
  with gr.Column(visible=False) as audio_col_3:
158
  generated_audio_3 = gr.Audio(
159
  label=t("results.generated_music", n=3),
 
162
  buttons=[]
163
  )
164
  with gr.Row(equal_height=True):
165
+ send_to_cover_btn_3 = gr.Button(
166
+ t("results.send_to_cover_btn"),
167
+ variant="secondary",
168
+ size="sm",
169
+ scale=1
170
+ )
171
+ send_to_repaint_btn_3 = gr.Button(
172
+ t("results.send_to_repaint_btn"),
173
  variant="secondary",
174
  size="sm",
175
  scale=1
 
184
  t("results.score_btn"),
185
  variant="secondary",
186
  size="sm",
187
+ scale=1,
188
+ visible=False
189
  )
190
  lrc_btn_3 = gr.Button(
191
  t("results.lrc_btn"),
192
  variant="secondary",
193
  size="sm",
194
+ scale=1,
195
+ visible=False
196
  )
197
  with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_3:
 
 
 
 
 
 
 
 
198
  score_display_3 = gr.Textbox(
199
  label=t("results.quality_score_label", n=3),
200
  interactive=False,
 
211
  max_lines=8,
212
  visible=True
213
  )
214
+ codes_display_3 = gr.Textbox(
215
+ label=t("results.codes_label", n=3),
216
+ interactive=False,
217
+ buttons=["copy"],
218
+ lines=4,
219
+ max_lines=4,
220
+ visible=True
221
+ )
222
  with gr.Column(visible=False) as audio_col_4:
223
  generated_audio_4 = gr.Audio(
224
  label=t("results.generated_music", n=4),
 
227
  buttons=[]
228
  )
229
  with gr.Row(equal_height=True):
230
+ send_to_cover_btn_4 = gr.Button(
231
+ t("results.send_to_cover_btn"),
232
+ variant="secondary",
233
+ size="sm",
234
+ scale=1
235
+ )
236
+ send_to_repaint_btn_4 = gr.Button(
237
+ t("results.send_to_repaint_btn"),
238
  variant="secondary",
239
  size="sm",
240
  scale=1
 
249
  t("results.score_btn"),
250
  variant="secondary",
251
  size="sm",
252
+ scale=1,
253
+ visible=False
254
  )
255
  lrc_btn_4 = gr.Button(
256
  t("results.lrc_btn"),
257
  variant="secondary",
258
  size="sm",
259
+ scale=1,
260
+ visible=False
261
  )
262
  with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_4:
 
 
 
 
 
 
 
 
263
  score_display_4 = gr.Textbox(
264
  label=t("results.quality_score_label", n=4),
265
  interactive=False,
 
276
  max_lines=8,
277
  visible=True
278
  )
279
+ codes_display_4 = gr.Textbox(
280
+ label=t("results.codes_label", n=4),
281
+ interactive=False,
282
+ buttons=["copy"],
283
+ lines=4,
284
+ max_lines=4,
285
+ visible=True
286
+ )
287
 
288
  # Second row for batch size 5-8 (initially hidden)
289
  with gr.Row(visible=False) as audio_row_5_8:
 
295
  buttons=[]
296
  )
297
  with gr.Row(equal_height=True):
298
+ send_to_cover_btn_5 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
299
+ send_to_repaint_btn_5 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
300
  save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
301
+ score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
302
+ lrc_btn_5 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
303
  with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_5:
 
 
 
 
 
 
 
 
304
  score_display_5 = gr.Textbox(
305
  label=t("results.quality_score_label", n=5),
306
  interactive=False,
 
317
  max_lines=8,
318
  visible=True
319
  )
320
+ codes_display_5 = gr.Textbox(
321
+ label=t("results.codes_label", n=5),
322
+ interactive=False,
323
+ buttons=["copy"],
324
+ lines=4,
325
+ max_lines=4,
326
+ visible=True
327
+ )
328
  with gr.Column() as audio_col_6:
329
  generated_audio_6 = gr.Audio(
330
  label=t("results.generated_music", n=6),
 
333
  buttons=[]
334
  )
335
  with gr.Row(equal_height=True):
336
+ send_to_cover_btn_6 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
337
+ send_to_repaint_btn_6 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
338
  save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
339
+ score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
340
+ lrc_btn_6 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
341
  with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_6:
 
 
 
 
 
 
 
 
342
  score_display_6 = gr.Textbox(
343
  label=t("results.quality_score_label", n=6),
344
  interactive=False,
 
355
  max_lines=8,
356
  visible=True
357
  )
358
+ codes_display_6 = gr.Textbox(
359
+ label=t("results.codes_label", n=6),
360
+ interactive=False,
361
+ buttons=["copy"],
362
+ lines=4,
363
+ max_lines=4,
364
+ visible=True
365
+ )
366
  with gr.Column() as audio_col_7:
367
  generated_audio_7 = gr.Audio(
368
  label=t("results.generated_music", n=7),
 
371
  buttons=[]
372
  )
373
  with gr.Row(equal_height=True):
374
+ send_to_cover_btn_7 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
375
+ send_to_repaint_btn_7 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
376
  save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
377
+ score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
378
+ lrc_btn_7 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
379
  with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_7:
 
 
 
 
 
 
 
 
380
  score_display_7 = gr.Textbox(
381
  label=t("results.quality_score_label", n=7),
382
  interactive=False,
 
393
  max_lines=8,
394
  visible=True
395
  )
396
+ codes_display_7 = gr.Textbox(
397
+ label=t("results.codes_label", n=7),
398
+ interactive=False,
399
+ buttons=["copy"],
400
+ lines=4,
401
+ max_lines=4,
402
+ visible=True
403
+ )
404
  with gr.Column() as audio_col_8:
405
  generated_audio_8 = gr.Audio(
406
  label=t("results.generated_music", n=8),
 
409
  buttons=[]
410
  )
411
  with gr.Row(equal_height=True):
412
+ send_to_cover_btn_8 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
413
+ send_to_repaint_btn_8 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
414
  save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
415
+ score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
416
+ lrc_btn_8 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
417
  with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_8:
 
 
 
 
 
 
 
 
418
  score_display_8 = gr.Textbox(
419
  label=t("results.quality_score_label", n=8),
420
  interactive=False,
 
431
  max_lines=8,
432
  visible=True
433
  )
434
+ codes_display_8 = gr.Textbox(
435
+ label=t("results.codes_label", n=8),
436
+ interactive=False,
437
+ buttons=["copy"],
438
+ lines=4,
439
+ max_lines=4,
440
+ visible=True
441
+ )
442
 
443
  status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
444
 
445
+ # Batch navigation controls (hidden for simplified UI)
446
+ with gr.Row(equal_height=True, visible=False):
447
  prev_batch_btn = gr.Button(
448
  t("results.prev_btn"),
449
  variant="secondary",
 
471
  size="sm"
472
  )
473
 
474
+ # One-click restore parameters button (hidden for simplified UI)
475
  restore_params_btn = gr.Button(
476
  t("results.restore_params_btn"),
477
  variant="secondary",
478
+ interactive=False,
479
+ size="sm",
480
+ visible=False
481
  )
482
 
483
  with gr.Accordion(t("results.batch_results_title"), open=False):
 
519
  "audio_col_6": audio_col_6,
520
  "audio_col_7": audio_col_7,
521
  "audio_col_8": audio_col_8,
522
+ "send_to_cover_btn_1": send_to_cover_btn_1,
523
+ "send_to_cover_btn_2": send_to_cover_btn_2,
524
+ "send_to_cover_btn_3": send_to_cover_btn_3,
525
+ "send_to_cover_btn_4": send_to_cover_btn_4,
526
+ "send_to_cover_btn_5": send_to_cover_btn_5,
527
+ "send_to_cover_btn_6": send_to_cover_btn_6,
528
+ "send_to_cover_btn_7": send_to_cover_btn_7,
529
+ "send_to_cover_btn_8": send_to_cover_btn_8,
530
+ "send_to_repaint_btn_1": send_to_repaint_btn_1,
531
+ "send_to_repaint_btn_2": send_to_repaint_btn_2,
532
+ "send_to_repaint_btn_3": send_to_repaint_btn_3,
533
+ "send_to_repaint_btn_4": send_to_repaint_btn_4,
534
+ "send_to_repaint_btn_5": send_to_repaint_btn_5,
535
+ "send_to_repaint_btn_6": send_to_repaint_btn_6,
536
+ "send_to_repaint_btn_7": send_to_repaint_btn_7,
537
+ "send_to_repaint_btn_8": send_to_repaint_btn_8,
538
  "save_btn_1": save_btn_1,
539
  "save_btn_2": save_btn_2,
540
  "save_btn_3": save_btn_3,
acestep/handler.py CHANGED
@@ -315,6 +315,11 @@ class AceStepHandler:
315
  offload_to_cpu: bool = False,
316
  offload_dit_to_cpu: bool = False,
317
  quantization: Optional[str] = None,
 
 
 
 
 
318
  ) -> Tuple[str, bool]:
319
  """
320
  Initialize DiT model service
@@ -327,6 +332,10 @@ class AceStepHandler:
327
  compile_model: Whether to use torch.compile to optimize the model
328
  offload_to_cpu: Whether to offload models to CPU when not in use
329
  offload_dit_to_cpu: Whether to offload DiT model to CPU when not in use (only effective if offload_to_cpu is True)
 
 
 
 
330
 
331
  Returns:
332
  (status_message, enable_generate_button)
@@ -440,54 +449,77 @@ class AceStepHandler:
440
  logger.info(f"[initialize_service] DiT quantized with: {self.quantization}")
441
 
442
 
443
- silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
444
- if os.path.exists(silence_latent_path):
445
- self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
446
- # Always keep silence_latent on GPU - it's used in many places outside model context
447
- # and is small enough that it won't significantly impact VRAM
448
- self.silence_latent = self.silence_latent.to(device).to(self.dtype)
449
  else:
450
- raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
 
 
 
 
 
 
 
451
  else:
452
  raise FileNotFoundError(f"ACE-Step V1.5 checkpoint not found at {acestep_v15_checkpoint_path}")
453
 
454
- # 2. Load VAE
455
- vae_checkpoint_path = os.path.join(checkpoint_dir, "vae")
456
- if os.path.exists(vae_checkpoint_path):
457
- self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
458
- # Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
459
- vae_dtype = self._get_vae_dtype(device)
460
- if not self.offload_to_cpu:
461
- self.vae = self.vae.to(device).to(vae_dtype)
462
- else:
463
- self.vae = self.vae.to("cpu").to(vae_dtype)
464
- self.vae.eval()
465
  else:
466
- raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
467
-
468
- if compile_model:
469
- self.vae = torch.compile(self.vae)
470
-
471
- # 3. Load text encoder and tokenizer
472
- text_encoder_path = os.path.join(checkpoint_dir, "Qwen3-Embedding-0.6B")
473
- if os.path.exists(text_encoder_path):
474
- self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)
475
- self.text_encoder = AutoModel.from_pretrained(text_encoder_path)
476
- if not self.offload_to_cpu:
477
- self.text_encoder = self.text_encoder.to(device).to(self.dtype)
478
  else:
479
- self.text_encoder = self.text_encoder.to("cpu").to(self.dtype)
480
- self.text_encoder.eval()
 
 
 
 
 
 
 
 
 
481
  else:
482
- raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}")
 
 
 
 
 
 
 
 
 
483
 
484
  # Determine actual attention implementation used
485
  actual_attn = getattr(self.config, "_attn_implementation", "eager")
486
 
 
 
 
487
  status_msg = f"✅ Model initialized successfully on {device}\n"
488
  status_msg += f"Main model: {acestep_v15_checkpoint_path}\n"
489
- status_msg += f"VAE: {vae_checkpoint_path}\n"
490
- status_msg += f"Text encoder: {text_encoder_path}\n"
 
 
 
 
 
 
491
  status_msg += f"Dtype: {self.dtype}\n"
492
  status_msg += f"Attention: {actual_attn}\n"
493
  status_msg += f"Compiled: {compile_model}\n"
 
315
  offload_to_cpu: bool = False,
316
  offload_dit_to_cpu: bool = False,
317
  quantization: Optional[str] = None,
318
+ # Shared components (for multi-model setup to save memory)
319
+ shared_vae = None,
320
+ shared_text_encoder = None,
321
+ shared_text_tokenizer = None,
322
+ shared_silence_latent = None,
323
  ) -> Tuple[str, bool]:
324
  """
325
  Initialize DiT model service
 
332
  compile_model: Whether to use torch.compile to optimize the model
333
  offload_to_cpu: Whether to offload models to CPU when not in use
334
  offload_dit_to_cpu: Whether to offload DiT model to CPU when not in use (only effective if offload_to_cpu is True)
335
+ shared_vae: Optional shared VAE instance (for multi-model setup)
336
+ shared_text_encoder: Optional shared text encoder instance (for multi-model setup)
337
+ shared_text_tokenizer: Optional shared text tokenizer instance (for multi-model setup)
338
+ shared_silence_latent: Optional shared silence latent tensor (for multi-model setup)
339
 
340
  Returns:
341
  (status_message, enable_generate_button)
 
449
  logger.info(f"[initialize_service] DiT quantized with: {self.quantization}")
450
 
451
 
452
+ # Load or use shared silence_latent
453
+ if shared_silence_latent is not None:
454
+ self.silence_latent = shared_silence_latent
455
+ logger.info("[initialize_service] Using shared silence_latent")
 
 
456
  else:
457
+ silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
458
+ if os.path.exists(silence_latent_path):
459
+ self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
460
+ # Always keep silence_latent on GPU - it's used in many places outside model context
461
+ # and is small enough that it won't significantly impact VRAM
462
+ self.silence_latent = self.silence_latent.to(device).to(self.dtype)
463
+ else:
464
+ raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
465
  else:
466
  raise FileNotFoundError(f"ACE-Step V1.5 checkpoint not found at {acestep_v15_checkpoint_path}")
467
 
468
+ # 2. Load or use shared VAE
469
+ vae_checkpoint_path = os.path.join(checkpoint_dir, "vae") # Define for status message
470
+ if shared_vae is not None:
471
+ self.vae = shared_vae
472
+ logger.info("[initialize_service] Using shared VAE")
 
 
 
 
 
 
473
  else:
474
+ if os.path.exists(vae_checkpoint_path):
475
+ self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
476
+ # Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
477
+ vae_dtype = self._get_vae_dtype(device)
478
+ if not self.offload_to_cpu:
479
+ self.vae = self.vae.to(device).to(vae_dtype)
480
+ else:
481
+ self.vae = self.vae.to("cpu").to(vae_dtype)
482
+ self.vae.eval()
 
 
 
483
  else:
484
+ raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
485
+
486
+ if compile_model:
487
+ self.vae = torch.compile(self.vae)
488
+
489
+ # 3. Load or use shared text encoder and tokenizer
490
+ text_encoder_path = os.path.join(checkpoint_dir, "Qwen3-Embedding-0.6B") # Define for status message
491
+ if shared_text_encoder is not None and shared_text_tokenizer is not None:
492
+ self.text_encoder = shared_text_encoder
493
+ self.text_tokenizer = shared_text_tokenizer
494
+ logger.info("[initialize_service] Using shared text encoder and tokenizer")
495
  else:
496
+ if os.path.exists(text_encoder_path):
497
+ self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)
498
+ self.text_encoder = AutoModel.from_pretrained(text_encoder_path)
499
+ if not self.offload_to_cpu:
500
+ self.text_encoder = self.text_encoder.to(device).to(self.dtype)
501
+ else:
502
+ self.text_encoder = self.text_encoder.to("cpu").to(self.dtype)
503
+ self.text_encoder.eval()
504
+ else:
505
+ raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}")
506
 
507
  # Determine actual attention implementation used
508
  actual_attn = getattr(self.config, "_attn_implementation", "eager")
509
 
510
+ # Determine if using shared components
511
+ using_shared = shared_vae is not None or shared_text_encoder is not None
512
+
513
  status_msg = f"✅ Model initialized successfully on {device}\n"
514
  status_msg += f"Main model: {acestep_v15_checkpoint_path}\n"
515
+ if shared_vae is None:
516
+ status_msg += f"VAE: {vae_checkpoint_path}\n"
517
+ else:
518
+ status_msg += f"VAE: shared\n"
519
+ if shared_text_encoder is None:
520
+ status_msg += f"Text encoder: {text_encoder_path}\n"
521
+ else:
522
+ status_msg += f"Text encoder: shared\n"
523
  status_msg += f"Dtype: {self.dtype}\n"
524
  status_msg += f"Attention: {actual_attn}\n"
525
  status_msg += f"Compiled: {compile_model}\n"
app.py CHANGED
@@ -53,7 +53,24 @@ def get_persistent_storage_path():
53
  1. Must be enabled in Space settings
54
  2. Path is typically /data for Docker SDK
55
  3. Falls back to app directory if /data is not writable
 
 
 
 
 
56
  """
 
 
 
 
 
 
 
 
 
 
 
 
57
  # Try HuggingFace Space persistent storage first
58
  hf_data_path = "/data"
59
 
@@ -80,6 +97,14 @@ def get_persistent_storage_path():
80
  def main():
81
  """Main entry point for HuggingFace Space"""
82
 
 
 
 
 
 
 
 
 
83
  # Get persistent storage path (auto-detect)
84
  persistent_storage_path = get_persistent_storage_path()
85
 
@@ -87,14 +112,15 @@ def main():
87
  gpu_memory_gb = get_gpu_memory_gb()
88
  auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
89
 
90
- if auto_offload:
91
- print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)")
92
- print("Auto-enabling CPU offload to reduce GPU memory usage")
93
- elif gpu_memory_gb > 0:
94
- print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (>= 16GB)")
95
- print("CPU offload disabled by default")
96
- else:
97
- print("No GPU detected, running on CPU")
 
98
 
99
  # Create handler instances
100
  print("Creating handlers...")
@@ -107,6 +133,9 @@ def main():
107
  "SERVICE_MODE_DIT_MODEL",
108
  "acestep-v15-turbo"
109
  )
 
 
 
110
  lm_model_path = os.environ.get(
111
  "SERVICE_MODE_LM_MODEL",
112
  "acestep-5Hz-lm-1.7B"
@@ -115,50 +144,97 @@ def main():
115
  device = "auto"
116
 
117
  print(f"Service mode configuration:")
118
- print(f" DiT model: {config_path}")
 
 
119
  print(f" LM model: {lm_model_path}")
120
  print(f" Backend: {backend}")
121
  print(f" Offload to CPU: {auto_offload}")
 
122
 
123
  # Determine flash attention availability
124
  use_flash_attention = dit_handler.is_flash_attention_available()
125
  print(f" Flash Attention: {use_flash_attention}")
126
 
127
- # Initialize DiT model
128
- print(f"Initializing DiT model: {config_path}...")
129
- init_status, enable_generate = dit_handler.initialize_service(
130
- project_root=current_dir,
131
- config_path=config_path,
132
- device=device,
133
- use_flash_attention=use_flash_attention,
134
- compile_model=False,
135
- offload_to_cpu=auto_offload,
136
- offload_dit_to_cpu=False
137
- )
138
 
139
- if not enable_generate:
140
- print(f"Warning: DiT model initialization issue: {init_status}", file=sys.stderr)
 
 
 
141
  else:
142
- print("DiT model initialized successfully")
143
-
144
- # Initialize LM model
145
- checkpoint_dir = dit_handler._get_checkpoint_dir()
146
- print(f"Initializing 5Hz LM: {lm_model_path}...")
147
- lm_status, lm_success = llm_handler.initialize(
148
- checkpoint_dir=checkpoint_dir,
149
- lm_model_path=lm_model_path,
150
- backend=backend,
151
- device=device,
152
- offload_to_cpu=auto_offload,
153
- dtype=dit_handler.dtype
154
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- if lm_success:
157
- print("5Hz LM initialized successfully")
158
- init_status += f"\n{lm_status}"
159
- else:
160
- print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
161
- init_status += f"\n{lm_status}"
162
 
163
  # Prepare initialization parameters for UI
164
  init_params = {
@@ -166,6 +242,7 @@ def main():
166
  'service_mode': True,
167
  'checkpoint': None,
168
  'config_path': config_path,
 
169
  'device': device,
170
  'init_llm': True,
171
  'lm_model_path': lm_model_path,
@@ -176,9 +253,12 @@ def main():
176
  'init_status': init_status,
177
  'enable_generate': enable_generate,
178
  'dit_handler': dit_handler,
 
 
179
  'llm_handler': llm_handler,
180
  'language': 'en',
181
  'persistent_storage_path': persistent_storage_path,
 
182
  }
183
 
184
  print("Service initialization completed!")
 
53
  1. Must be enabled in Space settings
54
  2. Path is typically /data for Docker SDK
55
  3. Falls back to app directory if /data is not writable
56
+
57
+ Local development:
58
+ - Set CHECKPOINT_DIR environment variable to use local checkpoints
59
+ Example: CHECKPOINT_DIR=/path/to/checkpoints python app.py
60
+ The path should be the parent directory of 'checkpoints' folder
61
  """
62
+ # Check for local checkpoint directory override (for development)
63
+ checkpoint_dir_override = os.environ.get("CHECKPOINT_DIR")
64
+ if checkpoint_dir_override:
65
+ # If user specifies the checkpoints folder directly, use its parent
66
+ if checkpoint_dir_override.endswith("/checkpoints") or checkpoint_dir_override.endswith("\\checkpoints"):
67
+ checkpoint_dir_override = os.path.dirname(checkpoint_dir_override)
68
+ if os.path.exists(checkpoint_dir_override):
69
+ print(f"Using local checkpoint directory (CHECKPOINT_DIR): {checkpoint_dir_override}")
70
+ return checkpoint_dir_override
71
+ else:
72
+ print(f"Warning: CHECKPOINT_DIR path does not exist: {checkpoint_dir_override}")
73
+
74
  # Try HuggingFace Space persistent storage first
75
  hf_data_path = "/data"
76
 
 
97
  def main():
98
  """Main entry point for HuggingFace Space"""
99
 
100
+ # Check for DEBUG_UI mode (skip model initialization for UI development)
101
+ debug_ui = os.environ.get("DEBUG_UI", "").lower() in ("1", "true", "yes")
102
+ if debug_ui:
103
+ print("=" * 60)
104
+ print("DEBUG_UI mode enabled - skipping model initialization")
105
+ print("UI will be fully functional but generation is disabled")
106
+ print("=" * 60)
107
+
108
  # Get persistent storage path (auto-detect)
109
  persistent_storage_path = get_persistent_storage_path()
110
 
 
112
  gpu_memory_gb = get_gpu_memory_gb()
113
  auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
114
 
115
+ if not debug_ui:
116
+ if auto_offload:
117
+ print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)")
118
+ print("Auto-enabling CPU offload to reduce GPU memory usage")
119
+ elif gpu_memory_gb > 0:
120
+ print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (>= 16GB)")
121
+ print("CPU offload disabled by default")
122
+ else:
123
+ print("No GPU detected, running on CPU")
124
 
125
  # Create handler instances
126
  print("Creating handlers...")
 
133
  "SERVICE_MODE_DIT_MODEL",
134
  "acestep-v15-turbo"
135
  )
136
+ # Second DiT model - default to turbo-shift3 for two-model setup
137
+ config_path_2 = os.environ.get("SERVICE_MODE_DIT_MODEL_2", "acestep-v15-turbo-shift3").strip()
138
+
139
  lm_model_path = os.environ.get(
140
  "SERVICE_MODE_LM_MODEL",
141
  "acestep-5Hz-lm-1.7B"
 
144
  device = "auto"
145
 
146
  print(f"Service mode configuration:")
147
+ print(f" DiT model 1: {config_path}")
148
+ if config_path_2:
149
+ print(f" DiT model 2: {config_path_2}")
150
  print(f" LM model: {lm_model_path}")
151
  print(f" Backend: {backend}")
152
  print(f" Offload to CPU: {auto_offload}")
153
+ print(f" DEBUG_UI: {debug_ui}")
154
 
155
  # Determine flash attention availability
156
  use_flash_attention = dit_handler.is_flash_attention_available()
157
  print(f" Flash Attention: {use_flash_attention}")
158
 
159
+ # Initialize models (skip in DEBUG_UI mode)
160
+ init_status = ""
161
+ enable_generate = False
162
+ dit_handler_2 = None
 
 
 
 
 
 
 
163
 
164
+ if debug_ui:
165
+ # In DEBUG_UI mode, skip all model initialization
166
+ init_status = "⚠️ DEBUG_UI mode - models not loaded\nUI is functional but generation is disabled"
167
+ enable_generate = False
168
+ print("Skipping model initialization (DEBUG_UI mode)")
169
  else:
170
+ # Initialize primary DiT model
171
+ print(f"Initializing DiT model 1: {config_path}...")
172
+ init_status, enable_generate = dit_handler.initialize_service(
173
+ project_root=current_dir,
174
+ config_path=config_path,
175
+ device=device,
176
+ use_flash_attention=use_flash_attention,
177
+ compile_model=False,
178
+ offload_to_cpu=auto_offload,
179
+ offload_dit_to_cpu=False
180
+ )
181
+
182
+ if not enable_generate:
183
+ print(f"Warning: DiT model 1 initialization issue: {init_status}", file=sys.stderr)
184
+ else:
185
+ print("DiT model 1 initialized successfully")
186
+
187
+ # Initialize second DiT model if configured
188
+ if config_path_2:
189
+ print(f"Initializing DiT model 2: {config_path_2}...")
190
+ dit_handler_2 = AceStepHandler(persistent_storage_path=persistent_storage_path)
191
+
192
+ # Share VAE, text_encoder, and silence_latent from the first handler to save memory
193
+ init_status_2, enable_generate_2 = dit_handler_2.initialize_service(
194
+ project_root=current_dir,
195
+ config_path=config_path_2,
196
+ device=device,
197
+ use_flash_attention=use_flash_attention,
198
+ compile_model=False,
199
+ offload_to_cpu=auto_offload,
200
+ offload_dit_to_cpu=False,
201
+ # Share components from first handler
202
+ shared_vae=dit_handler.vae,
203
+ shared_text_encoder=dit_handler.text_encoder,
204
+ shared_text_tokenizer=dit_handler.text_tokenizer,
205
+ shared_silence_latent=dit_handler.silence_latent,
206
+ )
207
+
208
+ if not enable_generate_2:
209
+ print(f"Warning: DiT model 2 initialization issue: {init_status_2}", file=sys.stderr)
210
+ init_status += f"\n⚠️ DiT model 2 failed: {init_status_2}"
211
+ else:
212
+ print("DiT model 2 initialized successfully")
213
+ init_status += f"\n✅ DiT model 2: {config_path_2}"
214
+
215
+ # Initialize LM model
216
+ checkpoint_dir = dit_handler._get_checkpoint_dir()
217
+ print(f"Initializing 5Hz LM: {lm_model_path}...")
218
+ lm_status, lm_success = llm_handler.initialize(
219
+ checkpoint_dir=checkpoint_dir,
220
+ lm_model_path=lm_model_path,
221
+ backend=backend,
222
+ device=device,
223
+ offload_to_cpu=auto_offload,
224
+ dtype=dit_handler.dtype
225
+ )
226
+
227
+ if lm_success:
228
+ print("5Hz LM initialized successfully")
229
+ init_status += f"\n{lm_status}"
230
+ else:
231
+ print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
232
+ init_status += f"\n{lm_status}"
233
 
234
+ # Build available models list for UI
235
+ available_dit_models = [config_path]
236
+ if config_path_2 and dit_handler_2 is not None:
237
+ available_dit_models.append(config_path_2)
 
 
238
 
239
  # Prepare initialization parameters for UI
240
  init_params = {
 
242
  'service_mode': True,
243
  'checkpoint': None,
244
  'config_path': config_path,
245
+ 'config_path_2': config_path_2 if config_path_2 else None,
246
  'device': device,
247
  'init_llm': True,
248
  'lm_model_path': lm_model_path,
 
253
  'init_status': init_status,
254
  'enable_generate': enable_generate,
255
  'dit_handler': dit_handler,
256
+ 'dit_handler_2': dit_handler_2,
257
+ 'available_dit_models': available_dit_models,
258
  'llm_handler': llm_handler,
259
  'language': 'en',
260
  'persistent_storage_path': persistent_storage_path,
261
+ 'debug_ui': debug_ui,
262
  }
263
 
264
  print("Service initialization completed!")