Spaces:
Running
on
A100
Running
on
A100
Gong Junmin
commited on
Commit
·
12f9f66
1
Parent(s):
11a221a
test ok
Browse files- acestep/handler.py +1293 -432
acestep/handler.py
CHANGED
|
@@ -12,16 +12,33 @@ import random
|
|
| 12 |
from typing import Optional, Dict, Any, Tuple, List, Union
|
| 13 |
|
| 14 |
import torch
|
|
|
|
| 15 |
import matplotlib.pyplot as plt
|
| 16 |
import numpy as np
|
| 17 |
import scipy.io.wavfile as wavfile
|
|
|
|
| 18 |
import soundfile as sf
|
| 19 |
import time
|
|
|
|
|
|
|
| 20 |
|
| 21 |
from transformers import AutoTokenizer, AutoModel
|
| 22 |
from diffusers.models import AutoencoderOobleck
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
class AceStepHandler:
|
| 26 |
"""ACE-Step Business Logic Handler"""
|
| 27 |
|
|
@@ -166,18 +183,17 @@ class AceStepHandler:
|
|
| 166 |
if os.path.exists(acestep_v15_checkpoint_path):
|
| 167 |
# Determine attention implementation
|
| 168 |
attn_implementation = "flash_attention_2" if use_flash_attention and self.is_flash_attention_available() else "eager"
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
)
|
| 174 |
self.config = self.model.config
|
| 175 |
# Move model to device and set dtype
|
| 176 |
self.model = self.model.to(device).to(self.dtype)
|
| 177 |
self.model.eval()
|
| 178 |
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
|
| 179 |
if os.path.exists(silence_latent_path):
|
| 180 |
-
self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
|
| 181 |
self.silence_latent = self.silence_latent.to(device).to(self.dtype)
|
| 182 |
else:
|
| 183 |
raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
|
|
@@ -395,50 +411,7 @@ class AceStepHandler:
|
|
| 395 |
metadata[key] = value
|
| 396 |
|
| 397 |
return metadata, audio_codes
|
| 398 |
-
|
| 399 |
-
def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]:
|
| 400 |
-
"""Process reference audio"""
|
| 401 |
-
if audio_file is None:
|
| 402 |
-
return None
|
| 403 |
-
|
| 404 |
-
try:
|
| 405 |
-
# Load audio using soundfile
|
| 406 |
-
audio_np, sr = sf.read(audio_file, dtype='float32')
|
| 407 |
-
# Convert to torch: [samples, channels] or [samples] -> [channels, samples]
|
| 408 |
-
if audio_np.ndim == 1:
|
| 409 |
-
audio = torch.from_numpy(audio_np).unsqueeze(0)
|
| 410 |
-
else:
|
| 411 |
-
audio = torch.from_numpy(audio_np.T)
|
| 412 |
-
|
| 413 |
-
if audio.shape[0] == 1:
|
| 414 |
-
audio = torch.cat([audio, audio], dim=0)
|
| 415 |
-
|
| 416 |
-
audio = audio[:2]
|
| 417 |
-
|
| 418 |
-
# Resample if needed
|
| 419 |
-
if sr != 48000:
|
| 420 |
-
import torch.nn.functional as F
|
| 421 |
-
# Simple resampling using interpolate
|
| 422 |
-
ratio = 48000 / sr
|
| 423 |
-
new_length = int(audio.shape[-1] * ratio)
|
| 424 |
-
audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0)
|
| 425 |
-
|
| 426 |
-
audio = torch.clamp(audio, -1.0, 1.0)
|
| 427 |
-
|
| 428 |
-
target_frames = 30 * 48000
|
| 429 |
-
if audio.shape[-1] > target_frames:
|
| 430 |
-
start_frame = (audio.shape[-1] - target_frames) // 2
|
| 431 |
-
audio = audio[:, start_frame:start_frame + target_frames]
|
| 432 |
-
elif audio.shape[-1] < target_frames:
|
| 433 |
-
audio = torch.nn.functional.pad(
|
| 434 |
-
audio, (0, target_frames - audio.shape[-1]), 'constant', 0
|
| 435 |
-
)
|
| 436 |
-
|
| 437 |
-
return audio
|
| 438 |
-
except Exception as e:
|
| 439 |
-
print(f"Error processing reference audio: {e}")
|
| 440 |
-
return None
|
| 441 |
-
|
| 442 |
def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
|
| 443 |
"""Process target audio"""
|
| 444 |
if audio_file is None:
|
|
@@ -541,18 +514,32 @@ class AceStepHandler:
|
|
| 541 |
)
|
| 542 |
|
| 543 |
def _parse_metas(self, metas: List[Union[str, Dict[str, Any]]]) -> List[str]:
|
| 544 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
parsed_metas = []
|
| 546 |
for meta in metas:
|
| 547 |
if meta is None:
|
|
|
|
| 548 |
parsed_meta = self._create_default_meta()
|
| 549 |
elif isinstance(meta, str):
|
|
|
|
| 550 |
parsed_meta = meta
|
| 551 |
elif isinstance(meta, dict):
|
|
|
|
| 552 |
parsed_meta = self._dict_to_meta_string(meta)
|
| 553 |
else:
|
|
|
|
| 554 |
parsed_meta = self._create_default_meta()
|
|
|
|
| 555 |
parsed_metas.append(parsed_meta)
|
|
|
|
| 556 |
return parsed_metas
|
| 557 |
|
| 558 |
def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
@@ -586,12 +573,1157 @@ class AceStepHandler:
|
|
| 586 |
return text_hidden_states, text_attention_mask
|
| 587 |
|
| 588 |
def extract_caption_from_sft_format(self, caption: str) -> str:
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
def generate_music(
|
| 596 |
self,
|
| 597 |
captions: str,
|
|
@@ -631,384 +1763,114 @@ class AceStepHandler:
|
|
| 631 |
"""
|
| 632 |
if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
|
| 633 |
return None, None, [], "", "❌ Model not fully initialized. Please initialize all components first.", "-1", "", "", None, "", "", None
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
# Determine actual batch size
|
| 642 |
-
actual_batch_size = batch_size if batch_size is not None else self.batch_size
|
| 643 |
-
actual_batch_size = max(1, min(actual_batch_size, 8)) # Limit to 8 for memory safety
|
| 644 |
-
|
| 645 |
-
# Process seeds
|
| 646 |
-
if use_random_seed:
|
| 647 |
-
seed_list = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
|
| 648 |
-
else:
|
| 649 |
-
# Parse seed input
|
| 650 |
-
if isinstance(seed, str):
|
| 651 |
-
seed_parts = [s.strip() for s in seed.split(",")]
|
| 652 |
-
seed_list = [int(float(s)) if s != "-1" and s else random.randint(0, 2**32 - 1) for s in seed_parts[:actual_batch_size]]
|
| 653 |
-
elif isinstance(seed, (int, float)) and seed >= 0:
|
| 654 |
-
seed_list = [int(seed)] * actual_batch_size
|
| 655 |
-
else:
|
| 656 |
-
seed_list = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
|
| 657 |
-
|
| 658 |
-
# Pad if needed
|
| 659 |
-
while len(seed_list) < actual_batch_size:
|
| 660 |
-
seed_list.append(random.randint(0, 2**32 - 1))
|
| 661 |
-
|
| 662 |
-
seed_value_for_ui = ", ".join(str(s) for s in seed_list)
|
| 663 |
-
|
| 664 |
-
# Process audio inputs
|
| 665 |
-
processed_ref_audio = self.process_reference_audio(reference_audio) if reference_audio else None
|
| 666 |
-
processed_src_audio = self.process_target_audio(src_audio) if src_audio else None
|
| 667 |
-
|
| 668 |
-
# Extract caption
|
| 669 |
-
pure_caption = self.extract_caption_from_sft_format(captions)
|
| 670 |
-
|
| 671 |
-
# Determine task type and update instruction if needed
|
| 672 |
-
if task_type == "text2music" and audio_code_string and str(audio_code_string).strip():
|
| 673 |
task_type = "cover"
|
|
|
|
| 674 |
instruction = "Generate audio semantic tokens based on the given conditions:"
|
| 675 |
-
|
| 676 |
-
# Build metadata
|
| 677 |
-
metadata_dict = {
|
| 678 |
-
"bpm": bpm if bpm else "N/A",
|
| 679 |
-
"keyscale": key_scale if key_scale else "N/A",
|
| 680 |
-
"timesignature": time_signature if time_signature else "N/A",
|
| 681 |
-
}
|
| 682 |
-
|
| 683 |
-
# Calculate duration
|
| 684 |
-
if processed_src_audio is not None:
|
| 685 |
-
calculated_duration = processed_src_audio.shape[-1] / self.sample_rate
|
| 686 |
-
elif audio_duration is not None and audio_duration > 0:
|
| 687 |
-
calculated_duration = audio_duration
|
| 688 |
-
else:
|
| 689 |
-
calculated_duration = 30.0 # Default 30 seconds
|
| 690 |
-
|
| 691 |
-
metadata_dict["duration"] = f"{int(calculated_duration)} seconds"
|
| 692 |
-
|
| 693 |
-
if progress:
|
| 694 |
-
progress(0.1, desc="Processing audio inputs...")
|
| 695 |
-
print("[generate_music] Processing audio inputs...")
|
| 696 |
-
|
| 697 |
-
# Prepare batch data
|
| 698 |
-
captions_batch = [pure_caption] * actual_batch_size
|
| 699 |
-
lyrics_batch = [lyrics] * actual_batch_size
|
| 700 |
-
vocal_languages_batch = [vocal_language] * actual_batch_size
|
| 701 |
-
instructions_batch = [instruction] * actual_batch_size
|
| 702 |
-
metas_batch = [metadata_dict.copy()] * actual_batch_size
|
| 703 |
-
audio_code_hints_batch = [audio_code_string if audio_code_string else None] * actual_batch_size
|
| 704 |
-
|
| 705 |
-
# Process reference audios
|
| 706 |
-
if processed_ref_audio is not None:
|
| 707 |
-
refer_audios = [[processed_ref_audio] for _ in range(actual_batch_size)]
|
| 708 |
-
else:
|
| 709 |
-
# Create silence as fallback
|
| 710 |
-
silence_frames = 30 * self.sample_rate
|
| 711 |
-
silence = torch.zeros(2, silence_frames)
|
| 712 |
-
refer_audios = [[silence] for _ in range(actual_batch_size)]
|
| 713 |
-
|
| 714 |
-
# Process target wavs (src_audio)
|
| 715 |
-
if processed_src_audio is not None:
|
| 716 |
-
target_wavs_list = [processed_src_audio.clone() for _ in range(actual_batch_size)]
|
| 717 |
-
else:
|
| 718 |
-
# Create silence based on duration
|
| 719 |
-
target_frames = int(calculated_duration * self.sample_rate)
|
| 720 |
-
silence = torch.zeros(2, target_frames)
|
| 721 |
-
target_wavs_list = [silence for _ in range(actual_batch_size)]
|
| 722 |
-
|
| 723 |
-
# Pad target_wavs to consistent length
|
| 724 |
-
max_target_frames = max(wav.shape[-1] for wav in target_wavs_list)
|
| 725 |
-
target_wavs = torch.stack([
|
| 726 |
-
torch.nn.functional.pad(wav, (0, max_target_frames - wav.shape[-1]), 'constant', 0)
|
| 727 |
-
for wav in target_wavs_list
|
| 728 |
-
])
|
| 729 |
-
|
| 730 |
-
if progress:
|
| 731 |
-
progress(0.2, desc="Encoding audio to latents...")
|
| 732 |
-
print("[generate_music] Encoding audio to latents...")
|
| 733 |
-
|
| 734 |
-
# Encode target_wavs to latents using VAE
|
| 735 |
-
target_latents_list = []
|
| 736 |
-
latent_lengths = []
|
| 737 |
-
|
| 738 |
-
with torch.no_grad():
|
| 739 |
-
for i in range(actual_batch_size):
|
| 740 |
-
# Check if audio codes are provided
|
| 741 |
-
code_hint = audio_code_hints_batch[i]
|
| 742 |
-
if code_hint:
|
| 743 |
-
decoded_latents = self._decode_audio_codes_to_latents(code_hint)
|
| 744 |
-
if decoded_latents is not None:
|
| 745 |
-
decoded_latents = decoded_latents.squeeze(0) # Remove batch dim
|
| 746 |
-
target_latents_list.append(decoded_latents)
|
| 747 |
-
latent_lengths.append(decoded_latents.shape[0])
|
| 748 |
-
continue
|
| 749 |
-
|
| 750 |
-
# If no src_audio provided, use silence_latent directly (skip VAE)
|
| 751 |
-
if processed_src_audio is None:
|
| 752 |
-
# Calculate required latent length based on duration
|
| 753 |
-
# VAE downsample ratio is 1920 (2*4*4*6*10), so latent rate is 48000/1920 = 25Hz
|
| 754 |
-
latent_length = int(calculated_duration * 25) # 25Hz latent rate
|
| 755 |
-
latent_length = max(128, latent_length) # Minimum 128
|
| 756 |
-
|
| 757 |
-
# Tile silence_latent to required length
|
| 758 |
-
if self.silence_latent.shape[0] >= latent_length:
|
| 759 |
-
target_latent = self.silence_latent[:latent_length].to(self.device).to(self.dtype)
|
| 760 |
-
else:
|
| 761 |
-
repeat_times = (latent_length // self.silence_latent.shape[0]) + 1
|
| 762 |
-
target_latent = self.silence_latent.repeat(repeat_times, 1)[:latent_length].to(self.device).to(self.dtype)
|
| 763 |
-
target_latents_list.append(target_latent)
|
| 764 |
-
latent_lengths.append(target_latent.shape[0])
|
| 765 |
-
continue
|
| 766 |
-
|
| 767 |
-
# Encode from audio using VAE
|
| 768 |
-
current_wav = target_wavs[i].unsqueeze(0).to(self.device).to(self.dtype)
|
| 769 |
-
target_latent = self.vae.encode(current_wav)
|
| 770 |
-
target_latent = target_latent.squeeze(0).transpose(0, 1) # [latent_length, latent_dim]
|
| 771 |
-
target_latents_list.append(target_latent)
|
| 772 |
-
latent_lengths.append(target_latent.shape[0])
|
| 773 |
-
|
| 774 |
-
# Pad latents to same length
|
| 775 |
-
max_latent_length = max(latent_lengths)
|
| 776 |
-
max_latent_length = max(128, max_latent_length) # Minimum 128
|
| 777 |
-
|
| 778 |
-
padded_latents = []
|
| 779 |
-
for i, latent in enumerate(target_latents_list):
|
| 780 |
-
if latent.shape[0] < max_latent_length:
|
| 781 |
-
pad_length = max_latent_length - latent.shape[0]
|
| 782 |
-
# Tile silence_latent to pad_length (silence_latent is [L, C])
|
| 783 |
-
if self.silence_latent.shape[0] >= pad_length:
|
| 784 |
-
pad_latent = self.silence_latent[:pad_length]
|
| 785 |
-
else:
|
| 786 |
-
repeat_times = (pad_length // self.silence_latent.shape[0]) + 1
|
| 787 |
-
pad_latent = self.silence_latent.repeat(repeat_times, 1)[:pad_length]
|
| 788 |
-
latent = torch.cat([latent, pad_latent.to(self.device).to(self.dtype)], dim=0)
|
| 789 |
-
padded_latents.append(latent)
|
| 790 |
-
|
| 791 |
-
target_latents = torch.stack(padded_latents).to(self.device).to(self.dtype)
|
| 792 |
-
latent_masks = torch.stack([
|
| 793 |
-
torch.cat([
|
| 794 |
-
torch.ones(l, dtype=torch.long, device=self.device),
|
| 795 |
-
torch.zeros(max_latent_length - l, dtype=torch.long, device=self.device)
|
| 796 |
-
])
|
| 797 |
-
for l in latent_lengths
|
| 798 |
-
])
|
| 799 |
-
|
| 800 |
-
if progress:
|
| 801 |
-
progress(0.3, desc="Preparing conditions...")
|
| 802 |
-
print("[generate_music] Preparing conditions...")
|
| 803 |
-
|
| 804 |
-
# Determine task type and create chunk masks
|
| 805 |
-
is_covers = []
|
| 806 |
-
chunk_masks = []
|
| 807 |
-
repainting_ranges = {}
|
| 808 |
-
|
| 809 |
-
for i in range(actual_batch_size):
|
| 810 |
-
has_code_hint = audio_code_hints_batch[i] is not None
|
| 811 |
-
has_repainting = (repainting_end is not None and repainting_end > repainting_start)
|
| 812 |
-
|
| 813 |
-
if has_repainting:
|
| 814 |
-
# Repainting mode
|
| 815 |
-
start_sec = max(0, repainting_start)
|
| 816 |
-
end_sec = repainting_end if repainting_end is not None else calculated_duration
|
| 817 |
-
|
| 818 |
-
start_latent = int(start_sec * self.sample_rate // 1920)
|
| 819 |
-
end_latent = int(end_sec * self.sample_rate // 1920)
|
| 820 |
-
start_latent = max(0, min(start_latent, max_latent_length - 1))
|
| 821 |
-
end_latent = max(start_latent + 1, min(end_latent, max_latent_length))
|
| 822 |
-
|
| 823 |
-
mask = torch.zeros(max_latent_length, dtype=torch.bool, device=self.device)
|
| 824 |
-
mask[start_latent:end_latent] = True
|
| 825 |
-
chunk_masks.append(mask)
|
| 826 |
-
repainting_ranges[i] = (start_latent, end_latent)
|
| 827 |
-
is_covers.append(False)
|
| 828 |
-
else:
|
| 829 |
-
# Full generation or cover
|
| 830 |
-
chunk_masks.append(torch.ones(max_latent_length, dtype=torch.bool, device=self.device))
|
| 831 |
-
# Check if cover task
|
| 832 |
-
instruction_lower = instructions_batch[i].lower()
|
| 833 |
-
is_cover = ("generate audio semantic tokens" in instruction_lower and
|
| 834 |
-
"based on the given conditions" in instruction_lower) or has_code_hint
|
| 835 |
-
is_covers.append(is_cover)
|
| 836 |
-
|
| 837 |
-
chunk_masks = torch.stack(chunk_masks).unsqueeze(-1).expand(-1, -1, 64) # [batch, length, 64]
|
| 838 |
-
is_covers = torch.tensor(is_covers, dtype=torch.bool, device=self.device)
|
| 839 |
-
|
| 840 |
-
# Create src_latents
|
| 841 |
-
# Tile silence_latent to max_latent_length (silence_latent is now [L, C])
|
| 842 |
-
if self.silence_latent.shape[0] >= max_latent_length:
|
| 843 |
-
silence_latent_tiled = self.silence_latent[:max_latent_length].to(self.device).to(self.dtype)
|
| 844 |
-
else:
|
| 845 |
-
repeat_times = (max_latent_length // self.silence_latent.shape[0]) + 1
|
| 846 |
-
silence_latent_tiled = self.silence_latent.repeat(repeat_times, 1)[:max_latent_length].to(self.device).to(self.dtype)
|
| 847 |
-
src_latents_list = []
|
| 848 |
-
|
| 849 |
-
for i in range(actual_batch_size):
|
| 850 |
-
has_target_audio = (target_wavs[i].abs().sum() > 1e-6) or (audio_code_hints_batch[i] is not None)
|
| 851 |
-
|
| 852 |
-
if has_target_audio:
|
| 853 |
-
if i in repainting_ranges:
|
| 854 |
-
# Repaint: replace inpainting region with silence
|
| 855 |
-
src_latent = target_latents[i].clone()
|
| 856 |
-
start_latent, end_latent = repainting_ranges[i]
|
| 857 |
-
src_latent[start_latent:end_latent] = silence_latent_tiled[start_latent:end_latent]
|
| 858 |
-
src_latents_list.append(src_latent)
|
| 859 |
-
else:
|
| 860 |
-
# Cover/extract/complete/lego: use target_latents
|
| 861 |
-
src_latents_list.append(target_latents[i].clone())
|
| 862 |
-
else:
|
| 863 |
-
# Text2music: use silence
|
| 864 |
-
src_latents_list.append(silence_latent_tiled.clone())
|
| 865 |
-
|
| 866 |
-
src_latents = torch.stack(src_latents_list) # [batch, length, channels]
|
| 867 |
-
|
| 868 |
-
if progress:
|
| 869 |
-
progress(0.4, desc="Tokenizing text inputs...")
|
| 870 |
-
print("[generate_music] Tokenizing text inputs...")
|
| 871 |
-
|
| 872 |
-
# Prepare text and lyric hidden states
|
| 873 |
-
SFT_GEN_PROMPT = """# Instruction
|
| 874 |
-
{}
|
| 875 |
|
| 876 |
-
|
| 877 |
-
|
|
|
|
|
|
|
| 878 |
|
| 879 |
-
#
|
| 880 |
-
|
| 881 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 882 |
|
| 883 |
-
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 887 |
|
| 888 |
-
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
if not inst.endswith(":"):
|
| 893 |
-
inst = inst + ":"
|
| 894 |
-
|
| 895 |
-
meta_str = self._dict_to_meta_string(metas_batch[i])
|
| 896 |
-
text_prompt = SFT_GEN_PROMPT.format(inst, captions_batch[i], meta_str)
|
| 897 |
-
|
| 898 |
-
# Tokenize and encode text
|
| 899 |
-
text_hidden, text_mask = self._get_text_hidden_states(text_prompt)
|
| 900 |
-
text_hidden_states_list.append(text_hidden.squeeze(0))
|
| 901 |
-
text_attention_masks_list.append(text_mask.squeeze(0))
|
| 902 |
-
|
| 903 |
-
# Format and tokenize lyrics
|
| 904 |
-
lyrics_text = f"# Languages\n{vocal_languages_batch[i]}\n\n# Lyric\n{lyrics_batch[i]}<|endoftext|>"
|
| 905 |
-
lyric_hidden, lyric_mask = self._get_text_hidden_states(lyrics_text)
|
| 906 |
-
lyric_hidden_states_list.append(lyric_hidden.squeeze(0))
|
| 907 |
-
lyric_attention_masks_list.append(lyric_mask.squeeze(0))
|
| 908 |
-
|
| 909 |
-
# Pad sequences
|
| 910 |
-
max_text_length = max(h.shape[0] for h in text_hidden_states_list)
|
| 911 |
-
max_lyric_length = max(h.shape[0] for h in lyric_hidden_states_list)
|
| 912 |
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
]).to(self.device).to(self.dtype)
|
| 927 |
-
|
| 928 |
-
lyric_attention_mask = torch.stack([
|
| 929 |
-
torch.nn.functional.pad(m, (0, max_lyric_length - m.shape[0]), 'constant', 0)
|
| 930 |
-
for m in lyric_attention_masks_list
|
| 931 |
-
]).to(self.device)
|
| 932 |
-
|
| 933 |
-
if progress:
|
| 934 |
-
progress(0.5, desc="Processing reference audio...")
|
| 935 |
-
print("[generate_music] Processing reference audio...")
|
| 936 |
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 943 |
|
| 944 |
-
|
| 945 |
-
for i, ref_audio_list in enumerate(refer_audios):
|
| 946 |
-
if ref_audio_list and len(ref_audio_list) > 0 and ref_audio_list[0].abs().sum() > 1e-6:
|
| 947 |
-
# Encode reference audio: [channels, samples] -> [1, latent_dim, T] -> [T, latent_dim]
|
| 948 |
-
ref_audio = ref_audio_list[0].unsqueeze(0).to(self.device).to(self.dtype)
|
| 949 |
-
ref_latent = self.vae.encode(ref_audio).latent_dist.sample() # [1, latent_dim, T]
|
| 950 |
-
ref_latent = ref_latent.squeeze(0).transpose(0, 1) # [T, latent_dim]
|
| 951 |
-
# Ensure dimension matches audio_acoustic_hidden_dim (64)
|
| 952 |
-
if ref_latent.shape[-1] != self.config.audio_acoustic_hidden_dim:
|
| 953 |
-
ref_latent = ref_latent[:, :self.config.audio_acoustic_hidden_dim]
|
| 954 |
-
# Pad or truncate to timbre_fix_frame
|
| 955 |
-
if ref_latent.shape[0] < timbre_fix_frame:
|
| 956 |
-
pad_length = timbre_fix_frame - ref_latent.shape[0]
|
| 957 |
-
padding = torch.zeros(pad_length, ref_latent.shape[1], device=self.device, dtype=self.dtype)
|
| 958 |
-
ref_latent = torch.cat([ref_latent, padding], dim=0)
|
| 959 |
-
else:
|
| 960 |
-
ref_latent = ref_latent[:timbre_fix_frame]
|
| 961 |
-
refer_audio_acoustic_hidden_states_packed_list.append(ref_latent)
|
| 962 |
-
refer_audio_order_mask_list.append(i)
|
| 963 |
-
else:
|
| 964 |
-
# Use silence_latent directly instead of running VAE
|
| 965 |
-
if self.silence_latent.shape[0] >= timbre_fix_frame:
|
| 966 |
-
silence_ref = self.silence_latent[:timbre_fix_frame, :self.config.audio_acoustic_hidden_dim]
|
| 967 |
-
else:
|
| 968 |
-
repeat_times = (timbre_fix_frame // self.silence_latent.shape[0]) + 1
|
| 969 |
-
silence_ref = self.silence_latent.repeat(repeat_times, 1)[:timbre_fix_frame, :self.config.audio_acoustic_hidden_dim]
|
| 970 |
-
refer_audio_acoustic_hidden_states_packed_list.append(silence_ref.to(self.device).to(self.dtype))
|
| 971 |
-
refer_audio_order_mask_list.append(i)
|
| 972 |
-
|
| 973 |
-
# Stack all reference audios: [N, timbre_fix_frame, audio_acoustic_hidden_dim]
|
| 974 |
-
refer_audio_acoustic_hidden_states_packed = torch.stack(refer_audio_acoustic_hidden_states_packed_list, dim=0).to(self.device).to(self.dtype)
|
| 975 |
-
# Order mask: [N] indicating which batch item each reference belongs to
|
| 976 |
-
refer_audio_order_mask = torch.tensor(refer_audio_order_mask_list, dtype=torch.long, device=self.device)
|
| 977 |
|
| 978 |
-
if
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
|
| 982 |
-
|
| 983 |
-
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
|
| 999 |
-
|
| 1000 |
-
|
| 1001 |
-
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
silence_latent=self.silence_latent.unsqueeze(0), # [1, L, C]
|
| 1007 |
-
seed=seed_list[0] if len(seed_list) > 0 else None,
|
| 1008 |
-
fix_nfe=inference_steps,
|
| 1009 |
-
infer_method="ode",
|
| 1010 |
-
use_cache=True,
|
| 1011 |
-
)
|
| 1012 |
|
| 1013 |
print("[generate_music] Model generation completed. Decoding latents...")
|
| 1014 |
pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
|
|
@@ -1040,7 +1902,7 @@ class AceStepHandler:
|
|
| 1040 |
|
| 1041 |
saved_files = []
|
| 1042 |
for i in range(actual_batch_size):
|
| 1043 |
-
audio_file = os.path.join(self.temp_dir, f"generated_{i}_{
|
| 1044 |
# Convert to numpy: [channels, samples] -> [samples, channels]
|
| 1045 |
audio_np = pred_wavs[i].cpu().float().numpy().T
|
| 1046 |
sf.write(audio_file, audio_np, self.sample_rate)
|
|
@@ -1064,10 +1926,9 @@ class AceStepHandler:
|
|
| 1064 |
|
| 1065 |
generation_info = f"""**🎵 Generation Complete**
|
| 1066 |
|
| 1067 |
-
**Seeds:** {seed_value_for_ui}
|
| 1068 |
-
**
|
| 1069 |
-
**
|
| 1070 |
-
**Files:** {len(saved_files)} audio(s){time_costs_str}"""
|
| 1071 |
status_message = f"✅ Generation completed successfully!"
|
| 1072 |
print(f"[generate_music] Done! Generated {len(saved_files)} audio files.")
|
| 1073 |
|
|
@@ -1093,8 +1954,8 @@ class AceStepHandler:
|
|
| 1093 |
align_text_2,
|
| 1094 |
align_plot_2,
|
| 1095 |
)
|
| 1096 |
-
|
| 1097 |
except Exception as e:
|
| 1098 |
-
error_msg = f"❌ Error
|
| 1099 |
-
return None, None, [], "", error_msg,
|
| 1100 |
|
|
|
|
| 12 |
from typing import Optional, Dict, Any, Tuple, List, Union
|
| 13 |
|
| 14 |
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
import matplotlib.pyplot as plt
|
| 17 |
import numpy as np
|
| 18 |
import scipy.io.wavfile as wavfile
|
| 19 |
+
import torchaudio
|
| 20 |
import soundfile as sf
|
| 21 |
import time
|
| 22 |
+
from loguru import logger
|
| 23 |
+
import warnings
|
| 24 |
|
| 25 |
from transformers import AutoTokenizer, AutoModel
|
| 26 |
from diffusers.models import AutoencoderOobleck
|
| 27 |
|
| 28 |
|
| 29 |
+
warnings.filterwarnings("ignore")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
SFT_GEN_PROMPT = """# Instruction
|
| 33 |
+
{}
|
| 34 |
+
|
| 35 |
+
# Caption
|
| 36 |
+
{}
|
| 37 |
+
|
| 38 |
+
# Metas
|
| 39 |
+
{}<|endoftext|>
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
class AceStepHandler:
|
| 43 |
"""ACE-Step Business Logic Handler"""
|
| 44 |
|
|
|
|
| 183 |
if os.path.exists(acestep_v15_checkpoint_path):
|
| 184 |
# Determine attention implementation
|
| 185 |
attn_implementation = "flash_attention_2" if use_flash_attention and self.is_flash_attention_available() else "eager"
|
| 186 |
+
if use_flash_attention and self.is_flash_attention_available():
|
| 187 |
+
self.dtype = torch.bfloat16
|
| 188 |
+
self.model = AutoModel.from_pretrained(acestep_v15_checkpoint_path, trust_remote_code=True, dtype=self.dtype)
|
| 189 |
+
self.model.config._attn_implementation = attn_implementation
|
|
|
|
| 190 |
self.config = self.model.config
|
| 191 |
# Move model to device and set dtype
|
| 192 |
self.model = self.model.to(device).to(self.dtype)
|
| 193 |
self.model.eval()
|
| 194 |
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
|
| 195 |
if os.path.exists(silence_latent_path):
|
| 196 |
+
self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
|
| 197 |
self.silence_latent = self.silence_latent.to(device).to(self.dtype)
|
| 198 |
else:
|
| 199 |
raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
|
|
|
|
| 411 |
metadata[key] = value
|
| 412 |
|
| 413 |
return metadata, audio_codes
|
| 414 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
|
| 416 |
"""Process target audio"""
|
| 417 |
if audio_file is None:
|
|
|
|
| 514 |
)
|
| 515 |
|
| 516 |
def _parse_metas(self, metas: List[Union[str, Dict[str, Any]]]) -> List[str]:
|
| 517 |
+
"""
|
| 518 |
+
Parse and normalize metadata with fallbacks.
|
| 519 |
+
|
| 520 |
+
Args:
|
| 521 |
+
metas: List of metadata (can be strings, dicts, or None)
|
| 522 |
+
|
| 523 |
+
Returns:
|
| 524 |
+
List of formatted metadata strings
|
| 525 |
+
"""
|
| 526 |
parsed_metas = []
|
| 527 |
for meta in metas:
|
| 528 |
if meta is None:
|
| 529 |
+
# Default fallback metadata
|
| 530 |
parsed_meta = self._create_default_meta()
|
| 531 |
elif isinstance(meta, str):
|
| 532 |
+
# Already formatted string
|
| 533 |
parsed_meta = meta
|
| 534 |
elif isinstance(meta, dict):
|
| 535 |
+
# Convert dict to formatted string
|
| 536 |
parsed_meta = self._dict_to_meta_string(meta)
|
| 537 |
else:
|
| 538 |
+
# Fallback for any other type
|
| 539 |
parsed_meta = self._create_default_meta()
|
| 540 |
+
|
| 541 |
parsed_metas.append(parsed_meta)
|
| 542 |
+
|
| 543 |
return parsed_metas
|
| 544 |
|
| 545 |
def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
| 573 |
return text_hidden_states, text_attention_mask
|
| 574 |
|
| 575 |
def extract_caption_from_sft_format(self, caption: str) -> str:
|
| 576 |
+
try:
|
| 577 |
+
if "# Instruction" in caption and "# Caption" in caption:
|
| 578 |
+
pattern = r'#\s*Caption\s*\n(.*?)(?:\n\s*#\s*Metas|$)'
|
| 579 |
+
match = re.search(pattern, caption, re.DOTALL)
|
| 580 |
+
if match:
|
| 581 |
+
return match.group(1).strip()
|
| 582 |
+
return caption
|
| 583 |
+
except Exception as e:
|
| 584 |
+
print(f"Error extracting caption: {e}")
|
| 585 |
+
return caption
|
| 586 |
+
|
| 587 |
+
def prepare_seeds(self, actual_batch_size, seed, use_random_seed):
|
| 588 |
+
actual_seed_list: List[int] = []
|
| 589 |
+
seed_value_for_ui = ""
|
| 590 |
+
|
| 591 |
+
if use_random_seed:
|
| 592 |
+
# Generate brand new seeds and expose them back to the UI
|
| 593 |
+
actual_seed_list = [random.randint(0, 2 ** 32 - 1) for _ in range(actual_batch_size)]
|
| 594 |
+
seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list)
|
| 595 |
+
else:
|
| 596 |
+
# Parse seed input: can be a single number, comma-separated numbers, or -1
|
| 597 |
+
# If seed is a string, try to parse it as comma-separated values
|
| 598 |
+
seed_list = []
|
| 599 |
+
if isinstance(seed, str):
|
| 600 |
+
# Handle string input (e.g., "123,456" or "-1")
|
| 601 |
+
seed_str_list = [s.strip() for s in seed.split(",")]
|
| 602 |
+
for s in seed_str_list:
|
| 603 |
+
if s == "-1" or s == "":
|
| 604 |
+
seed_list.append(-1)
|
| 605 |
+
else:
|
| 606 |
+
try:
|
| 607 |
+
seed_list.append(int(float(s)))
|
| 608 |
+
except (ValueError, TypeError):
|
| 609 |
+
seed_list.append(-1)
|
| 610 |
+
elif seed is None or (isinstance(seed, (int, float)) and seed < 0):
|
| 611 |
+
# If seed is None or negative, use -1 for all items
|
| 612 |
+
seed_list = [-1] * actual_batch_size
|
| 613 |
+
elif isinstance(seed, (int, float)):
|
| 614 |
+
# Single seed value
|
| 615 |
+
seed_list = [int(seed)]
|
| 616 |
+
else:
|
| 617 |
+
# Fallback: use -1
|
| 618 |
+
seed_list = [-1] * actual_batch_size
|
| 619 |
+
|
| 620 |
+
# Process seed list according to rules:
|
| 621 |
+
# 1. If all are -1, generate different random seeds for each batch item
|
| 622 |
+
# 2. If one non-negative seed is provided and batch_size > 1, first uses that seed, rest are random
|
| 623 |
+
# 3. If more seeds than batch_size, use first batch_size seeds
|
| 624 |
+
# Check if user provided only one non-negative seed (not -1)
|
| 625 |
+
has_single_non_negative_seed = (len(seed_list) == 1 and seed_list[0] != -1)
|
| 626 |
+
|
| 627 |
+
for i in range(actual_batch_size):
|
| 628 |
+
if i < len(seed_list):
|
| 629 |
+
seed_val = seed_list[i]
|
| 630 |
+
else:
|
| 631 |
+
# If not enough seeds provided, use -1 (will generate random)
|
| 632 |
+
seed_val = -1
|
| 633 |
+
|
| 634 |
+
# Special case: if only one non-negative seed was provided and batch_size > 1,
|
| 635 |
+
# only the first item uses that seed, others are random
|
| 636 |
+
if has_single_non_negative_seed and actual_batch_size > 1 and i > 0:
|
| 637 |
+
# Generate random seed for remaining items
|
| 638 |
+
actual_seed_list.append(random.randint(0, 2 ** 32 - 1))
|
| 639 |
+
elif seed_val == -1:
|
| 640 |
+
# Generate a random seed for this item
|
| 641 |
+
actual_seed_list.append(random.randint(0, 2 ** 32 - 1))
|
| 642 |
+
else:
|
| 643 |
+
actual_seed_list.append(int(seed_val))
|
| 644 |
+
|
| 645 |
+
seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list)
|
| 646 |
+
return actual_seed_list, seed_value_for_ui
|
| 647 |
+
|
| 648 |
+
def prepare_metadata(self, bpm, key_scale, time_signature):
|
| 649 |
+
# Build metadata dict - use "N/A" as default for empty fields
|
| 650 |
+
metadata_dict = {}
|
| 651 |
+
if bpm:
|
| 652 |
+
metadata_dict["bpm"] = bpm
|
| 653 |
+
else:
|
| 654 |
+
metadata_dict["bpm"] = "N/A"
|
| 655 |
+
|
| 656 |
+
if key_scale.strip():
|
| 657 |
+
metadata_dict["keyscale"] = key_scale
|
| 658 |
+
else:
|
| 659 |
+
metadata_dict["keyscale"] = "N/A"
|
| 660 |
+
|
| 661 |
+
if time_signature.strip() and time_signature != "N/A" and time_signature:
|
| 662 |
+
metadata_dict["timesignature"] = time_signature
|
| 663 |
+
else:
|
| 664 |
+
metadata_dict["timesignature"] = "N/A"
|
| 665 |
+
return metadata_dict
|
| 666 |
+
|
| 667 |
+
def is_silence(self, audio):
|
| 668 |
+
return torch.all(audio.abs() < 1e-6)
|
| 669 |
+
|
| 670 |
+
def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]:
|
| 671 |
+
if audio_file is None:
|
| 672 |
+
return None
|
| 673 |
+
|
| 674 |
+
try:
|
| 675 |
+
# Load audio file
|
| 676 |
+
audio, sr = torchaudio.load(audio_file)
|
| 677 |
+
|
| 678 |
+
# Convert to stereo (duplicate channel if mono)
|
| 679 |
+
if audio.shape[0] == 1:
|
| 680 |
+
audio = torch.cat([audio, audio], dim=0)
|
| 681 |
+
|
| 682 |
+
# Keep only first 2 channels
|
| 683 |
+
audio = audio[:2]
|
| 684 |
+
|
| 685 |
+
# Resample to 48kHz if needed
|
| 686 |
+
if sr != 48000:
|
| 687 |
+
audio = torchaudio.transforms.Resample(sr, 48000)(audio)
|
| 688 |
+
|
| 689 |
+
# Clamp values to [-1.0, 1.0]
|
| 690 |
+
audio = torch.clamp(audio, -1.0, 1.0)
|
| 691 |
+
|
| 692 |
+
is_silence = self.is_silence(audio)
|
| 693 |
+
if is_silence:
|
| 694 |
+
return None
|
| 695 |
+
|
| 696 |
+
# Target length: 30 seconds at 48kHz
|
| 697 |
+
target_frames = 30 * 48000
|
| 698 |
+
segment_frames = 10 * 48000 # 10 seconds per segment
|
| 699 |
+
|
| 700 |
+
# If audio is less than 30 seconds, repeat to at least 30 seconds
|
| 701 |
+
if audio.shape[-1] < target_frames:
|
| 702 |
+
repeat_times = math.ceil(target_frames / audio.shape[-1])
|
| 703 |
+
audio = audio.repeat(1, repeat_times)
|
| 704 |
+
# If audio is greater than or equal to 30 seconds, no operation needed
|
| 705 |
+
|
| 706 |
+
# For all cases, select random 10-second segments from front, middle, and back
|
| 707 |
+
# then concatenate them to form 30 seconds
|
| 708 |
+
total_frames = audio.shape[-1]
|
| 709 |
+
segment_size = total_frames // 3
|
| 710 |
+
|
| 711 |
+
# Front segment: [0, segment_size]
|
| 712 |
+
front_start = random.randint(0, max(0, segment_size - segment_frames))
|
| 713 |
+
front_audio = audio[:, front_start:front_start + segment_frames]
|
| 714 |
+
|
| 715 |
+
# Middle segment: [segment_size, 2*segment_size]
|
| 716 |
+
middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames))
|
| 717 |
+
middle_audio = audio[:, middle_start:middle_start + segment_frames]
|
| 718 |
+
|
| 719 |
+
# Back segment: [2*segment_size, total_frames]
|
| 720 |
+
back_start = 2 * segment_size + random.randint(0, max(0, (total_frames - 2 * segment_size) - segment_frames))
|
| 721 |
+
back_audio = audio[:, back_start:back_start + segment_frames]
|
| 722 |
+
|
| 723 |
+
# Concatenate three segments to form 30 seconds
|
| 724 |
+
audio = torch.cat([front_audio, middle_audio, back_audio], dim=-1)
|
| 725 |
+
|
| 726 |
+
return audio
|
| 727 |
+
|
| 728 |
+
except Exception as e:
|
| 729 |
+
print(f"Error processing reference audio: {e}")
|
| 730 |
+
return None
|
| 731 |
+
|
| 732 |
+
def process_src_audio(self, audio_file) -> Optional[torch.Tensor]:
|
| 733 |
+
if audio_file is None:
|
| 734 |
+
return None
|
| 735 |
+
|
| 736 |
+
try:
|
| 737 |
+
# Load audio file
|
| 738 |
+
audio, sr = torchaudio.load(audio_file)
|
| 739 |
+
|
| 740 |
+
# Convert to stereo (duplicate channel if mono)
|
| 741 |
+
if audio.shape[0] == 1:
|
| 742 |
+
audio = torch.cat([audio, audio], dim=0)
|
| 743 |
+
|
| 744 |
+
# Keep only first 2 channels
|
| 745 |
+
audio = audio[:2]
|
| 746 |
+
|
| 747 |
+
# Resample to 48kHz if needed
|
| 748 |
+
if sr != 48000:
|
| 749 |
+
audio = torchaudio.transforms.Resample(sr, 48000)(audio)
|
| 750 |
+
|
| 751 |
+
# Clamp values to [-1.0, 1.0]
|
| 752 |
+
audio = torch.clamp(audio, -1.0, 1.0)
|
| 753 |
+
|
| 754 |
+
return audio
|
| 755 |
+
|
| 756 |
+
except Exception as e:
|
| 757 |
+
print(f"Error processing target audio: {e}")
|
| 758 |
+
return None
|
| 759 |
+
|
| 760 |
+
def prepare_batch_data(
|
| 761 |
+
self,
|
| 762 |
+
actual_batch_size,
|
| 763 |
+
processed_src_audio,
|
| 764 |
+
audio_duration,
|
| 765 |
+
captions,
|
| 766 |
+
lyrics,
|
| 767 |
+
vocal_language,
|
| 768 |
+
instruction,
|
| 769 |
+
bpm,
|
| 770 |
+
key_scale,
|
| 771 |
+
time_signature
|
| 772 |
+
):
|
| 773 |
+
pure_caption = self.extract_caption_from_sft_format(captions)
|
| 774 |
+
captions_batch = [pure_caption] * actual_batch_size
|
| 775 |
+
instructions_batch = [instruction] * actual_batch_size
|
| 776 |
+
lyrics_batch = [lyrics] * actual_batch_size
|
| 777 |
+
vocal_languages_batch = [vocal_language] * actual_batch_size
|
| 778 |
+
# Calculate duration for metadata
|
| 779 |
+
calculated_duration = None
|
| 780 |
+
if processed_src_audio is not None:
|
| 781 |
+
calculated_duration = processed_src_audio.shape[-1] / 48000.0
|
| 782 |
+
elif audio_duration is not None and audio_duration > 0:
|
| 783 |
+
calculated_duration = audio_duration
|
| 784 |
+
|
| 785 |
+
# Build metadata dict - use "N/A" as default for empty fields
|
| 786 |
+
metadata_dict = {}
|
| 787 |
+
if bpm:
|
| 788 |
+
metadata_dict["bpm"] = bpm
|
| 789 |
+
else:
|
| 790 |
+
metadata_dict["bpm"] = "N/A"
|
| 791 |
+
|
| 792 |
+
if key_scale.strip():
|
| 793 |
+
metadata_dict["keyscale"] = key_scale
|
| 794 |
+
else:
|
| 795 |
+
metadata_dict["keyscale"] = "N/A"
|
| 796 |
+
|
| 797 |
+
if time_signature.strip() and time_signature != "N/A" and time_signature:
|
| 798 |
+
metadata_dict["timesignature"] = time_signature
|
| 799 |
+
else:
|
| 800 |
+
metadata_dict["timesignature"] = "N/A"
|
| 801 |
+
|
| 802 |
+
# Add duration to metadata if available (inference service format: "30 seconds")
|
| 803 |
+
if calculated_duration is not None:
|
| 804 |
+
metadata_dict["duration"] = f"{int(calculated_duration)} seconds"
|
| 805 |
+
# If duration not set, inference service will use default (30 seconds)
|
| 806 |
+
|
| 807 |
+
# Format metadata - inference service accepts dict and will convert to string
|
| 808 |
+
# Create a copy for each batch item (in case we modify it)
|
| 809 |
+
metas_batch = [metadata_dict.copy() for _ in range(actual_batch_size)]
|
| 810 |
+
return captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch
|
| 811 |
|
| 812 |
+
def determine_task_type(self, task_type, audio_code_string):
|
| 813 |
+
# Determine task type - repaint and lego tasks can have repainting parameters
|
| 814 |
+
# Other tasks (cover, text2music, extract, complete) should NOT have repainting
|
| 815 |
+
is_repaint_task = (task_type == "repaint")
|
| 816 |
+
is_lego_task = (task_type == "lego")
|
| 817 |
+
is_cover_task = (task_type == "cover")
|
| 818 |
+
if audio_code_string and str(audio_code_string).strip():
|
| 819 |
+
is_cover_task = True
|
| 820 |
+
# Both repaint and lego tasks can use repainting parameters for chunk mask
|
| 821 |
+
can_use_repainting = is_repaint_task or is_lego_task
|
| 822 |
+
return is_repaint_task, is_lego_task, is_cover_task, can_use_repainting
|
| 823 |
+
|
| 824 |
+
def create_target_wavs(self, duration_seconds: float) -> torch.Tensor:
|
| 825 |
+
try:
|
| 826 |
+
# Ensure minimum precision of 100ms
|
| 827 |
+
duration_seconds = max(0.1, round(duration_seconds, 1))
|
| 828 |
+
# Calculate frames for 48kHz stereo
|
| 829 |
+
frames = int(duration_seconds * 48000)
|
| 830 |
+
# Create silent stereo audio
|
| 831 |
+
target_wavs = torch.zeros(2, frames)
|
| 832 |
+
return target_wavs
|
| 833 |
+
except Exception as e:
|
| 834 |
+
print(f"Error creating target audio: {e}")
|
| 835 |
+
# Fallback to 30 seconds if error
|
| 836 |
+
return torch.zeros(2, 30 * 48000)
|
| 837 |
+
|
| 838 |
+
def prepare_padding_info(
|
| 839 |
+
self,
|
| 840 |
+
actual_batch_size,
|
| 841 |
+
processed_src_audio,
|
| 842 |
+
audio_duration,
|
| 843 |
+
repainting_start,
|
| 844 |
+
repainting_end,
|
| 845 |
+
is_repaint_task,
|
| 846 |
+
is_lego_task,
|
| 847 |
+
is_cover_task,
|
| 848 |
+
can_use_repainting,
|
| 849 |
+
):
|
| 850 |
+
target_wavs_batch = []
|
| 851 |
+
# Store padding info for each batch item to adjust repainting coordinates
|
| 852 |
+
padding_info_batch = []
|
| 853 |
+
for i in range(actual_batch_size):
|
| 854 |
+
if processed_src_audio is not None:
|
| 855 |
+
if is_cover_task:
|
| 856 |
+
# Cover task: Use src_audio directly without padding
|
| 857 |
+
batch_target_wavs = processed_src_audio
|
| 858 |
+
padding_info_batch.append({
|
| 859 |
+
'left_padding_duration': 0.0,
|
| 860 |
+
'right_padding_duration': 0.0
|
| 861 |
+
})
|
| 862 |
+
elif is_repaint_task or is_lego_task:
|
| 863 |
+
# Repaint/lego task: May need padding for outpainting
|
| 864 |
+
src_audio_duration = processed_src_audio.shape[-1] / 48000.0
|
| 865 |
+
|
| 866 |
+
# Determine actual end time
|
| 867 |
+
if repainting_end is None or repainting_end < 0:
|
| 868 |
+
actual_end = src_audio_duration
|
| 869 |
+
else:
|
| 870 |
+
actual_end = repainting_end
|
| 871 |
+
|
| 872 |
+
left_padding_duration = max(0, -repainting_start) if repainting_start is not None else 0
|
| 873 |
+
right_padding_duration = max(0, actual_end - src_audio_duration)
|
| 874 |
+
|
| 875 |
+
# Create padded audio
|
| 876 |
+
left_padding_frames = int(left_padding_duration * 48000)
|
| 877 |
+
right_padding_frames = int(right_padding_duration * 48000)
|
| 878 |
+
|
| 879 |
+
if left_padding_frames > 0 or right_padding_frames > 0:
|
| 880 |
+
# Pad the src audio
|
| 881 |
+
batch_target_wavs = torch.nn.functional.pad(
|
| 882 |
+
processed_src_audio,
|
| 883 |
+
(left_padding_frames, right_padding_frames),
|
| 884 |
+
'constant', 0
|
| 885 |
+
)
|
| 886 |
+
else:
|
| 887 |
+
batch_target_wavs = processed_src_audio
|
| 888 |
+
|
| 889 |
+
# Store padding info for coordinate adjustment
|
| 890 |
+
padding_info_batch.append({
|
| 891 |
+
'left_padding_duration': left_padding_duration,
|
| 892 |
+
'right_padding_duration': right_padding_duration
|
| 893 |
+
})
|
| 894 |
+
else:
|
| 895 |
+
# Other tasks: Use src_audio directly without padding
|
| 896 |
+
batch_target_wavs = processed_src_audio
|
| 897 |
+
padding_info_batch.append({
|
| 898 |
+
'left_padding_duration': 0.0,
|
| 899 |
+
'right_padding_duration': 0.0
|
| 900 |
+
})
|
| 901 |
+
else:
|
| 902 |
+
padding_info_batch.append({
|
| 903 |
+
'left_padding_duration': 0.0,
|
| 904 |
+
'right_padding_duration': 0.0
|
| 905 |
+
})
|
| 906 |
+
if audio_duration is not None and audio_duration > 0:
|
| 907 |
+
batch_target_wavs = self.create_target_wavs(audio_duration)
|
| 908 |
+
else:
|
| 909 |
+
import random
|
| 910 |
+
random_duration = random.uniform(10.0, 120.0)
|
| 911 |
+
batch_target_wavs = self.create_target_wavs(random_duration)
|
| 912 |
+
target_wavs_batch.append(batch_target_wavs)
|
| 913 |
+
|
| 914 |
+
# Stack target_wavs into batch tensor
|
| 915 |
+
# Ensure all tensors have the same shape by padding to max length
|
| 916 |
+
max_frames = max(wav.shape[-1] for wav in target_wavs_batch)
|
| 917 |
+
padded_target_wavs = []
|
| 918 |
+
for wav in target_wavs_batch:
|
| 919 |
+
if wav.shape[-1] < max_frames:
|
| 920 |
+
pad_frames = max_frames - wav.shape[-1]
|
| 921 |
+
padded_wav = torch.nn.functional.pad(wav, (0, pad_frames), 'constant', 0)
|
| 922 |
+
padded_target_wavs.append(padded_wav)
|
| 923 |
+
else:
|
| 924 |
+
padded_target_wavs.append(wav)
|
| 925 |
+
|
| 926 |
+
target_wavs_tensor = torch.stack(padded_target_wavs, dim=0) # [batch_size, 2, frames]
|
| 927 |
+
|
| 928 |
+
if can_use_repainting:
|
| 929 |
+
# Repaint task: Set repainting parameters
|
| 930 |
+
if repainting_start is None:
|
| 931 |
+
repainting_start_batch = None
|
| 932 |
+
elif isinstance(repainting_start, (int, float)):
|
| 933 |
+
if processed_src_audio is not None:
|
| 934 |
+
adjusted_start = repainting_start + padding_info_batch[0]['left_padding_duration']
|
| 935 |
+
repainting_start_batch = [adjusted_start] * actual_batch_size
|
| 936 |
+
else:
|
| 937 |
+
repainting_start_batch = [repainting_start] * actual_batch_size
|
| 938 |
+
else:
|
| 939 |
+
# List input - adjust each item
|
| 940 |
+
repainting_start_batch = []
|
| 941 |
+
for i in range(actual_batch_size):
|
| 942 |
+
if processed_src_audio is not None:
|
| 943 |
+
adjusted_start = repainting_start[i] + padding_info_batch[i]['left_padding_duration']
|
| 944 |
+
repainting_start_batch.append(adjusted_start)
|
| 945 |
+
else:
|
| 946 |
+
repainting_start_batch.append(repainting_start[i])
|
| 947 |
+
|
| 948 |
+
# Handle repainting_end - use src audio duration if not specified or negative
|
| 949 |
+
if processed_src_audio is not None:
|
| 950 |
+
# If src audio is provided, use its duration as default end
|
| 951 |
+
src_audio_duration = processed_src_audio.shape[-1] / 48000.0
|
| 952 |
+
if repainting_end is None or repainting_end < 0:
|
| 953 |
+
# Use src audio duration (before padding), then adjust for padding
|
| 954 |
+
adjusted_end = src_audio_duration + padding_info_batch[0]['left_padding_duration']
|
| 955 |
+
repainting_end_batch = [adjusted_end] * actual_batch_size
|
| 956 |
+
else:
|
| 957 |
+
# Adjust repainting_end to be relative to padded audio
|
| 958 |
+
adjusted_end = repainting_end + padding_info_batch[0]['left_padding_duration']
|
| 959 |
+
repainting_end_batch = [adjusted_end] * actual_batch_size
|
| 960 |
+
else:
|
| 961 |
+
# No src audio - repainting doesn't make sense without it
|
| 962 |
+
if repainting_end is None or repainting_end < 0:
|
| 963 |
+
repainting_end_batch = None
|
| 964 |
+
elif isinstance(repainting_end, (int, float)):
|
| 965 |
+
repainting_end_batch = [repainting_end] * actual_batch_size
|
| 966 |
+
else:
|
| 967 |
+
# List input - adjust each item
|
| 968 |
+
repainting_end_batch = []
|
| 969 |
+
for i in range(actual_batch_size):
|
| 970 |
+
if processed_src_audio is not None:
|
| 971 |
+
adjusted_end = repainting_end[i] + padding_info_batch[i]['left_padding_duration']
|
| 972 |
+
repainting_end_batch.append(adjusted_end)
|
| 973 |
+
else:
|
| 974 |
+
repainting_end_batch.append(repainting_end[i])
|
| 975 |
+
else:
|
| 976 |
+
# All other tasks (cover, text2music, extract, complete): No repainting
|
| 977 |
+
# Only repaint and lego tasks should have repainting parameters
|
| 978 |
+
repainting_start_batch = None
|
| 979 |
+
repainting_end_batch = None
|
| 980 |
+
|
| 981 |
+
return repainting_start_batch, repainting_end_batch, target_wavs_tensor
|
| 982 |
+
|
| 983 |
+
def _prepare_batch(
|
| 984 |
+
self,
|
| 985 |
+
captions: List[str],
|
| 986 |
+
lyrics: List[str],
|
| 987 |
+
keys: Optional[List[str]] = None,
|
| 988 |
+
target_wavs: Optional[torch.Tensor] = None,
|
| 989 |
+
refer_audios: Optional[List[List[torch.Tensor]]] = None,
|
| 990 |
+
metas: Optional[List[Union[str, Dict[str, Any]]]] = None,
|
| 991 |
+
vocal_languages: Optional[List[str]] = None,
|
| 992 |
+
repainting_start: Optional[List[float]] = None,
|
| 993 |
+
repainting_end: Optional[List[float]] = None,
|
| 994 |
+
instructions: Optional[List[str]] = None,
|
| 995 |
+
audio_code_hints: Optional[List[Optional[str]]] = None,
|
| 996 |
+
audio_cover_strength: float = 1.0,
|
| 997 |
+
) -> Dict[str, Any]:
|
| 998 |
+
"""
|
| 999 |
+
Prepare batch data with fallbacks for missing inputs.
|
| 1000 |
+
|
| 1001 |
+
Args:
|
| 1002 |
+
captions: List of text captions (optional, can be empty strings)
|
| 1003 |
+
lyrics: List of lyrics (optional, can be empty strings)
|
| 1004 |
+
keys: List of unique identifiers (optional)
|
| 1005 |
+
target_wavs: Target audio tensors (optional, will use silence if not provided)
|
| 1006 |
+
refer_audios: Reference audio tensors (optional, will use silence if not provided)
|
| 1007 |
+
metas: Metadata (optional, will use defaults if not provided)
|
| 1008 |
+
vocal_languages: Vocal languages (optional, will default to 'en')
|
| 1009 |
+
|
| 1010 |
+
Returns:
|
| 1011 |
+
Batch dictionary ready for model input
|
| 1012 |
+
"""
|
| 1013 |
+
batch_size = len(captions)
|
| 1014 |
+
|
| 1015 |
+
# Ensure audio_code_hints is a list of the correct length
|
| 1016 |
+
if audio_code_hints is None:
|
| 1017 |
+
audio_code_hints = [None] * batch_size
|
| 1018 |
+
elif len(audio_code_hints) != batch_size:
|
| 1019 |
+
if len(audio_code_hints) == 1:
|
| 1020 |
+
audio_code_hints = audio_code_hints * batch_size
|
| 1021 |
+
else:
|
| 1022 |
+
audio_code_hints = audio_code_hints[:batch_size]
|
| 1023 |
+
while len(audio_code_hints) < batch_size:
|
| 1024 |
+
audio_code_hints.append(None)
|
| 1025 |
+
|
| 1026 |
+
for ii, refer_audio_list in enumerate(refer_audios):
|
| 1027 |
+
if isinstance(refer_audio_list, list):
|
| 1028 |
+
for idx, refer_audio in enumerate(refer_audio_list):
|
| 1029 |
+
refer_audio_list[idx] = refer_audio_list[idx].to(self.device).to(torch.bfloat16)
|
| 1030 |
+
elif isinstance(refer_audio_list, torch.tensor):
|
| 1031 |
+
refer_audios[ii] = refer_audios[ii].to(self.device)
|
| 1032 |
+
|
| 1033 |
+
if vocal_languages is None:
|
| 1034 |
+
vocal_languages = self._create_fallback_vocal_languages(batch_size)
|
| 1035 |
+
|
| 1036 |
+
# Normalize audio_code_hints to batch list
|
| 1037 |
+
if audio_code_hints is None:
|
| 1038 |
+
audio_code_hints = [None] * batch_size
|
| 1039 |
+
elif not isinstance(audio_code_hints, list):
|
| 1040 |
+
audio_code_hints = [audio_code_hints] * batch_size
|
| 1041 |
+
elif len(audio_code_hints) == 1 and batch_size > 1:
|
| 1042 |
+
audio_code_hints = audio_code_hints * batch_size
|
| 1043 |
+
else:
|
| 1044 |
+
audio_code_hints = (audio_code_hints + [None] * batch_size)[:batch_size]
|
| 1045 |
+
audio_code_hints = [hint if isinstance(hint, str) and hint.strip() else None for hint in audio_code_hints]
|
| 1046 |
+
|
| 1047 |
+
# Parse metas with fallbacks
|
| 1048 |
+
parsed_metas = self._parse_metas(metas)
|
| 1049 |
+
|
| 1050 |
+
# Encode target_wavs to get target_latents
|
| 1051 |
+
with torch.no_grad():
|
| 1052 |
+
target_latents_list = []
|
| 1053 |
+
latent_lengths = []
|
| 1054 |
+
# Use per-item wavs (may be adjusted if audio_code_hints are provided)
|
| 1055 |
+
target_wavs_list = [target_wavs[i].clone() for i in range(batch_size)]
|
| 1056 |
+
if target_wavs.device != self.device:
|
| 1057 |
+
target_wavs = target_wavs.to(self.device)
|
| 1058 |
+
for i in range(batch_size):
|
| 1059 |
+
code_hint = audio_code_hints[i]
|
| 1060 |
+
# Prefer decoding from provided audio codes
|
| 1061 |
+
if code_hint:
|
| 1062 |
+
decoded_latents = self._decode_audio_codes_to_latents(code_hint)
|
| 1063 |
+
if decoded_latents is not None:
|
| 1064 |
+
decoded_latents = decoded_latents.squeeze(0)
|
| 1065 |
+
target_latents_list.append(decoded_latents)
|
| 1066 |
+
latent_lengths.append(decoded_latents.shape[0])
|
| 1067 |
+
# Create a silent wav matching the latent length for downstream scaling
|
| 1068 |
+
frames_from_codes = max(1, int(decoded_latents.shape[0] * 1920))
|
| 1069 |
+
target_wavs_list[i] = torch.zeros(2, frames_from_codes)
|
| 1070 |
+
continue
|
| 1071 |
+
# Fallback to VAE encode from audio
|
| 1072 |
+
current_wav = target_wavs_list[i].to(self.device).unsqueeze(0)
|
| 1073 |
+
if self.is_silence(current_wav):
|
| 1074 |
+
expected_latent_length = current_wav.shape[-1] // 1920
|
| 1075 |
+
target_latent = self.silence_latent[0, :expected_latent_length, :]
|
| 1076 |
+
else:
|
| 1077 |
+
target_latent = self.vae.encode(current_wav)
|
| 1078 |
+
target_latent = target_latent.squeeze(0).transpose(0, 1)
|
| 1079 |
+
target_latents_list.append(target_latent)
|
| 1080 |
+
latent_lengths.append(target_latent.shape[0])
|
| 1081 |
+
|
| 1082 |
+
# Pad target_wavs to consistent length for outputs
|
| 1083 |
+
max_target_frames = max(wav.shape[-1] for wav in target_wavs_list)
|
| 1084 |
+
padded_target_wavs = []
|
| 1085 |
+
for wav in target_wavs_list:
|
| 1086 |
+
if wav.shape[-1] < max_target_frames:
|
| 1087 |
+
pad_frames = max_target_frames - wav.shape[-1]
|
| 1088 |
+
wav = torch.nn.functional.pad(wav, (0, pad_frames), "constant", 0)
|
| 1089 |
+
padded_target_wavs.append(wav)
|
| 1090 |
+
target_wavs = torch.stack(padded_target_wavs)
|
| 1091 |
+
wav_lengths = torch.tensor([target_wavs.shape[-1]] * batch_size, dtype=torch.long)
|
| 1092 |
+
|
| 1093 |
+
# Pad latents to same length
|
| 1094 |
+
max_latent_length = max(latent.shape[0] for latent in target_latents_list)
|
| 1095 |
+
max_latent_length = max(128, max_latent_length)
|
| 1096 |
+
|
| 1097 |
+
padded_latents = []
|
| 1098 |
+
for latent in target_latents_list:
|
| 1099 |
+
latent_length = latent.shape[0]
|
| 1100 |
+
|
| 1101 |
+
if latent.shape[0] < max_latent_length:
|
| 1102 |
+
pad_length = max_latent_length - latent.shape[0]
|
| 1103 |
+
latent = torch.cat([latent, self.silence_latent[0, :pad_length, :]], dim=0)
|
| 1104 |
+
padded_latents.append(latent)
|
| 1105 |
+
|
| 1106 |
+
target_latents = torch.stack(padded_latents)
|
| 1107 |
+
latent_masks = torch.stack([
|
| 1108 |
+
torch.cat([
|
| 1109 |
+
torch.ones(l, dtype=torch.long, device=self.device),
|
| 1110 |
+
torch.zeros(max_latent_length - l, dtype=torch.long, device=self.device)
|
| 1111 |
+
])
|
| 1112 |
+
for l in latent_lengths
|
| 1113 |
+
])
|
| 1114 |
+
|
| 1115 |
+
# Process instructions early so we can use them for task type detection
|
| 1116 |
+
# Use custom instructions if provided, otherwise use default
|
| 1117 |
+
if instructions is None:
|
| 1118 |
+
instructions = ["Fill the audio semantic mask based on the given conditions:"] * batch_size
|
| 1119 |
+
|
| 1120 |
+
# Ensure instructions list has the same length as batch_size
|
| 1121 |
+
if len(instructions) != batch_size:
|
| 1122 |
+
if len(instructions) == 1:
|
| 1123 |
+
instructions = instructions * batch_size
|
| 1124 |
+
else:
|
| 1125 |
+
# Pad or truncate to match batch_size
|
| 1126 |
+
instructions = instructions[:batch_size]
|
| 1127 |
+
while len(instructions) < batch_size:
|
| 1128 |
+
instructions.append("Fill the audio semantic mask based on the given conditions:")
|
| 1129 |
+
|
| 1130 |
+
# Generate chunk_masks and spans based on repainting parameters
|
| 1131 |
+
# Also determine if this is a cover task (target audio provided without repainting)
|
| 1132 |
+
chunk_masks = []
|
| 1133 |
+
spans = []
|
| 1134 |
+
is_covers = []
|
| 1135 |
+
# Store repainting latent ranges for later use in src_latents creation
|
| 1136 |
+
repainting_ranges = {} # {batch_idx: (start_latent, end_latent)}
|
| 1137 |
+
|
| 1138 |
+
for i in range(batch_size):
|
| 1139 |
+
has_code_hint = audio_code_hints[i] is not None
|
| 1140 |
+
# Check if repainting is enabled for this batch item
|
| 1141 |
+
has_repainting = False
|
| 1142 |
+
if repainting_start is not None and repainting_end is not None:
|
| 1143 |
+
start_sec = repainting_start[i] if repainting_start[i] is not None else 0.0
|
| 1144 |
+
end_sec = repainting_end[i]
|
| 1145 |
+
|
| 1146 |
+
if end_sec is not None and end_sec > start_sec:
|
| 1147 |
+
# Repainting mode with outpainting support
|
| 1148 |
+
# The target_wavs may have been padded for outpainting
|
| 1149 |
+
# Need to calculate the actual position in the padded audio
|
| 1150 |
+
|
| 1151 |
+
# Calculate padding (if start < 0, there's left padding)
|
| 1152 |
+
left_padding_sec = max(0, -start_sec)
|
| 1153 |
+
|
| 1154 |
+
# Adjust positions to account for padding
|
| 1155 |
+
# In the padded audio, the original start is shifted by left_padding
|
| 1156 |
+
adjusted_start_sec = start_sec + left_padding_sec
|
| 1157 |
+
adjusted_end_sec = end_sec + left_padding_sec
|
| 1158 |
+
|
| 1159 |
+
# Convert seconds to latent frames (audio_frames / 1920 = latent_frames)
|
| 1160 |
+
start_latent = int(adjusted_start_sec * self.sample_rate // 1920)
|
| 1161 |
+
end_latent = int(adjusted_end_sec * self.sample_rate // 1920)
|
| 1162 |
+
|
| 1163 |
+
# Clamp to valid range
|
| 1164 |
+
start_latent = max(0, min(start_latent, max_latent_length - 1))
|
| 1165 |
+
end_latent = max(start_latent + 1, min(end_latent, max_latent_length))
|
| 1166 |
+
# Create mask: False = keep original, True = generate new
|
| 1167 |
+
mask = torch.zeros(max_latent_length, dtype=torch.bool, device=self.device)
|
| 1168 |
+
mask[start_latent:end_latent] = True
|
| 1169 |
+
chunk_masks.append(mask)
|
| 1170 |
+
spans.append(("repainting", start_latent, end_latent))
|
| 1171 |
+
# Store repainting range for later use
|
| 1172 |
+
repainting_ranges[i] = (start_latent, end_latent)
|
| 1173 |
+
has_repainting = True
|
| 1174 |
+
is_covers.append(False) # Repainting is not cover task
|
| 1175 |
+
else:
|
| 1176 |
+
# Full generation (no valid repainting range)
|
| 1177 |
+
chunk_masks.append(torch.ones(max_latent_length, dtype=torch.bool, device=self.device))
|
| 1178 |
+
spans.append(("full", 0, max_latent_length))
|
| 1179 |
+
# Determine task type from instruction, not from target_wavs
|
| 1180 |
+
# Only cover task should have is_cover=True
|
| 1181 |
+
instruction_i = instructions[i] if instructions and i < len(instructions) else ""
|
| 1182 |
+
instruction_lower = instruction_i.lower()
|
| 1183 |
+
# Cover task instruction: "Generate audio semantic tokens based on the given conditions:"
|
| 1184 |
+
is_cover = ("generate audio semantic tokens" in instruction_lower and
|
| 1185 |
+
"based on the given conditions" in instruction_lower) or has_code_hint
|
| 1186 |
+
is_covers.append(is_cover)
|
| 1187 |
+
else:
|
| 1188 |
+
# Full generation (no repainting parameters)
|
| 1189 |
+
chunk_masks.append(torch.ones(max_latent_length, dtype=torch.bool, device=self.device))
|
| 1190 |
+
spans.append(("full", 0, max_latent_length))
|
| 1191 |
+
# Determine task type from instruction, not from target_wavs
|
| 1192 |
+
# Only cover task should have is_cover=True
|
| 1193 |
+
instruction_i = instructions[i] if instructions and i < len(instructions) else ""
|
| 1194 |
+
instruction_lower = instruction_i.lower()
|
| 1195 |
+
# Cover task instruction: "Generate audio semantic tokens based on the given conditions:"
|
| 1196 |
+
is_cover = ("generate audio semantic tokens" in instruction_lower and
|
| 1197 |
+
"based on the given conditions" in instruction_lower) or has_code_hint
|
| 1198 |
+
is_covers.append(is_cover)
|
| 1199 |
+
|
| 1200 |
+
chunk_masks = torch.stack(chunk_masks)
|
| 1201 |
+
is_covers = torch.BoolTensor(is_covers).to(self.device)
|
| 1202 |
+
|
| 1203 |
+
# Create src_latents based on task type
|
| 1204 |
+
# For cover/extract/complete/lego/repaint tasks: src_latents = target_latents.clone() (if target_wavs provided)
|
| 1205 |
+
# For text2music task: src_latents = silence_latent (if no target_wavs or silence)
|
| 1206 |
+
# For repaint task: additionally replace inpainting region with silence_latent
|
| 1207 |
+
src_latents_list = []
|
| 1208 |
+
silence_latent_tiled = self.silence_latent[0, :max_latent_length, :]
|
| 1209 |
+
for i in range(batch_size):
|
| 1210 |
+
# Check if target_wavs is provided and not silent (for extract/complete/lego/cover/repaint tasks)
|
| 1211 |
+
has_code_hint = audio_code_hints[i] is not None
|
| 1212 |
+
has_target_audio = has_code_hint or (target_wavs is not None and target_wavs[i].abs().sum() > 1e-6)
|
| 1213 |
+
|
| 1214 |
+
if has_target_audio:
|
| 1215 |
+
# For tasks that use input audio (cover/extract/complete/lego/repaint)
|
| 1216 |
+
# Check if this item has repainting
|
| 1217 |
+
item_has_repainting = (i in repainting_ranges)
|
| 1218 |
+
|
| 1219 |
+
if item_has_repainting:
|
| 1220 |
+
# Repaint task: src_latents = target_latents with inpainting region replaced by silence_latent
|
| 1221 |
+
# 1. Clone target_latents (encoded from src audio, preserving original audio)
|
| 1222 |
+
src_latent = target_latents[i].clone()
|
| 1223 |
+
# 2. Replace inpainting region with silence_latent
|
| 1224 |
+
start_latent, end_latent = repainting_ranges[i]
|
| 1225 |
+
src_latent[start_latent:end_latent] = silence_latent_tiled[start_latent:end_latent]
|
| 1226 |
+
src_latents_list.append(src_latent)
|
| 1227 |
+
else:
|
| 1228 |
+
# Cover/extract/complete/lego tasks: src_latents = target_latents.clone()
|
| 1229 |
+
# All these tasks need to base on input audio
|
| 1230 |
+
src_latents_list.append(target_latents[i].clone())
|
| 1231 |
+
else:
|
| 1232 |
+
# Text2music task: src_latents = silence_latent (no input audio)
|
| 1233 |
+
# Use silence_latent for the full length
|
| 1234 |
+
src_latents_list.append(silence_latent_tiled.clone())
|
| 1235 |
+
|
| 1236 |
+
src_latents = torch.stack(src_latents_list)
|
| 1237 |
+
|
| 1238 |
+
# Process audio_code_hints to generate precomputed_lm_hints_25Hz
|
| 1239 |
+
precomputed_lm_hints_25Hz_list = []
|
| 1240 |
+
for i in range(batch_size):
|
| 1241 |
+
if audio_code_hints[i] is not None:
|
| 1242 |
+
# Decode audio codes to 25Hz latents
|
| 1243 |
+
hints = self._decode_audio_codes_to_latents(audio_code_hints[i])
|
| 1244 |
+
if hints is not None:
|
| 1245 |
+
# Pad or crop to match max_latent_length
|
| 1246 |
+
if hints.shape[1] < max_latent_length:
|
| 1247 |
+
pad_length = max_latent_length - hints.shape[1]
|
| 1248 |
+
hints = torch.cat([
|
| 1249 |
+
hints,
|
| 1250 |
+
self.silence_latent[0, :pad_length, :]
|
| 1251 |
+
], dim=1)
|
| 1252 |
+
elif hints.shape[1] > max_latent_length:
|
| 1253 |
+
hints = hints[:, :max_latent_length, :]
|
| 1254 |
+
precomputed_lm_hints_25Hz_list.append(hints[0]) # Remove batch dimension
|
| 1255 |
+
else:
|
| 1256 |
+
precomputed_lm_hints_25Hz_list.append(None)
|
| 1257 |
+
else:
|
| 1258 |
+
precomputed_lm_hints_25Hz_list.append(None)
|
| 1259 |
+
|
| 1260 |
+
# Stack precomputed hints if any exist, otherwise set to None
|
| 1261 |
+
if any(h is not None for h in precomputed_lm_hints_25Hz_list):
|
| 1262 |
+
# For items without hints, use silence_latent as placeholder
|
| 1263 |
+
precomputed_lm_hints_25Hz = torch.stack([
|
| 1264 |
+
h if h is not None else silence_latent_tiled
|
| 1265 |
+
for h in precomputed_lm_hints_25Hz_list
|
| 1266 |
+
])
|
| 1267 |
+
else:
|
| 1268 |
+
precomputed_lm_hints_25Hz = None
|
| 1269 |
+
|
| 1270 |
+
# Format text_inputs
|
| 1271 |
+
text_inputs = []
|
| 1272 |
+
text_token_idss = []
|
| 1273 |
+
text_attention_masks = []
|
| 1274 |
+
lyric_token_idss = []
|
| 1275 |
+
lyric_attention_masks = []
|
| 1276 |
+
|
| 1277 |
+
for i in range(batch_size):
|
| 1278 |
+
# Use custom instruction for this batch item
|
| 1279 |
+
instruction = instructions[i] if i < len(instructions) else "Fill the audio semantic mask based on the given conditions:"
|
| 1280 |
+
# Ensure instruction ends with ":"
|
| 1281 |
+
if not instruction.endswith(":"):
|
| 1282 |
+
instruction = instruction + ":"
|
| 1283 |
+
|
| 1284 |
+
# Format text prompt with custom instruction
|
| 1285 |
+
text_prompt = SFT_GEN_PROMPT.format(instruction, captions[i], parsed_metas[i])
|
| 1286 |
+
|
| 1287 |
+
# Tokenize text
|
| 1288 |
+
text_inputs_dict = self.text_tokenizer(
|
| 1289 |
+
text_prompt,
|
| 1290 |
+
padding="longest",
|
| 1291 |
+
truncation=True,
|
| 1292 |
+
max_length=256,
|
| 1293 |
+
return_tensors="pt",
|
| 1294 |
+
)
|
| 1295 |
+
text_token_ids = text_inputs_dict.input_ids[0]
|
| 1296 |
+
text_attention_mask = text_inputs_dict.attention_mask[0].bool()
|
| 1297 |
+
|
| 1298 |
+
# Format and tokenize lyrics
|
| 1299 |
+
lyrics_text = f"# Languages\n{vocal_languages[i]}\n\n# Lyric\n{lyrics[i]}<|endoftext|>"
|
| 1300 |
+
lyrics_inputs_dict = self.text_tokenizer(
|
| 1301 |
+
lyrics_text,
|
| 1302 |
+
padding="longest",
|
| 1303 |
+
truncation=True,
|
| 1304 |
+
max_length=2048,
|
| 1305 |
+
return_tensors="pt",
|
| 1306 |
+
)
|
| 1307 |
+
lyric_token_ids = lyrics_inputs_dict.input_ids[0]
|
| 1308 |
+
lyric_attention_mask = lyrics_inputs_dict.attention_mask[0].bool()
|
| 1309 |
+
|
| 1310 |
+
# Build full text input
|
| 1311 |
+
text_input = text_prompt + "\n\n" + lyrics_text
|
| 1312 |
+
|
| 1313 |
+
text_inputs.append(text_input)
|
| 1314 |
+
text_token_idss.append(text_token_ids)
|
| 1315 |
+
text_attention_masks.append(text_attention_mask)
|
| 1316 |
+
lyric_token_idss.append(lyric_token_ids)
|
| 1317 |
+
lyric_attention_masks.append(lyric_attention_mask)
|
| 1318 |
+
|
| 1319 |
+
# Pad tokenized sequences
|
| 1320 |
+
max_text_length = max(len(seq) for seq in text_token_idss)
|
| 1321 |
+
padded_text_token_idss = torch.stack([
|
| 1322 |
+
torch.nn.functional.pad(
|
| 1323 |
+
seq, (0, max_text_length - len(seq)), 'constant',
|
| 1324 |
+
self.text_tokenizer.pad_token_id
|
| 1325 |
+
)
|
| 1326 |
+
for seq in text_token_idss
|
| 1327 |
+
])
|
| 1328 |
+
|
| 1329 |
+
padded_text_attention_masks = torch.stack([
|
| 1330 |
+
torch.nn.functional.pad(
|
| 1331 |
+
seq, (0, max_text_length - len(seq)), 'constant', 0
|
| 1332 |
+
)
|
| 1333 |
+
for seq in text_attention_masks
|
| 1334 |
+
])
|
| 1335 |
+
|
| 1336 |
+
max_lyric_length = max(len(seq) for seq in lyric_token_idss)
|
| 1337 |
+
padded_lyric_token_idss = torch.stack([
|
| 1338 |
+
torch.nn.functional.pad(
|
| 1339 |
+
seq, (0, max_lyric_length - len(seq)), 'constant',
|
| 1340 |
+
self.text_tokenizer.pad_token_id
|
| 1341 |
+
)
|
| 1342 |
+
for seq in lyric_token_idss
|
| 1343 |
+
])
|
| 1344 |
+
|
| 1345 |
+
padded_lyric_attention_masks = torch.stack([
|
| 1346 |
+
torch.nn.functional.pad(
|
| 1347 |
+
seq, (0, max_lyric_length - len(seq)), 'constant', 0
|
| 1348 |
+
)
|
| 1349 |
+
for seq in lyric_attention_masks
|
| 1350 |
+
])
|
| 1351 |
+
|
| 1352 |
+
padded_non_cover_text_input_ids = None
|
| 1353 |
+
padded_non_cover_text_attention_masks = None
|
| 1354 |
+
if audio_cover_strength < 1.0 and is_covers is not None and is_covers.any():
|
| 1355 |
+
non_cover_text_input_ids = []
|
| 1356 |
+
non_cover_text_attention_masks = []
|
| 1357 |
+
for i in range(batch_size):
|
| 1358 |
+
# Use custom instruction for this batch item
|
| 1359 |
+
instruction = "Fill the audio semantic mask based on the given conditions:"
|
| 1360 |
+
|
| 1361 |
+
# Format text prompt with custom instruction
|
| 1362 |
+
text_prompt = SFT_GEN_PROMPT.format(instruction, captions[i], parsed_metas[i])
|
| 1363 |
+
|
| 1364 |
+
# Tokenize text
|
| 1365 |
+
text_inputs_dict = self.text_tokenizer(
|
| 1366 |
+
text_prompt,
|
| 1367 |
+
padding="longest",
|
| 1368 |
+
truncation=True,
|
| 1369 |
+
max_length=256,
|
| 1370 |
+
return_tensors="pt",
|
| 1371 |
+
)
|
| 1372 |
+
text_token_ids = text_inputs_dict.input_ids[0]
|
| 1373 |
+
non_cover_text_input_ids.append(text_token_ids)
|
| 1374 |
+
non_cover_text_attention_masks.append(text_attention_mask)
|
| 1375 |
+
|
| 1376 |
+
padded_non_cover_text_input_ids = torch.stack([
|
| 1377 |
+
torch.nn.functional.pad(
|
| 1378 |
+
seq, (0, max_text_length - len(seq)), 'constant',
|
| 1379 |
+
self.text_tokenizer.pad_token_id
|
| 1380 |
+
)
|
| 1381 |
+
for seq in non_cover_text_input_ids
|
| 1382 |
+
])
|
| 1383 |
+
padded_non_cover_text_attention_masks = torch.stack([
|
| 1384 |
+
torch.nn.functional.pad(
|
| 1385 |
+
seq, (0, max_text_length - len(seq)), 'constant', 0
|
| 1386 |
+
)
|
| 1387 |
+
for seq in non_cover_text_attention_masks
|
| 1388 |
+
])
|
| 1389 |
+
|
| 1390 |
+
# Prepare batch
|
| 1391 |
+
batch = {
|
| 1392 |
+
"keys": keys,
|
| 1393 |
+
"target_wavs": target_wavs.to(self.device),
|
| 1394 |
+
"refer_audioss": refer_audios,
|
| 1395 |
+
"wav_lengths": wav_lengths.to(self.device),
|
| 1396 |
+
"captions": captions,
|
| 1397 |
+
"lyrics": lyrics,
|
| 1398 |
+
"metas": parsed_metas,
|
| 1399 |
+
"vocal_languages": vocal_languages,
|
| 1400 |
+
"target_latents": target_latents,
|
| 1401 |
+
"src_latents": src_latents,
|
| 1402 |
+
"latent_masks": latent_masks,
|
| 1403 |
+
"chunk_masks": chunk_masks,
|
| 1404 |
+
"spans": spans,
|
| 1405 |
+
"text_inputs": text_inputs,
|
| 1406 |
+
"text_token_idss": padded_text_token_idss,
|
| 1407 |
+
"text_attention_masks": padded_text_attention_masks,
|
| 1408 |
+
"lyric_token_idss": padded_lyric_token_idss,
|
| 1409 |
+
"lyric_attention_masks": padded_lyric_attention_masks,
|
| 1410 |
+
"is_covers": is_covers,
|
| 1411 |
+
"precomputed_lm_hints_25Hz": precomputed_lm_hints_25Hz,
|
| 1412 |
+
"non_cover_text_input_ids": padded_non_cover_text_input_ids,
|
| 1413 |
+
"non_cover_text_attention_masks": padded_non_cover_text_attention_masks,
|
| 1414 |
+
}
|
| 1415 |
+
# to device
|
| 1416 |
+
for k, v in batch.items():
|
| 1417 |
+
if isinstance(v, torch.Tensor):
|
| 1418 |
+
batch[k] = v.to(self.device)
|
| 1419 |
+
if torch.is_floating_point(v):
|
| 1420 |
+
batch[k] = v.to(self.dtype)
|
| 1421 |
+
return batch
|
| 1422 |
+
|
| 1423 |
+
def infer_refer_latent(self, refer_audioss):
|
| 1424 |
+
refer_audio_order_mask = []
|
| 1425 |
+
refer_audio_latents = []
|
| 1426 |
+
for batch_idx, refer_audios in enumerate(refer_audioss):
|
| 1427 |
+
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
|
| 1428 |
+
refer_audio_latent = self.silence_latent[:, :750, :]
|
| 1429 |
+
refer_audio_latents.append(refer_audio_latent)
|
| 1430 |
+
refer_audio_order_mask.append(batch_idx)
|
| 1431 |
+
else:
|
| 1432 |
+
for refer_audio in refer_audios:
|
| 1433 |
+
refer_audio_latent = self.vae.encode(refer_audio.unsqueeze(0), chunked=False)
|
| 1434 |
+
refer_audio_latents.append(refer_audio_latent.transpose(1, 2))
|
| 1435 |
+
refer_audio_order_mask.append(batch_idx)
|
| 1436 |
+
|
| 1437 |
+
refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
|
| 1438 |
+
refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=self.device, dtype=torch.long)
|
| 1439 |
+
return refer_audio_latents, refer_audio_order_mask
|
| 1440 |
+
|
| 1441 |
+
def infer_text_embeddings(self, text_token_idss):
|
| 1442 |
+
with torch.no_grad():
|
| 1443 |
+
text_embeddings = self.text_encoder(input_ids=text_token_idss, lyric_attention_mask=None).last_hidden_state
|
| 1444 |
+
return text_embeddings
|
| 1445 |
+
|
| 1446 |
+
def infer_lyric_embeddings(self, lyric_token_ids):
|
| 1447 |
+
with torch.no_grad():
|
| 1448 |
+
lyric_embeddings = self.text_encoder.embed_tokens(lyric_token_ids)
|
| 1449 |
+
return lyric_embeddings
|
| 1450 |
+
|
| 1451 |
+
def preprocess_batch(self, batch):
|
| 1452 |
+
|
| 1453 |
+
# step 1: VAE encode latents, target_latents: N x T x d
|
| 1454 |
+
# target_latents: N x T x d
|
| 1455 |
+
target_latents = batch["target_latents"]
|
| 1456 |
+
src_latents = batch["src_latents"]
|
| 1457 |
+
attention_mask = batch["latent_masks"]
|
| 1458 |
+
audio_codes = batch.get("audio_codes", None)
|
| 1459 |
+
audio_attention_mask = attention_mask
|
| 1460 |
+
|
| 1461 |
+
dtype = target_latents.dtype
|
| 1462 |
+
bs = target_latents.shape[0]
|
| 1463 |
+
device = target_latents.device
|
| 1464 |
+
|
| 1465 |
+
# step 2: refer_audio timbre
|
| 1466 |
+
keys = batch["keys"]
|
| 1467 |
+
refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask = self.infer_refer_latent(batch["refer_audioss"])
|
| 1468 |
+
if refer_audio_acoustic_hidden_states_packed.dtype != dtype:
|
| 1469 |
+
refer_audio_acoustic_hidden_states_packed = refer_audio_acoustic_hidden_states_packed.to(dtype)
|
| 1470 |
+
|
| 1471 |
+
# step 4: chunk mask, N x T x d
|
| 1472 |
+
chunk_mask = batch["chunk_masks"]
|
| 1473 |
+
chunk_mask = chunk_mask.to(device).unsqueeze(-1).repeat(1, 1, target_latents.shape[2])
|
| 1474 |
+
|
| 1475 |
+
spans = batch["spans"]
|
| 1476 |
+
|
| 1477 |
+
text_token_idss = batch["text_token_idss"]
|
| 1478 |
+
text_attention_mask = batch["text_attention_masks"]
|
| 1479 |
+
lyric_token_idss = batch["lyric_token_idss"]
|
| 1480 |
+
lyric_attention_mask = batch["lyric_attention_masks"]
|
| 1481 |
+
text_inputs = batch["text_inputs"]
|
| 1482 |
+
|
| 1483 |
+
text_hidden_states = self.infer_text_embeddings(text_token_idss)
|
| 1484 |
+
lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss)
|
| 1485 |
+
|
| 1486 |
+
is_covers = batch["is_covers"]
|
| 1487 |
+
|
| 1488 |
+
# Get precomputed hints from batch if available
|
| 1489 |
+
precomputed_lm_hints_25Hz = batch.get("precomputed_lm_hints_25Hz", None)
|
| 1490 |
+
|
| 1491 |
+
# Get non-cover text input ids and attention masks from batch if available
|
| 1492 |
+
non_cover_text_input_ids = batch.get("non_cover_text_input_ids", None)
|
| 1493 |
+
non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None)
|
| 1494 |
+
non_cover_text_hidden_states = None
|
| 1495 |
+
if non_cover_text_input_ids is not None:
|
| 1496 |
+
non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids)
|
| 1497 |
+
|
| 1498 |
+
return (
|
| 1499 |
+
keys,
|
| 1500 |
+
text_inputs,
|
| 1501 |
+
src_latents,
|
| 1502 |
+
target_latents,
|
| 1503 |
+
# model inputs
|
| 1504 |
+
text_hidden_states,
|
| 1505 |
+
text_attention_mask,
|
| 1506 |
+
lyric_hidden_states,
|
| 1507 |
+
lyric_attention_mask,
|
| 1508 |
+
audio_attention_mask,
|
| 1509 |
+
refer_audio_acoustic_hidden_states_packed,
|
| 1510 |
+
refer_audio_order_mask,
|
| 1511 |
+
chunk_mask,
|
| 1512 |
+
spans,
|
| 1513 |
+
is_covers,
|
| 1514 |
+
audio_codes,
|
| 1515 |
+
lyric_token_idss,
|
| 1516 |
+
precomputed_lm_hints_25Hz,
|
| 1517 |
+
non_cover_text_hidden_states,
|
| 1518 |
+
non_cover_text_attention_masks,
|
| 1519 |
+
)
|
| 1520 |
+
|
| 1521 |
+
@torch.no_grad()
|
| 1522 |
+
def service_generate(
|
| 1523 |
+
self,
|
| 1524 |
+
captions: Union[str, List[str]],
|
| 1525 |
+
lyrics: Union[str, List[str]],
|
| 1526 |
+
keys: Optional[Union[str, List[str]]] = None,
|
| 1527 |
+
target_wavs: Optional[torch.Tensor] = None,
|
| 1528 |
+
refer_audios: Optional[List[List[torch.Tensor]]] = None,
|
| 1529 |
+
metas: Optional[Union[str, Dict[str, Any], List[Union[str, Dict[str, Any]]]]] = None,
|
| 1530 |
+
vocal_languages: Optional[Union[str, List[str]]] = None,
|
| 1531 |
+
infer_steps: int = 60,
|
| 1532 |
+
guidance_scale: float = 7.0,
|
| 1533 |
+
seed: Optional[Union[int, List[int]]] = None,
|
| 1534 |
+
return_intermediate: bool = False,
|
| 1535 |
+
repainting_start: Optional[Union[float, List[float]]] = None,
|
| 1536 |
+
repainting_end: Optional[Union[float, List[float]]] = None,
|
| 1537 |
+
instructions: Optional[Union[str, List[str]]] = None,
|
| 1538 |
+
audio_cover_strength: float = 1.0,
|
| 1539 |
+
use_adg: bool = False,
|
| 1540 |
+
cfg_interval_start: float = 0.0,
|
| 1541 |
+
cfg_interval_end: float = 1.0,
|
| 1542 |
+
audio_code_hints: Optional[Union[str, List[str]]] = None,
|
| 1543 |
+
infer_method: str = "ode",
|
| 1544 |
+
) -> Dict[str, Any]:
|
| 1545 |
+
|
| 1546 |
+
"""
|
| 1547 |
+
Generate music from text inputs.
|
| 1548 |
+
|
| 1549 |
+
Args:
|
| 1550 |
+
captions: Text caption(s) describing the music (optional, can be empty strings)
|
| 1551 |
+
lyrics: Lyric text(s) (optional, can be empty strings)
|
| 1552 |
+
keys: Unique identifier(s) (optional)
|
| 1553 |
+
target_wavs: Target audio tensor(s) for conditioning (optional)
|
| 1554 |
+
refer_audios: Reference audio tensor(s) for style transfer (optional)
|
| 1555 |
+
metas: Metadata dict(s) or string(s) (optional)
|
| 1556 |
+
vocal_languages: Language code(s) for lyrics (optional, defaults to 'en')
|
| 1557 |
+
infer_steps: Number of inference steps (default: 60)
|
| 1558 |
+
guidance_scale: Guidance scale for generation (default: 7.0)
|
| 1559 |
+
seed: Random seed (optional)
|
| 1560 |
+
return_intermediate: Whether to return intermediate results (default: False)
|
| 1561 |
+
repainting_start: Start time(s) for repainting region in seconds (optional)
|
| 1562 |
+
repainting_end: End time(s) for repainting region in seconds (optional)
|
| 1563 |
+
instructions: Instruction text(s) for generation (optional)
|
| 1564 |
+
audio_cover_strength: Strength of audio cover mode (default: 1.0)
|
| 1565 |
+
use_adg: Whether to use ADG (Adaptive Diffusion Guidance) (default: False)
|
| 1566 |
+
cfg_interval_start: Start of CFG interval (0.0-1.0, default: 0.0)
|
| 1567 |
+
cfg_interval_end: End of CFG interval (0.0-1.0, default: 1.0)
|
| 1568 |
+
|
| 1569 |
+
Returns:
|
| 1570 |
+
Dictionary containing:
|
| 1571 |
+
- pred_wavs: Generated audio tensors
|
| 1572 |
+
- target_wavs: Input target audio (if provided)
|
| 1573 |
+
- vqvae_recon_wavs: VAE reconstruction of target
|
| 1574 |
+
- keys: Identifiers used
|
| 1575 |
+
- text_inputs: Formatted text inputs
|
| 1576 |
+
- sr: Sample rate
|
| 1577 |
+
- spans: Generation spans
|
| 1578 |
+
- time_costs: Timing information
|
| 1579 |
+
- seed_num: Seed used
|
| 1580 |
+
"""
|
| 1581 |
+
if self.config.is_turbo:
|
| 1582 |
+
# Limit inference steps to maximum 8
|
| 1583 |
+
if infer_steps > 8:
|
| 1584 |
+
logger.warning(f"dmd_gan version: infer_steps {infer_steps} exceeds maximum 8, clamping to 8")
|
| 1585 |
+
infer_steps = 8
|
| 1586 |
+
# CFG parameters are not adjustable for dmd_gan (they will be ignored)
|
| 1587 |
+
# Note: guidance_scale, cfg_interval_start, cfg_interval_end are still passed but may be ignored by the model
|
| 1588 |
+
|
| 1589 |
+
# Convert single inputs to lists
|
| 1590 |
+
if isinstance(captions, str):
|
| 1591 |
+
captions = [captions]
|
| 1592 |
+
if isinstance(lyrics, str):
|
| 1593 |
+
lyrics = [lyrics]
|
| 1594 |
+
if isinstance(keys, str):
|
| 1595 |
+
keys = [keys]
|
| 1596 |
+
if isinstance(vocal_languages, str):
|
| 1597 |
+
vocal_languages = [vocal_languages]
|
| 1598 |
+
if isinstance(metas, (str, dict)):
|
| 1599 |
+
metas = [metas]
|
| 1600 |
+
|
| 1601 |
+
# Convert repainting parameters to lists
|
| 1602 |
+
if isinstance(repainting_start, (int, float)):
|
| 1603 |
+
repainting_start = [repainting_start]
|
| 1604 |
+
if isinstance(repainting_end, (int, float)):
|
| 1605 |
+
repainting_end = [repainting_end]
|
| 1606 |
+
|
| 1607 |
+
# Convert instructions to list
|
| 1608 |
+
if isinstance(instructions, str):
|
| 1609 |
+
instructions = [instructions]
|
| 1610 |
+
elif instructions is None:
|
| 1611 |
+
instructions = None
|
| 1612 |
+
|
| 1613 |
+
# Convert audio_code_hints to list
|
| 1614 |
+
if isinstance(audio_code_hints, str):
|
| 1615 |
+
audio_code_hints = [audio_code_hints]
|
| 1616 |
+
elif audio_code_hints is None:
|
| 1617 |
+
audio_code_hints = None
|
| 1618 |
+
|
| 1619 |
+
# Get batch size from captions
|
| 1620 |
+
batch_size = len(captions)
|
| 1621 |
+
|
| 1622 |
+
# Ensure audio_code_hints matches batch size
|
| 1623 |
+
if audio_code_hints is not None:
|
| 1624 |
+
if len(audio_code_hints) != batch_size:
|
| 1625 |
+
if len(audio_code_hints) == 1:
|
| 1626 |
+
audio_code_hints = audio_code_hints * batch_size
|
| 1627 |
+
else:
|
| 1628 |
+
audio_code_hints = audio_code_hints[:batch_size]
|
| 1629 |
+
while len(audio_code_hints) < batch_size:
|
| 1630 |
+
audio_code_hints.append(None)
|
| 1631 |
+
|
| 1632 |
+
# Convert seed to list format
|
| 1633 |
+
if seed is None:
|
| 1634 |
+
seed_list = None
|
| 1635 |
+
elif isinstance(seed, list):
|
| 1636 |
+
seed_list = seed
|
| 1637 |
+
# Ensure we have enough seeds for batch size
|
| 1638 |
+
if len(seed_list) < batch_size:
|
| 1639 |
+
# Pad with last seed or random seeds
|
| 1640 |
+
import random
|
| 1641 |
+
while len(seed_list) < batch_size:
|
| 1642 |
+
seed_list.append(random.randint(0, 2**32 - 1))
|
| 1643 |
+
elif len(seed_list) > batch_size:
|
| 1644 |
+
# Truncate to batch size
|
| 1645 |
+
seed_list = seed_list[:batch_size]
|
| 1646 |
+
else:
|
| 1647 |
+
# Single seed value - use for all batch items
|
| 1648 |
+
seed_list = [int(seed)] * batch_size
|
| 1649 |
+
|
| 1650 |
+
# Don't set global random seed here - each item will use its own seed
|
| 1651 |
+
|
| 1652 |
+
# Prepare batch
|
| 1653 |
+
batch = self._prepare_batch(
|
| 1654 |
+
captions=captions,
|
| 1655 |
+
lyrics=lyrics,
|
| 1656 |
+
keys=keys,
|
| 1657 |
+
target_wavs=target_wavs,
|
| 1658 |
+
refer_audios=refer_audios,
|
| 1659 |
+
metas=metas,
|
| 1660 |
+
vocal_languages=vocal_languages,
|
| 1661 |
+
repainting_start=repainting_start,
|
| 1662 |
+
repainting_end=repainting_end,
|
| 1663 |
+
instructions=instructions,
|
| 1664 |
+
audio_code_hints=audio_code_hints,
|
| 1665 |
+
)
|
| 1666 |
+
|
| 1667 |
+
processed_data = self.preprocess_batch(batch)
|
| 1668 |
+
|
| 1669 |
+
(
|
| 1670 |
+
keys,
|
| 1671 |
+
text_inputs,
|
| 1672 |
+
src_latents,
|
| 1673 |
+
target_latents,
|
| 1674 |
+
# model inputs
|
| 1675 |
+
text_hidden_states,
|
| 1676 |
+
text_attention_mask,
|
| 1677 |
+
lyric_hidden_states,
|
| 1678 |
+
lyric_attention_mask,
|
| 1679 |
+
audio_attention_mask,
|
| 1680 |
+
refer_audio_acoustic_hidden_states_packed,
|
| 1681 |
+
refer_audio_order_mask,
|
| 1682 |
+
chunk_mask,
|
| 1683 |
+
spans,
|
| 1684 |
+
is_covers,
|
| 1685 |
+
audio_codes,
|
| 1686 |
+
lyric_token_idss,
|
| 1687 |
+
precomputed_lm_hints_25Hz,
|
| 1688 |
+
non_cover_text_hidden_states,
|
| 1689 |
+
non_cover_text_attention_masks,
|
| 1690 |
+
) = processed_data
|
| 1691 |
+
|
| 1692 |
+
# Set generation parameters
|
| 1693 |
+
# Use seed_list if available, otherwise generate a single seed
|
| 1694 |
+
if seed_list is not None:
|
| 1695 |
+
# Pass seed list to model (will be handled there)
|
| 1696 |
+
seed_param = seed_list
|
| 1697 |
+
else:
|
| 1698 |
+
seed_param = random.randint(0, 2**32 - 1)
|
| 1699 |
+
|
| 1700 |
+
generate_kwargs = {
|
| 1701 |
+
"text_hidden_states": text_hidden_states,
|
| 1702 |
+
"text_attention_mask": text_attention_mask,
|
| 1703 |
+
"lyric_hidden_states": lyric_hidden_states,
|
| 1704 |
+
"lyric_attention_mask": lyric_attention_mask,
|
| 1705 |
+
"refer_audio_acoustic_hidden_states_packed": refer_audio_acoustic_hidden_states_packed,
|
| 1706 |
+
"refer_audio_order_mask": refer_audio_order_mask,
|
| 1707 |
+
"src_latents": src_latents,
|
| 1708 |
+
"chunk_masks": chunk_mask,
|
| 1709 |
+
"is_covers": is_covers,
|
| 1710 |
+
"silence_latent": self.silence_latent,
|
| 1711 |
+
"seed": seed_param,
|
| 1712 |
+
"non_cover_text_hidden_states": non_cover_text_hidden_states,
|
| 1713 |
+
"non_cover_text_attention_masks": non_cover_text_attention_masks,
|
| 1714 |
+
"precomputed_lm_hints_25Hz": precomputed_lm_hints_25Hz,
|
| 1715 |
+
"audio_cover_strength": audio_cover_strength,
|
| 1716 |
+
"infer_method": infer_method,
|
| 1717 |
+
"infer_steps": infer_steps,
|
| 1718 |
+
"diffusion_guidance_sale": guidance_scale,
|
| 1719 |
+
"use_adg": use_adg,
|
| 1720 |
+
"cfg_interval_start": cfg_interval_start,
|
| 1721 |
+
"cfg_interval_end": cfg_interval_end,
|
| 1722 |
+
}
|
| 1723 |
+
|
| 1724 |
+
outputs = self.model.generate_audio(**generate_kwargs)
|
| 1725 |
+
return outputs
|
| 1726 |
+
|
| 1727 |
def generate_music(
|
| 1728 |
self,
|
| 1729 |
captions: str,
|
|
|
|
| 1763 |
"""
|
| 1764 |
if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
|
| 1765 |
return None, None, [], "", "❌ Model not fully initialized. Please initialize all components first.", "-1", "", "", None, "", "", None
|
| 1766 |
+
|
| 1767 |
+
# Auto-detect task type based on audio_code_string
|
| 1768 |
+
# If audio_code_string is provided and not empty, use cover task
|
| 1769 |
+
# Otherwise, use text2music task (or keep current task_type if not text2music)
|
| 1770 |
+
if task_type == "text2music":
|
| 1771 |
+
if audio_code_string and str(audio_code_string).strip():
|
| 1772 |
+
# User has provided audio codes, switch to cover task
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1773 |
task_type = "cover"
|
| 1774 |
+
# Update instruction for cover task
|
| 1775 |
instruction = "Generate audio semantic tokens based on the given conditions:"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1776 |
|
| 1777 |
+
print("[generate_music] Starting generation...")
|
| 1778 |
+
if progress:
|
| 1779 |
+
progress(0.05, desc="Preparing inputs...")
|
| 1780 |
+
print("[generate_music] Preparing inputs...")
|
| 1781 |
|
| 1782 |
+
# Caption and lyrics are optional - can be empty
|
| 1783 |
+
# Use provided batch_size or default
|
| 1784 |
+
actual_batch_size = batch_size if batch_size is not None else self.batch_size
|
| 1785 |
+
actual_batch_size = max(1, actual_batch_size) # Ensure at least 1
|
| 1786 |
+
|
| 1787 |
+
actual_seed_list, seed_value_for_ui = self.prepare_seeds(actual_batch_size, seed, use_random_seed)
|
| 1788 |
+
|
| 1789 |
+
# Convert special values to None
|
| 1790 |
+
if audio_duration is not None and audio_duration <= 0:
|
| 1791 |
+
audio_duration = None
|
| 1792 |
+
# if seed is not None and seed < 0:
|
| 1793 |
+
# seed = None
|
| 1794 |
+
if repainting_end is not None and repainting_end < 0:
|
| 1795 |
+
repainting_end = None
|
| 1796 |
|
| 1797 |
+
try:
|
| 1798 |
+
progress(0.1, desc="Preparing inputs...")
|
| 1799 |
+
|
| 1800 |
+
# 1. Process reference audio
|
| 1801 |
+
refer_audios = None
|
| 1802 |
+
if reference_audio is not None:
|
| 1803 |
+
processed_ref_audio = self.process_reference_audio(reference_audio)
|
| 1804 |
+
if processed_ref_audio is not None:
|
| 1805 |
+
# Convert to the format expected by the service: List[List[torch.Tensor]]
|
| 1806 |
+
# Each batch item has a list of reference audios
|
| 1807 |
+
refer_audios = [[processed_ref_audio] for _ in range(actual_batch_size)]
|
| 1808 |
+
else:
|
| 1809 |
+
refer_audios = [[torch.zeros(2, 30*self.sample_rate)] for _ in range(actual_batch_size)]
|
| 1810 |
|
| 1811 |
+
# 2. Process source audio
|
| 1812 |
+
processed_src_audio = None
|
| 1813 |
+
if src_audio is not None:
|
| 1814 |
+
processed_src_audio = self.process_src_audio(src_audio)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1815 |
|
| 1816 |
+
# 3. Prepare batch data
|
| 1817 |
+
captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch = self.prepare_batch_data(
|
| 1818 |
+
actual_batch_size,
|
| 1819 |
+
processed_src_audio,
|
| 1820 |
+
audio_duration,
|
| 1821 |
+
captions,
|
| 1822 |
+
lyrics,
|
| 1823 |
+
vocal_language,
|
| 1824 |
+
instruction,
|
| 1825 |
+
bpm,
|
| 1826 |
+
key_scale,
|
| 1827 |
+
time_signature
|
| 1828 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1829 |
|
| 1830 |
+
is_repaint_task, is_lego_task, is_cover_task, can_use_repainting = self.determine_task_type(task_type, audio_code_string)
|
| 1831 |
+
|
| 1832 |
+
repainting_start_batch, repainting_end_batch, target_wavs_tensor = self.prepare_padding_info(
|
| 1833 |
+
actual_batch_size,
|
| 1834 |
+
processed_src_audio,
|
| 1835 |
+
audio_duration,
|
| 1836 |
+
repainting_start,
|
| 1837 |
+
repainting_end,
|
| 1838 |
+
is_repaint_task,
|
| 1839 |
+
is_lego_task,
|
| 1840 |
+
is_cover_task,
|
| 1841 |
+
can_use_repainting
|
| 1842 |
+
)
|
| 1843 |
|
| 1844 |
+
progress(0.3, desc=f"Generating music (batch size: {actual_batch_size})...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1845 |
|
| 1846 |
+
# Prepare audio_code_hints - use if audio_code_string is provided
|
| 1847 |
+
# This works for both text2music (auto-switched to cover) and cover tasks
|
| 1848 |
+
audio_code_hints_batch = None
|
| 1849 |
+
if audio_code_string and str(audio_code_string).strip():
|
| 1850 |
+
# Audio codes provided, use as hints (will trigger cover mode in inference service)
|
| 1851 |
+
audio_code_hints_batch = [audio_code_string] * actual_batch_size
|
| 1852 |
+
|
| 1853 |
+
should_return_intermediate = (task_type == "text2music")
|
| 1854 |
+
outputs = self.service_generate(
|
| 1855 |
+
captions=captions_batch,
|
| 1856 |
+
lyrics=lyrics_batch,
|
| 1857 |
+
metas=metas_batch, # Pass as dict, service will convert to string
|
| 1858 |
+
vocal_languages=vocal_languages_batch,
|
| 1859 |
+
refer_audios=refer_audios, # Already in List[List[torch.Tensor]] format
|
| 1860 |
+
target_wavs=target_wavs_tensor, # Shape: [batch_size, 2, frames]
|
| 1861 |
+
infer_steps=inference_steps,
|
| 1862 |
+
guidance_scale=guidance_scale,
|
| 1863 |
+
seed=actual_seed_list, # Pass list of seeds, one per batch item
|
| 1864 |
+
repainting_start=repainting_start_batch,
|
| 1865 |
+
repainting_end=repainting_end_batch,
|
| 1866 |
+
instructions=instructions_batch, # Pass instructions to service
|
| 1867 |
+
audio_cover_strength=audio_cover_strength, # Pass audio cover strength
|
| 1868 |
+
use_adg=use_adg, # Pass use_adg parameter
|
| 1869 |
+
cfg_interval_start=cfg_interval_start, # Pass CFG interval start
|
| 1870 |
+
cfg_interval_end=cfg_interval_end, # Pass CFG interval end
|
| 1871 |
+
audio_code_hints=audio_code_hints_batch, # Pass audio code hints as list
|
| 1872 |
+
return_intermediate=should_return_intermediate
|
| 1873 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1874 |
|
| 1875 |
print("[generate_music] Model generation completed. Decoding latents...")
|
| 1876 |
pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
|
|
|
|
| 1902 |
|
| 1903 |
saved_files = []
|
| 1904 |
for i in range(actual_batch_size):
|
| 1905 |
+
audio_file = os.path.join(self.temp_dir, f"generated_{i}_{actual_seed_list[i]}.{audio_format_lower}")
|
| 1906 |
# Convert to numpy: [channels, samples] -> [samples, channels]
|
| 1907 |
audio_np = pred_wavs[i].cpu().float().numpy().T
|
| 1908 |
sf.write(audio_file, audio_np, self.sample_rate)
|
|
|
|
| 1926 |
|
| 1927 |
generation_info = f"""**🎵 Generation Complete**
|
| 1928 |
|
| 1929 |
+
**Seeds:** {seed_value_for_ui}
|
| 1930 |
+
**Steps:** {inference_steps}
|
| 1931 |
+
**Files:** {len(saved_files)} audio(s){time_costs_str}"""
|
|
|
|
| 1932 |
status_message = f"✅ Generation completed successfully!"
|
| 1933 |
print(f"[generate_music] Done! Generated {len(saved_files)} audio files.")
|
| 1934 |
|
|
|
|
| 1954 |
align_text_2,
|
| 1955 |
align_plot_2,
|
| 1956 |
)
|
| 1957 |
+
|
| 1958 |
except Exception as e:
|
| 1959 |
+
error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
|
| 1960 |
+
return None, None, [], "", error_msg, seed_value_for_ui, "", "", None, "", "", None
|
| 1961 |
|