Gong Junmin commited on
Commit
12f9f66
·
1 Parent(s): 11a221a
Files changed (1) hide show
  1. acestep/handler.py +1293 -432
acestep/handler.py CHANGED
@@ -12,16 +12,33 @@ import random
12
  from typing import Optional, Dict, Any, Tuple, List, Union
13
 
14
  import torch
 
15
  import matplotlib.pyplot as plt
16
  import numpy as np
17
  import scipy.io.wavfile as wavfile
 
18
  import soundfile as sf
19
  import time
 
 
20
 
21
  from transformers import AutoTokenizer, AutoModel
22
  from diffusers.models import AutoencoderOobleck
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class AceStepHandler:
26
  """ACE-Step Business Logic Handler"""
27
 
@@ -166,18 +183,17 @@ class AceStepHandler:
166
  if os.path.exists(acestep_v15_checkpoint_path):
167
  # Determine attention implementation
168
  attn_implementation = "flash_attention_2" if use_flash_attention and self.is_flash_attention_available() else "eager"
169
- self.model = AutoModel.from_pretrained(
170
- acestep_v15_checkpoint_path,
171
- trust_remote_code=True,
172
- attn_implementation=attn_implementation
173
- )
174
  self.config = self.model.config
175
  # Move model to device and set dtype
176
  self.model = self.model.to(device).to(self.dtype)
177
  self.model.eval()
178
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
179
  if os.path.exists(silence_latent_path):
180
- self.silence_latent = torch.load(silence_latent_path).transpose(1, 2).squeeze(0) # [L, C]
181
  self.silence_latent = self.silence_latent.to(device).to(self.dtype)
182
  else:
183
  raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
@@ -395,50 +411,7 @@ class AceStepHandler:
395
  metadata[key] = value
396
 
397
  return metadata, audio_codes
398
-
399
- def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]:
400
- """Process reference audio"""
401
- if audio_file is None:
402
- return None
403
-
404
- try:
405
- # Load audio using soundfile
406
- audio_np, sr = sf.read(audio_file, dtype='float32')
407
- # Convert to torch: [samples, channels] or [samples] -> [channels, samples]
408
- if audio_np.ndim == 1:
409
- audio = torch.from_numpy(audio_np).unsqueeze(0)
410
- else:
411
- audio = torch.from_numpy(audio_np.T)
412
-
413
- if audio.shape[0] == 1:
414
- audio = torch.cat([audio, audio], dim=0)
415
-
416
- audio = audio[:2]
417
-
418
- # Resample if needed
419
- if sr != 48000:
420
- import torch.nn.functional as F
421
- # Simple resampling using interpolate
422
- ratio = 48000 / sr
423
- new_length = int(audio.shape[-1] * ratio)
424
- audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0)
425
-
426
- audio = torch.clamp(audio, -1.0, 1.0)
427
-
428
- target_frames = 30 * 48000
429
- if audio.shape[-1] > target_frames:
430
- start_frame = (audio.shape[-1] - target_frames) // 2
431
- audio = audio[:, start_frame:start_frame + target_frames]
432
- elif audio.shape[-1] < target_frames:
433
- audio = torch.nn.functional.pad(
434
- audio, (0, target_frames - audio.shape[-1]), 'constant', 0
435
- )
436
-
437
- return audio
438
- except Exception as e:
439
- print(f"Error processing reference audio: {e}")
440
- return None
441
-
442
  def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
443
  """Process target audio"""
444
  if audio_file is None:
@@ -541,18 +514,32 @@ class AceStepHandler:
541
  )
542
 
543
  def _parse_metas(self, metas: List[Union[str, Dict[str, Any]]]) -> List[str]:
544
- """Parse and normalize metadata with fallbacks."""
 
 
 
 
 
 
 
 
545
  parsed_metas = []
546
  for meta in metas:
547
  if meta is None:
 
548
  parsed_meta = self._create_default_meta()
549
  elif isinstance(meta, str):
 
550
  parsed_meta = meta
551
  elif isinstance(meta, dict):
 
552
  parsed_meta = self._dict_to_meta_string(meta)
553
  else:
 
554
  parsed_meta = self._create_default_meta()
 
555
  parsed_metas.append(parsed_meta)
 
556
  return parsed_metas
557
 
558
  def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -586,12 +573,1157 @@ class AceStepHandler:
586
  return text_hidden_states, text_attention_mask
587
 
588
  def extract_caption_from_sft_format(self, caption: str) -> str:
