ChuxiJ commited on
Commit
4166944
·
1 Parent(s): 12f9f66

cover & refer audio test ok

Browse files
Files changed (2) hide show
  1. acestep/gradio_ui.py +35 -15
  2. acestep/handler.py +6 -2
acestep/gradio_ui.py CHANGED
@@ -203,7 +203,8 @@ def create_generation_section(handler) -> dict:
203
  init_llm_checkbox = gr.Checkbox(
204
  label="Initialize 5Hz LM",
205
  value=False,
206
- info="Check to initialize 5Hz LM during service initialization"
 
207
  )
208
 
209
  with gr.Row():
@@ -224,10 +225,17 @@ def create_generation_section(handler) -> dict:
224
  with gr.Column(scale=2):
225
  with gr.Accordion("📝 Required Inputs", open=True):
226
  # Task type
 
 
 
 
 
 
 
227
  with gr.Row():
228
  with gr.Column(scale=2):
229
  task_type = gr.Dropdown(
230
- choices=["text2music", "repaint", "cover", "extract", "lego", "complete"],
231
  value="text2music",
232
  label="Task Type",
233
  info="Select the task type for generation",
@@ -458,6 +466,14 @@ def create_generation_section(handler) -> dict:
458
  label="Audio Format",
459
  info="Audio format for saved files"
460
  )
 
 
 
 
 
 
 
 
461
 
462
  generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg", interactive=False)
463
 
@@ -503,6 +519,7 @@ def create_generation_section(handler) -> dict:
503
  "cfg_interval_start": cfg_interval_start,
504
  "cfg_interval_end": cfg_interval_end,
505
  "audio_format": audio_format,
 
506
  "generate_btn": generate_btn,
507
  }
508
 
@@ -536,17 +553,16 @@ def create_results_section(handler) -> dict:
536
  )
537
  generation_info = gr.Markdown(label="Generation Details")
538
 
539
- gr.Markdown("### ⚖️ Alignment Preference Analysis")
540
-
541
- with gr.Row():
542
- with gr.Column():
543
- align_score_1 = gr.Textbox(label="Alignment Score (Sample 1)", interactive=False)
544
- align_text_1 = gr.Textbox(label="Lyric Timestamps (Sample 1)", interactive=False, lines=10)
545
- align_plot_1 = gr.Plot(label="Alignment Heatmap (Sample 1)")
546
- with gr.Column():
547
- align_score_2 = gr.Textbox(label="Alignment Score (Sample 2)", interactive=False)
548
- align_text_2 = gr.Textbox(label="Lyric Timestamps (Sample 2)", interactive=False, lines=10)
549
- align_plot_2 = gr.Plot(label="Alignment Heatmap (Sample 2)")
550
 
551
  return {
552
  "status_output": status_output,
@@ -595,22 +611,24 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
595
  config_path_lower = config_path.lower()
596
 
597
  if "turbo" in config_path_lower:
598
- # Turbo model: max 8 steps, hide CFG/ADG
599
  return (
600
  gr.update(value=8, maximum=8, minimum=1), # inference_steps
601
  gr.update(visible=False), # guidance_scale
602
  gr.update(visible=False), # use_adg
603
  gr.update(visible=False), # cfg_interval_start
604
  gr.update(visible=False), # cfg_interval_end
 
605
  )
606
  elif "base" in config_path_lower:
607
- # Base model: max 100 steps, show CFG/ADG
608
  return (
609
  gr.update(value=32, maximum=100, minimum=1), # inference_steps
610
  gr.update(visible=True), # guidance_scale
611
  gr.update(visible=True), # use_adg
612
  gr.update(visible=True), # cfg_interval_start
613
  gr.update(visible=True), # cfg_interval_end
 
614
  )
615
  else:
616
  # Default to turbo settings
@@ -620,6 +638,7 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
620
  gr.update(visible=False),
621
  gr.update(visible=False),
622
  gr.update(visible=False),
 
623
  )
624
 
625
  generation_section["config_path"].change(
@@ -631,6 +650,7 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
631
  generation_section["use_adg"],
632
  generation_section["cfg_interval_start"],
633
  generation_section["cfg_interval_end"],
 
634
  ]
635
  )
636
 
 
203
  init_llm_checkbox = gr.Checkbox(
204
  label="Initialize 5Hz LM",
205
  value=False,
206
+ info="Check to initialize 5Hz LM during service initialization",
207
+ interactive=False
208
  )
209
 
210
  with gr.Row():
 
225
  with gr.Column(scale=2):
226
  with gr.Accordion("📝 Required Inputs", open=True):
227
  # Task type
228
+ # Determine initial task_type choices based on default model
229
+ default_model_lower = (default_model or "").lower()
230
+ if "turbo" in default_model_lower:
231
+ initial_task_choices = ["text2music", "repaint", "cover"]
232
+ else:
233
+ initial_task_choices = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
234
+
235
  with gr.Row():
236
  with gr.Column(scale=2):
237
  task_type = gr.Dropdown(
238
+ choices=initial_task_choices,
239
  value="text2music",
240
  label="Task Type",
241
  info="Select the task type for generation",
 
466
  label="Audio Format",
467
  info="Audio format for saved files"
468
  )
469
+
470
+ with gr.Row():
471
+ output_alignment_preference = gr.Checkbox(
472
+ label="Output Attention Focus Score (disabled)",
473
+ value=False,
474
+ info="Output attention focus score analysis",
475
+ interactive=False
476
+ )
477
 
478
  generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg", interactive=False)
479
 
 
519
  "cfg_interval_start": cfg_interval_start,
