Spaces:
Running
on
A100
Running
on
A100
Commit
Β·
447806b
1
Parent(s):
e35f9c5
all handler log printing using loguru
Browse files- acestep/handler.py +34 -32
acestep/handler.py
CHANGED
|
@@ -282,18 +282,18 @@ class AceStepHandler:
|
|
| 282 |
if init_llm:
|
| 283 |
full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
|
| 284 |
if os.path.exists(full_lm_model_path):
|
| 285 |
-
|
| 286 |
start_time = time.time()
|
| 287 |
llm_tokenizer = deepcopy(self.text_tokenizer)
|
| 288 |
max_audio_length = 2**16 - 1
|
| 289 |
semantic_tokens = [f"<|audio_code_{i}|>" for i in range(max_audio_length)]
|
| 290 |
# 217204
|
| 291 |
llm_tokenizer.add_special_tokens({"additional_special_tokens": semantic_tokens})
|
| 292 |
-
|
| 293 |
self.llm_tokenizer = llm_tokenizer
|
| 294 |
if device == "cuda":
|
| 295 |
status_msg = self._initialize_5hz_lm_cuda(full_lm_model_path)
|
| 296 |
-
|
| 297 |
# Check if initialization failed (status_msg starts with β)
|
| 298 |
if status_msg.startswith("β"):
|
| 299 |
# vllm initialization failed, fallback to PyTorch
|
|
@@ -304,6 +304,7 @@ class AceStepHandler:
|
|
| 304 |
self.llm.eval()
|
| 305 |
self.llm_backend = "pt"
|
| 306 |
self.llm_initialized = True
|
|
|
|
| 307 |
except Exception as e:
|
| 308 |
return f"β Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
|
| 309 |
# If vllm initialization succeeded, self.llm_initialized should already be True
|
|
@@ -316,6 +317,7 @@ class AceStepHandler:
|
|
| 316 |
self.llm.eval()
|
| 317 |
self.llm_backend = "pt"
|
| 318 |
self.llm_initialized = True
|
|
|
|
| 319 |
except Exception as e:
|
| 320 |
return f"β Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
|
| 321 |
|
|
@@ -465,13 +467,13 @@ class AceStepHandler:
|
|
| 465 |
"""Initialize 5Hz LM model"""
|
| 466 |
if not torch.cuda.is_available():
|
| 467 |
self.llm_initialized = False
|
| 468 |
-
|
| 469 |
return "β CUDA is not available. Please check your GPU setup."
|
| 470 |
try:
|
| 471 |
from nanovllm import LLM, SamplingParams
|
| 472 |
except ImportError:
|
| 473 |
self.llm_initialized = False
|
| 474 |
-
|
| 475 |
return "β nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install ."
|
| 476 |
|
| 477 |
try:
|
|
@@ -489,7 +491,7 @@ class AceStepHandler:
|
|
| 489 |
else:
|
| 490 |
self.max_model_len = 2048
|
| 491 |
|
| 492 |
-
|
| 493 |
start_time = time.time()
|
| 494 |
self.llm = LLM(
|
| 495 |
model=model_path,
|
|
@@ -498,7 +500,7 @@ class AceStepHandler:
|
|
| 498 |
max_model_len=self.max_model_len,
|
| 499 |
gpu_memory_utilization=gpu_memory_utilization,
|
| 500 |
)
|
| 501 |
-
|
| 502 |
self.llm.tokenizer = self.llm_tokenizer
|
| 503 |
self.llm_initialized = True
|
| 504 |
self.llm_backend = "vllm"
|
|
@@ -523,7 +525,7 @@ class AceStepHandler:
|
|
| 523 |
tokenize=False,
|
| 524 |
add_generation_prompt=True,
|
| 525 |
)
|
| 526 |
-
|
| 527 |
|
| 528 |
sampling_params = SamplingParams(max_tokens=self.max_model_len, temperature=temperature)
|
| 529 |
outputs = self.llm.generate([formatted_prompt], sampling_params)
|
|
@@ -649,7 +651,7 @@ class AceStepHandler:
|
|
| 649 |
Tuple of (metadata_dict, audio_codes_string)
|
| 650 |
"""
|
| 651 |
debug_output_text = output_text.split("</think>")[0]
|
| 652 |
-
|
| 653 |
metadata = {}
|
| 654 |
audio_codes = ""
|
| 655 |
|
|
@@ -741,7 +743,7 @@ class AceStepHandler:
|
|
| 741 |
|
| 742 |
return audio
|
| 743 |
except Exception as e:
|
| 744 |
-
|
| 745 |
return None
|
| 746 |
|
| 747 |
def _parse_audio_code_string(self, code_str: str) -> List[int]:
|
|
@@ -882,7 +884,7 @@ class AceStepHandler:
|
|
| 882 |
return match.group(1).strip()
|
| 883 |
return caption
|
| 884 |
except Exception as e:
|
| 885 |
-
|
| 886 |
return caption
|
| 887 |
|
| 888 |
def prepare_seeds(self, actual_batch_size, seed, use_random_seed):
|
|
@@ -1073,7 +1075,7 @@ class AceStepHandler:
|
|
| 1073 |
return audio
|
| 1074 |
|
| 1075 |
except Exception as e:
|
| 1076 |
-
|
| 1077 |
return None
|
| 1078 |
|
| 1079 |
def process_src_audio(self, audio_file) -> Optional[torch.Tensor]:
|
|
@@ -1101,7 +1103,7 @@ class AceStepHandler:
|
|
| 1101 |
return audio
|
| 1102 |
|
| 1103 |
except Exception as e:
|
| 1104 |
-
|
| 1105 |
return None
|
| 1106 |
|
| 1107 |
def prepare_batch_data(
|
|
@@ -1178,7 +1180,7 @@ class AceStepHandler:
|
|
| 1178 |
target_wavs = torch.zeros(2, frames)
|
| 1179 |
return target_wavs
|
| 1180 |
except Exception as e:
|
| 1181 |
-
|
| 1182 |
# Fallback to 30 seconds if error
|
| 1183 |
return torch.zeros(2, 30 * 48000)
|
| 1184 |
|
|
@@ -1408,7 +1410,7 @@ class AceStepHandler:
|
|
| 1408 |
code_hint = audio_code_hints[i]
|
| 1409 |
# Prefer decoding from provided audio codes
|
| 1410 |
if code_hint:
|
| 1411 |
-
|
| 1412 |
decoded_latents = self._decode_audio_codes_to_latents(code_hint)
|
| 1413 |
if decoded_latents is not None:
|
| 1414 |
decoded_latents = decoded_latents.squeeze(0)
|
|
@@ -1425,7 +1427,7 @@ class AceStepHandler:
|
|
| 1425 |
target_latent = self.silence_latent[0, :expected_latent_length, :]
|
| 1426 |
else:
|
| 1427 |
# Ensure input is in VAE's dtype
|
| 1428 |
-
|
| 1429 |
vae_input = current_wav.to(self.device).to(self.vae.dtype)
|
| 1430 |
target_latent = self.vae.encode(vae_input).latent_dist.sample()
|
| 1431 |
# Cast back to model dtype
|
|
@@ -1595,7 +1597,7 @@ class AceStepHandler:
|
|
| 1595 |
for i in range(batch_size):
|
| 1596 |
if audio_code_hints[i] is not None:
|
| 1597 |
# Decode audio codes to 25Hz latents
|
| 1598 |
-
|
| 1599 |
hints = self._decode_audio_codes_to_latents(audio_code_hints[i])
|
| 1600 |
if hints is not None:
|
| 1601 |
# Pad or crop to match max_latent_length
|
|
@@ -1841,10 +1843,10 @@ class AceStepHandler:
|
|
| 1841 |
lyric_attention_mask = batch["lyric_attention_masks"]
|
| 1842 |
text_inputs = batch["text_inputs"]
|
| 1843 |
|
| 1844 |
-
|
| 1845 |
with self._load_model_context("text_encoder"):
|
| 1846 |
text_hidden_states = self.infer_text_embeddings(text_token_idss)
|
| 1847 |
-
|
| 1848 |
lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss)
|
| 1849 |
|
| 1850 |
is_covers = batch["is_covers"]
|
|
@@ -1857,7 +1859,7 @@ class AceStepHandler:
|
|
| 1857 |
non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None)
|
| 1858 |
non_cover_text_hidden_states = None
|
| 1859 |
if non_cover_text_input_ids is not None:
|
| 1860 |
-
|
| 1861 |
non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids)
|
| 1862 |
|
| 1863 |
return (
|
|
@@ -2085,7 +2087,7 @@ class AceStepHandler:
|
|
| 2085 |
"cfg_interval_start": cfg_interval_start,
|
| 2086 |
"cfg_interval_end": cfg_interval_end,
|
| 2087 |
}
|
| 2088 |
-
|
| 2089 |
with self._load_model_context("model"):
|
| 2090 |
outputs = self.model.generate_audio(**generate_kwargs)
|
| 2091 |
return outputs
|
|
@@ -2213,10 +2215,10 @@ class AceStepHandler:
|
|
| 2213 |
# Update instruction for cover task
|
| 2214 |
instruction = "Generate audio semantic tokens based on the given conditions:"
|
| 2215 |
|
| 2216 |
-
|
| 2217 |
if progress:
|
| 2218 |
progress(0.05, desc="Preparing inputs...")
|
| 2219 |
-
|
| 2220 |
|
| 2221 |
# Reset offload cost
|
| 2222 |
self.current_offload_cost = 0.0
|
|
@@ -2242,7 +2244,7 @@ class AceStepHandler:
|
|
| 2242 |
# 1. Process reference audio
|
| 2243 |
refer_audios = None
|
| 2244 |
if reference_audio is not None:
|
| 2245 |
-
|
| 2246 |
processed_ref_audio = self.process_reference_audio(reference_audio)
|
| 2247 |
if processed_ref_audio is not None:
|
| 2248 |
# Convert to the format expected by the service: List[List[torch.Tensor]]
|
|
@@ -2254,7 +2256,7 @@ class AceStepHandler:
|
|
| 2254 |
# 2. Process source audio
|
| 2255 |
processed_src_audio = None
|
| 2256 |
if src_audio is not None:
|
| 2257 |
-
|
| 2258 |
processed_src_audio = self.process_src_audio(src_audio)
|
| 2259 |
|
| 2260 |
# 3. Prepare batch data
|
|
@@ -2316,15 +2318,15 @@ class AceStepHandler:
|
|
| 2316 |
return_intermediate=should_return_intermediate
|
| 2317 |
)
|
| 2318 |
|
| 2319 |
-
|
| 2320 |
pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
|
| 2321 |
time_costs = outputs["time_costs"]
|
| 2322 |
time_costs["offload_time_cost"] = self.current_offload_cost
|
| 2323 |
-
|
| 2324 |
-
|
| 2325 |
if progress:
|
| 2326 |
progress(0.8, desc="Decoding audio...")
|
| 2327 |
-
|
| 2328 |
|
| 2329 |
# Decode latents to audio
|
| 2330 |
start_time = time.time()
|
|
@@ -2336,7 +2338,7 @@ class AceStepHandler:
|
|
| 2336 |
pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype)
|
| 2337 |
|
| 2338 |
if use_tiled_decode:
|
| 2339 |
-
|
| 2340 |
pred_wavs = self.tiled_decode(pred_latents_for_decode) # [batch, channels, samples]
|
| 2341 |
else:
|
| 2342 |
pred_wavs = self.vae.decode(pred_latents_for_decode).sample
|
|
@@ -2350,7 +2352,7 @@ class AceStepHandler:
|
|
| 2350 |
# Update offload cost one last time to include VAE offloading
|
| 2351 |
time_costs["offload_time_cost"] = self.current_offload_cost
|
| 2352 |
|
| 2353 |
-
|
| 2354 |
if progress:
|
| 2355 |
progress(0.9, desc="Saving audio files...")
|
| 2356 |
|
|
@@ -2389,7 +2391,7 @@ class AceStepHandler:
|
|
| 2389 |
**Steps:** {inference_steps}
|
| 2390 |
**Files:** {len(saved_files)} audio(s){time_costs_str}"""
|
| 2391 |
status_message = f"β
Generation completed successfully!"
|
| 2392 |
-
|
| 2393 |
|
| 2394 |
# Alignment scores and plots (placeholder for now)
|
| 2395 |
align_score_1 = ""
|
|
|
|
| 282 |
if init_llm:
|
| 283 |
full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
|
| 284 |
if os.path.exists(full_lm_model_path):
|
| 285 |
+
logger.info("loading 5Hz LM tokenizer...")
|
| 286 |
start_time = time.time()
|
| 287 |
llm_tokenizer = deepcopy(self.text_tokenizer)
|
| 288 |
max_audio_length = 2**16 - 1
|
| 289 |
semantic_tokens = [f"<|audio_code_{i}|>" for i in range(max_audio_length)]
|
| 290 |
# 217204
|
| 291 |
llm_tokenizer.add_special_tokens({"additional_special_tokens": semantic_tokens})
|
| 292 |
+
logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
|
| 293 |
self.llm_tokenizer = llm_tokenizer
|
| 294 |
if device == "cuda":
|
| 295 |
status_msg = self._initialize_5hz_lm_cuda(full_lm_model_path)
|
| 296 |
+
logger.info(f"5Hz LM status message: {status_msg}")
|
| 297 |
# Check if initialization failed (status_msg starts with β)
|
| 298 |
if status_msg.startswith("β"):
|
| 299 |
# vllm initialization failed, fallback to PyTorch
|
|
|
|
| 304 |
self.llm.eval()
|
| 305 |
self.llm_backend = "pt"
|
| 306 |
self.llm_initialized = True
|
| 307 |
+
logger.info("5Hz LM initialized successfully on CUDA device using Transformers backend")
|
| 308 |
except Exception as e:
|
| 309 |
return f"β Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
|
| 310 |
# If vllm initialization succeeded, self.llm_initialized should already be True
|
|
|
|
| 317 |
self.llm.eval()
|
| 318 |
self.llm_backend = "pt"
|
| 319 |
self.llm_initialized = True
|
| 320 |
+
logger.info("5Hz LM initialized successfully on non-CUDA device using Transformers backend")
|
| 321 |
except Exception as e:
|
| 322 |
return f"β Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
|
| 323 |
|
|
|
|
| 467 |
"""Initialize 5Hz LM model"""
|
| 468 |
if not torch.cuda.is_available():
|
| 469 |
self.llm_initialized = False
|
| 470 |
+
logger.error("CUDA is not available. Please check your GPU setup.")
|
| 471 |
return "β CUDA is not available. Please check your GPU setup."
|
| 472 |
try:
|
| 473 |
from nanovllm import LLM, SamplingParams
|
| 474 |
except ImportError:
|
| 475 |
self.llm_initialized = False
|
| 476 |
+
logger.error("nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install .")
|
| 477 |
return "β nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install ."
|
| 478 |
|
| 479 |
try:
|
|
|
|
| 491 |
else:
|
| 492 |
self.max_model_len = 2048
|
| 493 |
|
| 494 |
+
logger.info(f"Initializing 5Hz LM with model: {model_path}, enforce_eager: False, tensor_parallel_size: 1, max_model_len: {self.max_model_len}, gpu_memory_utilization: {gpu_memory_utilization}")
|
| 495 |
start_time = time.time()
|
| 496 |
self.llm = LLM(
|
| 497 |
model=model_path,
|
|
|
|
| 500 |
max_model_len=self.max_model_len,
|
| 501 |
gpu_memory_utilization=gpu_memory_utilization,
|
| 502 |
)
|
| 503 |
+
logger.info(f"5Hz LM initialized successfully in {time.time() - start_time:.2f} seconds")
|
| 504 |
self.llm.tokenizer = self.llm_tokenizer
|
| 505 |
self.llm_initialized = True
|
| 506 |
self.llm_backend = "vllm"
|
|
|
|
| 525 |
tokenize=False,
|
| 526 |
add_generation_prompt=True,
|
| 527 |
)
|
| 528 |
+
logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
|
| 529 |
|
| 530 |
sampling_params = SamplingParams(max_tokens=self.max_model_len, temperature=temperature)
|
| 531 |
outputs = self.llm.generate([formatted_prompt], sampling_params)
|
|
|
|
| 651 |
Tuple of (metadata_dict, audio_codes_string)
|
| 652 |
"""
|
| 653 |
debug_output_text = output_text.split("</think>")[0]
|
| 654 |
+
logger.debug(f"Debug output text: {debug_output_text}")
|
| 655 |
metadata = {}
|
| 656 |
audio_codes = ""
|
| 657 |
|
|
|
|
| 743 |
|
| 744 |
return audio
|
| 745 |
except Exception as e:
|
| 746 |
+
logger.error(f"Error processing target audio: {e}")
|
| 747 |
return None
|
| 748 |
|
| 749 |
def _parse_audio_code_string(self, code_str: str) -> List[int]:
|
|
|
|
| 884 |
return match.group(1).strip()
|
| 885 |
return caption
|
| 886 |
except Exception as e:
|
| 887 |
+
logger.error(f"Error extracting caption: {e}")
|
| 888 |
return caption
|
| 889 |
|
| 890 |
def prepare_seeds(self, actual_batch_size, seed, use_random_seed):
|
|
|
|
| 1075 |
return audio
|
| 1076 |
|
| 1077 |
except Exception as e:
|
| 1078 |
+
logger.error(f"Error processing reference audio: {e}")
|
| 1079 |
return None
|
| 1080 |
|
| 1081 |
def process_src_audio(self, audio_file) -> Optional[torch.Tensor]:
|
|
|
|
| 1103 |
return audio
|
| 1104 |
|
| 1105 |
except Exception as e:
|
| 1106 |
+
logger.error(f"Error processing target audio: {e}")
|
| 1107 |
return None
|
| 1108 |
|
| 1109 |
def prepare_batch_data(
|
|
|
|
| 1180 |
target_wavs = torch.zeros(2, frames)
|
| 1181 |
return target_wavs
|
| 1182 |
except Exception as e:
|
| 1183 |
+
logger.error(f"Error creating target audio: {e}")
|
| 1184 |
# Fallback to 30 seconds if error
|
| 1185 |
return torch.zeros(2, 30 * 48000)
|
| 1186 |
|
|
|
|
| 1410 |
code_hint = audio_code_hints[i]
|
| 1411 |
# Prefer decoding from provided audio codes
|
| 1412 |
if code_hint:
|
| 1413 |
+
logger.info(f"[generate_music] Decoding audio codes for item {i}...")
|
| 1414 |
decoded_latents = self._decode_audio_codes_to_latents(code_hint)
|
| 1415 |
if decoded_latents is not None:
|
| 1416 |
decoded_latents = decoded_latents.squeeze(0)
|
|
|
|
| 1427 |
target_latent = self.silence_latent[0, :expected_latent_length, :]
|
| 1428 |
else:
|
| 1429 |
# Ensure input is in VAE's dtype
|
| 1430 |
+
logger.info(f"[generate_music] Encoding target audio to latents for item {i}...")
|
| 1431 |
vae_input = current_wav.to(self.device).to(self.vae.dtype)
|
| 1432 |
target_latent = self.vae.encode(vae_input).latent_dist.sample()
|
| 1433 |
# Cast back to model dtype
|
|
|
|
| 1597 |
for i in range(batch_size):
|
| 1598 |
if audio_code_hints[i] is not None:
|
| 1599 |
# Decode audio codes to 25Hz latents
|
| 1600 |
+
logger.info(f"[generate_music] Decoding audio codes for LM hints for item {i}...")
|
| 1601 |
hints = self._decode_audio_codes_to_latents(audio_code_hints[i])
|
| 1602 |
if hints is not None:
|
| 1603 |
# Pad or crop to match max_latent_length
|
|
|
|
| 1843 |
lyric_attention_mask = batch["lyric_attention_masks"]
|
| 1844 |
text_inputs = batch["text_inputs"]
|
| 1845 |
|
| 1846 |
+
logger.info("[preprocess_batch] Inferring prompt embeddings...")
|
| 1847 |
with self._load_model_context("text_encoder"):
|
| 1848 |
text_hidden_states = self.infer_text_embeddings(text_token_idss)
|
| 1849 |
+
logger.info("[preprocess_batch] Inferring lyric embeddings...")
|
| 1850 |
lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss)
|
| 1851 |
|
| 1852 |
is_covers = batch["is_covers"]
|
|
|
|
| 1859 |
non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None)
|
| 1860 |
non_cover_text_hidden_states = None
|
| 1861 |
if non_cover_text_input_ids is not None:
|
| 1862 |
+
logger.info("[preprocess_batch] Inferring non-cover text embeddings...")
|
| 1863 |
non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids)
|
| 1864 |
|
| 1865 |
return (
|
|
|
|
| 2087 |
"cfg_interval_start": cfg_interval_start,
|
| 2088 |
"cfg_interval_end": cfg_interval_end,
|
| 2089 |
}
|
| 2090 |
+
logger.info("[service_generate] Generating audio...")
|
| 2091 |
with self._load_model_context("model"):
|
| 2092 |
outputs = self.model.generate_audio(**generate_kwargs)
|
| 2093 |
return outputs
|
|
|
|
| 2215 |
# Update instruction for cover task
|
| 2216 |
instruction = "Generate audio semantic tokens based on the given conditions:"
|
| 2217 |
|
| 2218 |
+
logger.info("[generate_music] Starting generation...")
|
| 2219 |
if progress:
|
| 2220 |
progress(0.05, desc="Preparing inputs...")
|
| 2221 |
+
logger.info("[generate_music] Preparing inputs...")
|
| 2222 |
|
| 2223 |
# Reset offload cost
|
| 2224 |
self.current_offload_cost = 0.0
|
|
|
|
| 2244 |
# 1. Process reference audio
|
| 2245 |
refer_audios = None
|
| 2246 |
if reference_audio is not None:
|
| 2247 |
+
logger.info("[generate_music] Processing reference audio...")
|
| 2248 |
processed_ref_audio = self.process_reference_audio(reference_audio)
|
| 2249 |
if processed_ref_audio is not None:
|
| 2250 |
# Convert to the format expected by the service: List[List[torch.Tensor]]
|
|
|
|
| 2256 |
# 2. Process source audio
|
| 2257 |
processed_src_audio = None
|
| 2258 |
if src_audio is not None:
|
| 2259 |
+
logger.info("[generate_music] Processing source audio...")
|
| 2260 |
processed_src_audio = self.process_src_audio(src_audio)
|
| 2261 |
|
| 2262 |
# 3. Prepare batch data
|
|
|
|
| 2318 |
return_intermediate=should_return_intermediate
|
| 2319 |
)
|
| 2320 |
|
| 2321 |
+
logger.info("[generate_music] Model generation completed. Decoding latents...")
|
| 2322 |
pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
|
| 2323 |
time_costs = outputs["time_costs"]
|
| 2324 |
time_costs["offload_time_cost"] = self.current_offload_cost
|
| 2325 |
+
logger.info(f" - pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}")
|
| 2326 |
+
logger.info(f" - time_costs: {time_costs}")
|
| 2327 |
if progress:
|
| 2328 |
progress(0.8, desc="Decoding audio...")
|
| 2329 |
+
logger.info("[generate_music] Decoding latents with VAE...")
|
| 2330 |
|
| 2331 |
# Decode latents to audio
|
| 2332 |
start_time = time.time()
|
|
|
|
| 2338 |
pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype)
|
| 2339 |
|
| 2340 |
if use_tiled_decode:
|
| 2341 |
+
logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...")
|
| 2342 |
pred_wavs = self.tiled_decode(pred_latents_for_decode) # [batch, channels, samples]
|
| 2343 |
else:
|
| 2344 |
pred_wavs = self.vae.decode(pred_latents_for_decode).sample
|
|
|
|
| 2352 |
# Update offload cost one last time to include VAE offloading
|
| 2353 |
time_costs["offload_time_cost"] = self.current_offload_cost
|
| 2354 |
|
| 2355 |
+
logger.info("[generate_music] VAE decode completed. Saving audio files...")
|
| 2356 |
if progress:
|
| 2357 |
progress(0.9, desc="Saving audio files...")
|
| 2358 |
|
|
|
|
| 2391 |
**Steps:** {inference_steps}
|
| 2392 |
**Files:** {len(saved_files)} audio(s){time_costs_str}"""
|
| 2393 |
status_message = f"β
Generation completed successfully!"
|
| 2394 |
+
logger.info(f"[generate_music] Done! Generated {len(saved_files)} audio files.")
|
| 2395 |
|
| 2396 |
# Alignment scores and plots (placeholder for now)
|
| 2397 |
align_score_1 = ""
|