589
- """Extract caption from SFT format if needed."""
590
- # Simple extraction - can be enhanced if needed
591
- if caption and isinstance(caption, str):
592
- return caption.strip()
593
- return caption if caption else ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
  def generate_music(
596
  self,
597
  captions: str,
@@ -631,384 +1763,114 @@ class AceStepHandler:
631
  """
632
  if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
633
  return None, None, [], "", "❌ Model not fully initialized. Please initialize all components first.", "-1", "", "", None, "", "", None
634
-
635
- try:
636
- print("[generate_music] Starting generation...")
637
- if progress:
638
- progress(0.05, desc="Preparing inputs...")
639
- print("[generate_music] Preparing inputs...")
640
-
641
- # Determine actual batch size
642
- actual_batch_size = batch_size if batch_size is not None else self.batch_size
643
- actual_batch_size = max(1, min(actual_batch_size, 8)) # Limit to 8 for memory safety
644
-
645
- # Process seeds
646
- if use_random_seed:
647
- seed_list = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
648
- else:
649
- # Parse seed input
650
- if isinstance(seed, str):
651
- seed_parts = [s.strip() for s in seed.split(",")]
652
- seed_list = [int(float(s)) if s != "-1" and s else random.randint(0, 2**32 - 1) for s in seed_parts[:actual_batch_size]]
653
- elif isinstance(seed, (int, float)) and seed >= 0:
654
- seed_list = [int(seed)] * actual_batch_size
655
- else:
656
- seed_list = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
657
-
658
- # Pad if needed
659
- while len(seed_list) < actual_batch_size:
660
- seed_list.append(random.randint(0, 2**32 - 1))
661
-
662
- seed_value_for_ui = ", ".join(str(s) for s in seed_list)
663
-
664
- # Process audio inputs
665
- processed_ref_audio = self.process_reference_audio(reference_audio) if reference_audio else None
666
- processed_src_audio = self.process_target_audio(src_audio) if src_audio else None
667
-
668
- # Extract caption
669
- pure_caption = self.extract_caption_from_sft_format(captions)
670
-
671
- # Determine task type and update instruction if needed
672
- if task_type == "text2music" and audio_code_string and str(audio_code_string).strip():
673
  task_type = "cover"
 
674
  instruction = "Generate audio semantic tokens based on the given conditions:"
675
-
676
- # Build metadata
677
- metadata_dict = {
678
- "bpm": bpm if bpm else "N/A",
679
- "keyscale": key_scale if key_scale else "N/A",
680
- "timesignature": time_signature if time_signature else "N/A",
681
- }
682
-
683
- # Calculate duration
684
- if processed_src_audio is not None:
685
- calculated_duration = processed_src_audio.shape[-1] / self.sample_rate
686
- elif audio_duration is not None and audio_duration > 0:
687
- calculated_duration = audio_duration
688
- else:
689
- calculated_duration = 30.0 # Default 30 seconds
690
-
691
- metadata_dict["duration"] = f"{int(calculated_duration)} seconds"
692
-
693
- if progress:
694
- progress(0.1, desc="Processing audio inputs...")
695
- print("[generate_music] Processing audio inputs...")
696
-
697
- # Prepare batch data
698
- captions_batch = [pure_caption] * actual_batch_size
699
- lyrics_batch = [lyrics] * actual_batch_size
700
- vocal_languages_batch = [vocal_language] * actual_batch_size
701
- instructions_batch = [instruction] * actual_batch_size
702
- metas_batch = [metadata_dict.copy()] * actual_batch_size
703
- audio_code_hints_batch = [audio_code_string if audio_code_string else None] * actual_batch_size
704
-
705
- # Process reference audios
706
- if processed_ref_audio is not None:
707
- refer_audios = [[processed_ref_audio] for _ in range(actual_batch_size)]
708
- else:
709
- # Create silence as fallback
710
- silence_frames = 30 * self.sample_rate
711
- silence = torch.zeros(2, silence_frames)
712
- refer_audios = [[silence] for _ in range(actual_batch_size)]
713
-
714
- # Process target wavs (src_audio)
715
- if processed_src_audio is not None:
716
- target_wavs_list = [processed_src_audio.clone() for _ in range(actual_batch_size)]
717
- else:
718
- # Create silence based on duration
719
- target_frames = int(calculated_duration * self.sample_rate)
720
- silence = torch.zeros(2, target_frames)
721
- target_wavs_list = [silence for _ in range(actual_batch_size)]
722
-
723
- # Pad target_wavs to consistent length
724
- max_target_frames = max(wav.shape[-1] for wav in target_wavs_list)
725
- target_wavs = torch.stack([
726
- torch.nn.functional.pad(wav, (0, max_target_frames - wav.shape[-1]), 'constant', 0)
727
- for wav in target_wavs_list
728
- ])
729
-
730
- if progress:
731
- progress(0.2, desc="Encoding audio to latents...")
732
- print("[generate_music] Encoding audio to latents...")
733
-
734
- # Encode target_wavs to latents using VAE
735
- target_latents_list = []
736
- latent_lengths = []
737
-
738
- with torch.no_grad():
739
- for i in range(actual_batch_size):
740
- # Check if audio codes are provided
741
- code_hint = audio_code_hints_batch[i]
742
- if code_hint:
743
- decoded_latents = self._decode_audio_codes_to_latents(code_hint)
744
- if decoded_latents is not None:
745
- decoded_latents = decoded_latents.squeeze(0) # Remove batch dim
746
- target_latents_list.append(decoded_latents)
747
- latent_lengths.append(decoded_latents.shape[0])
748
- continue
749
-
750
- # If no src_audio provided, use silence_latent directly (skip VAE)
751
- if processed_src_audio is None:
752
- # Calculate required latent length based on duration
753
- # VAE downsample ratio is 1920 (2*4*4*6*10), so latent rate is 48000/1920 = 25Hz
754
- latent_length = int(calculated_duration * 25) # 25Hz latent rate
755
- latent_length = max(128, latent_length) # Minimum 128
756
-
757
- # Tile silence_latent to required length
758
- if self.silence_latent.shape[0] >= latent_length:
759
- target_latent = self.silence_latent[:latent_length].to(self.device).to(self.dtype)
760
- else:
761
- repeat_times = (latent_length // self.silence_latent.shape[0]) + 1
762
- target_latent = self.silence_latent.repeat(repeat_times, 1)[:latent_length].to(self.device).to(self.dtype)
763
- target_latents_list.append(target_latent)
764
- latent_lengths.append(target_latent.shape[0])
765
- continue
766
-
767
- # Encode from audio using VAE
768
- current_wav = target_wavs[i].unsqueeze(0).to(self.device).to(self.dtype)
769
- target_latent = self.vae.encode(current_wav)
770
- target_latent = target_latent.squeeze(0).transpose(0, 1) # [latent_length, latent_dim]
771
- target_latents_list.append(target_latent)
772
- latent_lengths.append(target_latent.shape[0])
773
-
774
- # Pad latents to same length
775
- max_latent_length = max(latent_lengths)
776
- max_latent_length = max(128, max_latent_length) # Minimum 128
777
-
778
- padded_latents = []
779
- for i, latent in enumerate(target_latents_list):
780
- if latent.shape[0] < max_latent_length:
781
- pad_length = max_latent_length - latent.shape[0]
782
- # Tile silence_latent to pad_length (silence_latent is [L, C])
783
- if self.silence_latent.shape[0] >= pad_length:
784
- pad_latent = self.silence_latent[:pad_length]
785
- else:
786
- repeat_times = (pad_length // self.silence_latent.shape[0]) + 1
787
- pad_latent = self.silence_latent.repeat(repeat_times, 1)[:pad_length]
788
- latent = torch.cat([latent, pad_latent.to(self.device).to(self.dtype)], dim=0)
789
- padded_latents.append(latent)
790
-
791
- target_latents = torch.stack(padded_latents).to(self.device).to(self.dtype)
792
- latent_masks = torch.stack([
793
- torch.cat([
794
- torch.ones(l, dtype=torch.long, device=self.device),
795
- torch.zeros(max_latent_length - l, dtype=torch.long, device=self.device)
796
- ])
797
- for l in latent_lengths
798
- ])
799
-
800
- if progress:
801
- progress(0.3, desc="Preparing conditions...")
802
- print("[generate_music] Preparing conditions...")
803
-
804
- # Determine task type and create chunk masks
805
- is_covers = []
806
- chunk_masks = []
807
- repainting_ranges = {}
808
-
809
- for i in range(actual_batch_size):
810
- has_code_hint = audio_code_hints_batch[i] is not None
811
- has_repainting = (repainting_end is not None and repainting_end > repainting_start)
812
-
813
- if has_repainting:
814
- # Repainting mode
815
- start_sec = max(0, repainting_start)
816
- end_sec = repainting_end if repainting_end is not None else calculated_duration
817
-
818
- start_latent = int(start_sec * self.sample_rate // 1920)
819
- end_latent = int(end_sec * self.sample_rate // 1920)
820
- start_latent = max(0, min(start_latent, max_latent_length - 1))
821
- end_latent = max(start_latent + 1, min(end_latent, max_latent_length))
822
-
823
- mask = torch.zeros(max_latent_length, dtype=torch.bool, device=self.device)
824
- mask[start_latent:end_latent] = True
825
- chunk_masks.append(mask)
826
- repainting_ranges[i] = (start_latent, end_latent)
827
- is_covers.append(False)
828
- else:
829
- # Full generation or cover
830
- chunk_masks.append(torch.ones(max_latent_length, dtype=torch.bool, device=self.device))
831
- # Check if cover task
832
- instruction_lower = instructions_batch[i].lower()
833
- is_cover = ("generate audio semantic tokens" in instruction_lower and
834
- "based on the given conditions" in instruction_lower) or has_code_hint
835
- is_covers.append(is_cover)
836
-
837
- chunk_masks = torch.stack(chunk_masks).unsqueeze(-1).expand(-1, -1, 64) # [batch, length, 64]
838
- is_covers = torch.tensor(is_covers, dtype=torch.bool, device=self.device)
839
-
840
- # Create src_latents
841
- # Tile silence_latent to max_latent_length (silence_latent is now [L, C])
842
- if self.silence_latent.shape[0] >= max_latent_length:
843
- silence_latent_tiled = self.silence_latent[:max_latent_length].to(self.device).to(self.dtype)
844
- else:
845
- repeat_times = (max_latent_length // self.silence_latent.shape[0]) + 1
846
- silence_latent_tiled = self.silence_latent.repeat(repeat_times, 1)[:max_latent_length].to(self.device).to(self.dtype)
847
- src_latents_list = []
848
-
849
- for i in range(actual_batch_size):
850
- has_target_audio = (target_wavs[i].abs().sum() > 1e-6) or (audio_code_hints_batch[i] is not None)
851
-
852
- if has_target_audio:
853
- if i in repainting_ranges:
854
- # Repaint: replace inpainting region with silence
855
- src_latent = target_latents[i].clone()
856
- start_latent, end_latent = repainting_ranges[i]
857
- src_latent[start_latent:end_latent] = silence_latent_tiled[start_latent:end_latent]
858
- src_latents_list.append(src_latent)
859
- else:
860
- # Cover/extract/complete/lego: use target_latents
861
- src_latents_list.append(target_latents[i].clone())
862
- else:
863
- # Text2music: use silence
864
- src_latents_list.append(silence_latent_tiled.clone())
865
-
866
- src_latents = torch.stack(src_latents_list) # [batch, length, channels]
867
-
868
- if progress:
869
- progress(0.4, desc="Tokenizing text inputs...")
870
- print("[generate_music] Tokenizing text inputs...")
871
-
872
- # Prepare text and lyric hidden states
873
- SFT_GEN_PROMPT = """# Instruction
874
- {}
875
 
876
- # Caption
877
- {}
 
 
878
 
879
- # Metas
880
- {}<|endoftext|>
881
- """
 
 
 
 
 
 
 
 
 
 
 
882
 
883
- text_hidden_states_list = []
884
- text_attention_masks_list = []
885
- lyric_hidden_states_list = []
886
- lyric_attention_masks_list = []
 
 
 
 
 
 
 
 
 
887
 
888
- with torch.no_grad():
889
- for i in range(actual_batch_size):
890
- # Format text prompt
891
- inst = instructions_batch[i]
892
- if not inst.endswith(":"):
893
- inst = inst + ":"
894
-
895
- meta_str = self._dict_to_meta_string(metas_batch[i])
896
- text_prompt = SFT_GEN_PROMPT.format(inst, captions_batch[i], meta_str)
897
-
898
- # Tokenize and encode text
899
- text_hidden, text_mask = self._get_text_hidden_states(text_prompt)
900
- text_hidden_states_list.append(text_hidden.squeeze(0))
901
- text_attention_masks_list.append(text_mask.squeeze(0))
902
-
903
- # Format and tokenize lyrics
904
- lyrics_text = f"# Languages\n{vocal_languages_batch[i]}\n\n# Lyric\n{lyrics_batch[i]}<|endoftext|>"
905
- lyric_hidden, lyric_mask = self._get_text_hidden_states(lyrics_text)
906
- lyric_hidden_states_list.append(lyric_hidden.squeeze(0))
907
- lyric_attention_masks_list.append(lyric_mask.squeeze(0))
908
-
909
- # Pad sequences
910
- max_text_length = max(h.shape[0] for h in text_hidden_states_list)
911
- max_lyric_length = max(h.shape[0] for h in lyric_hidden_states_list)
912
 
913
- text_hidden_states = torch.stack([
914
- torch.nn.functional.pad(h, (0, 0, 0, max_text_length - h.shape[0]), 'constant', 0)
915
- for h in text_hidden_states_list
916
- ]).to(self.device).to(self.dtype)
917
-
918
- text_attention_mask = torch.stack([
919
- torch.nn.functional.pad(m, (0, max_text_length - m.shape[0]), 'constant', 0)
920
- for m in text_attention_masks_list
921
- ]).to(self.device)
922
-
923
- lyric_hidden_states = torch.stack([
924
- torch.nn.functional.pad(h, (0, 0, 0, max_lyric_length - h.shape[0]), 'constant', 0)
925
- for h in lyric_hidden_states_list
926
- ]).to(self.device).to(self.dtype)
927
-
928
- lyric_attention_mask = torch.stack([
929
- torch.nn.functional.pad(m, (0, max_lyric_length - m.shape[0]), 'constant', 0)
930
- for m in lyric_attention_masks_list
931
- ]).to(self.device)
932
-
933
- if progress:
934
- progress(0.5, desc="Processing reference audio...")
935
- print("[generate_music] Processing reference audio...")
936
 
937
- # Process reference audio for timbre
938
- # Model expects: refer_audio_acoustic_hidden_states_packed [N, timbre_fix_frame, audio_acoustic_hidden_dim]
939
- # refer_audio_order_mask [N] indicating batch assignment
940
- timbre_fix_frame = getattr(self.config, 'timbre_fix_frame', 750)
941
- refer_audio_acoustic_hidden_states_packed_list = []
942
- refer_audio_order_mask_list = []
 
 
 
 
 
 
 
943
 
944
- with torch.no_grad():
945
- for i, ref_audio_list in enumerate(refer_audios):
946
- if ref_audio_list and len(ref_audio_list) > 0 and ref_audio_list[0].abs().sum() > 1e-6:
947
- # Encode reference audio: [channels, samples] -> [1, latent_dim, T] -> [T, latent_dim]
948
- ref_audio = ref_audio_list[0].unsqueeze(0).to(self.device).to(self.dtype)
949
- ref_latent = self.vae.encode(ref_audio).latent_dist.sample() # [1, latent_dim, T]
950
- ref_latent = ref_latent.squeeze(0).transpose(0, 1) # [T, latent_dim]
951
- # Ensure dimension matches audio_acoustic_hidden_dim (64)
952
- if ref_latent.shape[-1] != self.config.audio_acoustic_hidden_dim:
953
- ref_latent = ref_latent[:, :self.config.audio_acoustic_hidden_dim]
954
- # Pad or truncate to timbre_fix_frame
955
- if ref_latent.shape[0] < timbre_fix_frame:
956
- pad_length = timbre_fix_frame - ref_latent.shape[0]
957
- padding = torch.zeros(pad_length, ref_latent.shape[1], device=self.device, dtype=self.dtype)
958
- ref_latent = torch.cat([ref_latent, padding], dim=0)
959
- else:
960
- ref_latent = ref_latent[:timbre_fix_frame]
961
- refer_audio_acoustic_hidden_states_packed_list.append(ref_latent)
962
- refer_audio_order_mask_list.append(i)
963
- else:
964
- # Use silence_latent directly instead of running VAE
965
- if self.silence_latent.shape[0] >= timbre_fix_frame:
966
- silence_ref = self.silence_latent[:timbre_fix_frame, :self.config.audio_acoustic_hidden_dim]
967
- else:
968
- repeat_times = (timbre_fix_frame // self.silence_latent.shape[0]) + 1
969
- silence_ref = self.silence_latent.repeat(repeat_times, 1)[:timbre_fix_frame, :self.config.audio_acoustic_hidden_dim]
970
- refer_audio_acoustic_hidden_states_packed_list.append(silence_ref.to(self.device).to(self.dtype))
971
- refer_audio_order_mask_list.append(i)
972
-
973
- # Stack all reference audios: [N, timbre_fix_frame, audio_acoustic_hidden_dim]
974
- refer_audio_acoustic_hidden_states_packed = torch.stack(refer_audio_acoustic_hidden_states_packed_list, dim=0).to(self.device).to(self.dtype)
975
- # Order mask: [N] indicating which batch item each reference belongs to
976
- refer_audio_order_mask = torch.tensor(refer_audio_order_mask_list, dtype=torch.long, device=self.device)
977
 
978
- if progress:
979
- progress(0.6, desc="Generating audio...")
980
- print("[generate_music] Calling model.generate_audio()...")
981
- print(f" - text_hidden_states: {text_hidden_states.shape}, dtype={text_hidden_states.dtype}")
982
- print(f" - text_attention_mask: {text_attention_mask.shape}, dtype={text_attention_mask.dtype}")
983
- print(f" - lyric_hidden_states: {lyric_hidden_states.shape}, dtype={lyric_hidden_states.dtype}")
984
- print(f" - lyric_attention_mask: {lyric_attention_mask.shape}, dtype={lyric_attention_mask.dtype}")
985
- print(f" - refer_audio_acoustic_hidden_states_packed: {refer_audio_acoustic_hidden_states_packed.shape}, dtype={refer_audio_acoustic_hidden_states_packed.dtype}")
986
- print(f" - refer_audio_order_mask: {refer_audio_order_mask.shape}, dtype={refer_audio_order_mask.dtype}")
987
- print(f" - src_latents: {src_latents.shape}, dtype={src_latents.dtype}")
988
- print(f" - chunk_masks: {chunk_masks.shape}, dtype={chunk_masks.dtype}")
989
- print(f" - is_covers: {is_covers.shape}, dtype={is_covers.dtype}")
990
- print(f" - silence_latent: {self.silence_latent.unsqueeze(0).shape}")
991
- print(f" - seed: {seed_list[0] if len(seed_list) > 0 else None}")
992
- print(f" - fix_nfe: {inference_steps}")
993
-
994
- # Call model to generate
995
- with torch.no_grad():
996
- outputs = self.model.generate_audio(
997
- text_hidden_states=text_hidden_states,
998
- text_attention_mask=text_attention_mask,
999
- lyric_hidden_states=lyric_hidden_states,
1000
- lyric_attention_mask=lyric_attention_mask,
1001
- refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
1002
- refer_audio_order_mask=refer_audio_order_mask,
1003
- src_latents=src_latents,
1004
- chunk_masks=chunk_masks,
1005
- is_covers=is_covers,
1006
- silence_latent=self.silence_latent.unsqueeze(0), # [1, L, C]
1007
- seed=seed_list[0] if len(seed_list) > 0 else None,
1008
- fix_nfe=inference_steps,
1009
- infer_method="ode",
1010
- use_cache=True,
1011
- )
1012
 
1013
  print("[generate_music] Model generation completed. Decoding latents...")
1014
  pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
@@ -1040,7 +1902,7 @@ class AceStepHandler:
1040
 
1041
  saved_files = []
1042
  for i in range(actual_batch_size):
1043
- audio_file = os.path.join(self.temp_dir, f"generated_{i}_{seed_list[i]}.{audio_format_lower}")
1044
  # Convert to numpy: [channels, samples] -> [samples, channels]
1045
  audio_np = pred_wavs[i].cpu().float().numpy().T
1046
  sf.write(audio_file, audio_np, self.sample_rate)
@@ -1064,10 +1926,9 @@ class AceStepHandler:
1064
 
1065
  generation_info = f"""**🎵 Generation Complete**
1066
 
1067
- **Seeds:** {seed_value_for_ui}
1068
- **Duration:** {calculated_duration:.1f}s
1069
- **Steps:** {inference_steps}
1070
- **Files:** {len(saved_files)} audio(s){time_costs_str}"""
1071
  status_message = f"✅ Generation completed successfully!"
1072
  print(f"[generate_music] Done! Generated {len(saved_files)} audio files.")
1073
 
@@ -1093,8 +1954,8 @@ class AceStepHandler:
1093
  align_text_2,
1094
  align_plot_2,
1095
  )
1096
-
1097
  except Exception as e:
1098
- error_msg = f"❌ Error generating music: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
1099
- return None, None, [], "", error_msg, "-1", "", "", None, "", "", None
1100
 
 
12
  from typing import Optional, Dict, Any, Tuple, List, Union
13
 
14
  import torch
15
+ import torch.nn.functional as F
16
  import matplotlib.pyplot as plt
17
  import numpy as np
18
  import scipy.io.wavfile as wavfile
19
+ import torchaudio
20
  import soundfile as sf
21
  import time
22
+ from loguru import logger
23
+ import warnings
24
 
25
  from transformers import AutoTokenizer, AutoModel
26
  from diffusers.models import AutoencoderOobleck
27
 
28
 
29
+ warnings.filterwarnings("ignore")
30
+
31
+
32
+ SFT_GEN_PROMPT = """# Instruction
33
+ {}
34
+
35
+ # Caption
36
+ {}
37
+
38
+ # Metas
39
+ {}<|endoftext|>
40
+ """
41
+
42
  class AceStepHandler:
43
  """ACE-Step Business Logic Handler"""
44
 
 
183
  if os.path.exists(acestep_v15_checkpoint_path):
184
  # Determine attention implementation
185
  attn_implementation = "flash_attention_2" if use_flash_attention and self.is_flash_attention_available() else "eager"
186
+ if use_flash_attention and self.is_flash_attention_available():
187
+ self.dtype = torch.bfloat16
188
+ self.model = AutoModel.from_pretrained(acestep_v15_checkpoint_path, trust_remote_code=True, dtype=self.dtype)
189
+ self.model.config._attn_implementation = attn_implementation
 
190
  self.config = self.model.config
191
  # Move model to device and set dtype
192
  self.model = self.model.to(device).to(self.dtype)
193
  self.model.eval()
194
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
195
  if os.path.exists(silence_latent_path):
196
+ self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
197
  self.silence_latent = self.silence_latent.to(device).to(self.dtype)
198
  else:
199
  raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
 
411
  metadata[key] = value
412
 
413
  return metadata, audio_codes
414
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
416
  """Process target audio"""
417
  if audio_file is None:
 
514
  )
515
 
516
  def _parse_metas(self, metas: List[Union[str, Dict[str, Any]]]) -> List[str]:
517
+ """
518
+ Parse and normalize metadata with fallbacks.
519
+
520
+ Args:
521
+ metas: List of metadata (can be strings, dicts, or None)
522
+
523
+ Returns:
524
+ List of formatted metadata strings
525
+ """
526
  parsed_metas = []
527
  for meta in metas:
528
  if meta is None:
529
+ # Default fallback metadata
530
  parsed_meta = self._create_default_meta()
531
  elif isinstance(meta, str):
532
+ # Already formatted string
533
  parsed_meta = meta
534
  elif isinstance(meta, dict):
535
+ # Convert dict to formatted string
536
  parsed_meta = self._dict_to_meta_string(meta)
537
  else:
538
+ # Fallback for any other type
539
  parsed_meta = self._create_default_meta()
540
+
541
  parsed_metas.append(parsed_meta)
542
+
543
  return parsed_metas
544
 
545
  def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
 
573
  return text_hidden_states, text_attention_mask
574
 
575
  def extract_caption_from_sft_format(self, caption: str) -> str:
576
+ try:
577
+ if "# Instruction" in caption and "# Caption" in caption:
578
+ pattern = r'#\s*Caption\s*\n(.*?)(?:\n\s*#\s*Metas|$)'
579
+ match = re.search(pattern, caption, re.DOTALL)
580
+ if match:
581
+ return match.group(1).strip()
582
+ return caption
583
+ except Exception as e:
584
+ print(f"Error extracting caption: {e}")
585
+ return caption
586
+
587
+ def prepare_seeds(self, actual_batch_size, seed, use_random_seed):
588
+ actual_seed_list: List[int] = []
589
+ seed_value_for_ui = ""
590
+
591
+ if use_random_seed:
592
+ # Generate brand new seeds and expose them back to the UI
593
+ actual_seed_list = [random.randint(0, 2 ** 32 - 1) for _ in range(actual_batch_size)]
594
+ seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list)
595
+ else:
596
+ # Parse seed input: can be a single number, comma-separated numbers, or -1
597
+ # If seed is a string, try to parse it as comma-separated values
598
+ seed_list = []
599
+ if isinstance(seed, str):
600
+ # Handle string input (e.g., "123,456" or "-1")
601
+ seed_str_list = [s.strip() for s in seed.split(",")]
602
+ for s in seed_str_list:
603
+ if s == "-1" or s == "":
604
+ seed_list.append(-1)
605
+ else:
606
+ try:
607
+ seed_list.append(int(float(s)))
608
+ except (ValueError, TypeError):
609
+ seed_list.append(-1)
610
+ elif seed is None or (isinstance(seed, (int, float)) and seed < 0):
611
+ # If seed is None or negative, use -1 for all items
612
+ seed_list = [-1] * actual_batch_size
613
+ elif isinstance(seed, (int, float)):
614
+ # Single seed value
615
+ seed_list = [int(seed)]
616
+ else:
617
+ # Fallback: use -1
618
+ seed_list = [-1] * actual_batch_size
619
+
620
+ # Process seed list according to rules:
621
+ # 1. If all are -1, generate different random seeds for each batch item
622
+ # 2. If one non-negative seed is provided and batch_size > 1, first uses that seed, rest are random
623
+ # 3. If more seeds than batch_size, use first batch_size seeds
624
+ # Check if user provided only one non-negative seed (not -1)
625
+ has_single_non_negative_seed = (len(seed_list) == 1 and seed_list[0] != -1)
626
+
627
+ for i in range(actual_batch_size):
628
+ if i < len(seed_list):
629
+ seed_val = seed_list[i]
630
+ else:
631
+ # If not enough seeds provided, use -1 (will generate random)
632
+ seed_val = -1
633
+
634
+ # Special case: if only one non-negative seed was provided and batch_size > 1,
635
+ # only the first item uses that seed, others are random
636
+ if has_single_non_negative_seed and actual_batch_size > 1 and i > 0:
637
+ # Generate random seed for remaining items
638
+ actual_seed_list.append(random.randint(0, 2 ** 32 - 1))
639
+ elif seed_val == -1:
640
+ # Generate a random seed for this item
641
+ actual_seed_list.append(random.randint(0, 2 ** 32 - 1))
642
+ else:
643
+ actual_seed_list.append(int(seed_val))
644
+
645
+ seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list)
646
+ return actual_seed_list, seed_value_for_ui
647
+
648
+ def prepare_metadata(self, bpm, key_scale, time_signature):
649
+ # Build metadata dict - use "N/A" as default for empty fields
650
+ metadata_dict = {}
651
+ if bpm:
652
+ metadata_dict["bpm"] = bpm
653
+ else:
654
+ metadata_dict["bpm"] = "N/A"
655
+
656
+ if key_scale.strip():
657
+ metadata_dict["keyscale"] = key_scale
658
+ else:
659
+ metadata_dict["keyscale"] = "N/A"
660
+
661
+ if time_signature.strip() and time_signature != "N/A" and time_signature:
662
+ metadata_dict["timesignature"] = time_signature
663
+ else:
664
+ metadata_dict["timesignature"] = "N/A"
665
+ return metadata_dict
666
+
667
+ def is_silence(self, audio):
668
+ return torch.all(audio.abs() < 1e-6)
669
+
670
+ def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]:
671
+ if audio_file is None:
672
+ return None
673
+
674
+ try:
675
+ # Load audio file
676
+ audio, sr = torchaudio.load(audio_file)
677
+
678
+ # Convert to stereo (duplicate channel if mono)
679
+ if audio.shape[0] == 1:
680
+ audio = torch.cat([audio, audio], dim=0)
681
+
682
+ # Keep only first 2 channels
683
+ audio = audio[:2]
684
+
685
+ # Resample to 48kHz if needed
686
+ if sr != 48000:
687
+ audio = torchaudio.transforms.Resample(sr, 48000)(audio)
688
+
689
+ # Clamp values to [-1.0, 1.0]
690
+ audio = torch.clamp(audio, -1.0, 1.0)
691
+
692
+ is_silence = self.is_silence(audio)
693
+ if is_silence:
694
+ return None
695
+
696
+ # Target length: 30 seconds at 48kHz
697
+ target_frames = 30 * 48000
698
+ segment_frames = 10 * 48000 # 10 seconds per segment
699
+
700
+ # If audio is less than 30 seconds, repeat to at least 30 seconds
701
+ if audio.shape[-1] < target_frames:
702
+ repeat_times = math.ceil(target_frames / audio.shape[-1])
703
+ audio = audio.repeat(1, repeat_times)
704
+ # If audio is greater than or equal to 30 seconds, no operation needed
705
+
706
+ # For all cases, select random 10-second segments from front, middle, and back
707
+ # then concatenate them to form 30 seconds
708
+ total_frames = audio.shape[-1]
709
+ segment_size = total_frames // 3
710
+
711
+ # Front segment: [0, segment_size]
712
+ front_start = random.randint(0, max(0, segment_size - segment_frames))
713
+ front_audio = audio[:, front_start:front_start + segment_frames]
714
+
715
+ # Middle segment: [segment_size, 2*segment_size]
716
+ middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames))
717
+ middle_audio = audio[:, middle_start:middle_start + segment_frames]
718
+
719
+ # Back segment: [2*segment_size, total_frames]
720
+ back_start = 2 * segment_size + random.randint(0, max(0, (total_frames - 2 * segment_size) - segment_frames))
721
+ back_audio = audio[:, back_start:back_start + segment_frames]
722
+
723
+ # Concatenate three segments to form 30 seconds
724
+ audio = torch.cat([front_audio, middle_audio, back_audio], dim=-1)
725
+
726
+ return audio
727
+
728
+ except Exception as e:
729
+ print(f"Error processing reference audio: {e}")
730
+ return None
731
+
732
+ def process_src_audio(self, audio_file) -> Optional[torch.Tensor]:
733
+ if audio_file is None:
734
+ return None
735
+
736
+ try:
737
+ # Load audio file
738
+ audio, sr = torchaudio.load(audio_file)
739
+
740
+ # Convert to stereo (duplicate channel if mono)
741
+ if audio.shape[0] == 1:
742
+ audio = torch.cat([audio, audio], dim=0)
743
+
744
+ # Keep only first 2 channels
745
+ audio = audio[:2]
746
+
747
+ # Resample to 48kHz if needed
748
+ if sr != 48000:
749
+ audio = torchaudio.transforms.Resample(sr, 48000)(audio)
750
+
751
+ # Clamp values to [-1.0, 1.0]
752
+ audio = torch.clamp(audio, -1.0, 1.0)
753
+
754
+ return audio
755
+
756
+ except Exception as e:
757
+ print(f"Error processing target audio: {e}")
758
+ return None
759
+
760
+ def prepare_batch_data(
761
+ self,
762
+ actual_batch_size,
763
+ processed_src_audio,
764
+ audio_duration,
765
+ captions,
766
+ lyrics,
767
+ vocal_language,
768
+ instruction,
769
+ bpm,
770
+ key_scale,
771
+ time_signature
772
+ ):
773
+ pure_caption = self.extract_caption_from_sft_format(captions)
774
+ captions_batch = [pure_caption] * actual_batch_size
775
+ instructions_batch = [instruction] * actual_batch_size
776
+ lyrics_batch = [lyrics] * actual_batch_size
777
+ vocal_languages_batch = [vocal_language] * actual_batch_size
778
+ # Calculate duration for metadata
779
+ calculated_duration = None
780
+ if processed_src_audio is not None:
781
+ calculated_duration = processed_src_audio.shape[-1] / 48000.0
782
+ elif audio_duration is not None and audio_duration > 0:
783
+ calculated_duration = audio_duration
784
+
785
+ # Build metadata dict - use "N/A" as default for empty fields
786
+ metadata_dict = {}
787
+ if bpm:
788
+ metadata_dict["bpm"] = bpm
789
+ else:
790
+ metadata_dict["bpm"] = "N/A"
791
+
792
+ if key_scale.strip():
793
+ metadata_dict["keyscale"] = key_scale
794
+ else:
795
+ metadata_dict["keyscale"] = "N/A"
796
+
797
+ if time_signature.strip() and time_signature != "N/A" and time_signature:
798
+ metadata_dict["timesignature"] = time_signature
799
+ else:
800
+ metadata_dict["timesignature"] = "N/A"
801
+
802
+ # Add duration to metadata if available (inference service format: "30 seconds")
803
+ if calculated_duration is not None:
804
+ metadata_dict["duration"] = f"{int(calculated_duration)} seconds"
805
+ # If duration not set, inference service will use default (30 seconds)
806
+
807
+ # Format metadata - inference service accepts dict and will convert to string
808
+ # Create a copy for each batch item (in case we modify it)
809
+ metas_batch = [metadata_dict.copy() for _ in range(actual_batch_size)]
810
+ return captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch
811
 
812
+ def determine_task_type(self, task_type, audio_code_string):
813
+ # Determine task type - repaint and lego tasks can have repainting parameters
814
+ # Other tasks (cover, text2music, extract, complete) should NOT have repainting
815
+ is_repaint_task = (task_type == "repaint")
816
+ is_lego_task = (task_type == "lego")
817
+ is_cover_task = (task_type == "cover")
818
+ if audio_code_string and str(audio_code_string).strip():
819
+ is_cover_task = True
820
+ # Both repaint and lego tasks can use repainting parameters for chunk mask
821
+ can_use_repainting = is_repaint_task or is_lego_task
822
+ return is_repaint_task, is_lego_task, is_cover_task, can_use_repainting
823
+
824
+ def create_target_wavs(self, duration_seconds: float) -> torch.Tensor:
825
+ try:
826
+ # Ensure minimum precision of 100ms
827
+ duration_seconds = max(0.1, round(duration_seconds, 1))
828
+ # Calculate frames for 48kHz stereo
829
+ frames = int(duration_seconds * 48000)
830
+ # Create silent stereo audio
831
+ target_wavs = torch.zeros(2, frames)
832
+ return target_wavs
833
+ except Exception as e:
834
+ print(f"Error creating target audio: {e}")
835
+ # Fallback to 30 seconds if error
836
+ return torch.zeros(2, 30 * 48000)
837
+
838
+ def prepare_padding_info(
839
+ self,
840
+ actual_batch_size,
841
+ processed_src_audio,
842
+ audio_duration,
843
+ repainting_start,
844
+ repainting_end,
845
+ is_repaint_task,
846
+ is_lego_task,
847
+ is_cover_task,
848
+ can_use_repainting,
849
+ ):
850
+ target_wavs_batch = []
851
+ # Store padding info for each batch item to adjust repainting coordinates
852
+ padding_info_batch = []
853
+ for i in range(actual_batch_size):
854
+ if processed_src_audio is not None:
855
+ if is_cover_task:
856
+ # Cover task: Use src_audio directly without padding
857
+ batch_target_wavs = processed_src_audio
858
+ padding_info_batch.append({
859
+ 'left_padding_duration': 0.0,
860
+ 'right_padding_duration': 0.0
861
+ })
862
+ elif is_repaint_task or is_lego_task:
863
+ # Repaint/lego task: May need padding for outpainting
864
+ src_audio_duration = processed_src_audio.shape[-1] / 48000.0
865
+
866
+ # Determine actual end time
867
+ if repainting_end is None or repainting_end < 0:
868
+ actual_end = src_audio_duration
869
+ else:
870
+ actual_end = repainting_end
871
+
872
+ left_padding_duration = max(0, -repainting_start) if repainting_start is not None else 0
873
+ right_padding_duration = max(0, actual_end - src_audio_duration)
874
+
875
+ # Create padded audio
876
+ left_padding_frames = int(left_padding_duration * 48000)
877
+ right_padding_frames = int(right_padding_duration * 48000)
878
+
879
+ if left_padding_frames > 0 or right_padding_frames > 0:
880
+ # Pad the src audio
881
+ batch_target_wavs = torch.nn.functional.pad(
882
+ processed_src_audio,
883
+ (left_padding_frames, right_padding_frames),
884
+ 'constant', 0
885
+ )
886
+ else:
887
+ batch_target_wavs = processed_src_audio
888
+
889
+ # Store padding info for coordinate adjustment
890
+ padding_info_batch.append({
891
+ 'left_padding_duration': left_padding_duration,
892
+ 'right_padding_duration': right_padding_duration
893
+ })
894
+ else:
895
+ # Other tasks: Use src_audio directly without padding
896
+ batch_target_wavs = processed_src_audio
897
+ padding_info_batch.append({
898
+ 'left_padding_duration': 0.0,
899
+ 'right_padding_duration': 0.0
900
+ })
901
+ else:
902
+ padding_info_batch.append({
903
+ 'left_padding_duration': 0.0,
904
+ 'right_padding_duration': 0.0
905
+ })
906
+ if audio_duration is not None and audio_duration > 0:
907
+ batch_target_wavs = self.create_target_wavs(audio_duration)
908
+ else:
909
+ import random
910
+ random_duration = random.uniform(10.0, 120.0)
911
+ batch_target_wavs = self.create_target_wavs(random_duration)
912
+ target_wavs_batch.append(batch_target_wavs)
913
+
914
+ # Stack target_wavs into batch tensor
915
+ # Ensure all tensors have the same shape by padding to max length
916
+ max_frames = max(wav.shape[-1] for wav in target_wavs_batch)
917
+ padded_target_wavs = []
918
+ for wav in target_wavs_batch:
919
+ if wav.shape[-1] < max_frames:
920
+ pad_frames = max_frames - wav.shape[-1]
921
+ padded_wav = torch.nn.functional.pad(wav, (0, pad_frames), 'constant', 0)
922
+ padded_target_wavs.append(padded_wav)
923
+ else:
924
+ padded_target_wavs.append(wav)
925
+
926
+ target_wavs_tensor = torch.stack(padded_target_wavs, dim=0) # [batch_size, 2, frames]
927
+
928
+ if can_use_repainting:
929
+ # Repaint task: Set repainting parameters
930
+ if repainting_start is None:
931
+ repainting_start_batch = None
932
+ elif isinstance(repainting_start, (int, float)):
933
+ if processed_src_audio is not None:
934
+ adjusted_start = repainting_start + padding_info_batch[0]['left_padding_duration']
935
+ repainting_start_batch = [adjusted_start] * actual_batch_size
936
+ else:
937
+ repainting_start_batch = [repainting_start] * actual_batch_size
938
+ else:
939
+ # List input - adjust each item
940
+ repainting_start_batch = []
941
+ for i in range(actual_batch_size):
942
+ if processed_src_audio is not None:
943
+ adjusted_start = repainting_start[i] + padding_info_batch[i]['left_padding_duration']
944
+ repainting_start_batch.append(adjusted_start)
945
+ else:
946
+ repainting_start_batch.append(repainting_start[i])
947
+
948
+ # Handle repainting_end - use src audio duration if not specified or negative
949
+ if processed_src_audio is not None:
950
+ # If src audio is provided, use its duration as default end
951
+ src_audio_duration = processed_src_audio.shape[-1] / 48000.0
952
+ if repainting_end is None or repainting_end < 0:
953
+ # Use src audio duration (before padding), then adjust for padding
954
+ adjusted_end = src_audio_duration + padding_info_batch[0]['left_padding_duration']
955
+ repainting_end_batch = [adjusted_end] * actual_batch_size
956
+ else:
957
+ # Adjust repainting_end to be relative to padded audio
958
+ adjusted_end = repainting_end + padding_info_batch[0]['left_padding_duration']
959
+ repainting_end_batch = [adjusted_end] * actual_batch_size
960
+ else:
961
+ # No src audio - repainting doesn't make sense without it
962
+ if repainting_end is None or repainting_end < 0:
963
+ repainting_end_batch = None
964
+ elif isinstance(repainting_end, (int, float)):
965
+ repainting_end_batch = [repainting_end] * actual_batch_size
966
+ else:
967
+ # List input - adjust each item
968
+ repainting_end_batch = []
969
+ for i in range(actual_batch_size):
970
+ if processed_src_audio is not None:
971
+ adjusted_end = repainting_end[i] + padding_info_batch[i]['left_padding_duration']
972
+ repainting_end_batch.append(adjusted_end)
973
+ else:
974
+ repainting_end_batch.append(repainting_end[i])
975
+ else:
976
+ # All other tasks (cover, text2music, extract, complete): No repainting
977
+ # Only repaint and lego tasks should have repainting parameters
978
+ repainting_start_batch = None
979
+ repainting_end_batch = None
980
+
981
+ return repainting_start_batch, repainting_end_batch, target_wavs_tensor
982
+
983
+ def _prepare_batch(
984
+ self,
985
+ captions: List[str],
986
+ lyrics: List[str],
987
+ keys: Optional[List[str]] = None,
988
+ target_wavs: Optional[torch.Tensor] = None,
989
+ refer_audios: Optional[List[List[torch.Tensor]]] = None,
990
+ metas: Optional[List[Union[str, Dict[str, Any]]]] = None,
991
+ vocal_languages: Optional[List[str]] = None,
992
+ repainting_start: Optional[List[float]] = None,
993
+ repainting_end: Optional[List[float]] = None,
994
+ instructions: Optional[List[str]] = None,
995
+ audio_code_hints: Optional[List[Optional[str]]] = None,
996
+ audio_cover_strength: float = 1.0,
997
+ ) -> Dict[str, Any]:
998
+ """
999
+ Prepare batch data with fallbacks for missing inputs.
1000
+
1001
+ Args:
1002
+ captions: List of text captions (optional, can be empty strings)
1003
+ lyrics: List of lyrics (optional, can be empty strings)
1004
+ keys: List of unique identifiers (optional)
1005
+ target_wavs: Target audio tensors (optional, will use silence if not provided)
1006
+ refer_audios: Reference audio tensors (optional, will use silence if not provided)
1007
+ metas: Metadata (optional, will use defaults if not provided)
1008
+ vocal_languages: Vocal languages (optional, will default to 'en')
1009
+
1010
+ Returns:
1011
+ Batch dictionary ready for model input
1012
+ """
1013
+ batch_size = len(captions)
1014
+
1015
+ # Ensure audio_code_hints is a list of the correct length
1016
+ if audio_code_hints is None:
1017
+ audio_code_hints = [None] * batch_size
1018
+ elif len(audio_code_hints) != batch_size:
1019
+ if len(audio_code_hints) == 1:
1020
+ audio_code_hints = audio_code_hints * batch_size
1021
+ else:
1022
+ audio_code_hints = audio_code_hints[:batch_size]
1023
+ while len(audio_code_hints) < batch_size:
1024
+ audio_code_hints.append(None)
1025
+
1026
+ for ii, refer_audio_list in enumerate(refer_audios):
1027
+ if isinstance(refer_audio_list, list):
1028
+ for idx, refer_audio in enumerate(refer_audio_list):
1029
+ refer_audio_list[idx] = refer_audio_list[idx].to(self.device).to(torch.bfloat16)
1030
+ elif isinstance(refer_audio_list, torch.tensor):
1031
+ refer_audios[ii] = refer_audios[ii].to(self.device)
1032
+
1033
+ if vocal_languages is None:
1034
+ vocal_languages = self._create_fallback_vocal_languages(batch_size)
1035
+
1036
+ # Normalize audio_code_hints to batch list
1037
+ if audio_code_hints is None:
1038
+ audio_code_hints = [None] * batch_size
1039
+ elif not isinstance(audio_code_hints, list):
1040
+ audio_code_hints = [audio_code_hints] * batch_size
1041
+ elif len(audio_code_hints) == 1 and batch_size > 1:
1042
+ audio_code_hints = audio_code_hints * batch_size
1043
+ else:
1044
+ audio_code_hints = (audio_code_hints + [None] * batch_size)[:batch_size]
1045
+ audio_code_hints = [hint if isinstance(hint, str) and hint.strip() else None for hint in audio_code_hints]
1046
+
1047
+ # Parse metas with fallbacks
1048
+ parsed_metas = self._parse_metas(metas)
1049
+
1050
+ # Encode target_wavs to get target_latents
1051
+ with torch.no_grad():
1052
+ target_latents_list = []
1053
+ latent_lengths = []
1054
+ # Use per-item wavs (may be adjusted if audio_code_hints are provided)
1055
+ target_wavs_list = [target_wavs[i].clone() for i in range(batch_size)]
1056
+ if target_wavs.device != self.device:
1057
+ target_wavs = target_wavs.to(self.device)
1058
+ for i in range(batch_size):
1059
+ code_hint = audio_code_hints[i]
1060
+ # Prefer decoding from provided audio codes
1061
+ if code_hint:
1062
+ decoded_latents = self._decode_audio_codes_to_latents(code_hint)
1063
+ if decoded_latents is not None:
1064
+ decoded_latents = decoded_latents.squeeze(0)
1065
+ target_latents_list.append(decoded_latents)
1066
+ latent_lengths.append(decoded_latents.shape[0])
1067
+ # Create a silent wav matching the latent length for downstream scaling
1068
+ frames_from_codes = max(1, int(decoded_latents.shape[0] * 1920))
1069
+ target_wavs_list[i] = torch.zeros(2, frames_from_codes)
1070
+ continue
1071
+ # Fallback to VAE encode from audio
1072
+ current_wav = target_wavs_list[i].to(self.device).unsqueeze(0)
1073
+ if self.is_silence(current_wav):
1074
+ expected_latent_length = current_wav.shape[-1] // 1920
1075
+ target_latent = self.silence_latent[0, :expected_latent_length, :]
1076
+ else:
1077
+ target_latent = self.vae.encode(current_wav)
1078
+ target_latent = target_latent.squeeze(0).transpose(0, 1)
1079
+ target_latents_list.append(target_latent)
1080
+ latent_lengths.append(target_latent.shape[0])
1081
+
1082
+ # Pad target_wavs to consistent length for outputs
1083
+ max_target_frames = max(wav.shape[-1] for wav in target_wavs_list)
1084
+ padded_target_wavs = []
1085
+ for wav in target_wavs_list:
1086
+ if wav.shape[-1] < max_target_frames:
1087
+ pad_frames = max_target_frames - wav.shape[-1]
1088
+ wav = torch.nn.functional.pad(wav, (0, pad_frames), "constant", 0)
1089
+ padded_target_wavs.append(wav)
1090
+ target_wavs = torch.stack(padded_target_wavs)
1091
+ wav_lengths = torch.tensor([target_wavs.shape[-1]] * batch_size, dtype=torch.long)
1092
+
1093
+ # Pad latents to same length
1094
+ max_latent_length = max(latent.shape[0] for latent in target_latents_list)
1095
+ max_latent_length = max(128, max_latent_length)
1096
+
1097
+ padded_latents = []
1098
+ for latent in target_latents_list:
1099
+ latent_length = latent.shape[0]
1100
+
1101
+ if latent.shape[0] < max_latent_length:
1102
+ pad_length = max_latent_length - latent.shape[0]
1103
+ latent = torch.cat([latent, self.silence_latent[0, :pad_length, :]], dim=0)
1104
+ padded_latents.append(latent)
1105
+
1106
+ target_latents = torch.stack(padded_latents)
1107
+ latent_masks = torch.stack([
1108
+ torch.cat([
1109
+ torch.ones(l, dtype=torch.long, device=self.device),
1110
+ torch.zeros(max_latent_length - l, dtype=torch.long, device=self.device)
1111
+ ])
1112
+ for l in latent_lengths
1113
+ ])
1114
+
1115
+ # Process instructions early so we can use them for task type detection
1116
+ # Use custom instructions if provided, otherwise use default
1117
+ if instructions is None:
1118
+ instructions = ["Fill the audio semantic mask based on the given conditions:"] * batch_size
1119
+
1120
+ # Ensure instructions list has the same length as batch_size
1121
+ if len(instructions) != batch_size:
1122
+ if len(instructions) == 1:
1123
+ instructions = instructions * batch_size
1124
+ else:
1125
+ # Pad or truncate to match batch_size
1126
+ instructions = instructions[:batch_size]
1127
+ while len(instructions) < batch_size:
1128
+ instructions.append("Fill the audio semantic mask based on the given conditions:")
1129
+
1130
+ # Generate chunk_masks and spans based on repainting parameters
1131
+ # Also determine if this is a cover task (target audio provided without repainting)
1132
+ chunk_masks = []
1133
+ spans = []
1134
+ is_covers = []
1135
+ # Store repainting latent ranges for later use in src_latents creation
1136
+ repainting_ranges = {} # {batch_idx: (start_latent, end_latent)}
1137
+
1138
+ for i in range(batch_size):
1139
+ has_code_hint = audio_code_hints[i] is not None
1140
+ # Check if repainting is enabled for this batch item
1141
+ has_repainting = False
1142
+ if repainting_start is not None and repainting_end is not None:
1143
+ start_sec = repainting_start[i] if repainting_start[i] is not None else 0.0
1144
+ end_sec = repainting_end[i]
1145
+
1146
+ if end_sec is not None and end_sec > start_sec:
1147
+ # Repainting mode with outpainting support
1148
+ # The target_wavs may have been padded for outpainting
1149
+ # Need to calculate the actual position in the padded audio
1150
+
1151
+ # Calculate padding (if start < 0, there's left padding)
1152
+ left_padding_sec = max(0, -start_sec)
1153
+
1154
+ # Adjust positions to account for padding
1155
+ # In the padded audio, the original start is shifted by left_padding
1156
+ adjusted_start_sec = start_sec + left_padding_sec
1157
+ adjusted_end_sec = end_sec + left_padding_sec
1158
+
1159
+ # Convert seconds to latent frames (audio_frames / 1920 = latent_frames)
1160
+ start_latent = int(adjusted_start_sec * self.sample_rate // 1920)
1161
+ end_latent = int(adjusted_end_sec * self.sample_rate // 1920)
1162
+
1163
+ # Clamp to valid range
1164
+ start_latent = max(0, min(start_latent, max_latent_length - 1))
1165
+ end_latent = max(start_latent + 1, min(end_latent, max_latent_length))
1166
+ # Create mask: False = keep original, True = generate new
1167
+ mask = torch.zeros(max_latent_length, dtype=torch.bool, device=self.device)
1168
+ mask[start_latent:end_latent] = True
1169
+ chunk_masks.append(mask)
1170
+ spans.append(("repainting", start_latent, end_latent))
1171
+ # Store repainting range for later use
1172
+ repainting_ranges[i] = (start_latent, end_latent)
1173
+ has_repainting = True
1174
+ is_covers.append(False) # Repainting is not cover task
1175
+ else:
1176
+ # Full generation (no valid repainting range)
1177
+ chunk_masks.append(torch.ones(max_latent_length, dtype=torch.bool, device=self.device))
1178
+ spans.append(("full", 0, max_latent_length))
1179
+ # Determine task type from instruction, not from target_wavs
1180
+ # Only cover task should have is_cover=True
1181
+ instruction_i = instructions[i] if instructions and i < len(instructions) else ""
1182
+ instruction_lower = instruction_i.lower()
1183
+ # Cover task instruction: "Generate audio semantic tokens based on the given conditions:"
1184
+ is_cover = ("generate audio semantic tokens" in instruction_lower and
1185
+ "based on the given conditions" in instruction_lower) or has_code_hint
1186
+ is_covers.append(is_cover)
1187
+ else:
1188
+ # Full generation (no repainting parameters)
1189
+ chunk_masks.append(torch.ones(max_latent_length, dtype=torch.bool, device=self.device))
1190
+ spans.append(("full", 0, max_latent_length))
1191
+ # Determine task type from instruction, not from target_wavs
1192
+ # Only cover task should have is_cover=True
1193
+ instruction_i = instructions[i] if instructions and i < len(instructions) else ""
1194
+ instruction_lower = instruction_i.lower()
1195
+ # Cover task instruction: "Generate audio semantic tokens based on the given conditions:"
1196
+ is_cover = ("generate audio semantic tokens" in instruction_lower and
1197
+ "based on the given conditions" in instruction_lower) or has_code_hint
1198
+ is_covers.append(is_cover)
1199
+
1200
+ chunk_masks = torch.stack(chunk_masks)
1201
+ is_covers = torch.BoolTensor(is_covers).to(self.device)
1202
+
1203
+ # Create src_latents based on task type
1204
+ # For cover/extract/complete/lego/repaint tasks: src_latents = target_latents.clone() (if target_wavs provided)
1205
+ # For text2music task: src_latents = silence_latent (if no target_wavs or silence)
1206
+ # For repaint task: additionally replace inpainting region with silence_latent
1207
+ src_latents_list = []
1208
+ silence_latent_tiled = self.silence_latent[0, :max_latent_length, :]
1209
+ for i in range(batch_size):
1210
+ # Check if target_wavs is provided and not silent (for extract/complete/lego/cover/repaint tasks)
1211
+ has_code_hint = audio_code_hints[i] is not None
1212
+ has_target_audio = has_code_hint or (target_wavs is not None and target_wavs[i].abs().sum() > 1e-6)
1213
+
1214
+ if has_target_audio:
1215
+ # For tasks that use input audio (cover/extract/complete/lego/repaint)
1216
+ # Check if this item has repainting
1217
+ item_has_repainting = (i in repainting_ranges)
1218
+
1219
+ if item_has_repainting:
1220
+ # Repaint task: src_latents = target_latents with inpainting region replaced by silence_latent
1221
+ # 1. Clone target_latents (encoded from src audio, preserving original audio)
1222
+ src_latent = target_latents[i].clone()
1223
+ # 2. Replace inpainting region with silence_latent
1224
+ start_latent, end_latent = repainting_ranges[i]
1225
+ src_latent[start_latent:end_latent] = silence_latent_tiled[start_latent:end_latent]
1226
+ src_latents_list.append(src_latent)
1227
+ else:
1228
+ # Cover/extract/complete/lego tasks: src_latents = target_latents.clone()
1229
+ # All these tasks need to base on input audio
1230
+ src_latents_list.append(target_latents[i].clone())
1231
+ else:
1232
+ # Text2music task: src_latents = silence_latent (no input audio)
1233
+ # Use silence_latent for the full length
1234
+ src_latents_list.append(silence_latent_tiled.clone())
1235
+
1236
+ src_latents = torch.stack(src_latents_list)
1237
+
1238
+ # Process audio_code_hints to generate precomputed_lm_hints_25Hz
1239
+ precomputed_lm_hints_25Hz_list = []
1240
+ for i in range(batch_size):
1241
+ if audio_code_hints[i] is not None:
1242
+ # Decode audio codes to 25Hz latents
1243
+ hints = self._decode_audio_codes_to_latents(audio_code_hints[i])
1244
+ if hints is not None:
1245
+ # Pad or crop to match max_latent_length
1246
+ if hints.shape[1] < max_latent_length:
1247
+ pad_length = max_latent_length - hints.shape[1]
1248
+ hints = torch.cat([
1249
+ hints,
1250
+ self.silence_latent[0, :pad_length, :]
1251
+ ], dim=1)
1252
+ elif hints.shape[1] > max_latent_length:
1253
+ hints = hints[:, :max_latent_length, :]
1254
+ precomputed_lm_hints_25Hz_list.append(hints[0]) # Remove batch dimension
1255
+ else:
1256
+ precomputed_lm_hints_25Hz_list.append(None)
1257
+ else:
1258
+ precomputed_lm_hints_25Hz_list.append(None)
1259
+
1260
+ # Stack precomputed hints if any exist, otherwise set to None
1261
+ if any(h is not None for h in precomputed_lm_hints_25Hz_list):
1262
+ # For items without hints, use silence_latent as placeholder
1263
+ precomputed_lm_hints_25Hz = torch.stack([
1264
+ h if h is not None else silence_latent_tiled
1265
+ for h in precomputed_lm_hints_25Hz_list
1266
+ ])
1267
+ else:
1268
+ precomputed_lm_hints_25Hz = None
1269
+
1270
+ # Format text_inputs
1271
+ text_inputs = []
1272
+ text_token_idss = []
1273
+ text_attention_masks = []
1274
+ lyric_token_idss = []
1275
+ lyric_attention_masks = []
1276
+
1277
+ for i in range(batch_size):
1278
+ # Use custom instruction for this batch item
1279
+ instruction = instructions[i] if i < len(instructions) else "Fill the audio semantic mask based on the given conditions:"
1280
+ # Ensure instruction ends with ":"
1281
+ if not instruction.endswith(":"):
1282
+ instruction = instruction + ":"
1283
+
1284
+ # Format text prompt with custom instruction
1285
+ text_prompt = SFT_GEN_PROMPT.format(instruction, captions[i], parsed_metas[i])
1286
+
1287
+ # Tokenize text
1288
+ text_inputs_dict = self.text_tokenizer(
1289
+ text_prompt,
1290
+ padding="longest",
1291
+ truncation=True,
1292
+ max_length=256,
1293
+ return_tensors="pt",
1294
+ )
1295
+ text_token_ids = text_inputs_dict.input_ids[0]
1296
+ text_attention_mask = text_inputs_dict.attention_mask[0].bool()
1297
+
1298
+ # Format and tokenize lyrics
1299
+ lyrics_text = f"# Languages\n{vocal_languages[i]}\n\n# Lyric\n{lyrics[i]}<|endoftext|>"
1300
+ lyrics_inputs_dict = self.text_tokenizer(
1301
+ lyrics_text,
1302
+ padding="longest",
1303
+ truncation=True,
1304
+ max_length=2048,
1305
+ return_tensors="pt",
1306
+ )
1307
+ lyric_token_ids = lyrics_inputs_dict.input_ids[0]
1308
+ lyric_attention_mask = lyrics_inputs_dict.attention_mask[0].bool()
1309
+
1310
+ # Build full text input
1311
+ text_input = text_prompt + "\n\n" + lyrics_text
1312
+
1313
+ text_inputs.append(text_input)
1314
+ text_token_idss.append(text_token_ids)
1315
+ text_attention_masks.append(text_attention_mask)
1316
+ lyric_token_idss.append(lyric_token_ids)
1317
+ lyric_attention_masks.append(lyric_attention_mask)
1318
+
1319
+ # Pad tokenized sequences
1320
+ max_text_length = max(len(seq) for seq in text_token_idss)
1321
+ padded_text_token_idss = torch.stack([
1322
+ torch.nn.functional.pad(
1323
+ seq, (0, max_text_length - len(seq)), 'constant',
1324
+ self.text_tokenizer.pad_token_id
1325
+ )
1326
+ for seq in text_token_idss
1327
+ ])
1328
+
1329
+ padded_text_attention_masks = torch.stack([
1330
+ torch.nn.functional.pad(
1331
+ seq, (0, max_text_length - len(seq)), 'constant', 0
1332
+ )
1333
+ for seq in text_attention_masks
1334
+ ])
1335
+
1336
+ max_lyric_length = max(len(seq) for seq in lyric_token_idss)
1337
+ padded_lyric_token_idss = torch.stack([
1338
+ torch.nn.functional.pad(
1339
+ seq, (0, max_lyric_length - len(seq)), 'constant',
1340
+ self.text_tokenizer.pad_token_id
1341
+ )
1342
+ for seq in lyric_token_idss
1343
+ ])
1344
+
1345
+ padded_lyric_attention_masks = torch.stack([
1346
+ torch.nn.functional.pad(
1347
+ seq, (0, max_lyric_length - len(seq)), 'constant', 0
1348
+ )
1349
+ for seq in lyric_attention_masks
1350
+ ])
1351
+
1352
+ padded_non_cover_text_input_ids = None
1353
+ padded_non_cover_text_attention_masks = None
1354
+ if audio_cover_strength < 1.0 and is_covers is not None and is_covers.any():
1355
+ non_cover_text_input_ids = []
1356
+ non_cover_text_attention_masks = []
1357
+ for i in range(batch_size):
1358
+ # Use custom instruction for this batch item
1359
+ instruction = "Fill the audio semantic mask based on the given conditions:"
1360
+
1361
+ # Format text prompt with custom instruction
1362
+ text_prompt = SFT_GEN_PROMPT.format(instruction, captions[i], parsed_metas[i])
1363
+
1364
+ # Tokenize text
1365
+ text_inputs_dict = self.text_tokenizer(
1366
+ text_prompt,
1367
+ padding="longest",
1368
+ truncation=True,
1369
+ max_length=256,
1370
+ return_tensors="pt",
1371
+ )
1372
+ text_token_ids = text_inputs_dict.input_ids[0]
1373
+ non_cover_text_input_ids.append(text_token_ids)
1374
+ non_cover_text_attention_masks.append(text_attention_mask)
1375
+
1376
+ padded_non_cover_text_input_ids = torch.stack([
1377
+ torch.nn.functional.pad(
1378
+ seq, (0, max_text_length - len(seq)), 'constant',
1379
+ self.text_tokenizer.pad_token_id
1380
+ )
1381
+ for seq in non_cover_text_input_ids
1382
+ ])
1383
+ padded_non_cover_text_attention_masks = torch.stack([
1384
+ torch.nn.functional.pad(
1385
+ seq, (0, max_text_length - len(seq)), 'constant', 0
1386
+ )
1387
+ for seq in non_cover_text_attention_masks
1388
+ ])
1389
+
1390
+ # Prepare batch
1391
+ batch = {
1392
+ "keys": keys,
1393
+ "target_wavs": target_wavs.to(self.device),
1394
+ "refer_audioss": refer_audios,
1395
+ "wav_lengths": wav_lengths.to(self.device),
1396
+ "captions": captions,
1397
+ "lyrics": lyrics,
1398
+ "metas": parsed_metas,
1399
+ "vocal_languages": vocal_languages,
1400
+ "target_latents": target_latents,
1401
+ "src_latents": src_latents,
1402
+ "latent_masks": latent_masks,
1403
+ "chunk_masks": chunk_masks,
1404
+ "spans": spans,
1405
+ "text_inputs": text_inputs,
1406
+ "text_token_idss": padded_text_token_idss,
1407
+ "text_attention_masks": padded_text_attention_masks,
1408
+ "lyric_token_idss": padded_lyric_token_idss,
1409
+ "lyric_attention_masks": padded_lyric_attention_masks,
1410
+ "is_covers": is_covers,
1411
+ "precomputed_lm_hints_25Hz": precomputed_lm_hints_25Hz,
1412
+ "non_cover_text_input_ids": padded_non_cover_text_input_ids,
1413
+ "non_cover_text_attention_masks": padded_non_cover_text_attention_masks,
1414
+ }
1415
+ # to device
1416
+ for k, v in batch.items():
1417
+ if isinstance(v, torch.Tensor):
1418
+ batch[k] = v.to(self.device)
1419
+ if torch.is_floating_point(v):
1420
+ batch[k] = v.to(self.dtype)
1421
+ return batch
1422
+
1423
+ def infer_refer_latent(self, refer_audioss):
1424
+ refer_audio_order_mask = []
1425
+ refer_audio_latents = []
1426
+ for batch_idx, refer_audios in enumerate(refer_audioss):
1427
+ if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
1428
+ refer_audio_latent = self.silence_latent[:, :750, :]
1429
+ refer_audio_latents.append(refer_audio_latent)
1430
+ refer_audio_order_mask.append(batch_idx)
1431
+ else:
1432
+ for refer_audio in refer_audios:
1433
+ refer_audio_latent = self.vae.encode(refer_audio.unsqueeze(0), chunked=False)
1434
+ refer_audio_latents.append(refer_audio_latent.transpose(1, 2))
1435
+ refer_audio_order_mask.append(batch_idx)
1436
+
1437
+ refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
1438
+ refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=self.device, dtype=torch.long)
1439
+ return refer_audio_latents, refer_audio_order_mask
1440
+
1441
+ def infer_text_embeddings(self, text_token_idss):
1442
+ with torch.no_grad():
1443
+ text_embeddings = self.text_encoder(input_ids=text_token_idss, lyric_attention_mask=None).last_hidden_state
1444
+ return text_embeddings
1445
+
1446
+ def infer_lyric_embeddings(self, lyric_token_ids):
1447
+ with torch.no_grad():
1448
+ lyric_embeddings = self.text_encoder.embed_tokens(lyric_token_ids)
1449
+ return lyric_embeddings
1450
+
1451
+ def preprocess_batch(self, batch):
1452
+
1453
+ # step 1: VAE encode latents, target_latents: N x T x d
1454
+ # target_latents: N x T x d
1455
+ target_latents = batch["target_latents"]
1456
+ src_latents = batch["src_latents"]
1457
+ attention_mask = batch["latent_masks"]
1458
+ audio_codes = batch.get("audio_codes", None)
1459
+ audio_attention_mask = attention_mask
1460
+
1461
+ dtype = target_latents.dtype
1462
+ bs = target_latents.shape[0]
1463
+ device = target_latents.device
1464
+
1465
+ # step 2: refer_audio timbre
1466
+ keys = batch["keys"]
1467
+ refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask = self.infer_refer_latent(batch["refer_audioss"])
1468
+ if refer_audio_acoustic_hidden_states_packed.dtype != dtype:
1469
+ refer_audio_acoustic_hidden_states_packed = refer_audio_acoustic_hidden_states_packed.to(dtype)
1470
+
1471
+ # step 4: chunk mask, N x T x d
1472
+ chunk_mask = batch["chunk_masks"]
1473
+ chunk_mask = chunk_mask.to(device).unsqueeze(-1).repeat(1, 1, target_latents.shape[2])
1474
+
1475
+ spans = batch["spans"]
1476
+
1477
+ text_token_idss = batch["text_token_idss"]
1478
+ text_attention_mask = batch["text_attention_masks"]
1479
+ lyric_token_idss = batch["lyric_token_idss"]
1480
+ lyric_attention_mask = batch["lyric_attention_masks"]
1481
+ text_inputs = batch["text_inputs"]
1482
+
1483
+ text_hidden_states = self.infer_text_embeddings(text_token_idss)
1484
+ lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss)
1485
+
1486
+ is_covers = batch["is_covers"]
1487
+
1488
+ # Get precomputed hints from batch if available
1489
+ precomputed_lm_hints_25Hz = batch.get("precomputed_lm_hints_25Hz", None)
1490
+
1491
+ # Get non-cover text input ids and attention masks from batch if available
1492
+ non_cover_text_input_ids = batch.get("non_cover_text_input_ids", None)
1493
+ non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None)
1494
+ non_cover_text_hidden_states = None
1495
+ if non_cover_text_input_ids is not None:
1496
+ non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids)
1497
+
1498
+ return (
1499
+ keys,
1500
+ text_inputs,
1501
+ src_latents,
1502
+ target_latents,
1503
+ # model inputs
1504
+ text_hidden_states,
1505
+ text_attention_mask,
1506
+ lyric_hidden_states,
1507
+ lyric_attention_mask,
1508
+ audio_attention_mask,
1509
+ refer_audio_acoustic_hidden_states_packed,
1510
+ refer_audio_order_mask,
1511
+ chunk_mask,
1512
+ spans,
1513
+ is_covers,
1514
+ audio_codes,
1515
+ lyric_token_idss,
1516
+ precomputed_lm_hints_25Hz,
1517
+ non_cover_text_hidden_states,
1518
+ non_cover_text_attention_masks,
1519
+ )
1520
+
1521
+ @torch.no_grad()
1522
+ def service_generate(
1523
+ self,
1524
+ captions: Union[str, List[str]],
1525
+ lyrics: Union[str, List[str]],
1526
+ keys: Optional[Union[str, List[str]]] = None,
1527
+ target_wavs: Optional[torch.Tensor] = None,
1528
+ refer_audios: Optional[List[List[torch.Tensor]]] = None,
1529
+ metas: Optional[Union[str, Dict[str, Any], List[Union[str, Dict[str, Any]]]]] = None,
1530
+ vocal_languages: Optional[Union[str, List[str]]] = None,
1531
+ infer_steps: int = 60,
1532
+ guidance_scale: float = 7.0,
1533
+ seed: Optional[Union[int, List[int]]] = None,
1534
+ return_intermediate: bool = False,
1535
+ repainting_start: Optional[Union[float, List[float]]] = None,
1536
+ repainting_end: Optional[Union[float, List[float]]] = None,
1537
+ instructions: Optional[Union[str, List[str]]] = None,
1538
+ audio_cover_strength: float = 1.0,
1539
+ use_adg: bool = False,
1540
+ cfg_interval_start: float = 0.0,
1541
+ cfg_interval_end: float = 1.0,
1542
+ audio_code_hints: Optional[Union[str, List[str]]] = None,
1543
+ infer_method: str = "ode",
1544
+ ) -> Dict[str, Any]:
1545
+
1546
+ """
1547
+ Generate music from text inputs.
1548
+
1549
+ Args:
1550
+ captions: Text caption(s) describing the music (optional, can be empty strings)
1551
+ lyrics: Lyric text(s) (optional, can be empty strings)
1552
+ keys: Unique identifier(s) (optional)
1553
+ target_wavs: Target audio tensor(s) for conditioning (optional)
1554
+ refer_audios: Reference audio tensor(s) for style transfer (optional)
1555
+ metas: Metadata dict(s) or string(s) (optional)
1556
+ vocal_languages: Language code(s) for lyrics (optional, defaults to 'en')
1557
+ infer_steps: Number of inference steps (default: 60)
1558
+ guidance_scale: Guidance scale for generation (default: 7.0)
1559
+ seed: Random seed (optional)
1560
+ return_intermediate: Whether to return intermediate results (default: False)
1561
+ repainting_start: Start time(s) for repainting region in seconds (optional)
1562
+ repainting_end: End time(s) for repainting region in seconds (optional)
1563
+ instructions: Instruction text(s) for generation (optional)
1564
+ audio_cover_strength: Strength of audio cover mode (default: 1.0)
1565
+ use_adg: Whether to use ADG (Adaptive Diffusion Guidance) (default: False)
1566
+ cfg_interval_start: Start of CFG interval (0.0-1.0, default: 0.0)
1567
+ cfg_interval_end: End of CFG interval (0.0-1.0, default: 1.0)
1568
+
1569
+ Returns:
1570
+ Dictionary containing:
1571
+ - pred_wavs: Generated audio tensors
1572
+ - target_wavs: Input target audio (if provided)
1573
+ - vqvae_recon_wavs: VAE reconstruction of target
1574
+ - keys: Identifiers used
1575
+ - text_inputs: Formatted text inputs
1576
+ - sr: Sample rate
1577
+ - spans: Generation spans
1578
+ - time_costs: Timing information
1579
+ - seed_num: Seed used
1580
+ """
1581
+ if self.config.is_turbo:
1582
+ # Limit inference steps to maximum 8
1583
+ if infer_steps > 8:
1584
+ logger.warning(f"dmd_gan version: infer_steps {infer_steps} exceeds maximum 8, clamping to 8")
1585
+ infer_steps = 8
1586
+ # CFG parameters are not adjustable for dmd_gan (they will be ignored)
1587
+ # Note: guidance_scale, cfg_interval_start, cfg_interval_end are still passed but may be ignored by the model
1588
+
1589
+ # Convert single inputs to lists
1590
+ if isinstance(captions, str):
1591
+ captions = [captions]
1592
+ if isinstance(lyrics, str):
1593
+ lyrics = [lyrics]
1594
+ if isinstance(keys, str):
1595
+ keys = [keys]
1596
+ if isinstance(vocal_languages, str):
1597
+ vocal_languages = [vocal_languages]
1598
+ if isinstance(metas, (str, dict)):
1599
+ metas = [metas]
1600
+
1601
+ # Convert repainting parameters to lists
1602
+ if isinstance(repainting_start, (int, float)):
1603
+ repainting_start = [repainting_start]
1604
+ if isinstance(repainting_end, (int, float)):
1605
+ repainting_end = [repainting_end]
1606
+
1607
+ # Convert instructions to list
1608
+ if isinstance(instructions, str):
1609
+ instructions = [instructions]
1610
+ elif instructions is None:
1611
+ instructions = None
1612
+
1613
+ # Convert audio_code_hints to list
1614
+ if isinstance(audio_code_hints, str):
1615
+ audio_code_hints = [audio_code_hints]
1616
+ elif audio_code_hints is None:
1617
+ audio_code_hints = None
1618
+
1619
+ # Get batch size from captions
1620
+ batch_size = len(captions)
1621
+
1622
+ # Ensure audio_code_hints matches batch size
1623
+ if audio_code_hints is not None:
1624
+ if len(audio_code_hints) != batch_size:
1625
+ if len(audio_code_hints) == 1:
1626
+ audio_code_hints = audio_code_hints * batch_size
1627
+ else:
1628
+ audio_code_hints = audio_code_hints[:batch_size]
1629
+ while len(audio_code_hints) < batch_size:
1630
+ audio_code_hints.append(None)
1631
+
1632
+ # Convert seed to list format
1633
+ if seed is None:
1634
+ seed_list = None
1635
+ elif isinstance(seed, list):
1636
+ seed_list = seed
1637
+ # Ensure we have enough seeds for batch size
1638
+ if len(seed_list) < batch_size:
1639
+ # Pad with last seed or random seeds
1640
+ import random
1641
+ while len(seed_list) < batch_size:
1642
+ seed_list.append(random.randint(0, 2**32 - 1))
1643
+ elif len(seed_list) > batch_size:
1644
+ # Truncate to batch size
1645
+ seed_list = seed_list[:batch_size]
1646
+ else:
1647
+ # Single seed value - use for all batch items
1648
+ seed_list = [int(seed)] * batch_size
1649
+
1650
+ # Don't set global random seed here - each item will use its own seed
1651
+
1652
+ # Prepare batch
1653
+ batch = self._prepare_batch(
1654
+ captions=captions,
1655
+ lyrics=lyrics,
1656
+ keys=keys,
1657
+ target_wavs=target_wavs,
1658
+ refer_audios=refer_audios,
1659
+ metas=metas,
1660
+ vocal_languages=vocal_languages,
1661
+ repainting_start=repainting_start,
1662
+ repainting_end=repainting_end,
1663
+ instructions=instructions,
1664
+ audio_code_hints=audio_code_hints,
1665
+ )
1666
+
1667
+ processed_data = self.preprocess_batch(batch)
1668
+
1669
+ (
1670
+ keys,
1671
+ text_inputs,
1672
+ src_latents,
1673
+ target_latents,
1674
+ # model inputs
1675
+ text_hidden_states,
1676
+ text_attention_mask,
1677
+ lyric_hidden_states,
1678
+ lyric_attention_mask,
1679
+ audio_attention_mask,
1680
+ refer_audio_acoustic_hidden_states_packed,
1681
+ refer_audio_order_mask,
1682
+ chunk_mask,
1683
+ spans,
1684
+ is_covers,
1685
+ audio_codes,
1686
+ lyric_token_idss,
1687
+ precomputed_lm_hints_25Hz,
1688
+ non_cover_text_hidden_states,
1689
+ non_cover_text_attention_masks,
1690
+ ) = processed_data
1691
+
1692
+ # Set generation parameters
1693
+ # Use seed_list if available, otherwise generate a single seed
1694
+ if seed_list is not None:
1695
+ # Pass seed list to model (will be handled there)
1696
+ seed_param = seed_list
1697
+ else:
1698
+ seed_param = random.randint(0, 2**32 - 1)
1699
+
1700
+ generate_kwargs = {
1701
+ "text_hidden_states": text_hidden_states,
1702
+ "text_attention_mask": text_attention_mask,
1703
+ "lyric_hidden_states": lyric_hidden_states,
1704
+ "lyric_attention_mask": lyric_attention_mask,
1705
+ "refer_audio_acoustic_hidden_states_packed": refer_audio_acoustic_hidden_states_packed,
1706
+ "refer_audio_order_mask": refer_audio_order_mask,
1707
+ "src_latents": src_latents,
1708
+ "chunk_masks": chunk_mask,
1709
+ "is_covers": is_covers,
1710
+ "silence_latent": self.silence_latent,
1711
+ "seed": seed_param,
1712
+ "non_cover_text_hidden_states": non_cover_text_hidden_states,
1713
+ "non_cover_text_attention_masks": non_cover_text_attention_masks,
1714
+ "precomputed_lm_hints_25Hz": precomputed_lm_hints_25Hz,
1715
+ "audio_cover_strength": audio_cover_strength,
1716
+ "infer_method": infer_method,
1717
+ "infer_steps": infer_steps,
1718
+ "diffusion_guidance_sale": guidance_scale,
1719
+ "use_adg": use_adg,
1720
+ "cfg_interval_start": cfg_interval_start,
1721
+ "cfg_interval_end": cfg_interval_end,
1722
+ }
1723
+
1724
+ outputs = self.model.generate_audio(**generate_kwargs)
1725
+ return outputs
1726
+
1727
  def generate_music(
1728
  self,
1729
  captions: str,
 
1763
  """
1764
  if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
1765
  return None, None, [], "", "❌ Model not fully initialized. Please initialize all components first.", "-1", "", "", None, "", "", None
1766
+
1767
+ # Auto-detect task type based on audio_code_string
1768
+ # If audio_code_string is provided and not empty, use cover task
1769
+ # Otherwise, use text2music task (or keep current task_type if not text2music)
1770
+ if task_type == "text2music":
1771
+ if audio_code_string and str(audio_code_string).strip():
1772
+ # User has provided audio codes, switch to cover task
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1773
  task_type = "cover"
1774
+ # Update instruction for cover task
1775
  instruction = "Generate audio semantic tokens based on the given conditions:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1776
 
1777
+ print("[generate_music] Starting generation...")
1778
+ if progress:
1779
+ progress(0.05, desc="Preparing inputs...")
1780
+ print("[generate_music] Preparing inputs...")
1781
 
1782
+ # Caption and lyrics are optional - can be empty
1783
+ # Use provided batch_size or default
1784
+ actual_batch_size = batch_size if batch_size is not None else self.batch_size
1785
+ actual_batch_size = max(1, actual_batch_size) # Ensure at least 1
1786
+
1787
+ actual_seed_list, seed_value_for_ui = self.prepare_seeds(actual_batch_size, seed, use_random_seed)
1788
+
1789
+ # Convert special values to None
1790
+ if audio_duration is not None and audio_duration <= 0:
1791
+ audio_duration = None
1792
+ # if seed is not None and seed < 0:
1793
+ # seed = None
1794
+ if repainting_end is not None and repainting_end < 0:
1795
+ repainting_end = None
1796
 
1797
+ try:
1798
+ progress(0.1, desc="Preparing inputs...")
1799
+
1800
+ # 1. Process reference audio
1801
+ refer_audios = None
1802
+ if reference_audio is not None:
1803
+ processed_ref_audio = self.process_reference_audio(reference_audio)
1804
+ if processed_ref_audio is not None:
1805
+ # Convert to the format expected by the service: List[List[torch.Tensor]]
1806
+ # Each batch item has a list of reference audios
1807
+ refer_audios = [[processed_ref_audio] for _ in range(actual_batch_size)]
1808
+ else:
1809
+ refer_audios = [[torch.zeros(2, 30*self.sample_rate)] for _ in range(actual_batch_size)]
1810
 
1811
+ # 2. Process source audio
1812
+ processed_src_audio = None
1813
+ if src_audio is not None:
1814
+ processed_src_audio = self.process_src_audio(src_audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1815
 
1816
+ # 3. Prepare batch data
1817
+ captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch = self.prepare_batch_data(
1818
+ actual_batch_size,
1819
+ processed_src_audio,
1820
+ audio_duration,
1821
+ captions,
1822
+ lyrics,
1823
+ vocal_language,
1824
+ instruction,
1825
+ bpm,
1826
+ key_scale,
1827
+ time_signature
1828
+ )
 
 
 
 
 
 
 
 
 
 
1829
 
1830
+ is_repaint_task, is_lego_task, is_cover_task, can_use_repainting = self.determine_task_type(task_type, audio_code_string)
1831
+
1832
+ repainting_start_batch, repainting_end_batch, target_wavs_tensor = self.prepare_padding_info(
1833
+ actual_batch_size,
1834
+ processed_src_audio,
1835
+ audio_duration,
1836
+ repainting_start,
1837
+ repainting_end,
1838
+ is_repaint_task,
1839
+ is_lego_task,
1840
+ is_cover_task,
1841
+ can_use_repainting
1842
+ )
1843
 
1844
+ progress(0.3, desc=f"Generating music (batch size: {actual_batch_size})...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1845
 
1846
+ # Prepare audio_code_hints - use if audio_code_string is provided
1847
+ # This works for both text2music (auto-switched to cover) and cover tasks
1848
+ audio_code_hints_batch = None
1849
+ if audio_code_string and str(audio_code_string).strip():
1850
+ # Audio codes provided, use as hints (will trigger cover mode in inference service)
1851
+ audio_code_hints_batch = [audio_code_string] * actual_batch_size
1852
+
1853
+ should_return_intermediate = (task_type == "text2music")
1854
+ outputs = self.service_generate(
1855
+ captions=captions_batch,
1856
+ lyrics=lyrics_batch,
1857
+ metas=metas_batch, # Pass as dict, service will convert to string
1858
+ vocal_languages=vocal_languages_batch,
1859
+ refer_audios=refer_audios, # Already in List[List[torch.Tensor]] format
1860
+ target_wavs=target_wavs_tensor, # Shape: [batch_size, 2, frames]
1861
+ infer_steps=inference_steps,
1862
+ guidance_scale=guidance_scale,
1863
+ seed=actual_seed_list, # Pass list of seeds, one per batch item
1864
+ repainting_start=repainting_start_batch,
1865
+ repainting_end=repainting_end_batch,
1866
+ instructions=instructions_batch, # Pass instructions to service
1867
+ audio_cover_strength=audio_cover_strength, # Pass audio cover strength
1868
+ use_adg=use_adg, # Pass use_adg parameter
1869
+ cfg_interval_start=cfg_interval_start, # Pass CFG interval start
1870
+ cfg_interval_end=cfg_interval_end, # Pass CFG interval end
1871
+ audio_code_hints=audio_code_hints_batch, # Pass audio code hints as list
1872
+ return_intermediate=should_return_intermediate
1873
+ )
 
 
 
 
 
 
1874
 
1875
  print("[generate_music] Model generation completed. Decoding latents...")
1876
  pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
 
1902
 
1903
  saved_files = []
1904
  for i in range(actual_batch_size):
1905
+ audio_file = os.path.join(self.temp_dir, f"generated_{i}_{actual_seed_list[i]}.{audio_format_lower}")
1906
  # Convert to numpy: [channels, samples] -> [samples, channels]
1907
  audio_np = pred_wavs[i].cpu().float().numpy().T
1908
  sf.write(audio_file, audio_np, self.sample_rate)
 
1926
 
1927
  generation_info = f"""**🎵 Generation Complete**
1928
 
1929
+ **Seeds:** {seed_value_for_ui}
1930
+ **Steps:** {inference_steps}
1931
+ **Files:** {len(saved_files)} audio(s){time_costs_str}"""
 
1932
  status_message = f"✅ Generation completed successfully!"
1933
  print(f"[generate_music] Done! Generated {len(saved_files)} audio files.")
1934
 
 
1954
  align_text_2,
1955
  align_plot_2,
1956
  )
1957
+
1958
  except Exception as e:
1959
+ error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
1960
+ return None, None, [], "", error_msg, seed_value_for_ui, "", "", None, "", "", None
1961