ChuxiJ commited on
Commit
a368091
·
1 Parent(s): 4477394

support save & load

Browse files
Files changed (2) hide show
  1. acestep/gradio_ui.py +420 -43
  2. acestep/llm_inference.py +2 -1
acestep/gradio_ui.py CHANGED
@@ -308,7 +308,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
308
  else:
309
  initial_task_choices = TASK_TYPES_BASE
310
 
311
- with gr.Row():
312
  with gr.Column(scale=2):
313
  task_type = gr.Dropdown(
314
  choices=initial_task_choices,
@@ -316,7 +316,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
316
  label="Task Type",
317
  info="Select the task type for generation",
318
  )
319
- with gr.Column(scale=8):
320
  instruction_display_gen = gr.Textbox(
321
  label="Instruction",
322
  value=DEFAULT_DIT_INSTRUCTION,
@@ -324,6 +324,14 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
324
  lines=1,
325
  info="Instruction is automatically generated based on task type",
326
  )
 
 
 
 
 
 
 
 
327
 
328
  track_name = gr.Dropdown(
329
  choices=TRACK_NAMES,
@@ -486,15 +494,22 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
486
  info="Higher values follow text more closely",
487
  visible=False
488
  )
489
- seed = gr.Textbox(
490
- label="Seed",
491
- value="-1",
492
- info="Use comma-separated values for batches"
493
- )
494
- random_seed_checkbox = gr.Checkbox(
495
- label="Random Seed",
496
- value=True,
497
- info="Enable to auto-generate seeds"
 
 
 
 
 
 
 
498
  )
499
 
500
  with gr.Row():
@@ -522,15 +537,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
522
  label="CFG Interval End",
523
  visible=False
524
  )
525
-
526
- with gr.Row():
527
- audio_format = gr.Dropdown(
528
- choices=["mp3", "flac"],
529
- value="mp3",
530
- label="Audio Format",
531
- info="Audio format for saved files"
532
- )
533
-
534
  # LM (Language Model) Parameters
535
  gr.HTML("<h4>🤖 LM Generation Parameters</h4>")
536
  with gr.Row():
@@ -582,6 +589,12 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
582
  )
583
 
584
  with gr.Row():
 
 
 
 
 
 
585
  use_cot_caption = gr.Checkbox(
586
  label="CoT Caption",
587
  value=True,
@@ -654,6 +667,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
654
  "lm_top_k": lm_top_k,
655
  "lm_top_p": lm_top_p,
656
  "lm_negative_prompt": lm_negative_prompt,
 
657
  "use_cot_caption": use_cot_caption,
658
  "use_cot_language": use_cot_language,
659
  "repainting_group": repainting_group,
@@ -662,6 +676,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
662
  "audio_cover_strength": audio_cover_strength,
663
  "captions": captions,
664
  "sample_btn": sample_btn,
 
665
  "lyrics": lyrics,
666
  "vocal_language": vocal_language,
667
  "bpm": bpm,
@@ -691,6 +706,9 @@ def create_results_section(dit_handler) -> dict:
691
  # Hidden state to store LM-generated metadata
692
  lm_metadata_state = gr.State(value=None)
693
 
 
 
 
694
  status_output = gr.Textbox(label="Generation Status", interactive=False)
695
 
696
  with gr.Row():
@@ -700,22 +718,38 @@ def create_results_section(dit_handler) -> dict:
700
  type="filepath",
701
  interactive=False
702
  )
703
- send_to_src_btn_1 = gr.Button(
704
- "Send To Src Audio",
705
- variant="secondary",
706
- size="sm"
707
- )
 
 
 
 
 
 
 
 
708
  with gr.Column():
709
  generated_audio_2 = gr.Audio(
710
  label="🎵 Generated Music (Sample 2)",
711
  type="filepath",
712
  interactive=False
713
  )
714
- send_to_src_btn_2 = gr.Button(
715
- "Send To Src Audio",
716
- variant="secondary",
717
- size="sm"
718
- )
 
 
 
 
 
 
 
 
719
 
720
  with gr.Accordion("📁 Batch Results & Generation Details", open=False):
