Spaces:
Running
on
A100
Running
on
A100
Commit
·
1d23edb
1
Parent(s):
497ca57
cpu offloading
Browse files- acestep/gradio_ui.py +19 -2
- acestep/handler.py +335 -162
- test.py +3 -1
acestep/gradio_ui.py
CHANGED
|
@@ -216,6 +216,16 @@ def create_generation_section(handler) -> dict:
|
|
| 216 |
interactive=flash_attn_available,
|
| 217 |
info="Enable flash attention for faster inference (requires flash_attn package)" if flash_attn_available else "Flash attention not available (flash_attn package not installed)"
|
| 218 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
init_btn = gr.Button("Initialize Service", variant="primary", size="lg")
|
| 221 |
init_status = gr.Textbox(label="Status", interactive=False, lines=3)
|
|
@@ -487,6 +497,7 @@ def create_generation_section(handler) -> dict:
|
|
| 487 |
"lm_model_path": lm_model_path,
|
| 488 |
"init_llm_checkbox": init_llm_checkbox,
|
| 489 |
"use_flash_attention_checkbox": use_flash_attention_checkbox,
|
|
|
|
| 490 |
"task_type": task_type,
|
| 491 |
"instruction_display_gen": instruction_display_gen,
|
| 492 |
"track_name": track_name,
|
|
@@ -655,9 +666,13 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
|
|
| 655 |
)
|
| 656 |
|
| 657 |
# Service initialization
|
| 658 |
-
def init_service_wrapper(checkpoint, config_path, device, init_llm, lm_model_path, use_flash_attention):
|
| 659 |
"""Wrapper for service initialization, returns status and button state"""
|
| 660 |
-
status, enable = handler.initialize_service(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
return status, gr.update(interactive=enable)
|
| 662 |
|
| 663 |
generation_section["init_btn"].click(
|
|
@@ -669,6 +684,8 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
|
|
| 669 |
generation_section["init_llm_checkbox"],
|
| 670 |
generation_section["lm_model_path"],
|
| 671 |
generation_section["use_flash_attention_checkbox"],
|
|
|
|
|
|
|
| 672 |
],
|
| 673 |
outputs=[generation_section["init_status"], generation_section["generate_btn"]]
|
| 674 |
)
|
|
|
|
| 216 |
interactive=flash_attn_available,
|
| 217 |
info="Enable flash attention for faster inference (requires flash_attn package)" if flash_attn_available else "Flash attention not available (flash_attn package not installed)"
|
| 218 |
)
|
| 219 |
+
offload_to_cpu_checkbox = gr.Checkbox(
|
| 220 |
+
label="Offload to CPU",
|
| 221 |
+
value=False,
|
| 222 |
+
info="Offload models to CPU when not in use to save GPU memory"
|
| 223 |
+
)
|
| 224 |
+
offload_dit_to_cpu_checkbox = gr.Checkbox(
|
| 225 |
+
label="Offload DiT to CPU",
|
| 226 |
+
value=False,
|
| 227 |
+
info="Offload DiT model to CPU when not in use (only effective if Offload to CPU is checked)"
|
| 228 |
+
)
|
| 229 |
|
| 230 |
init_btn = gr.Button("Initialize Service", variant="primary", size="lg")
|
| 231 |
init_status = gr.Textbox(label="Status", interactive=False, lines=3)
|
|
|
|
| 497 |
"lm_model_path": lm_model_path,
|
| 498 |
"init_llm_checkbox": init_llm_checkbox,
|
| 499 |
"use_flash_attention_checkbox": use_flash_attention_checkbox,
|
| 500 |
+
"offload_to_cpu_checkbox": offload_to_cpu_checkbox,
|
| 501 |
"task_type": task_type,
|
| 502 |
"instruction_display_gen": instruction_display_gen,
|
| 503 |
"track_name": track_name,
|
|
|
|
| 666 |
)
|
| 667 |
|
| 668 |
# Service initialization
|
| 669 |
+
def init_service_wrapper(checkpoint, config_path, device, init_llm, lm_model_path, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
|
| 670 |
"""Wrapper for service initialization, returns status and button state"""
|
| 671 |
+
status, enable = handler.initialize_service(
|
| 672 |
+
checkpoint, config_path, device, init_llm, lm_model_path,
|
| 673 |
+
use_flash_attention, compile_model=False,
|
| 674 |
+
offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu
|
| 675 |
+
)
|
| 676 |
return status, gr.update(interactive=enable)
|
| 677 |
|
| 678 |
generation_section["init_btn"].click(
|
|
|
|
| 684 |
generation_section["init_llm_checkbox"],
|
| 685 |
generation_section["lm_model_path"],
|
| 686 |
generation_section["use_flash_attention_checkbox"],
|
| 687 |
+
generation_section["offload_to_cpu_checkbox"],
|
| 688 |
+
generation_section["offload_dit_to_cpu_checkbox"],
|
| 689 |
],
|
| 690 |
outputs=[generation_section["init_status"], generation_section["generate_btn"]]
|
| 691 |
)
|
acestep/handler.py
CHANGED
|
@@ -8,6 +8,7 @@ import tempfile
|
|
| 8 |
import traceback
|
| 9 |
import re
|
| 10 |
import random
|
|
|
|
| 11 |
from typing import Optional, Dict, Any, Tuple, List, Union
|
| 12 |
|
| 13 |
import torch
|
|
@@ -81,6 +82,9 @@ class AceStepHandler:
|
|
| 81 |
5: [8, 9, 11],
|
| 82 |
6: [8]
|
| 83 |
}
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
def get_available_checkpoints(self) -> str:
|
| 86 |
"""Return project root directory path"""
|
|
@@ -146,6 +150,8 @@ class AceStepHandler:
|
|
| 146 |
lm_model_path: str = "acestep-5Hz-lm-0.6B",
|
| 147 |
use_flash_attention: bool = False,
|
| 148 |
compile_model: bool = False,
|
|
|
|
|
|
|
| 149 |
) -> Tuple[str, bool]:
|
| 150 |
"""
|
| 151 |
Initialize model service
|
|
@@ -158,6 +164,8 @@ class AceStepHandler:
|
|
| 158 |
lm_model_path: 5Hz LM model path
|
| 159 |
use_flash_attention: Whether to use flash attention (requires flash_attn package)
|
| 160 |
compile_model: Whether to use torch.compile to optimize the model
|
|
|
|
|
|
|
| 161 |
|
| 162 |
Returns:
|
| 163 |
(status_message, enable_generate_button)
|
|
@@ -167,6 +175,8 @@ class AceStepHandler:
|
|
| 167 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 168 |
|
| 169 |
self.device = device
|
|
|
|
|
|
|
| 170 |
# Set dtype based on device: bfloat16 for cuda, float32 for cpu
|
| 171 |
self.dtype = torch.bfloat16 if device in ["cuda","xpu"] else torch.float32
|
| 172 |
|
|
@@ -211,7 +221,15 @@ class AceStepHandler:
|
|
| 211 |
self.model.config._attn_implementation = attn_implementation
|
| 212 |
self.config = self.model.config
|
| 213 |
# Move model to device and set dtype
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
self.model.eval()
|
| 216 |
|
| 217 |
if compile_model:
|
|
@@ -221,7 +239,11 @@ class AceStepHandler:
|
|
| 221 |
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
|
| 222 |
if os.path.exists(silence_latent_path):
|
| 223 |
self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
else:
|
| 226 |
raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
|
| 227 |
else:
|
|
@@ -233,7 +255,10 @@ class AceStepHandler:
|
|
| 233 |
self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
|
| 234 |
# Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
|
| 235 |
vae_dtype = torch.bfloat16 if device in ["cuda", "xpu"] else self.dtype
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
| 237 |
self.vae.eval()
|
| 238 |
else:
|
| 239 |
raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
|
|
@@ -243,7 +268,10 @@ class AceStepHandler:
|
|
| 243 |
if os.path.exists(text_encoder_path):
|
| 244 |
self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)
|
| 245 |
self.text_encoder = AutoModel.from_pretrained(text_encoder_path)
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
| 247 |
self.text_encoder.eval()
|
| 248 |
else:
|
| 249 |
raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}")
|
|
@@ -252,12 +280,11 @@ class AceStepHandler:
|
|
| 252 |
if init_llm:
|
| 253 |
full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
|
| 254 |
if os.path.exists(full_lm_model_path):
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
self.llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path)
|
| 261 |
else:
|
| 262 |
# 5Hz LM path not found
|
| 263 |
return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
|
|
@@ -275,7 +302,9 @@ class AceStepHandler:
|
|
| 275 |
status_msg += f"5Hz LM model: Not loaded (checkbox not selected)\n"
|
| 276 |
status_msg += f"Dtype: {self.dtype}\n"
|
| 277 |
status_msg += f"Attention: {actual_attn}\n"
|
| 278 |
-
status_msg += f"Compiled: {compile_model}"
|
|
|
|
|
|
|
| 279 |
|
| 280 |
return status_msg, True
|
| 281 |
|
|
@@ -283,6 +312,86 @@ class AceStepHandler:
|
|
| 283 |
error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 284 |
return error_msg, False
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
def import_dataset(self, dataset_type: str) -> str:
|
| 287 |
"""Import dataset (temporarily disabled)"""
|
| 288 |
self.dataset_imported = False
|
|
@@ -314,36 +423,66 @@ class AceStepHandler:
|
|
| 314 |
except Exception as e:
|
| 315 |
return 0.9
|
| 316 |
|
| 317 |
-
def
|
| 318 |
"""Initialize 5Hz LM model"""
|
| 319 |
try:
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
if
|
| 323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
-
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
max_ratio=0.9
|
| 333 |
)
|
| 334 |
|
| 335 |
-
self.
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
)
|
| 342 |
-
self.llm_tokenizer = self.llm.tokenizer
|
| 343 |
self.llm_initialized = True
|
| 344 |
-
|
|
|
|
|
|
|
| 345 |
except Exception as e:
|
| 346 |
self.llm_initialized = False
|
|
|
|
| 347 |
error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 348 |
return error_msg
|
| 349 |
|
|
@@ -353,35 +492,54 @@ class AceStepHandler:
|
|
| 353 |
return {}, "", "❌ 5Hz LM not initialized. Please initialize it first."
|
| 354 |
|
| 355 |
try:
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
else:
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
except Exception as e:
|
| 387 |
error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
|
@@ -495,24 +653,25 @@ class AceStepHandler:
|
|
| 495 |
if len(code_ids) == 0:
|
| 496 |
return None
|
| 497 |
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
quantized
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
|
|
|
| 516 |
|
| 517 |
def _create_default_meta(self) -> str:
|
| 518 |
"""Create default metadata string."""
|
|
@@ -577,30 +736,31 @@ class AceStepHandler:
|
|
| 577 |
if self.text_tokenizer is None or self.text_encoder is None:
|
| 578 |
raise ValueError("Text encoder not initialized")
|
| 579 |
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
|
|
|
| 604 |
|
| 605 |
def extract_caption_from_sft_format(self, caption: str) -> str:
|
| 606 |
try:
|
|
@@ -1103,7 +1263,7 @@ class AceStepHandler:
|
|
| 1103 |
if isinstance(refer_audio_list, list):
|
| 1104 |
for idx, refer_audio in enumerate(refer_audio_list):
|
| 1105 |
refer_audio_list[idx] = refer_audio_list[idx].to(self.device).to(torch.bfloat16)
|
| 1106 |
-
elif isinstance(refer_audio_list, torch.
|
| 1107 |
refer_audios[ii] = refer_audios[ii].to(self.device)
|
| 1108 |
|
| 1109 |
if vocal_languages is None:
|
|
@@ -1131,35 +1291,37 @@ class AceStepHandler:
|
|
| 1131 |
target_wavs_list = [target_wavs[i].clone() for i in range(batch_size)]
|
| 1132 |
if target_wavs.device != self.device:
|
| 1133 |
target_wavs = target_wavs.to(self.device)
|
| 1134 |
-
|
| 1135 |
-
|
| 1136 |
-
|
| 1137 |
-
|
| 1138 |
-
|
| 1139 |
-
|
| 1140 |
-
|
| 1141 |
-
decoded_latents =
|
| 1142 |
-
|
| 1143 |
-
|
| 1144 |
-
|
| 1145 |
-
|
| 1146 |
-
|
| 1147 |
-
|
| 1148 |
-
|
| 1149 |
-
|
| 1150 |
-
|
| 1151 |
-
|
| 1152 |
-
|
| 1153 |
-
|
| 1154 |
-
|
| 1155 |
-
|
| 1156 |
-
|
| 1157 |
-
|
| 1158 |
-
|
| 1159 |
-
|
| 1160 |
-
|
| 1161 |
-
|
| 1162 |
-
|
|
|
|
|
|
|
| 1163 |
|
| 1164 |
# Pad target_wavs to consistent length for outputs
|
| 1165 |
max_target_frames = max(wav.shape[-1] for wav in target_wavs_list)
|
|
@@ -1551,7 +1713,8 @@ class AceStepHandler:
|
|
| 1551 |
|
| 1552 |
# step 2: refer_audio timbre
|
| 1553 |
keys = batch["keys"]
|
| 1554 |
-
|
|
|
|
| 1555 |
if refer_audio_acoustic_hidden_states_packed.dtype != dtype:
|
| 1556 |
refer_audio_acoustic_hidden_states_packed = refer_audio_acoustic_hidden_states_packed.to(dtype)
|
| 1557 |
|
|
@@ -1568,22 +1731,23 @@ class AceStepHandler:
|
|
| 1568 |
text_inputs = batch["text_inputs"]
|
| 1569 |
|
| 1570 |
print("[preprocess_batch] Inferring prompt embeddings...")
|
| 1571 |
-
|
| 1572 |
-
|
| 1573 |
-
|
|
|
|
| 1574 |
|
| 1575 |
-
|
| 1576 |
-
|
| 1577 |
-
|
| 1578 |
-
|
| 1579 |
-
|
| 1580 |
-
|
| 1581 |
-
|
| 1582 |
-
|
| 1583 |
-
|
| 1584 |
-
|
| 1585 |
-
|
| 1586 |
-
|
| 1587 |
|
| 1588 |
return (
|
| 1589 |
keys,
|
|
@@ -1811,7 +1975,8 @@ class AceStepHandler:
|
|
| 1811 |
"cfg_interval_end": cfg_interval_end,
|
| 1812 |
}
|
| 1813 |
print("[service_generate] Generating audio...")
|
| 1814 |
-
|
|
|
|
| 1815 |
return outputs
|
| 1816 |
|
| 1817 |
def tiled_decode(self, latents, chunk_size=512, overlap=64):
|
|
@@ -1941,6 +2106,9 @@ class AceStepHandler:
|
|
| 1941 |
if progress:
|
| 1942 |
progress(0.05, desc="Preparing inputs...")
|
| 1943 |
print("[generate_music] Preparing inputs...")
|
|
|
|
|
|
|
|
|
|
| 1944 |
|
| 1945 |
# Caption and lyrics are optional - can be empty
|
| 1946 |
# Use provided batch_size or default
|
|
@@ -2040,6 +2208,7 @@ class AceStepHandler:
|
|
| 2040 |
print("[generate_music] Model generation completed. Decoding latents...")
|
| 2041 |
pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
|
| 2042 |
time_costs = outputs["time_costs"]
|
|
|
|
| 2043 |
print(f" - pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}")
|
| 2044 |
print(f" - time_costs: {time_costs}")
|
| 2045 |
if progress:
|
|
@@ -2049,23 +2218,27 @@ class AceStepHandler:
|
|
| 2049 |
# Decode latents to audio
|
| 2050 |
start_time = time.time()
|
| 2051 |
with torch.no_grad():
|
| 2052 |
-
|
| 2053 |
-
|
| 2054 |
-
|
| 2055 |
-
|
| 2056 |
-
|
| 2057 |
-
|
| 2058 |
-
|
| 2059 |
-
|
| 2060 |
-
|
| 2061 |
-
|
| 2062 |
-
|
| 2063 |
-
|
| 2064 |
-
|
|
|
|
| 2065 |
end_time = time.time()
|
| 2066 |
time_costs["vae_decode_time_cost"] = end_time - start_time
|
| 2067 |
time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"]
|
| 2068 |
|
|
|
|
|
|
|
|
|
|
| 2069 |
print("[generate_music] VAE decode completed. Saving audio files...")
|
| 2070 |
if progress:
|
| 2071 |
progress(0.9, desc="Saving audio files...")
|
|
|
|
| 8 |
import traceback
|
| 9 |
import re
|
| 10 |
import random
|
| 11 |
+
from contextlib import contextmanager
|
| 12 |
from typing import Optional, Dict, Any, Tuple, List, Union
|
| 13 |
|
| 14 |
import torch
|
|
|
|
| 82 |
5: [8, 9, 11],
|
| 83 |
6: [8]
|
| 84 |
}
|
| 85 |
+
self.offload_to_cpu = False
|
| 86 |
+
self.offload_dit_to_cpu = False
|
| 87 |
+
self.current_offload_cost = 0.0
|
| 88 |
|
| 89 |
def get_available_checkpoints(self) -> str:
|
| 90 |
"""Return project root directory path"""
|
|
|
|
| 150 |
lm_model_path: str = "acestep-5Hz-lm-0.6B",
|
| 151 |
use_flash_attention: bool = False,
|
| 152 |
compile_model: bool = False,
|
| 153 |
+
offload_to_cpu: bool = False,
|
| 154 |
+
offload_dit_to_cpu: bool = False,
|
| 155 |
) -> Tuple[str, bool]:
|
| 156 |
"""
|
| 157 |
Initialize model service
|
|
|
|
| 164 |
lm_model_path: 5Hz LM model path
|
| 165 |
use_flash_attention: Whether to use flash attention (requires flash_attn package)
|
| 166 |
compile_model: Whether to use torch.compile to optimize the model
|
| 167 |
+
offload_to_cpu: Whether to offload models to CPU when not in use
|
| 168 |
+
offload_dit_to_cpu: Whether to offload DiT model to CPU when not in use (only effective if offload_to_cpu is True)
|
| 169 |
|
| 170 |
Returns:
|
| 171 |
(status_message, enable_generate_button)
|
|
|
|
| 175 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 176 |
|
| 177 |
self.device = device
|
| 178 |
+
self.offload_to_cpu = offload_to_cpu
|
| 179 |
+
self.offload_dit_to_cpu = offload_dit_to_cpu
|
| 180 |
# Set dtype based on device: bfloat16 for cuda, float32 for cpu
|
| 181 |
self.dtype = torch.bfloat16 if device in ["cuda","xpu"] else torch.float32
|
| 182 |
|
|
|
|
| 221 |
self.model.config._attn_implementation = attn_implementation
|
| 222 |
self.config = self.model.config
|
| 223 |
# Move model to device and set dtype
|
| 224 |
+
if not self.offload_to_cpu:
|
| 225 |
+
self.model = self.model.to(device).to(self.dtype)
|
| 226 |
+
else:
|
| 227 |
+
# If offload_to_cpu is True, check if we should keep DiT on GPU
|
| 228 |
+
if not self.offload_dit_to_cpu:
|
| 229 |
+
logger.info(f"Keeping main model on {device} (persistent)")
|
| 230 |
+
self.model = self.model.to(device).to(self.dtype)
|
| 231 |
+
else:
|
| 232 |
+
self.model = self.model.to("cpu").to(self.dtype)
|
| 233 |
self.model.eval()
|
| 234 |
|
| 235 |
if compile_model:
|
|
|
|
| 239 |
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
|
| 240 |
if os.path.exists(silence_latent_path):
|
| 241 |
self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
|
| 242 |
+
# If DiT is on GPU, silence_latent should also be on GPU
|
| 243 |
+
if not self.offload_to_cpu or not self.offload_dit_to_cpu:
|
| 244 |
+
self.silence_latent = self.silence_latent.to(device).to(self.dtype)
|
| 245 |
+
else:
|
| 246 |
+
self.silence_latent = self.silence_latent.to("cpu").to(self.dtype)
|
| 247 |
else:
|
| 248 |
raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
|
| 249 |
else:
|
|
|
|
| 255 |
self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
|
| 256 |
# Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
|
| 257 |
vae_dtype = torch.bfloat16 if device in ["cuda", "xpu"] else self.dtype
|
| 258 |
+
if not self.offload_to_cpu:
|
| 259 |
+
self.vae = self.vae.to(device).to(vae_dtype)
|
| 260 |
+
else:
|
| 261 |
+
self.vae = self.vae.to("cpu").to(vae_dtype)
|
| 262 |
self.vae.eval()
|
| 263 |
else:
|
| 264 |
raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
|
|
|
|
| 268 |
if os.path.exists(text_encoder_path):
|
| 269 |
self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)
|
| 270 |
self.text_encoder = AutoModel.from_pretrained(text_encoder_path)
|
| 271 |
+
if not self.offload_to_cpu:
|
| 272 |
+
self.text_encoder = self.text_encoder.to(device).to(self.dtype)
|
| 273 |
+
else:
|
| 274 |
+
self.text_encoder = self.text_encoder.to("cpu").to(self.dtype)
|
| 275 |
self.text_encoder.eval()
|
| 276 |
else:
|
| 277 |
raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}")
|
|
|
|
| 280 |
if init_llm:
|
| 281 |
full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
|
| 282 |
if os.path.exists(full_lm_model_path):
|
| 283 |
+
status_msg = self._initialize_5hz_lm(full_lm_model_path)
|
| 284 |
+
if not self.llm_initialized:
|
| 285 |
+
print(f"Error initializing 5Hz LM: {status_msg}")
|
| 286 |
+
return status_msg, False
|
| 287 |
+
print(status_msg)
|
|
|
|
| 288 |
else:
|
| 289 |
# 5Hz LM path not found
|
| 290 |
return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
|
|
|
|
| 302 |
status_msg += f"5Hz LM model: Not loaded (checkbox not selected)\n"
|
| 303 |
status_msg += f"Dtype: {self.dtype}\n"
|
| 304 |
status_msg += f"Attention: {actual_attn}\n"
|
| 305 |
+
status_msg += f"Compiled: {compile_model}\n"
|
| 306 |
+
status_msg += f"Offload to CPU: {self.offload_to_cpu}\n"
|
| 307 |
+
status_msg += f"Offload DiT to CPU: {self.offload_dit_to_cpu}"
|
| 308 |
|
| 309 |
return status_msg, True
|
| 310 |
|
|
|
|
| 312 |
error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 313 |
return error_msg, False
|
| 314 |
|
| 315 |
+
@contextmanager
|
| 316 |
+
def _load_model_context(self, model_name: str):
|
| 317 |
+
"""
|
| 318 |
+
Context manager to load a model to GPU and offload it back to CPU after use.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
model_name: Name of the model to load ("text_encoder", "vae", "model", "llm")
|
| 322 |
+
"""
|
| 323 |
+
if not self.offload_to_cpu:
|
| 324 |
+
yield
|
| 325 |
+
return
|
| 326 |
+
|
| 327 |
+
# If model is DiT ("model") and offload_dit_to_cpu is False, do not offload
|
| 328 |
+
if model_name == "model" and not self.offload_dit_to_cpu:
|
| 329 |
+
# Ensure it's on device if not already (should be handled by init, but safe to check)
|
| 330 |
+
model = getattr(self, model_name, None)
|
| 331 |
+
if model is not None:
|
| 332 |
+
# Check if model is on CPU, if so move to device (one-time move if it was somehow on CPU)
|
| 333 |
+
# We check the first parameter's device
|
| 334 |
+
try:
|
| 335 |
+
param = next(model.parameters())
|
| 336 |
+
if param.device.type == "cpu":
|
| 337 |
+
logger.info(f"Moving {model_name} to {self.device} (persistent)")
|
| 338 |
+
model.to(self.device).to(self.dtype)
|
| 339 |
+
if hasattr(self, "silence_latent"):
|
| 340 |
+
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
|
| 341 |
+
except StopIteration:
|
| 342 |
+
pass
|
| 343 |
+
yield
|
| 344 |
+
return
|
| 345 |
+
|
| 346 |
+
# If model is LLM and using nanovllm, do not offload (it stays on GPU)
|
| 347 |
+
if model_name == "llm" and getattr(self, "llm_type", None) == "nanovllm":
|
| 348 |
+
yield
|
| 349 |
+
return
|
| 350 |
+
|
| 351 |
+
model = getattr(self, model_name, None)
|
| 352 |
+
if model is None:
|
| 353 |
+
yield
|
| 354 |
+
return
|
| 355 |
+
|
| 356 |
+
# Load to GPU
|
| 357 |
+
logger.info(f"Loading {model_name} to {self.device}")
|
| 358 |
+
start_time = time.time()
|
| 359 |
+
if model_name == "vae":
|
| 360 |
+
vae_dtype = torch.bfloat16 if self.device in ["cuda", "xpu"] else self.dtype
|
| 361 |
+
model.to(self.device).to(vae_dtype)
|
| 362 |
+
elif model_name == "llm" and hasattr(model, "to"):
|
| 363 |
+
# Special handling for nanovllm LLM which might have custom to() method or structure
|
| 364 |
+
# Assuming it has a .to() method based on our previous edits to nanovllm
|
| 365 |
+
model.to(self.device)
|
| 366 |
+
else:
|
| 367 |
+
model.to(self.device).to(self.dtype)
|
| 368 |
+
|
| 369 |
+
if model_name == "model" and hasattr(self, "silence_latent"):
|
| 370 |
+
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
|
| 371 |
+
|
| 372 |
+
load_time = time.time() - start_time
|
| 373 |
+
self.current_offload_cost += load_time
|
| 374 |
+
logger.info(f"Loaded {model_name} to {self.device} in {load_time:.4f}s")
|
| 375 |
+
|
| 376 |
+
try:
|
| 377 |
+
yield
|
| 378 |
+
finally:
|
| 379 |
+
# Offload to CPU
|
| 380 |
+
logger.info(f"Offloading {model_name} to CPU")
|
| 381 |
+
start_time = time.time()
|
| 382 |
+
if model_name == "llm" and hasattr(model, "to"):
|
| 383 |
+
model.to("cpu")
|
| 384 |
+
else:
|
| 385 |
+
model.to("cpu")
|
| 386 |
+
|
| 387 |
+
if model_name == "model" and hasattr(self, "silence_latent"):
|
| 388 |
+
self.silence_latent = self.silence_latent.to("cpu")
|
| 389 |
+
|
| 390 |
+
torch.cuda.empty_cache()
|
| 391 |
+
offload_time = time.time() - start_time
|
| 392 |
+
self.current_offload_cost += offload_time
|
| 393 |
+
logger.info(f"Offloaded {model_name} to CPU in {offload_time:.4f}s")
|
| 394 |
+
|
| 395 |
def import_dataset(self, dataset_type: str) -> str:
|
| 396 |
"""Import dataset (temporarily disabled)"""
|
| 397 |
self.dataset_imported = False
|
|
|
|
| 423 |
except Exception as e:
|
| 424 |
return 0.9
|
| 425 |
|
| 426 |
+
def _initialize_5hz_lm(self, model_path: str) -> str:
|
| 427 |
"""Initialize 5Hz LM model"""
|
| 428 |
try:
|
| 429 |
+
# Try to use nanovllm if on CUDA
|
| 430 |
+
use_nanovllm = False
|
| 431 |
+
if self.device == "cuda":
|
| 432 |
+
try:
|
| 433 |
+
from nanovllm import LLM, SamplingParams
|
| 434 |
+
use_nanovllm = True
|
| 435 |
+
except ImportError:
|
| 436 |
+
pass
|
| 437 |
|
| 438 |
+
if use_nanovllm:
|
| 439 |
+
try:
|
| 440 |
+
current_device = torch.cuda.current_device()
|
| 441 |
+
device_name = torch.cuda.get_device_name(current_device)
|
| 442 |
+
|
| 443 |
+
torch.cuda.empty_cache()
|
| 444 |
+
gpu_memory_utilization = self.get_gpu_memory_utilization(
|
| 445 |
+
minimal_gpu=8,
|
| 446 |
+
min_ratio=0.2,
|
| 447 |
+
max_ratio=0.9
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
self.llm = LLM(
|
| 451 |
+
model=model_path,
|
| 452 |
+
enforce_eager=False,
|
| 453 |
+
tensor_parallel_size=1,
|
| 454 |
+
max_model_len=4096,
|
| 455 |
+
gpu_memory_utilization=gpu_memory_utilization,
|
| 456 |
+
)
|
| 457 |
+
self.llm_tokenizer = self.llm.tokenizer
|
| 458 |
+
self.llm_initialized = True
|
| 459 |
+
self.llm_type = "nanovllm"
|
| 460 |
+
return f"✅ 5Hz LM initialized successfully (nanovllm)\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
|
| 461 |
+
except Exception as e:
|
| 462 |
+
logger.warning(f"nanovllm initialization failed: {e}, falling back to transformers")
|
| 463 |
+
|
| 464 |
+
# Fallback to transformers
|
| 465 |
+
from transformers import AutoModelForCausalLM
|
| 466 |
|
| 467 |
+
self.llm = AutoModelForCausalLM.from_pretrained(
|
| 468 |
+
model_path,
|
| 469 |
+
torch_dtype=self.dtype,
|
| 470 |
+
trust_remote_code=True
|
|
|
|
| 471 |
)
|
| 472 |
|
| 473 |
+
if not self.offload_to_cpu:
|
| 474 |
+
self.llm.to(self.device)
|
| 475 |
+
else:
|
| 476 |
+
self.llm.to("cpu")
|
| 477 |
+
|
| 478 |
+
self.llm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
|
|
|
|
|
| 479 |
self.llm_initialized = True
|
| 480 |
+
self.llm_type = "transformers"
|
| 481 |
+
return f"✅ 5Hz LM initialized successfully (transformers)\nModel: {model_path}\nDevice: {self.device}"
|
| 482 |
+
|
| 483 |
except Exception as e:
|
| 484 |
self.llm_initialized = False
|
| 485 |
+
self.llm_type = None
|
| 486 |
error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 487 |
return error_msg
|
| 488 |
|
|
|
|
| 492 |
return {}, "", "❌ 5Hz LM not initialized. Please initialize it first."
|
| 493 |
|
| 494 |
try:
|
| 495 |
+
with self._load_model_context("llm"):
|
| 496 |
+
prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
|
| 497 |
+
|
| 498 |
+
formatted_prompt = self.lm_tokenizer.apply_chat_template(
|
| 499 |
+
[
|
| 500 |
+
{"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
|
| 501 |
+
{"role": "user", "content": prompt}
|
| 502 |
+
],
|
| 503 |
+
tokenize=False,
|
| 504 |
+
add_generation_prompt=True,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
if getattr(self, "llm_type", "nanovllm") == "nanovllm":
|
| 508 |
+
from nanovllm import SamplingParams
|
| 509 |
+
sampling_params = SamplingParams(max_tokens=3072, temperature=temperature)
|
| 510 |
+
outputs = self.llm.generate([formatted_prompt], sampling_params)
|
| 511 |
+
|
| 512 |
+
if isinstance(outputs, list) and len(outputs) > 0:
|
| 513 |
+
if hasattr(outputs[0], 'outputs') and len(outputs[0].outputs) > 0:
|
| 514 |
+
output_text = outputs[0].outputs[0].text
|
| 515 |
+
elif hasattr(outputs[0], 'text'):
|
| 516 |
+
output_text = outputs[0].text
|
| 517 |
+
else:
|
| 518 |
+
output_text = str(outputs[0])
|
| 519 |
+
else:
|
| 520 |
+
output_text = str(outputs)
|
| 521 |
else:
|
| 522 |
+
# Transformers generation
|
| 523 |
+
inputs = self.llm_tokenizer(formatted_prompt, return_tensors="pt").to(self.llm.device)
|
| 524 |
+
|
| 525 |
+
# Generate
|
| 526 |
+
with torch.no_grad():
|
| 527 |
+
outputs = self.llm.generate(
|
| 528 |
+
**inputs,
|
| 529 |
+
max_new_tokens=3072,
|
| 530 |
+
temperature=temperature,
|
| 531 |
+
do_sample=True,
|
| 532 |
+
pad_token_id=self.llm_tokenizer.pad_token_id,
|
| 533 |
+
eos_token_id=self.llm_tokenizer.eos_token_id
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
# Decode
|
| 537 |
+
generated_ids = outputs[0][inputs.input_ids.shape[1]:]
|
| 538 |
+
output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
|
| 539 |
+
|
| 540 |
+
metadata, audio_codes = self.parse_lm_output(output_text)
|
| 541 |
+
codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
|
| 542 |
+
return metadata, audio_codes, f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"
|
| 543 |
|
| 544 |
except Exception as e:
|
| 545 |
error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
|
|
|
| 653 |
if len(code_ids) == 0:
|
| 654 |
return None
|
| 655 |
|
| 656 |
+
with self._load_model_context("model"):
|
| 657 |
+
quantizer = self.model.tokenizer.quantizer
|
| 658 |
+
detokenizer = self.model.detokenizer
|
| 659 |
+
|
| 660 |
+
num_quantizers = getattr(quantizer, "num_quantizers", 1)
|
| 661 |
+
indices = torch.tensor(code_ids, device=self.device, dtype=torch.long).unsqueeze(0) # [1, T_5Hz]
|
| 662 |
+
|
| 663 |
+
# Expand to include quantizer dimension: [1, T_5Hz, num_quantizers]
|
| 664 |
+
if indices.dim() == 2:
|
| 665 |
+
indices = indices.unsqueeze(-1).expand(-1, -1, num_quantizers)
|
| 666 |
+
|
| 667 |
+
# Get quantized representation from indices: [1, T_5Hz, dim]
|
| 668 |
+
quantized = quantizer.get_output_from_indices(indices)
|
| 669 |
+
if quantized.dtype != self.dtype:
|
| 670 |
+
quantized = quantized.to(self.dtype)
|
| 671 |
+
|
| 672 |
+
# Detokenize to 25Hz: [1, T_5Hz, dim] -> [1, T_25Hz, dim]
|
| 673 |
+
lm_hints_25hz = detokenizer(quantized)
|
| 674 |
+
return lm_hints_25hz
|
| 675 |
|
| 676 |
def _create_default_meta(self) -> str:
|
| 677 |
"""Create default metadata string."""
|
|
|
|
| 736 |
if self.text_tokenizer is None or self.text_encoder is None:
|
| 737 |
raise ValueError("Text encoder not initialized")
|
| 738 |
|
| 739 |
+
with self._load_model_context("text_encoder"):
|
| 740 |
+
# Tokenize
|
| 741 |
+
text_inputs = self.text_tokenizer(
|
| 742 |
+
text_prompt,
|
| 743 |
+
padding="longest",
|
| 744 |
+
truncation=True,
|
| 745 |
+
max_length=256,
|
| 746 |
+
return_tensors="pt",
|
| 747 |
+
)
|
| 748 |
+
text_input_ids = text_inputs.input_ids.to(self.device)
|
| 749 |
+
text_attention_mask = text_inputs.attention_mask.to(self.device).bool()
|
| 750 |
+
|
| 751 |
+
# Encode
|
| 752 |
+
with torch.no_grad():
|
| 753 |
+
text_outputs = self.text_encoder(text_input_ids)
|
| 754 |
+
if hasattr(text_outputs, 'last_hidden_state'):
|
| 755 |
+
text_hidden_states = text_outputs.last_hidden_state
|
| 756 |
+
elif isinstance(text_outputs, tuple):
|
| 757 |
+
text_hidden_states = text_outputs[0]
|
| 758 |
+
else:
|
| 759 |
+
text_hidden_states = text_outputs
|
| 760 |
+
|
| 761 |
+
text_hidden_states = text_hidden_states.to(self.dtype)
|
| 762 |
+
|
| 763 |
+
return text_hidden_states, text_attention_mask
|
| 764 |
|
| 765 |
def extract_caption_from_sft_format(self, caption: str) -> str:
|
| 766 |
try:
|
|
|
|
| 1263 |
if isinstance(refer_audio_list, list):
|
| 1264 |
for idx, refer_audio in enumerate(refer_audio_list):
|
| 1265 |
refer_audio_list[idx] = refer_audio_list[idx].to(self.device).to(torch.bfloat16)
|
| 1266 |
+
elif isinstance(refer_audio_list, torch.Tensor):
|
| 1267 |
refer_audios[ii] = refer_audios[ii].to(self.device)
|
| 1268 |
|
| 1269 |
if vocal_languages is None:
|
|
|
|
| 1291 |
target_wavs_list = [target_wavs[i].clone() for i in range(batch_size)]
|
| 1292 |
if target_wavs.device != self.device:
|
| 1293 |
target_wavs = target_wavs.to(self.device)
|
| 1294 |
+
|
| 1295 |
+
with self._load_model_context("vae"):
|
| 1296 |
+
for i in range(batch_size):
|
| 1297 |
+
code_hint = audio_code_hints[i]
|
| 1298 |
+
# Prefer decoding from provided audio codes
|
| 1299 |
+
if code_hint:
|
| 1300 |
+
print(f"[generate_music] Decoding audio codes for item {i}...")
|
| 1301 |
+
decoded_latents = self._decode_audio_codes_to_latents(code_hint)
|
| 1302 |
+
if decoded_latents is not None:
|
| 1303 |
+
decoded_latents = decoded_latents.squeeze(0)
|
| 1304 |
+
target_latents_list.append(decoded_latents)
|
| 1305 |
+
latent_lengths.append(decoded_latents.shape[0])
|
| 1306 |
+
# Create a silent wav matching the latent length for downstream scaling
|
| 1307 |
+
frames_from_codes = max(1, int(decoded_latents.shape[0] * 1920))
|
| 1308 |
+
target_wavs_list[i] = torch.zeros(2, frames_from_codes)
|
| 1309 |
+
continue
|
| 1310 |
+
# Fallback to VAE encode from audio
|
| 1311 |
+
current_wav = target_wavs_list[i].to(self.device).unsqueeze(0)
|
| 1312 |
+
if self.is_silence(current_wav):
|
| 1313 |
+
expected_latent_length = current_wav.shape[-1] // 1920
|
| 1314 |
+
target_latent = self.silence_latent[0, :expected_latent_length, :]
|
| 1315 |
+
else:
|
| 1316 |
+
# Ensure input is in VAE's dtype
|
| 1317 |
+
print(f"[generate_music] Encoding target audio to latents for item {i}...")
|
| 1318 |
+
vae_input = current_wav.to(self.device).to(self.vae.dtype)
|
| 1319 |
+
target_latent = self.vae.encode(vae_input).latent_dist.sample()
|
| 1320 |
+
# Cast back to model dtype
|
| 1321 |
+
target_latent = target_latent.to(self.dtype)
|
| 1322 |
+
target_latent = target_latent.squeeze(0).transpose(0, 1)
|
| 1323 |
+
target_latents_list.append(target_latent)
|
| 1324 |
+
latent_lengths.append(target_latent.shape[0])
|
| 1325 |
|
| 1326 |
# Pad target_wavs to consistent length for outputs
|
| 1327 |
max_target_frames = max(wav.shape[-1] for wav in target_wavs_list)
|
|
|
|
| 1713 |
|
| 1714 |
# step 2: refer_audio timbre
|
| 1715 |
keys = batch["keys"]
|
| 1716 |
+
with self._load_model_context("vae"):
|
| 1717 |
+
refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask = self.infer_refer_latent(batch["refer_audioss"])
|
| 1718 |
if refer_audio_acoustic_hidden_states_packed.dtype != dtype:
|
| 1719 |
refer_audio_acoustic_hidden_states_packed = refer_audio_acoustic_hidden_states_packed.to(dtype)
|
| 1720 |
|
|
|
|
| 1731 |
text_inputs = batch["text_inputs"]
|
| 1732 |
|
| 1733 |
print("[preprocess_batch] Inferring prompt embeddings...")
|
| 1734 |
+
with self._load_model_context("text_encoder"):
|
| 1735 |
+
text_hidden_states = self.infer_text_embeddings(text_token_idss)
|
| 1736 |
+
print("[preprocess_batch] Inferring lyric embeddings...")
|
| 1737 |
+
lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss)
|
| 1738 |
|
| 1739 |
+
is_covers = batch["is_covers"]
|
| 1740 |
+
|
| 1741 |
+
# Get precomputed hints from batch if available
|
| 1742 |
+
precomputed_lm_hints_25Hz = batch.get("precomputed_lm_hints_25Hz", None)
|
| 1743 |
+
|
| 1744 |
+
# Get non-cover text input ids and attention masks from batch if available
|
| 1745 |
+
non_cover_text_input_ids = batch.get("non_cover_text_input_ids", None)
|
| 1746 |
+
non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None)
|
| 1747 |
+
non_cover_text_hidden_states = None
|
| 1748 |
+
if non_cover_text_input_ids is not None:
|
| 1749 |
+
print("[preprocess_batch] Inferring non-cover text embeddings...")
|
| 1750 |
+
non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids)
|
| 1751 |
|
| 1752 |
return (
|
| 1753 |
keys,
|
|
|
|
| 1975 |
"cfg_interval_end": cfg_interval_end,
|
| 1976 |
}
|
| 1977 |
print("[service_generate] Generating audio...")
|
| 1978 |
+
with self._load_model_context("model"):
|
| 1979 |
+
outputs = self.model.generate_audio(**generate_kwargs)
|
| 1980 |
return outputs
|
| 1981 |
|
| 1982 |
def tiled_decode(self, latents, chunk_size=512, overlap=64):
|
|
|
|
| 2106 |
if progress:
|
| 2107 |
progress(0.05, desc="Preparing inputs...")
|
| 2108 |
print("[generate_music] Preparing inputs...")
|
| 2109 |
+
|
| 2110 |
+
# Reset offload cost
|
| 2111 |
+
self.current_offload_cost = 0.0
|
| 2112 |
|
| 2113 |
# Caption and lyrics are optional - can be empty
|
| 2114 |
# Use provided batch_size or default
|
|
|
|
| 2208 |
print("[generate_music] Model generation completed. Decoding latents...")
|
| 2209 |
pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
|
| 2210 |
time_costs = outputs["time_costs"]
|
| 2211 |
+
time_costs["offload_time_cost"] = self.current_offload_cost
|
| 2212 |
print(f" - pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}")
|
| 2213 |
print(f" - time_costs: {time_costs}")
|
| 2214 |
if progress:
|
|
|
|
| 2218 |
# Decode latents to audio
|
| 2219 |
start_time = time.time()
|
| 2220 |
with torch.no_grad():
|
| 2221 |
+
with self._load_model_context("vae"):
|
| 2222 |
+
# Transpose for VAE decode: [batch, latent_length, latent_dim] -> [batch, latent_dim, latent_length]
|
| 2223 |
+
pred_latents_for_decode = pred_latents.transpose(1, 2)
|
| 2224 |
+
# Ensure input is in VAE's dtype
|
| 2225 |
+
pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype)
|
| 2226 |
+
|
| 2227 |
+
if use_tiled_decode:
|
| 2228 |
+
print("[generate_music] Using tiled VAE decode to reduce VRAM usage...")
|
| 2229 |
+
pred_wavs = self.tiled_decode(pred_latents_for_decode) # [batch, channels, samples]
|
| 2230 |
+
else:
|
| 2231 |
+
pred_wavs = self.vae.decode(pred_latents_for_decode).sample
|
| 2232 |
+
|
| 2233 |
+
# Cast output to float32 for audio processing/saving
|
| 2234 |
+
pred_wavs = pred_wavs.to(torch.float32)
|
| 2235 |
end_time = time.time()
|
| 2236 |
time_costs["vae_decode_time_cost"] = end_time - start_time
|
| 2237 |
time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"]
|
| 2238 |
|
| 2239 |
+
# Update offload cost one last time to include VAE offloading
|
| 2240 |
+
time_costs["offload_time_cost"] = self.current_offload_cost
|
| 2241 |
+
|
| 2242 |
print("[generate_music] VAE decode completed. Saving audio files...")
|
| 2243 |
if progress:
|
| 2244 |
progress(0.9, desc="Saving audio files...")
|
test.py
CHANGED
|
@@ -41,7 +41,9 @@ def main():
|
|
| 41 |
device=device,
|
| 42 |
init_llm=True,
|
| 43 |
use_flash_attention=False, # Default in UI
|
| 44 |
-
compile_model=
|
|
|
|
|
|
|
| 45 |
)
|
| 46 |
|
| 47 |
if not enabled:
|
|
|
|
| 41 |
device=device,
|
| 42 |
init_llm=True,
|
| 43 |
use_flash_attention=False, # Default in UI
|
| 44 |
+
compile_model=False,
|
| 45 |
+
offload_to_cpu=True,
|
| 46 |
+
offload_dit_to_cpu=False, # Keep DiT on GPU
|
| 47 |
)
|
| 48 |
|
| 49 |
if not enabled:
|