Spaces:
Configuration error
Configuration error
support save & load
Browse files- acestep/gradio_ui.py +420 -43
- acestep/llm_inference.py +2 -1
acestep/gradio_ui.py
CHANGED
|
@@ -308,7 +308,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 308 |
else:
|
| 309 |
initial_task_choices = TASK_TYPES_BASE
|
| 310 |
|
| 311 |
-
with gr.Row():
|
| 312 |
with gr.Column(scale=2):
|
| 313 |
task_type = gr.Dropdown(
|
| 314 |
choices=initial_task_choices,
|
|
@@ -316,7 +316,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 316 |
label="Task Type",
|
| 317 |
info="Select the task type for generation",
|
| 318 |
)
|
| 319 |
-
with gr.Column(scale=
|
| 320 |
instruction_display_gen = gr.Textbox(
|
| 321 |
label="Instruction",
|
| 322 |
value=DEFAULT_DIT_INSTRUCTION,
|
|
@@ -324,6 +324,14 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 324 |
lines=1,
|
| 325 |
info="Instruction is automatically generated based on task type",
|
| 326 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
track_name = gr.Dropdown(
|
| 329 |
choices=TRACK_NAMES,
|
|
@@ -486,15 +494,22 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 486 |
info="Higher values follow text more closely",
|
| 487 |
visible=False
|
| 488 |
)
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
)
|
| 499 |
|
| 500 |
with gr.Row():
|
|
@@ -522,15 +537,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 522 |
label="CFG Interval End",
|
| 523 |
visible=False
|
| 524 |
)
|
| 525 |
-
|
| 526 |
-
with gr.Row():
|
| 527 |
-
audio_format = gr.Dropdown(
|
| 528 |
-
choices=["mp3", "flac"],
|
| 529 |
-
value="mp3",
|
| 530 |
-
label="Audio Format",
|
| 531 |
-
info="Audio format for saved files"
|
| 532 |
-
)
|
| 533 |
-
|
| 534 |
# LM (Language Model) Parameters
|
| 535 |
gr.HTML("<h4>🤖 LM Generation Parameters</h4>")
|
| 536 |
with gr.Row():
|
|
@@ -582,6 +589,12 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 582 |
)
|
| 583 |
|
| 584 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 585 |
use_cot_caption = gr.Checkbox(
|
| 586 |
label="CoT Caption",
|
| 587 |
value=True,
|
|
@@ -654,6 +667,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 654 |
"lm_top_k": lm_top_k,
|
| 655 |
"lm_top_p": lm_top_p,
|
| 656 |
"lm_negative_prompt": lm_negative_prompt,
|
|
|
|
| 657 |
"use_cot_caption": use_cot_caption,
|
| 658 |
"use_cot_language": use_cot_language,
|
| 659 |
"repainting_group": repainting_group,
|
|
@@ -662,6 +676,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 662 |
"audio_cover_strength": audio_cover_strength,
|
| 663 |
"captions": captions,
|
| 664 |
"sample_btn": sample_btn,
|
|
|
|
| 665 |
"lyrics": lyrics,
|
| 666 |
"vocal_language": vocal_language,
|
| 667 |
"bpm": bpm,
|
|
@@ -691,6 +706,9 @@ def create_results_section(dit_handler) -> dict:
|
|
| 691 |
# Hidden state to store LM-generated metadata
|
| 692 |
lm_metadata_state = gr.State(value=None)
|
| 693 |
|
|
|
|
|
|
|
|
|
|
| 694 |
status_output = gr.Textbox(label="Generation Status", interactive=False)
|
| 695 |
|
| 696 |
with gr.Row():
|
|
@@ -700,22 +718,38 @@ def create_results_section(dit_handler) -> dict:
|
|
| 700 |
type="filepath",
|
| 701 |
interactive=False
|
| 702 |
)
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
with gr.Column():
|
| 709 |
generated_audio_2 = gr.Audio(
|
| 710 |
label="🎵 Generated Music (Sample 2)",
|
| 711 |
type="filepath",
|
| 712 |
interactive=False
|
| 713 |
)
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 719 |
|
| 720 |
with gr.Accordion("📁 Batch Results & Generation Details", open=False):
|
| 721 |
generated_audio_batch = gr.File(
|
|
@@ -738,11 +772,14 @@ def create_results_section(dit_handler) -> dict:
|
|
| 738 |
|
| 739 |
return {
|
| 740 |
"lm_metadata_state": lm_metadata_state,
|
|
|
|
| 741 |
"status_output": status_output,
|
| 742 |
"generated_audio_1": generated_audio_1,
|
| 743 |
"generated_audio_2": generated_audio_2,
|
| 744 |
"send_to_src_btn_1": send_to_src_btn_1,
|
| 745 |
"send_to_src_btn_2": send_to_src_btn_2,
|
|
|
|
|
|
|
| 746 |
"generated_audio_batch": generated_audio_batch,
|
| 747 |
"generation_info": generation_info,
|
| 748 |
"align_score_1": align_score_1,
|
|
@@ -757,6 +794,161 @@ def create_results_section(dit_handler) -> dict:
|
|
| 757 |
def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section):
|
| 758 |
"""Setup event handlers connecting UI components and business logic"""
|
| 759 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 760 |
def load_random_example(task_type: str):
|
| 761 |
"""Load a random example from the task-specific examples directory
|
| 762 |
|
|
@@ -1092,14 +1284,14 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1092 |
instruction_display_gen, audio_cover_strength, task_type,
|
| 1093 |
use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
|
| 1094 |
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 1095 |
-
use_cot_caption, use_cot_language,
|
| 1096 |
progress=gr.Progress(track_tqdm=True)
|
| 1097 |
):
|
| 1098 |
-
# If think is enabled (llm_dit mode), generate audio codes using LM first
|
| 1099 |
audio_code_string_to_use = text2music_audio_code_string
|
| 1100 |
lm_generated_metadata = None # Store LM-generated metadata for display
|
| 1101 |
lm_generated_audio_codes = None # Store LM-generated audio codes for display
|
| 1102 |
-
if think_checkbox and llm_handler.llm_initialized:
|
| 1103 |
# Convert top_k: 0 means None (disabled)
|
| 1104 |
top_k_value = None if lm_top_k == 0 else int(lm_top_k)
|
| 1105 |
# Convert top_p: 1.0 means None (disabled)
|
|
@@ -1149,6 +1341,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1149 |
user_metadata=user_metadata_to_pass,
|
| 1150 |
use_cot_caption=use_cot_caption,
|
| 1151 |
use_cot_language=use_cot_language,
|
|
|
|
| 1152 |
)
|
| 1153 |
|
| 1154 |
# Store LM-generated metadata and audio codes for display
|
|
@@ -1238,7 +1431,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1238 |
align_text_2,
|
| 1239 |
align_plot_2,
|
| 1240 |
updated_audio_codes, # Update audio codes in UI
|
| 1241 |
-
lm_generated_metadata # Store metadata for "Send to src audio" buttons
|
|
|
|
| 1242 |
)
|
| 1243 |
|
| 1244 |
generation_section["generate_btn"].click(
|
|
@@ -1274,8 +1468,10 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1274 |
generation_section["lm_top_k"],
|
| 1275 |
generation_section["lm_top_p"],
|
| 1276 |
generation_section["lm_negative_prompt"],
|
|
|
|
| 1277 |
generation_section["use_cot_caption"],
|
| 1278 |
-
generation_section["use_cot_language"]
|
|
|
|
| 1279 |
],
|
| 1280 |
outputs=[
|
| 1281 |
results_section["generated_audio_1"],
|
|
@@ -1291,7 +1487,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1291 |
results_section["align_text_2"],
|
| 1292 |
results_section["align_plot_2"],
|
| 1293 |
generation_section["text2music_audio_code_string"], # Update audio codes display
|
| 1294 |
-
results_section["lm_metadata_state"] # Store metadata
|
|
|
|
| 1295 |
]
|
| 1296 |
)
|
| 1297 |
|
|
@@ -1420,10 +1617,10 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1420 |
lm_metadata: Dictionary containing LM-generated metadata
|
| 1421 |
|
| 1422 |
Returns:
|
| 1423 |
-
Tuple of (audio_file, bpm, caption, duration, key_scale, language, time_signature)
|
| 1424 |
"""
|
| 1425 |
if audio_file is None:
|
| 1426 |
-
return None, None, None, None, None, None, None
|
| 1427 |
|
| 1428 |
# Extract metadata fields if available
|
| 1429 |
bpm_value = None
|
|
@@ -1481,7 +1678,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1481 |
duration_value,
|
| 1482 |
key_scale_value,
|
| 1483 |
language_value,
|
| 1484 |
-
time_signature_value
|
|
|
|
| 1485 |
)
|
| 1486 |
|
| 1487 |
results_section["send_to_src_btn_1"].click(
|
|
@@ -1497,7 +1695,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1497 |
generation_section["audio_duration"],
|
| 1498 |
generation_section["key_scale"],
|
| 1499 |
generation_section["vocal_language"],
|
| 1500 |
-
generation_section["time_signature"]
|
|
|
|
| 1501 |
]
|
| 1502 |
)
|
| 1503 |
|
|
@@ -1514,13 +1713,21 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1514 |
generation_section["audio_duration"],
|
| 1515 |
generation_section["key_scale"],
|
| 1516 |
generation_section["vocal_language"],
|
| 1517 |
-
generation_section["time_signature"]
|
|
|
|
| 1518 |
]
|
| 1519 |
)
|
| 1520 |
|
| 1521 |
# Sample button - smart sample (uses LM if initialized, otherwise examples)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1522 |
generation_section["sample_btn"].click(
|
| 1523 |
-
fn=
|
| 1524 |
inputs=[generation_section["task_type"]],
|
| 1525 |
outputs=[
|
| 1526 |
generation_section["captions"],
|
|
@@ -1531,6 +1738,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1531 |
generation_section["key_scale"],
|
| 1532 |
generation_section["vocal_language"],
|
| 1533 |
generation_section["time_signature"],
|
|
|
|
| 1534 |
]
|
| 1535 |
)
|
| 1536 |
|
|
@@ -1585,7 +1793,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1585 |
duration,
|
| 1586 |
keyscale,
|
| 1587 |
language,
|
| 1588 |
-
timesignature
|
|
|
|
| 1589 |
)
|
| 1590 |
|
| 1591 |
# Update transcribe button text based on whether codes are present
|
|
@@ -1619,9 +1828,58 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1619 |
generation_section["key_scale"], # Update keyscale field
|
| 1620 |
generation_section["vocal_language"], # Update language field
|
| 1621 |
generation_section["time_signature"], # Update time signature field
|
|
|
|
| 1622 |
]
|
| 1623 |
)
|
| 1624 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1625 |
# Auto-expand Audio Uploads accordion when audio is uploaded
|
| 1626 |
def update_audio_uploads_accordion(reference_audio, src_audio):
|
| 1627 |
"""Update Audio Uploads accordion open state based on whether audio files are present"""
|
|
@@ -1640,4 +1898,123 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1640 |
inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
|
| 1641 |
outputs=[generation_section["audio_uploads_accordion"]]
|
| 1642 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1643 |
|
|
|
|
| 308 |
else:
|
| 309 |
initial_task_choices = TASK_TYPES_BASE
|
| 310 |
|
| 311 |
+
with gr.Row(equal_height=True):
|
| 312 |
with gr.Column(scale=2):
|
| 313 |
task_type = gr.Dropdown(
|
| 314 |
choices=initial_task_choices,
|
|
|
|
| 316 |
label="Task Type",
|
| 317 |
info="Select the task type for generation",
|
| 318 |
)
|
| 319 |
+
with gr.Column(scale=7):
|
| 320 |
instruction_display_gen = gr.Textbox(
|
| 321 |
label="Instruction",
|
| 322 |
value=DEFAULT_DIT_INSTRUCTION,
|
|
|
|
| 324 |
lines=1,
|
| 325 |
info="Instruction is automatically generated based on task type",
|
| 326 |
)
|
| 327 |
+
with gr.Column(scale=1, min_width=100):
|
| 328 |
+
load_file = gr.UploadButton(
|
| 329 |
+
"Load",
|
| 330 |
+
file_types=[".json"],
|
| 331 |
+
file_count="single",
|
| 332 |
+
variant="secondary",
|
| 333 |
+
size="sm",
|
| 334 |
+
)
|
| 335 |
|
| 336 |
track_name = gr.Dropdown(
|
| 337 |
choices=TRACK_NAMES,
|
|
|
|
| 494 |
info="Higher values follow text more closely",
|
| 495 |
visible=False
|
| 496 |
)
|
| 497 |
+
with gr.Column():
|
| 498 |
+
seed = gr.Textbox(
|
| 499 |
+
label="Seed",
|
| 500 |
+
value="-1",
|
| 501 |
+
info="Use comma-separated values for batches"
|
| 502 |
+
)
|
| 503 |
+
random_seed_checkbox = gr.Checkbox(
|
| 504 |
+
label="Random Seed",
|
| 505 |
+
value=True,
|
| 506 |
+
info="Enable to auto-generate seeds"
|
| 507 |
+
)
|
| 508 |
+
audio_format = gr.Dropdown(
|
| 509 |
+
choices=["mp3", "flac"],
|
| 510 |
+
value="mp3",
|
| 511 |
+
label="Audio Format",
|
| 512 |
+
info="Audio format for saved files"
|
| 513 |
)
|
| 514 |
|
| 515 |
with gr.Row():
|
|
|
|
| 537 |
label="CFG Interval End",
|
| 538 |
visible=False
|
| 539 |
)
|
| 540 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
# LM (Language Model) Parameters
|
| 542 |
gr.HTML("<h4>🤖 LM Generation Parameters</h4>")
|
| 543 |
with gr.Row():
|
|
|
|
| 589 |
)
|
| 590 |
|
| 591 |
with gr.Row():
|
| 592 |
+
use_cot_metas = gr.Checkbox(
|
| 593 |
+
label="CoT Metas",
|
| 594 |
+
value=True,
|
| 595 |
+
info="Use LM to generate CoT metadata (uncheck to skip LM CoT generation)",
|
| 596 |
+
scale=1,
|
| 597 |
+
)
|
| 598 |
use_cot_caption = gr.Checkbox(
|
| 599 |
label="CoT Caption",
|
| 600 |
value=True,
|
|
|
|
| 667 |
"lm_top_k": lm_top_k,
|
| 668 |
"lm_top_p": lm_top_p,
|
| 669 |
"lm_negative_prompt": lm_negative_prompt,
|
| 670 |
+
"use_cot_metas": use_cot_metas,
|
| 671 |
"use_cot_caption": use_cot_caption,
|
| 672 |
"use_cot_language": use_cot_language,
|
| 673 |
"repainting_group": repainting_group,
|
|
|
|
| 676 |
"audio_cover_strength": audio_cover_strength,
|
| 677 |
"captions": captions,
|
| 678 |
"sample_btn": sample_btn,
|
| 679 |
+
"load_file": load_file,
|
| 680 |
"lyrics": lyrics,
|
| 681 |
"vocal_language": vocal_language,
|
| 682 |
"bpm": bpm,
|
|
|
|
| 706 |
# Hidden state to store LM-generated metadata
|
| 707 |
lm_metadata_state = gr.State(value=None)
|
| 708 |
|
| 709 |
+
# Hidden state to track if caption/metadata is from formatted source (LM/transcription)
|
| 710 |
+
is_format_caption_state = gr.State(value=False)
|
| 711 |
+
|
| 712 |
status_output = gr.Textbox(label="Generation Status", interactive=False)
|
| 713 |
|
| 714 |
with gr.Row():
|
|
|
|
| 718 |
type="filepath",
|
| 719 |
interactive=False
|
| 720 |
)
|
| 721 |
+
with gr.Row(equal_height=True):
|
| 722 |
+
send_to_src_btn_1 = gr.Button(
|
| 723 |
+
"Send To Src Audio",
|
| 724 |
+
variant="secondary",
|
| 725 |
+
size="sm",
|
| 726 |
+
scale=1
|
| 727 |
+
)
|
| 728 |
+
save_btn_1 = gr.Button(
|
| 729 |
+
"💾 Save",
|
| 730 |
+
variant="primary",
|
| 731 |
+
size="sm",
|
| 732 |
+
scale=1
|
| 733 |
+
)
|
| 734 |
with gr.Column():
|
| 735 |
generated_audio_2 = gr.Audio(
|
| 736 |
label="🎵 Generated Music (Sample 2)",
|
| 737 |
type="filepath",
|
| 738 |
interactive=False
|
| 739 |
)
|
| 740 |
+
with gr.Row(equal_height=True):
|
| 741 |
+
send_to_src_btn_2 = gr.Button(
|
| 742 |
+
"Send To Src Audio",
|
| 743 |
+
variant="secondary",
|
| 744 |
+
size="sm",
|
| 745 |
+
scale=1
|
| 746 |
+
)
|
| 747 |
+
save_btn_2 = gr.Button(
|
| 748 |
+
"💾 Save",
|
| 749 |
+
variant="primary",
|
| 750 |
+
size="sm",
|
| 751 |
+
scale=1
|
| 752 |
+
)
|
| 753 |
|
| 754 |
with gr.Accordion("📁 Batch Results & Generation Details", open=False):
|
| 755 |
generated_audio_batch = gr.File(
|
|
|
|
| 772 |
|
| 773 |
return {
|
| 774 |
"lm_metadata_state": lm_metadata_state,
|
| 775 |
+
"is_format_caption_state": is_format_caption_state,
|
| 776 |
"status_output": status_output,
|
| 777 |
"generated_audio_1": generated_audio_1,
|
| 778 |
"generated_audio_2": generated_audio_2,
|
| 779 |
"send_to_src_btn_1": send_to_src_btn_1,
|
| 780 |
"send_to_src_btn_2": send_to_src_btn_2,
|
| 781 |
+
"save_btn_1": save_btn_1,
|
| 782 |
+
"save_btn_2": save_btn_2,
|
| 783 |
"generated_audio_batch": generated_audio_batch,
|
| 784 |
"generation_info": generation_info,
|
| 785 |
"align_score_1": align_score_1,
|
|
|
|
| 794 |
def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section):
|
| 795 |
"""Setup event handlers connecting UI components and business logic"""
|
| 796 |
|
| 797 |
+
def save_metadata(
|
| 798 |
+
task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature, audio_duration,
|
| 799 |
+
batch_size_input, inference_steps, guidance_scale, seed, random_seed_checkbox,
|
| 800 |
+
use_adg, cfg_interval_start, cfg_interval_end, audio_format,
|
| 801 |
+
lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 802 |
+
use_cot_caption, use_cot_language, audio_cover_strength,
|
| 803 |
+
think_checkbox, text2music_audio_code_string, repainting_start, repainting_end,
|
| 804 |
+
track_name, complete_track_classes, lm_metadata
|
| 805 |
+
):
|
| 806 |
+
"""Save all generation parameters to a JSON file"""
|
| 807 |
+
import datetime
|
| 808 |
+
|
| 809 |
+
# Create metadata dictionary
|
| 810 |
+
metadata = {
|
| 811 |
+
"saved_at": datetime.datetime.now().isoformat(),
|
| 812 |
+
"task_type": task_type,
|
| 813 |
+
"caption": captions or "",
|
| 814 |
+
"lyrics": lyrics or "",
|
| 815 |
+
"vocal_language": vocal_language,
|
| 816 |
+
"bpm": bpm if bpm is not None else None,
|
| 817 |
+
"keyscale": key_scale or "",
|
| 818 |
+
"timesignature": time_signature or "",
|
| 819 |
+
"duration": audio_duration if audio_duration is not None else -1,
|
| 820 |
+
"batch_size": batch_size_input,
|
| 821 |
+
"inference_steps": inference_steps,
|
| 822 |
+
"guidance_scale": guidance_scale,
|
| 823 |
+
"seed": seed,
|
| 824 |
+
"random_seed": False, # Disable random seed for reproducibility
|
| 825 |
+
"use_adg": use_adg,
|
| 826 |
+
"cfg_interval_start": cfg_interval_start,
|
| 827 |
+
"cfg_interval_end": cfg_interval_end,
|
| 828 |
+
"audio_format": audio_format,
|
| 829 |
+
"lm_temperature": lm_temperature,
|
| 830 |
+
"lm_cfg_scale": lm_cfg_scale,
|
| 831 |
+
"lm_top_k": lm_top_k,
|
| 832 |
+
"lm_top_p": lm_top_p,
|
| 833 |
+
"lm_negative_prompt": lm_negative_prompt,
|
| 834 |
+
"use_cot_caption": use_cot_caption,
|
| 835 |
+
"use_cot_language": use_cot_language,
|
| 836 |
+
"audio_cover_strength": audio_cover_strength,
|
| 837 |
+
"think": think_checkbox,
|
| 838 |
+
"audio_codes": text2music_audio_code_string or "",
|
| 839 |
+
"repainting_start": repainting_start,
|
| 840 |
+
"repainting_end": repainting_end,
|
| 841 |
+
"track_name": track_name,
|
| 842 |
+
"complete_track_classes": complete_track_classes or [],
|
| 843 |
+
}
|
| 844 |
+
|
| 845 |
+
# Add LM-generated metadata if available
|
| 846 |
+
if lm_metadata:
|
| 847 |
+
metadata["lm_generated_metadata"] = lm_metadata
|
| 848 |
+
|
| 849 |
+
# Save to file
|
| 850 |
+
try:
|
| 851 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 852 |
+
filename = f"generation_params_{timestamp}.json"
|
| 853 |
+
|
| 854 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
| 855 |
+
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
| 856 |
+
|
| 857 |
+
gr.Info(f"✅ Parameters saved to {filename}")
|
| 858 |
+
return filename
|
| 859 |
+
except Exception as e:
|
| 860 |
+
gr.Warning(f"❌ Failed to save parameters: {str(e)}")
|
| 861 |
+
return None
|
| 862 |
+
|
| 863 |
+
def load_metadata(file_obj):
|
| 864 |
+
"""Load generation parameters from a JSON file"""
|
| 865 |
+
if file_obj is None:
|
| 866 |
+
gr.Warning("⚠️ No file selected")
|
| 867 |
+
return [None] * 31 + [False] # Return None for all fields, False for is_format_caption
|
| 868 |
+
|
| 869 |
+
try:
|
| 870 |
+
# Read the uploaded file
|
| 871 |
+
if hasattr(file_obj, 'name'):
|
| 872 |
+
filepath = file_obj.name
|
| 873 |
+
else:
|
| 874 |
+
filepath = file_obj
|
| 875 |
+
|
| 876 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 877 |
+
metadata = json.load(f)
|
| 878 |
+
|
| 879 |
+
# Extract all fields
|
| 880 |
+
task_type = metadata.get('task_type', 'text2music')
|
| 881 |
+
captions = metadata.get('caption', '')
|
| 882 |
+
lyrics = metadata.get('lyrics', '')
|
| 883 |
+
vocal_language = metadata.get('vocal_language', 'unknown')
|
| 884 |
+
|
| 885 |
+
# Convert bpm
|
| 886 |
+
bpm_value = metadata.get('bpm')
|
| 887 |
+
if bpm_value is not None and bpm_value != "N/A":
|
| 888 |
+
try:
|
| 889 |
+
bpm = int(bpm_value) if bpm_value else None
|
| 890 |
+
except:
|
| 891 |
+
bpm = None
|
| 892 |
+
else:
|
| 893 |
+
bpm = None
|
| 894 |
+
|
| 895 |
+
key_scale = metadata.get('keyscale', '')
|
| 896 |
+
time_signature = metadata.get('timesignature', '')
|
| 897 |
+
|
| 898 |
+
# Convert duration
|
| 899 |
+
duration_value = metadata.get('duration', -1)
|
| 900 |
+
if duration_value is not None and duration_value != "N/A":
|
| 901 |
+
try:
|
| 902 |
+
audio_duration = float(duration_value)
|
| 903 |
+
except:
|
| 904 |
+
audio_duration = -1
|
| 905 |
+
else:
|
| 906 |
+
audio_duration = -1
|
| 907 |
+
|
| 908 |
+
batch_size = metadata.get('batch_size', 2)
|
| 909 |
+
inference_steps = metadata.get('inference_steps', 8)
|
| 910 |
+
guidance_scale = metadata.get('guidance_scale', 7.0)
|
| 911 |
+
seed = metadata.get('seed', '-1')
|
| 912 |
+
random_seed = metadata.get('random_seed', True)
|
| 913 |
+
use_adg = metadata.get('use_adg', False)
|
| 914 |
+
cfg_interval_start = metadata.get('cfg_interval_start', 0.0)
|
| 915 |
+
cfg_interval_end = metadata.get('cfg_interval_end', 1.0)
|
| 916 |
+
audio_format = metadata.get('audio_format', 'mp3')
|
| 917 |
+
lm_temperature = metadata.get('lm_temperature', 0.85)
|
| 918 |
+
lm_cfg_scale = metadata.get('lm_cfg_scale', 2.0)
|
| 919 |
+
lm_top_k = metadata.get('lm_top_k', 0)
|
| 920 |
+
lm_top_p = metadata.get('lm_top_p', 0.9)
|
| 921 |
+
lm_negative_prompt = metadata.get('lm_negative_prompt', 'NO USER INPUT')
|
| 922 |
+
use_cot_caption = metadata.get('use_cot_caption', True)
|
| 923 |
+
use_cot_language = metadata.get('use_cot_language', True)
|
| 924 |
+
audio_cover_strength = metadata.get('audio_cover_strength', 1.0)
|
| 925 |
+
think = metadata.get('think', True)
|
| 926 |
+
audio_codes = metadata.get('audio_codes', '')
|
| 927 |
+
repainting_start = metadata.get('repainting_start', 0.0)
|
| 928 |
+
repainting_end = metadata.get('repainting_end', -1)
|
| 929 |
+
track_name = metadata.get('track_name')
|
| 930 |
+
complete_track_classes = metadata.get('complete_track_classes', [])
|
| 931 |
+
|
| 932 |
+
gr.Info(f"✅ Parameters loaded from {os.path.basename(filepath)}")
|
| 933 |
+
|
| 934 |
+
return (
|
| 935 |
+
task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature,
|
| 936 |
+
audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed,
|
| 937 |
+
use_adg, cfg_interval_start, cfg_interval_end, audio_format,
|
| 938 |
+
lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 939 |
+
use_cot_caption, use_cot_language, audio_cover_strength,
|
| 940 |
+
think, audio_codes, repainting_start, repainting_end,
|
| 941 |
+
track_name, complete_track_classes,
|
| 942 |
+
True # Set is_format_caption to True when loading from file
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
except json.JSONDecodeError as e:
|
| 946 |
+
gr.Warning(f"❌ Invalid JSON file: {str(e)}")
|
| 947 |
+
return [None] * 31 + [False]
|
| 948 |
+
except Exception as e:
|
| 949 |
+
gr.Warning(f"❌ Error loading file: {str(e)}")
|
| 950 |
+
return [None] * 31 + [False]
|
| 951 |
+
|
| 952 |
def load_random_example(task_type: str):
|
| 953 |
"""Load a random example from the task-specific examples directory
|
| 954 |
|
|
|
|
| 1284 |
instruction_display_gen, audio_cover_strength, task_type,
|
| 1285 |
use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
|
| 1286 |
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 1287 |
+
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
|
| 1288 |
progress=gr.Progress(track_tqdm=True)
|
| 1289 |
):
|
| 1290 |
+
# If think is enabled (llm_dit mode) and use_cot_metas is True, generate audio codes using LM first
|
| 1291 |
audio_code_string_to_use = text2music_audio_code_string
|
| 1292 |
lm_generated_metadata = None # Store LM-generated metadata for display
|
| 1293 |
lm_generated_audio_codes = None # Store LM-generated audio codes for display
|
| 1294 |
+
if think_checkbox and llm_handler.llm_initialized and use_cot_metas:
|
| 1295 |
# Convert top_k: 0 means None (disabled)
|
| 1296 |
top_k_value = None if lm_top_k == 0 else int(lm_top_k)
|
| 1297 |
# Convert top_p: 1.0 means None (disabled)
|
|
|
|
| 1341 |
user_metadata=user_metadata_to_pass,
|
| 1342 |
use_cot_caption=use_cot_caption,
|
| 1343 |
use_cot_language=use_cot_language,
|
| 1344 |
+
is_format_caption=is_format_caption,
|
| 1345 |
)
|
| 1346 |
|
| 1347 |
# Store LM-generated metadata and audio codes for display
|
|
|
|
| 1431 |
align_text_2,
|
| 1432 |
align_plot_2,
|
| 1433 |
updated_audio_codes, # Update audio codes in UI
|
| 1434 |
+
lm_generated_metadata, # Store metadata for "Send to src audio" buttons
|
| 1435 |
+
is_format_caption # Keep is_format_caption unchanged (LM doesn't modify user input fields)
|
| 1436 |
)
|
| 1437 |
|
| 1438 |
generation_section["generate_btn"].click(
|
|
|
|
| 1468 |
generation_section["lm_top_k"],
|
| 1469 |
generation_section["lm_top_p"],
|
| 1470 |
generation_section["lm_negative_prompt"],
|
| 1471 |
+
generation_section["use_cot_metas"],
|
| 1472 |
generation_section["use_cot_caption"],
|
| 1473 |
+
generation_section["use_cot_language"],
|
| 1474 |
+
results_section["is_format_caption_state"]
|
| 1475 |
],
|
| 1476 |
outputs=[
|
| 1477 |
results_section["generated_audio_1"],
|
|
|
|
| 1487 |
results_section["align_text_2"],
|
| 1488 |
results_section["align_plot_2"],
|
| 1489 |
generation_section["text2music_audio_code_string"], # Update audio codes display
|
| 1490 |
+
results_section["lm_metadata_state"], # Store metadata
|
| 1491 |
+
results_section["is_format_caption_state"] # Update is_format_caption state
|
| 1492 |
]
|
| 1493 |
)
|
| 1494 |
|
|
|
|
| 1617 |
lm_metadata: Dictionary containing LM-generated metadata
|
| 1618 |
|
| 1619 |
Returns:
|
| 1620 |
+
Tuple of (audio_file, bpm, caption, duration, key_scale, language, time_signature, is_format_caption)
|
| 1621 |
"""
|
| 1622 |
if audio_file is None:
|
| 1623 |
+
return None, None, None, None, None, None, None, True # Keep is_format_caption as True
|
| 1624 |
|
| 1625 |
# Extract metadata fields if available
|
| 1626 |
bpm_value = None
|
|
|
|
| 1678 |
duration_value,
|
| 1679 |
key_scale_value,
|
| 1680 |
language_value,
|
| 1681 |
+
time_signature_value,
|
| 1682 |
+
True # Set is_format_caption to True (from LM-generated metadata)
|
| 1683 |
)
|
| 1684 |
|
| 1685 |
results_section["send_to_src_btn_1"].click(
|
|
|
|
| 1695 |
generation_section["audio_duration"],
|
| 1696 |
generation_section["key_scale"],
|
| 1697 |
generation_section["vocal_language"],
|
| 1698 |
+
generation_section["time_signature"],
|
| 1699 |
+
results_section["is_format_caption_state"]
|
| 1700 |
]
|
| 1701 |
)
|
| 1702 |
|
|
|
|
| 1713 |
generation_section["audio_duration"],
|
| 1714 |
generation_section["key_scale"],
|
| 1715 |
generation_section["vocal_language"],
|
| 1716 |
+
generation_section["time_signature"],
|
| 1717 |
+
results_section["is_format_caption_state"]
|
| 1718 |
]
|
| 1719 |
)
|
| 1720 |
|
| 1721 |
# Sample button - smart sample (uses LM if initialized, otherwise examples)
|
| 1722 |
+
# Need to add is_format_caption return value to sample_example_smart
|
| 1723 |
+
def sample_example_smart_with_flag(task_type: str):
|
| 1724 |
+
"""Wrapper for sample_example_smart that adds is_format_caption flag"""
|
| 1725 |
+
result = sample_example_smart(task_type)
|
| 1726 |
+
# Add True at the end to set is_format_caption
|
| 1727 |
+
return result + (True,)
|
| 1728 |
+
|
| 1729 |
generation_section["sample_btn"].click(
|
| 1730 |
+
fn=sample_example_smart_with_flag,
|
| 1731 |
inputs=[generation_section["task_type"]],
|
| 1732 |
outputs=[
|
| 1733 |
generation_section["captions"],
|
|
|
|
| 1738 |
generation_section["key_scale"],
|
| 1739 |
generation_section["vocal_language"],
|
| 1740 |
generation_section["time_signature"],
|
| 1741 |
+
results_section["is_format_caption_state"] # Set is_format_caption to True (from Sample/LM)
|
| 1742 |
]
|
| 1743 |
)
|
| 1744 |
|
|
|
|
| 1793 |
duration,
|
| 1794 |
keyscale,
|
| 1795 |
language,
|
| 1796 |
+
timesignature,
|
| 1797 |
+
True # Set is_format_caption to True (from Transcribe/LM understanding)
|
| 1798 |
)
|
| 1799 |
|
| 1800 |
# Update transcribe button text based on whether codes are present
|
|
|
|
| 1828 |
generation_section["key_scale"], # Update keyscale field
|
| 1829 |
generation_section["vocal_language"], # Update language field
|
| 1830 |
generation_section["time_signature"], # Update time signature field
|
| 1831 |
+
results_section["is_format_caption_state"] # Set is_format_caption to True
|
| 1832 |
]
|
| 1833 |
)
|
| 1834 |
|
| 1835 |
+
# Reset is_format_caption to False when user manually edits fields
|
| 1836 |
+
def reset_format_caption_flag():
|
| 1837 |
+
"""Reset is_format_caption to False when user manually edits caption/metadata"""
|
| 1838 |
+
return False
|
| 1839 |
+
|
| 1840 |
+
# Connect reset function to all user-editable metadata fields
|
| 1841 |
+
generation_section["captions"].change(
|
| 1842 |
+
fn=reset_format_caption_flag,
|
| 1843 |
+
inputs=[],
|
| 1844 |
+
outputs=[results_section["is_format_caption_state"]]
|
| 1845 |
+
)
|
| 1846 |
+
|
| 1847 |
+
generation_section["lyrics"].change(
|
| 1848 |
+
fn=reset_format_caption_flag,
|
| 1849 |
+
inputs=[],
|
| 1850 |
+
outputs=[results_section["is_format_caption_state"]]
|
| 1851 |
+
)
|
| 1852 |
+
|
| 1853 |
+
generation_section["bpm"].change(
|
| 1854 |
+
fn=reset_format_caption_flag,
|
| 1855 |
+
inputs=[],
|
| 1856 |
+
outputs=[results_section["is_format_caption_state"]]
|
| 1857 |
+
)
|
| 1858 |
+
|
| 1859 |
+
generation_section["key_scale"].change(
|
| 1860 |
+
fn=reset_format_caption_flag,
|
| 1861 |
+
inputs=[],
|
| 1862 |
+
outputs=[results_section["is_format_caption_state"]]
|
| 1863 |
+
)
|
| 1864 |
+
|
| 1865 |
+
generation_section["time_signature"].change(
|
| 1866 |
+
fn=reset_format_caption_flag,
|
| 1867 |
+
inputs=[],
|
| 1868 |
+
outputs=[results_section["is_format_caption_state"]]
|
| 1869 |
+
)
|
| 1870 |
+
|
| 1871 |
+
generation_section["vocal_language"].change(
|
| 1872 |
+
fn=reset_format_caption_flag,
|
| 1873 |
+
inputs=[],
|
| 1874 |
+
outputs=[results_section["is_format_caption_state"]]
|
| 1875 |
+
)
|
| 1876 |
+
|
| 1877 |
+
generation_section["audio_duration"].change(
|
| 1878 |
+
fn=reset_format_caption_flag,
|
| 1879 |
+
inputs=[],
|
| 1880 |
+
outputs=[results_section["is_format_caption_state"]]
|
| 1881 |
+
)
|
| 1882 |
+
|
| 1883 |
# Auto-expand Audio Uploads accordion when audio is uploaded
|
| 1884 |
def update_audio_uploads_accordion(reference_audio, src_audio):
|
| 1885 |
"""Update Audio Uploads accordion open state based on whether audio files are present"""
|
|
|
|
| 1898 |
inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
|
| 1899 |
outputs=[generation_section["audio_uploads_accordion"]]
|
| 1900 |
)
|
| 1901 |
+
|
| 1902 |
+
# Save metadata handlers
|
| 1903 |
+
results_section["save_btn_1"].click(
|
| 1904 |
+
fn=save_metadata,
|
| 1905 |
+
inputs=[
|
| 1906 |
+
generation_section["task_type"],
|
| 1907 |
+
generation_section["captions"],
|
| 1908 |
+
generation_section["lyrics"],
|
| 1909 |
+
generation_section["vocal_language"],
|
| 1910 |
+
generation_section["bpm"],
|
| 1911 |
+
generation_section["key_scale"],
|
| 1912 |
+
generation_section["time_signature"],
|
| 1913 |
+
generation_section["audio_duration"],
|
| 1914 |
+
generation_section["batch_size_input"],
|
| 1915 |
+
generation_section["inference_steps"],
|
| 1916 |
+
generation_section["guidance_scale"],
|
| 1917 |
+
generation_section["seed"],
|
| 1918 |
+
generation_section["random_seed_checkbox"],
|
| 1919 |
+
generation_section["use_adg"],
|
| 1920 |
+
generation_section["cfg_interval_start"],
|
| 1921 |
+
generation_section["cfg_interval_end"],
|
| 1922 |
+
generation_section["audio_format"],
|
| 1923 |
+
generation_section["lm_temperature"],
|
| 1924 |
+
generation_section["lm_cfg_scale"],
|
| 1925 |
+
generation_section["lm_top_k"],
|
| 1926 |
+
generation_section["lm_top_p"],
|
| 1927 |
+
generation_section["lm_negative_prompt"],
|
| 1928 |
+
generation_section["use_cot_caption"],
|
| 1929 |
+
generation_section["use_cot_language"],
|
| 1930 |
+
generation_section["audio_cover_strength"],
|
| 1931 |
+
generation_section["think_checkbox"],
|
| 1932 |
+
generation_section["text2music_audio_code_string"],
|
| 1933 |
+
generation_section["repainting_start"],
|
| 1934 |
+
generation_section["repainting_end"],
|
| 1935 |
+
generation_section["track_name"],
|
| 1936 |
+
generation_section["complete_track_classes"],
|
| 1937 |
+
results_section["lm_metadata_state"],
|
| 1938 |
+
],
|
| 1939 |
+
outputs=[]
|
| 1940 |
+
)
|
| 1941 |
+
|
| 1942 |
+
results_section["save_btn_2"].click(
|
| 1943 |
+
fn=save_metadata,
|
| 1944 |
+
inputs=[
|
| 1945 |
+
generation_section["task_type"],
|
| 1946 |
+
generation_section["captions"],
|
| 1947 |
+
generation_section["lyrics"],
|
| 1948 |
+
generation_section["vocal_language"],
|
| 1949 |
+
generation_section["bpm"],
|
| 1950 |
+
generation_section["key_scale"],
|
| 1951 |
+
generation_section["time_signature"],
|
| 1952 |
+
generation_section["audio_duration"],
|
| 1953 |
+
generation_section["batch_size_input"],
|
| 1954 |
+
generation_section["inference_steps"],
|
| 1955 |
+
generation_section["guidance_scale"],
|
| 1956 |
+
generation_section["seed"],
|
| 1957 |
+
generation_section["random_seed_checkbox"],
|
| 1958 |
+
generation_section["use_adg"],
|
| 1959 |
+
generation_section["cfg_interval_start"],
|
| 1960 |
+
generation_section["cfg_interval_end"],
|
| 1961 |
+
generation_section["audio_format"],
|
| 1962 |
+
generation_section["lm_temperature"],
|
| 1963 |
+
generation_section["lm_cfg_scale"],
|
| 1964 |
+
generation_section["lm_top_k"],
|
| 1965 |
+
generation_section["lm_top_p"],
|
| 1966 |
+
generation_section["lm_negative_prompt"],
|
| 1967 |
+
generation_section["use_cot_caption"],
|
| 1968 |
+
generation_section["use_cot_language"],
|
| 1969 |
+
generation_section["audio_cover_strength"],
|
| 1970 |
+
generation_section["think_checkbox"],
|
| 1971 |
+
generation_section["text2music_audio_code_string"],
|
| 1972 |
+
generation_section["repainting_start"],
|
| 1973 |
+
generation_section["repainting_end"],
|
| 1974 |
+
generation_section["track_name"],
|
| 1975 |
+
generation_section["complete_track_classes"],
|
| 1976 |
+
results_section["lm_metadata_state"],
|
| 1977 |
+
],
|
| 1978 |
+
outputs=[]
|
| 1979 |
+
)
|
| 1980 |
+
|
| 1981 |
+
# Load metadata handler - triggered when file is uploaded via UploadButton
|
| 1982 |
+
generation_section["load_file"].upload(
|
| 1983 |
+
fn=load_metadata,
|
| 1984 |
+
inputs=[generation_section["load_file"]],
|
| 1985 |
+
outputs=[
|
| 1986 |
+
generation_section["task_type"],
|
| 1987 |
+
generation_section["captions"],
|
| 1988 |
+
generation_section["lyrics"],
|
| 1989 |
+
generation_section["vocal_language"],
|
| 1990 |
+
generation_section["bpm"],
|
| 1991 |
+
generation_section["key_scale"],
|
| 1992 |
+
generation_section["time_signature"],
|
| 1993 |
+
generation_section["audio_duration"],
|
| 1994 |
+
generation_section["batch_size_input"],
|
| 1995 |
+
generation_section["inference_steps"],
|
| 1996 |
+
generation_section["guidance_scale"],
|
| 1997 |
+
generation_section["seed"],
|
| 1998 |
+
generation_section["random_seed_checkbox"],
|
| 1999 |
+
generation_section["use_adg"],
|
| 2000 |
+
generation_section["cfg_interval_start"],
|
| 2001 |
+
generation_section["cfg_interval_end"],
|
| 2002 |
+
generation_section["audio_format"],
|
| 2003 |
+
generation_section["lm_temperature"],
|
| 2004 |
+
generation_section["lm_cfg_scale"],
|
| 2005 |
+
generation_section["lm_top_k"],
|
| 2006 |
+
generation_section["lm_top_p"],
|
| 2007 |
+
generation_section["lm_negative_prompt"],
|
| 2008 |
+
generation_section["use_cot_caption"],
|
| 2009 |
+
generation_section["use_cot_language"],
|
| 2010 |
+
generation_section["audio_cover_strength"],
|
| 2011 |
+
generation_section["think_checkbox"],
|
| 2012 |
+
generation_section["text2music_audio_code_string"],
|
| 2013 |
+
generation_section["repainting_start"],
|
| 2014 |
+
generation_section["repainting_end"],
|
| 2015 |
+
generation_section["track_name"],
|
| 2016 |
+
generation_section["complete_track_classes"],
|
| 2017 |
+
results_section["is_format_caption_state"]
|
| 2018 |
+
]
|
| 2019 |
+
)
|
| 2020 |
|
acestep/llm_inference.py
CHANGED
|
@@ -551,6 +551,7 @@ class LLMHandler:
|
|
| 551 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 552 |
use_cot_caption: bool = True,
|
| 553 |
use_cot_language: bool = True,
|
|
|
|
| 554 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 555 |
"""Two-phase LM generation: CoT generation followed by audio codes generation.
|
| 556 |
|
|
@@ -575,7 +576,7 @@ class LLMHandler:
|
|
| 575 |
|
| 576 |
# ========== PHASE 1: CoT Generation ==========
|
| 577 |
# Always generate CoT unless all metadata are user-provided
|
| 578 |
-
if not has_all_metas:
|
| 579 |
logger.info("Phase 1: Generating CoT metadata...")
|
| 580 |
|
| 581 |
# Build formatted prompt for CoT phase
|
|
|
|
| 551 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 552 |
use_cot_caption: bool = True,
|
| 553 |
use_cot_language: bool = True,
|
| 554 |
+
is_format_caption: bool = False,
|
| 555 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 556 |
"""Two-phase LM generation: CoT generation followed by audio codes generation.
|
| 557 |
|
|
|
|
| 576 |
|
| 577 |
# ========== PHASE 1: CoT Generation ==========
|
| 578 |
# Always generate CoT unless all metadata are user-provided
|
| 579 |
+
if not has_all_metas or not is_format_caption:
|
| 580 |
logger.info("Phase 1: Generating CoT metadata...")
|
| 581 |
|
| 582 |
# Build formatted prompt for CoT phase
|