721
  generated_audio_batch = gr.File(
@@ -738,11 +772,14 @@ def create_results_section(dit_handler) -> dict:
738
 
739
  return {
740
  "lm_metadata_state": lm_metadata_state,
 
741
  "status_output": status_output,
742
  "generated_audio_1": generated_audio_1,
743
  "generated_audio_2": generated_audio_2,
744
  "send_to_src_btn_1": send_to_src_btn_1,
745
  "send_to_src_btn_2": send_to_src_btn_2,
 
 
746
  "generated_audio_batch": generated_audio_batch,
747
  "generation_info": generation_info,
748
  "align_score_1": align_score_1,
@@ -757,6 +794,161 @@ def create_results_section(dit_handler) -> dict:
757
  def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section):
758
  """Setup event handlers connecting UI components and business logic"""
759
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760
  def load_random_example(task_type: str):
761
  """Load a random example from the task-specific examples directory
762
 
@@ -1092,14 +1284,14 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1092
  instruction_display_gen, audio_cover_strength, task_type,
1093
  use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
1094
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
1095
- use_cot_caption, use_cot_language,
1096
  progress=gr.Progress(track_tqdm=True)
1097
  ):
1098
- # If think is enabled (llm_dit mode), generate audio codes using LM first
1099
  audio_code_string_to_use = text2music_audio_code_string
1100
  lm_generated_metadata = None # Store LM-generated metadata for display
1101
  lm_generated_audio_codes = None # Store LM-generated audio codes for display
1102
- if think_checkbox and llm_handler.llm_initialized:
1103
  # Convert top_k: 0 means None (disabled)
1104
  top_k_value = None if lm_top_k == 0 else int(lm_top_k)
1105
  # Convert top_p: 1.0 means None (disabled)
@@ -1149,6 +1341,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1149
  user_metadata=user_metadata_to_pass,
1150
  use_cot_caption=use_cot_caption,
1151
  use_cot_language=use_cot_language,
 
1152
  )
1153
 
1154
  # Store LM-generated metadata and audio codes for display
@@ -1238,7 +1431,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1238
  align_text_2,
1239
  align_plot_2,
1240
  updated_audio_codes, # Update audio codes in UI
1241
- lm_generated_metadata # Store metadata for "Send to src audio" buttons
 
1242
  )
1243
 
1244
  generation_section["generate_btn"].click(
@@ -1274,8 +1468,10 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1274
  generation_section["lm_top_k"],
1275
  generation_section["lm_top_p"],
1276
  generation_section["lm_negative_prompt"],
 
1277
  generation_section["use_cot_caption"],
1278
- generation_section["use_cot_language"]
 
1279
  ],
1280
  outputs=[
1281
  results_section["generated_audio_1"],
@@ -1291,7 +1487,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1291
  results_section["align_text_2"],
1292
  results_section["align_plot_2"],
1293
  generation_section["text2music_audio_code_string"], # Update audio codes display
1294
- results_section["lm_metadata_state"] # Store metadata
 
1295
  ]
1296
  )
1297
 
@@ -1420,10 +1617,10 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1420
  lm_metadata: Dictionary containing LM-generated metadata
1421
 
1422
  Returns:
1423
- Tuple of (audio_file, bpm, caption, duration, key_scale, language, time_signature)
1424
  """
1425
  if audio_file is None:
1426
- return None, None, None, None, None, None, None
1427
 
1428
  # Extract metadata fields if available
1429
  bpm_value = None
@@ -1481,7 +1678,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1481
  duration_value,
1482
  key_scale_value,
1483
  language_value,
1484
- time_signature_value
 
1485
  )
1486
 
1487
  results_section["send_to_src_btn_1"].click(
@@ -1497,7 +1695,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1497
  generation_section["audio_duration"],
1498
  generation_section["key_scale"],
1499
  generation_section["vocal_language"],
1500
- generation_section["time_signature"]
 
1501
  ]
1502
  )
1503
 
@@ -1514,13 +1713,21 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1514
  generation_section["audio_duration"],
1515
  generation_section["key_scale"],
1516
  generation_section["vocal_language"],
1517
- generation_section["time_signature"]
 
1518
  ]
1519
  )
1520
 
1521
  # Sample button - smart sample (uses LM if initialized, otherwise examples)
 
 
 
 
 
 
 
1522
  generation_section["sample_btn"].click(
1523
- fn=sample_example_smart,
1524
  inputs=[generation_section["task_type"]],
1525
  outputs=[
1526
  generation_section["captions"],
@@ -1531,6 +1738,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1531
  generation_section["key_scale"],
1532
  generation_section["vocal_language"],
1533
  generation_section["time_signature"],
 
1534
  ]
1535
  )
1536
 
@@ -1585,7 +1793,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1585
  duration,
1586
  keyscale,
1587
  language,
1588
- timesignature
 
1589
  )
1590
 
1591
  # Update transcribe button text based on whether codes are present