520
  "cfg_interval_end": cfg_interval_end,
521
  "audio_format": audio_format,
522
+ "output_alignment_preference": output_alignment_preference,
523
  "generate_btn": generate_btn,
524
  }
525
 
 
553
  )
554
  generation_info = gr.Markdown(label="Generation Details")
555
 
556
+ with gr.Accordion("⚖️ Attention Focus Score Analysis", open=False):
557
+ with gr.Row():
558
+ with gr.Column():
559
+ align_score_1 = gr.Textbox(label="Attention Focus Score (Sample 1)", interactive=False)
560
+ align_text_1 = gr.Textbox(label="Lyric Timestamps (Sample 1)", interactive=False, lines=10)
561
+ align_plot_1 = gr.Plot(label="Attention Focus Score Heatmap (Sample 1)")
562
+ with gr.Column():
563
+ align_score_2 = gr.Textbox(label="Attention Focus Score (Sample 2)", interactive=False)
564
+ align_text_2 = gr.Textbox(label="Lyric Timestamps (Sample 2)", interactive=False, lines=10)
565
+ align_plot_2 = gr.Plot(label="Attention Focus Score Heatmap (Sample 2)")
 
566
 
567
  return {
568
  "status_output": status_output,
 
611
  config_path_lower = config_path.lower()
612
 
613
  if "turbo" in config_path_lower:
614
+ # Turbo model: max 8 steps, hide CFG/ADG, only show text2music/repaint/cover
615
  return (
616
  gr.update(value=8, maximum=8, minimum=1), # inference_steps
617
  gr.update(visible=False), # guidance_scale
618
  gr.update(visible=False), # use_adg
619
  gr.update(visible=False), # cfg_interval_start
620
  gr.update(visible=False), # cfg_interval_end
621
+ gr.update(choices=["text2music", "repaint", "cover"]), # task_type
622
  )
623
  elif "base" in config_path_lower:
624
+ # Base model: max 100 steps, show CFG/ADG, show all task types
625
  return (
626
  gr.update(value=32, maximum=100, minimum=1), # inference_steps
627
  gr.update(visible=True), # guidance_scale
628
  gr.update(visible=True), # use_adg
629
  gr.update(visible=True), # cfg_interval_start
630
  gr.update(visible=True), # cfg_interval_end
631
+ gr.update(choices=["text2music", "repaint", "cover", "extract", "lego", "complete"]), # task_type
632
  )
633
  else:
634
  # Default to turbo settings
 
638
  gr.update(visible=False),
639
  gr.update(visible=False),
640
  gr.update(visible=False),
641
+ gr.update(choices=["text2music", "repaint", "cover"]), # task_type
642
  )
643
 
644
  generation_section["config_path"].change(
 
650
  generation_section["use_adg"],
651
  generation_section["cfg_interval_start"],
652
  generation_section["cfg_interval_end"],
653
+ generation_section["task_type"],
654
  ]
655
  )
656
 
acestep/handler.py CHANGED
@@ -675,6 +675,10 @@ class AceStepHandler:
675
  # Load audio file
676
  audio, sr = torchaudio.load(audio_file)
677
 
 
 
 
 
678
  # Convert to stereo (duplicate channel if mono)
679
  if audio.shape[0] == 1:
680
  audio = torch.cat([audio, audio], dim=0)
@@ -1074,7 +1078,7 @@ class AceStepHandler:
1074
  expected_latent_length = current_wav.shape[-1] // 1920
1075
  target_latent = self.silence_latent[0, :expected_latent_length, :]
1076
  else:
1077
- target_latent = self.vae.encode(current_wav)
1078
  target_latent = target_latent.squeeze(0).transpose(0, 1)
1079
  target_latents_list.append(target_latent)
1080
  latent_lengths.append(target_latent.shape[0])
@@ -1430,7 +1434,7 @@ class AceStepHandler:
1430
  refer_audio_order_mask.append(batch_idx)
1431
  else:
1432
  for refer_audio in refer_audios:
1433
- refer_audio_latent = self.vae.encode(refer_audio.unsqueeze(0), chunked=False)
1434
  refer_audio_latents.append(refer_audio_latent.transpose(1, 2))
1435
  refer_audio_order_mask.append(batch_idx)
1436
 
 
675
  # Load audio file
676
  audio, sr = torchaudio.load(audio_file)
677
 
678
+ logger.info(f"Reference audio shape: {audio.shape}")
679
+ logger.info(f"Reference audio sample rate: {sr}")
680
+ logger.info(f"Reference audio duration: {audio.shape[-1] / 48000.0} seconds")
681
+
682
  # Convert to stereo (duplicate channel if mono)
683
  if audio.shape[0] == 1:
684
  audio = torch.cat([audio, audio], dim=0)
 
1078
  expected_latent_length = current_wav.shape[-1] // 1920
1079
  target_latent = self.silence_latent[0, :expected_latent_length, :]
1080
  else:
1081
+ target_latent = self.vae.encode(current_wav.to(self.device).to(self.dtype)).latent_dist.sample()
1082
  target_latent = target_latent.squeeze(0).transpose(0, 1)
1083
  target_latents_list.append(target_latent)
1084
  latent_lengths.append(target_latent.shape[0])
 
1434
  refer_audio_order_mask.append(batch_idx)
1435
  else:
1436
  for refer_audio in refer_audios:
1437
+ refer_audio_latent = self.vae.encode(refer_audio.unsqueeze(0)).latent_dist.sample()
1438
  refer_audio_latents.append(refer_audio_latent.transpose(1, 2))
1439
  refer_audio_order_mask.append(batch_idx)
1440