@@ -1619,9 +1828,58 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1619
  generation_section["key_scale"], # Update keyscale field
1620
  generation_section["vocal_language"], # Update language field
1621
  generation_section["time_signature"], # Update time signature field
 
1622
  ]
1623
  )
1624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1625
  # Auto-expand Audio Uploads accordion when audio is uploaded
1626
  def update_audio_uploads_accordion(reference_audio, src_audio):
1627
  """Update Audio Uploads accordion open state based on whether audio files are present"""
@@ -1640,4 +1898,123 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1640
  inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
1641
  outputs=[generation_section["audio_uploads_accordion"]]
1642
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1643
 
 
308
  else:
309
  initial_task_choices = TASK_TYPES_BASE
310
 
311
+ with gr.Row(equal_height=True):
312
  with gr.Column(scale=2):
313
  task_type = gr.Dropdown(
314
  choices=initial_task_choices,
 
316
  label="Task Type",
317
  info="Select the task type for generation",
318
  )
319
+ with gr.Column(scale=7):
320
  instruction_display_gen = gr.Textbox(
321
  label="Instruction",
322
  value=DEFAULT_DIT_INSTRUCTION,
 
324
  lines=1,
325
  info="Instruction is automatically generated based on task type",
326
  )
327
+ with gr.Column(scale=1, min_width=100):
328
+ load_file = gr.UploadButton(
329
+ "Load",
330
+ file_types=[".json"],
331
+ file_count="single",
332
+ variant="secondary",
333
+ size="sm",
334
+ )
335
 
336
  track_name = gr.Dropdown(
337
  choices=TRACK_NAMES,
 
494
  info="Higher values follow text more closely",
495
  visible=False
496
  )
497
+ with gr.Column():
498
+ seed = gr.Textbox(
499
+ label="Seed",
500
+ value="-1",
501
+ info="Use comma-separated values for batches"
502
+ )
503
+ random_seed_checkbox = gr.Checkbox(
504
+ label="Random Seed",
505
+ value=True,
506
+ info="Enable to auto-generate seeds"
507
+ )
508
+ audio_format = gr.Dropdown(
509
+ choices=["mp3", "flac"],
510
+ value="mp3",
511
+ label="Audio Format",
512
+ info="Audio format for saved files"
513
  )
514
 
515
  with gr.Row():
 
537
  label="CFG Interval End",
538
  visible=False
539
  )
540
+
 
 
 
 
 
 
 
 
541
  # LM (Language Model) Parameters
542
  gr.HTML("<h4>🤖 LM Generation Parameters</h4>")
543
  with gr.Row():
 
589
  )
590
 
591
  with gr.Row():
592
+ use_cot_metas = gr.Checkbox(
593
+ label="CoT Metas",
594
+ value=True,
595
+ info="Use LM to generate CoT metadata (uncheck to skip LM CoT generation)",
596
+ scale=1,
597
+ )
598
  use_cot_caption = gr.Checkbox(
599
  label="CoT Caption",
600
  value=True,
 
667
  "lm_top_k": lm_top_k,
668
  "lm_top_p": lm_top_p,
669
  "lm_negative_prompt": lm_negative_prompt,
670
+ "use_cot_metas": use_cot_metas,
671
  "use_cot_caption": use_cot_caption,
672
  "use_cot_language": use_cot_language,
673
  "repainting_group": repainting_group,
 
676
  "audio_cover_strength": audio_cover_strength,
677
  "captions": captions,
678
  "sample_btn": sample_btn,
679
+ "load_file": load_file,
680
  "lyrics": lyrics,
681
  "vocal_language": vocal_language,
682
  "bpm": bpm,
 
706
  # Hidden state to store LM-generated metadata
707
  lm_metadata_state = gr.State(value=None)
708
 
709
+ # Hidden state to track if caption/metadata is from formatted source (LM/transcription)
710
+ is_format_caption_state = gr.State(value=False)
711
+
712
  status_output = gr.Textbox(label="Generation Status", interactive=False)
713
 
714
  with gr.Row():
 
718
  type="filepath",
719
  interactive=False
720
  )
721
+ with gr.Row(equal_height=True):
722
+ send_to_src_btn_1 = gr.Button(
723
+ "Send To Src Audio",
724
+ variant="secondary",
725
+ size="sm",
726
+ scale=1
727
+ )
728
+ save_btn_1 = gr.Button(
729
+ "💾 Save",
730
+ variant="primary",
731
+ size="sm",
732
+ scale=1
733
+ )
734
  with gr.Column():
735
  generated_audio_2 = gr.Audio(
736
  label="🎵 Generated Music (Sample 2)",
737
  type="filepath",
738
  interactive=False
739
  )
740
+ with gr.Row(equal_height=True):
741
+ send_to_src_btn_2 = gr.Button(
742
+ "Send To Src Audio",
743
+ variant="secondary",
744
+ size="sm",
745
+ scale=1
746
+ )
747
+ save_btn_2 = gr.Button(
748
+ "💾 Save",
749
+ variant="primary",
750
+ size="sm",
751
+ scale=1
752
+ )
753
 
754
  with gr.Accordion("📁 Batch Results & Generation Details", open=False):
755
  generated_audio_batch = gr.File(
 
772
 
773
  return {
774
  "lm_metadata_state": lm_metadata_state,
775
+ "is_format_caption_state": is_format_caption_state,
776
  "status_output": status_output,
777
  "generated_audio_1": generated_audio_1,
778
  "generated_audio_2": generated_audio_2,
779
  "send_to_src_btn_1": send_to_src_btn_1,
780
  "send_to_src_btn_2": send_to_src_btn_2,
781
+ "save_btn_1": save_btn_1,
782
+ "save_btn_2": save_btn_2,
783
  "generated_audio_batch": generated_audio_batch,
784
  "generation_info": generation_info,
785
  "align_score_1": align_score_1,
 
794
  def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section):
795
  """Setup event handlers connecting UI components and business logic"""
796
 
797
+ def save_metadata(
798
+ task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature, audio_duration,
799
+ batch_size_input, inference_steps, guidance_scale, seed, random_seed_checkbox,
800
+ use_adg, cfg_interval_start, cfg_interval_end, audio_format,
801
+ lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
802
+ use_cot_caption, use_cot_language, audio_cover_strength,
803
+ think_checkbox, text2music_audio_code_string, repainting_start, repainting_end,
804
+ track_name, complete_track_classes, lm_metadata
805
+ ):
806
+ """Save all generation parameters to a JSON file"""
807
+ import datetime
808
+
809
+ # Create metadata dictionary
810
+ metadata = {
811
+ "saved_at": datetime.datetime.now().isoformat(),
812
+ "task_type": task_type,
813
+ "caption": captions or "",
814
+ "lyrics": lyrics or "",
815
+ "vocal_language": vocal_language,
816
+ "bpm": bpm if bpm is not None else None,
817
+ "keyscale": key_scale or "",
818
+ "timesignature": time_signature or "",
819
+ "duration": audio_duration if audio_duration is not None else -1,
820
+ "batch_size": batch_size_input,
821
+ "inference_steps": inference_steps,
822
+ "guidance_scale": guidance_scale,
823
+ "seed": seed,
824
+ "random_seed": False, # Disable random seed for reproducibility
825
+ "use_adg": use_adg,
826
+ "cfg_interval_start": cfg_interval_start,
827
+ "cfg_interval_end": cfg_interval_end,
828
+ "audio_format": audio_format,
829
+ "lm_temperature": lm_temperature,
830
+ "lm_cfg_scale": lm_cfg_scale,
831
+ "lm_top_k": lm_top_k,
832
+ "lm_top_p": lm_top_p,
833
+ "lm_negative_prompt": lm_negative_prompt,
834
+ "use_cot_caption": use_cot_caption,
835
+ "use_cot_language": use_cot_language,
836
+ "audio_cover_strength": audio_cover_strength,
837
+ "think": think_checkbox,
838
+ "audio_codes": text2music_audio_code_string or "",
839
+ "repainting_start": repainting_start,
840
+ "repainting_end": repainting_end,
841
+ "track_name": track_name,
842
+ "complete_track_classes": complete_track_classes or [],
843
+ }
844
+
845
+ # Add LM-generated metadata if available
846
+ if lm_metadata:
847
+ metadata["lm_generated_metadata"] = lm_metadata
848
+
849
+ # Save to file
850
+ try:
851
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
852
+ filename = f"generation_params_{timestamp}.json"
853
+
854
+ with open(filename, 'w', encoding='utf-8') as f:
855
+ json.dump(metadata, f, indent=2, ensure_ascii=False)
856
+
857
+ gr.Info(f"✅ Parameters saved to {filename}")
858
+ return filename
859
+ except Exception as e:
860
+ gr.Warning(f"❌ Failed to save parameters: {str(e)}")
861
+ return None
862
+
863
+ def load_metadata(file_obj):
864
+ """Load generation parameters from a JSON file"""
865
+ if file_obj is None:
866
+ gr.Warning("⚠️ No file selected")
867
+ return [None] * 31 + [False] # Return None for all fields, False for is_format_caption
868
+
869
+ try:
870
+ # Read the uploaded file
871
+ if hasattr(file_obj, 'name'):
872
+ filepath = file_obj.name
873
+ else:
874
+ filepath = file_obj
875
+
876
+ with open(filepath, 'r', encoding='utf-8') as f:
877
+ metadata = json.load(f)
878
+
879
+ # Extract all fields
880
+ task_type = metadata.get('task_type', 'text2music')
881
+ captions = metadata.get('caption', '')
882
+ lyrics = metadata.get('lyrics', '')
883
+ vocal_language = metadata.get('vocal_language', 'unknown')
884
+
885
+ # Convert bpm
886
+ bpm_value = metadata.get('bpm')
887
+ if bpm_value is not None and bpm_value != "N/A":
888
+ try:
889
+ bpm = int(bpm_value) if bpm_value else None
890
+ except:
891
+ bpm = None
892
+ else:
893
+ bpm = None
894
+
895
+ key_scale = metadata.get('keyscale', '')
896
+ time_signature = metadata.get('timesignature', '')
897
+
898
+ # Convert duration
899
+ duration_value = metadata.get('duration', -1)
900
+ if duration_value is not None and duration_value != "N/A":
901
+ try:
902
+ audio_duration = float(duration_value)
903
+ except:
904
+ audio_duration = -1
905
+ else:
906
+ audio_duration = -1
907
+
908
+ batch_size = metadata.get('batch_size', 2)
909
+ inference_steps = metadata.get('inference_steps', 8)
910
+ guidance_scale = metadata.get('guidance_scale', 7.0)
911
+ seed = metadata.get('seed', '-1')
912
+ random_seed = metadata.get('random_seed', True)
913
+ use_adg = metadata.get('use_adg', False)
914
+ cfg_interval_start = metadata.get('cfg_interval_start', 0.0)
915
+ cfg_interval_end = metadata.get('cfg_interval_end', 1.0)
916
+ audio_format = metadata.get('audio_format', 'mp3')
917
+ lm_temperature = metadata.get('lm_temperature', 0.85)
918
+ lm_cfg_scale = metadata.get('lm_cfg_scale', 2.0)
919
+ lm_top_k = metadata.get('lm_top_k', 0)
920
+ lm_top_p = metadata.get('lm_top_p', 0.9)
921
+ lm_negative_prompt = metadata.get('lm_negative_prompt', 'NO USER INPUT')
922
+ use_cot_caption = metadata.get('use_cot_caption', True)
923
+ use_cot_language = metadata.get('use_cot_language', True)
924
+ audio_cover_strength = metadata.get('audio_cover_strength', 1.0)
925
+ think = metadata.get('think', True)
926
+ audio_codes = metadata.get('audio_codes', '')
927
+ repainting_start = metadata.get('repainting_start', 0.0)
928
+ repainting_end = metadata.get('repainting_end', -1)
929
+ track_name = metadata.get('track_name')
930
+ complete_track_classes = metadata.get('complete_track_classes', [])
931
+
932
+ gr.Info(f"✅ Parameters loaded from {os.path.basename(filepath)}")
933
+
934
+ return (
935
+ task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature,
936
+ audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed,
937
+ use_adg, cfg_interval_start, cfg_interval_end, audio_format,
938
+ lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
939
+ use_cot_caption, use_cot_language, audio_cover_strength,
940
+ think, audio_codes, repainting_start, repainting_end,
941
+ track_name, complete_track_classes,
942
+ True # Set is_format_caption to True when loading from file
943
+ )
944
+
945
+ except json.JSONDecodeError as e:
946
+ gr.Warning(f"❌ Invalid JSON file: {str(e)}")
947
+ return [None] * 31 + [False]
948
+ except Exception as e:
949
+ gr.Warning(f"❌ Error loading file: {str(e)}")
950
+ return [None] * 31 + [False]
951
+
952
  def load_random_example(task_type: str):
953
  """Load a random example from the task-specific examples directory
954
 
 
1284
  instruction_display_gen, audio_cover_strength, task_type,
1285
  use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
1286
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
1287
+ use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
1288
  progress=gr.Progress(track_tqdm=True)
1289
  ):
1290
+ # If think is enabled (llm_dit mode) and use_cot_metas is True, generate audio codes using LM first
1291
  audio_code_string_to_use = text2music_audio_code_string
1292
  lm_generated_metadata = None # Store LM-generated metadata for display
1293
  lm_generated_audio_codes = None # Store LM-generated audio codes for display
1294
+ if think_checkbox and llm_handler.llm_initialized and use_cot_metas:
1295
  # Convert top_k: 0 means None (disabled)
1296
  top_k_value = None if lm_top_k == 0 else int(lm_top_k)
1297
  # Convert top_p: 1.0 means None (disabled)
 
1341
  user_metadata=user_metadata_to_pass,
1342
  use_cot_caption=use_cot_caption,
1343
  use_cot_language=use_cot_language,
1344
+ is_format_caption=is_format_caption,
1345
  )
1346
 
1347
  # Store LM-generated metadata and audio codes for display
 
1431
  align_text_2,
1432
  align_plot_2,
1433
  updated_audio_codes, # Update audio codes in UI
1434
+ lm_generated_metadata, # Store metadata for "Send to src audio" buttons
1435
+ is_format_caption # Keep is_format_caption unchanged (LM doesn't modify user input fields)
1436
  )
1437
 
1438
  generation_section["generate_btn"].click(
 
1468
  generation_section["lm_top_k"],
1469
  generation_section["lm_top_p"],
1470
  generation_section["lm_negative_prompt"],
1471
+ generation_section["use_cot_metas"],
1472
  generation_section["use_cot_caption"],
1473
+ generation_section["use_cot_language"],
1474
+ results_section["is_format_caption_state"]
1475
  ],
1476
  outputs=[
1477
  results_section["generated_audio_1"],
 
1487
  results_section["align_text_2"],
1488
  results_section["align_plot_2"],
1489
  generation_section["text2music_audio_code_string"], # Update audio codes display
1490
+ results_section["lm_metadata_state"], # Store metadata
1491
+ results_section["is_format_caption_state"] # Update is_format_caption state
1492
  ]
1493
  )
1494
 
 
1617
  lm_metadata: Dictionary containing LM-generated metadata
1618
 
1619
  Returns:
1620
+ Tuple of (audio_file, bpm, caption, duration, key_scale, language, time_signature, is_format_caption)
1621
  """
1622
  if audio_file is None:
1623
+ return None, None, None, None, None, None, None, True # Keep is_format_caption as True
1624
 
1625
  # Extract metadata fields if available
1626
  bpm_value = None
 
1678
  duration_value,
1679
  key_scale_value,
1680
  language_value,
1681
+ time_signature_value,
1682
+ True # Set is_format_caption to True (from LM-generated metadata)
1683
  )
1684
 
1685
  results_section["send_to_src_btn_1"].click(
 
1695
  generation_section["audio_duration"],
1696
  generation_section["key_scale"],
1697
  generation_section["vocal_language"],
1698
+ generation_section["time_signature"],
1699
+ results_section["is_format_caption_state"]
1700
  ]
1701
  )
1702
 
 
1713
  generation_section["audio_duration"],
1714
  generation_section["key_scale"],
1715
  generation_section["vocal_language"],
1716
+ generation_section["time_signature"],
1717
+ results_section["is_format_caption_state"]
1718
  ]
1719
  )
1720
 
1721
  # Sample button - smart sample (uses LM if initialized, otherwise examples)
1722
+ # Need to add is_format_caption return value to sample_example_smart
1723
+ def sample_example_smart_with_flag(task_type: str):
1724
+ """Wrapper for sample_example_smart that adds is_format_caption flag"""
1725
+ result = sample_example_smart(task_type)
1726
+ # Add True at the end to set is_format_caption
1727
+ return result + (True,)
1728
+
1729
  generation_section["sample_btn"].click(
1730
+ fn=sample_example_smart_with_flag,
1731
  inputs=[generation_section["task_type"]],
1732
  outputs=[
1733
  generation_section["captions"],
 
1738
  generation_section["key_scale"],
1739
  generation_section["vocal_language"],
1740
  generation_section["time_signature"],
1741
+ results_section["is_format_caption_state"] # Set is_format_caption to True (from Sample/LM)
1742
  ]
1743
  )
1744
 
 
1793
  duration,
1794
  keyscale,
1795
  language,
1796
+ timesignature,
1797
+ True # Set is_format_caption to True (from Transcribe/LM understanding)
1798
  )
1799
 
1800
  # Update transcribe button text based on whether codes are present
 
1828
  generation_section["key_scale"], # Update keyscale field
1829
  generation_section["vocal_language"], # Update language field
1830
  generation_section["time_signature"], # Update time signature field
1831
+ results_section["is_format_caption_state"] # Set is_format_caption to True
1832
  ]
1833
  )
1834
 
1835
+ # Reset is_format_caption to False when user manually edits fields
1836
+ def reset_format_caption_flag():
1837
+ """Reset is_format_caption to False when user manually edits caption/metadata"""
1838
+ return False
1839
+
1840
+ # Connect reset function to all user-editable metadata fields
1841
+ generation_section["captions"].change(
1842
+ fn=reset_format_caption_flag,
1843
+ inputs=[],
1844
+ outputs=[results_section["is_format_caption_state"]]
1845
+ )
1846
+
1847
+ generation_section["lyrics"].change(
1848
+ fn=reset_format_caption_flag,
1849
+ inputs=[],
1850
+ outputs=[results_section["is_format_caption_state"]]
1851
+ )
1852
+
1853
+ generation_section["bpm"].change(
1854
+ fn=reset_format_caption_flag,
1855
+ inputs=[],
1856
+ outputs=[results_section["is_format_caption_state"]]
1857
+ )
1858
+
1859
+ generation_section["key_scale"].change(
1860
+ fn=reset_format_caption_flag,
1861
+ inputs=[],
1862
+ outputs=[results_section["is_format_caption_state"]]
1863
+ )
1864
+
1865
+ generation_section["time_signature"].change(
1866
+ fn=reset_format_caption_flag,
1867
+ inputs=[],
1868
+ outputs=[results_section["is_format_caption_state"]]
1869
+ )
1870
+
1871
+ generation_section["vocal_language"].change(
1872
+ fn=reset_format_caption_flag,
1873
+ inputs=[],
1874
+ outputs=[results_section["is_format_caption_state"]]
1875
+ )
1876
+
1877
+ generation_section["audio_duration"].change(
1878
+ fn=reset_format_caption_flag,
1879
+ inputs=[],
1880
+ outputs=[results_section["is_format_caption_state"]]
1881
+ )
1882
+
1883
  # Auto-expand Audio Uploads accordion when audio is uploaded
1884
  def update_audio_uploads_accordion(reference_audio, src_audio):
1885
  """Update Audio Uploads accordion open state based on whether audio files are present"""
 
1898
  inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
1899
  outputs=[generation_section["audio_uploads_accordion"]]
1900
  )
1901
+
1902
+ # Save metadata handlers
1903
+ results_section["save_btn_1"].click(
1904
+ fn=save_metadata,
1905
+ inputs=[
1906
+ generation_section["task_type"],
1907
+ generation_section["captions"],
1908
+ generation_section["lyrics"],
1909
+ generation_section["vocal_language"],
1910
+ generation_section["bpm"],
1911
+ generation_section["key_scale"],
1912
+ generation_section["time_signature"],
1913
+ generation_section["audio_duration"],
1914
+ generation_section["batch_size_input"],
1915
+ generation_section["inference_steps"],
1916
+ generation_section["guidance_scale"],
1917
+ generation_section["seed"],
1918
+ generation_section["random_seed_checkbox"],
1919
+ generation_section["use_adg"],
1920
+ generation_section["cfg_interval_start"],
1921
+ generation_section["cfg_interval_end"],
1922
+ generation_section["audio_format"],
1923
+ generation_section["lm_temperature"],
1924
+ generation_section["lm_cfg_scale"],
1925
+ generation_section["lm_top_k"],
1926
+ generation_section["lm_top_p"],
1927
+ generation_section["lm_negative_prompt"],
1928
+ generation_section["use_cot_caption"],
1929
+ generation_section["use_cot_language"],
1930
+ generation_section["audio_cover_strength"],
1931
+ generation_section["think_checkbox"],
1932
+ generation_section["text2music_audio_code_string"],
1933
+ generation_section["repainting_start"],
1934
+ generation_section["repainting_end"],
1935
+ generation_section["track_name"],
1936
+ generation_section["complete_track_classes"],
1937
+ results_section["lm_metadata_state"],
1938
+ ],
1939
+ outputs=[]
1940
+ )
1941
+
1942
+ results_section["save_btn_2"].click(
1943
+ fn=save_metadata,
1944
+ inputs=[
1945
+ generation_section["task_type"],
1946
+ generation_section["captions"],
1947
+ generation_section["lyrics"],
1948
+ generation_section["vocal_language"],
1949
+ generation_section["bpm"],
1950
+ generation_section["key_scale"],
1951
+ generation_section["time_signature"],
1952
+ generation_section["audio_duration"],
1953
+ generation_section["batch_size_input"],
1954
+ generation_section["inference_steps"],
1955
+ generation_section["guidance_scale"],
1956
+ generation_section["seed"],
1957
+ generation_section["random_seed_checkbox"],
1958
+ generation_section["use_adg"],
1959
+ generation_section["cfg_interval_start"],
1960
+ generation_section["cfg_interval_end"],
1961
+ generation_section["audio_format"],
1962
+ generation_section["lm_temperature"],
1963
+ generation_section["lm_cfg_scale"],
1964
+ generation_section["lm_top_k"],
1965
+ generation_section["lm_top_p"],
1966
+ generation_section["lm_negative_prompt"],
1967
+ generation_section["use_cot_caption"],
1968
+ generation_section["use_cot_language"],
1969
+ generation_section["audio_cover_strength"],
1970
+ generation_section["think_checkbox"],
1971
+ generation_section["text2music_audio_code_string"],
1972
+ generation_section["repainting_start"],
1973
+ generation_section["repainting_end"],
1974
+ generation_section["track_name"],
1975
+ generation_section["complete_track_classes"],
1976
+ results_section["lm_metadata_state"],
1977
+ ],
1978
+ outputs=[]
1979
+ )
1980
+
1981
+ # Load metadata handler - triggered when file is uploaded via UploadButton
1982
+ generation_section["load_file"].upload(
1983
+ fn=load_metadata,
1984
+ inputs=[generation_section["load_file"]],
1985
+ outputs=[
1986
+ generation_section["task_type"],
1987
+ generation_section["captions"],
1988
+ generation_section["lyrics"],
1989
+ generation_section["vocal_language"],
1990
+ generation_section["bpm"],
1991
+ generation_section["key_scale"],
1992
+ generation_section["time_signature"],
1993
+ generation_section["audio_duration"],
1994
+ generation_section["batch_size_input"],
1995
+ generation_section["inference_steps"],
1996
+ generation_section["guidance_scale"],
1997
+ generation_section["seed"],
1998
+ generation_section["random_seed_checkbox"],
1999
+ generation_section["use_adg"],
2000
+ generation_section["cfg_interval_start"],
2001
+ generation_section["cfg_interval_end"],
2002
+ generation_section["audio_format"],
2003
+ generation_section["lm_temperature"],
2004
+ generation_section["lm_cfg_scale"],
2005
+ generation_section["lm_top_k"],
2006
+ generation_section["lm_top_p"],
2007
+ generation_section["lm_negative_prompt"],
2008
+ generation_section["use_cot_caption"],
2009
+ generation_section["use_cot_language"],
2010
+ generation_section["audio_cover_strength"],
2011
+ generation_section["think_checkbox"],
2012
+ generation_section["text2music_audio_code_string"],
2013
+ generation_section["repainting_start"],
2014
+ generation_section["repainting_end"],
2015
+ generation_section["track_name"],
2016
+ generation_section["complete_track_classes"],
2017
+ results_section["is_format_caption_state"]
2018
+ ]
2019
+ )
2020
 
acestep/llm_inference.py CHANGED
@@ -551,6 +551,7 @@ class LLMHandler:
551
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
552
  use_cot_caption: bool = True,
553
  use_cot_language: bool = True,
 
554
  ) -> Tuple[Dict[str, Any], str, str]:
555
  """Two-phase LM generation: CoT generation followed by audio codes generation.
556
 
@@ -575,7 +576,7 @@ class LLMHandler:
575
 
576
  # ========== PHASE 1: CoT Generation ==========
577
  # Always generate CoT unless all metadata are user-provided
578
- if not has_all_metas:
579
  logger.info("Phase 1: Generating CoT metadata...")
580
 
581
  # Build formatted prompt for CoT phase
 
551
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
552
  use_cot_caption: bool = True,
553
  use_cot_language: bool = True,
554
+ is_format_caption: bool = False,
555
  ) -> Tuple[Dict[str, Any], str, str]:
556
  """Two-phase LM generation: CoT generation followed by audio codes generation.
557
 
 
576
 
577
  # ========== PHASE 1: CoT Generation ==========
578
  # Always generate CoT unless all metadata are user-provided
579
+ if not has_all_metas or not is_format_caption:
580
  logger.info("Phase 1: Generating CoT metadata...")
581
 
582
  # Build formatted prompt for CoT phase