xushengyuan commited on
Commit
447806b
Β·
1 Parent(s): e35f9c5

all handler log printing using loguru

Browse files
Files changed (1) hide show
  1. acestep/handler.py +34 -32
acestep/handler.py CHANGED
@@ -282,18 +282,18 @@ class AceStepHandler:
282
  if init_llm:
283
  full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
284
  if os.path.exists(full_lm_model_path):
285
- print("loading 5Hz LM tokenizer...")
286
  start_time = time.time()
287
  llm_tokenizer = deepcopy(self.text_tokenizer)
288
  max_audio_length = 2**16 - 1
289
  semantic_tokens = [f"<|audio_code_{i}|>" for i in range(max_audio_length)]
290
  # 217204
291
  llm_tokenizer.add_special_tokens({"additional_special_tokens": semantic_tokens})
292
- print(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
293
  self.llm_tokenizer = llm_tokenizer
294
  if device == "cuda":
295
  status_msg = self._initialize_5hz_lm_cuda(full_lm_model_path)
296
- print(f"5Hz LM status message: {status_msg}")
297
  # Check if initialization failed (status_msg starts with ❌)
298
  if status_msg.startswith("❌"):
299
  # vllm initialization failed, fallback to PyTorch
@@ -304,6 +304,7 @@ class AceStepHandler:
304
  self.llm.eval()
305
  self.llm_backend = "pt"
306
  self.llm_initialized = True
 
307
  except Exception as e:
308
  return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
309
  # If vllm initialization succeeded, self.llm_initialized should already be True
@@ -316,6 +317,7 @@ class AceStepHandler:
316
  self.llm.eval()
317
  self.llm_backend = "pt"
318
  self.llm_initialized = True
 
319
  except Exception as e:
320
  return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
321
 
@@ -465,13 +467,13 @@ class AceStepHandler:
465
  """Initialize 5Hz LM model"""
466
  if not torch.cuda.is_available():
467
  self.llm_initialized = False
468
- print("CUDA is not available. Please check your GPU setup.")
469
  return "❌ CUDA is not available. Please check your GPU setup."
470
  try:
471
  from nanovllm import LLM, SamplingParams
472
  except ImportError:
473
  self.llm_initialized = False
474
- print("nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install .")
475
  return "❌ nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install ."
476
 
477
  try:
@@ -489,7 +491,7 @@ class AceStepHandler:
489
  else:
490
  self.max_model_len = 2048
491
 
492
- print(f"Initializing 5Hz LM with model: {model_path}, enforce_eager: False, tensor_parallel_size: 1, max_model_len: {self.max_model_len}, gpu_memory_utilization: {gpu_memory_utilization}")
493
  start_time = time.time()
494
  self.llm = LLM(
495
  model=model_path,
@@ -498,7 +500,7 @@ class AceStepHandler:
498
  max_model_len=self.max_model_len,
499
  gpu_memory_utilization=gpu_memory_utilization,
500
  )
501
- print(f"5Hz LM initialized successfully in {time.time() - start_time:.2f} seconds")
502
  self.llm.tokenizer = self.llm_tokenizer
503
  self.llm_initialized = True
504
  self.llm_backend = "vllm"
@@ -523,7 +525,7 @@ class AceStepHandler:
523
  tokenize=False,
524
  add_generation_prompt=True,
525
  )
526
- print("[debug] formatted_prompt: ", formatted_prompt)
527
 
528
  sampling_params = SamplingParams(max_tokens=self.max_model_len, temperature=temperature)
529
  outputs = self.llm.generate([formatted_prompt], sampling_params)
@@ -649,7 +651,7 @@ class AceStepHandler:
649
  Tuple of (metadata_dict, audio_codes_string)
650
  """
651
  debug_output_text = output_text.split("</think>")[0]
652
- print(f"Debug output text: {debug_output_text}")
653
  metadata = {}
654
  audio_codes = ""
655
 
@@ -741,7 +743,7 @@ class AceStepHandler:
741
 
742
  return audio
743
  except Exception as e:
744
- print(f"Error processing target audio: {e}")
745
  return None
746
 
747
  def _parse_audio_code_string(self, code_str: str) -> List[int]:
@@ -882,7 +884,7 @@ class AceStepHandler:
882
  return match.group(1).strip()
883
  return caption
884
  except Exception as e:
885
- print(f"Error extracting caption: {e}")
886
  return caption
887
 
888
  def prepare_seeds(self, actual_batch_size, seed, use_random_seed):
@@ -1073,7 +1075,7 @@ class AceStepHandler:
1073
  return audio
1074
 
1075
  except Exception as e:
1076
- print(f"Error processing reference audio: {e}")
1077
  return None
1078
 
1079
  def process_src_audio(self, audio_file) -> Optional[torch.Tensor]:
@@ -1101,7 +1103,7 @@ class AceStepHandler:
1101
  return audio
1102
 
1103
  except Exception as e:
1104
- print(f"Error processing target audio: {e}")
1105
  return None
1106
 
1107
  def prepare_batch_data(
@@ -1178,7 +1180,7 @@ class AceStepHandler:
1178
  target_wavs = torch.zeros(2, frames)
1179
  return target_wavs
1180
  except Exception as e:
1181
- print(f"Error creating target audio: {e}")
1182
  # Fallback to 30 seconds if error
1183
  return torch.zeros(2, 30 * 48000)
1184
 
@@ -1408,7 +1410,7 @@ class AceStepHandler:
1408
  code_hint = audio_code_hints[i]
1409
  # Prefer decoding from provided audio codes
1410
  if code_hint:
1411
- print(f"[generate_music] Decoding audio codes for item {i}...")
1412
  decoded_latents = self._decode_audio_codes_to_latents(code_hint)
1413
  if decoded_latents is not None:
1414
  decoded_latents = decoded_latents.squeeze(0)
@@ -1425,7 +1427,7 @@ class AceStepHandler:
1425
  target_latent = self.silence_latent[0, :expected_latent_length, :]
1426
  else:
1427
  # Ensure input is in VAE's dtype
1428
- print(f"[generate_music] Encoding target audio to latents for item {i}...")
1429
  vae_input = current_wav.to(self.device).to(self.vae.dtype)
1430
  target_latent = self.vae.encode(vae_input).latent_dist.sample()
1431
  # Cast back to model dtype
@@ -1595,7 +1597,7 @@ class AceStepHandler:
1595
  for i in range(batch_size):
1596
  if audio_code_hints[i] is not None:
1597
  # Decode audio codes to 25Hz latents
1598
- print(f"[generate_music] Decoding audio codes for LM hints for item {i}...")
1599
  hints = self._decode_audio_codes_to_latents(audio_code_hints[i])
1600
  if hints is not None:
1601
  # Pad or crop to match max_latent_length
@@ -1841,10 +1843,10 @@ class AceStepHandler:
1841
  lyric_attention_mask = batch["lyric_attention_masks"]
1842
  text_inputs = batch["text_inputs"]
1843
 
1844
- print("[preprocess_batch] Inferring prompt embeddings...")
1845
  with self._load_model_context("text_encoder"):
1846
  text_hidden_states = self.infer_text_embeddings(text_token_idss)
1847
- print("[preprocess_batch] Inferring lyric embeddings...")
1848
  lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss)
1849
 
1850
  is_covers = batch["is_covers"]
@@ -1857,7 +1859,7 @@ class AceStepHandler:
1857
  non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None)
1858
  non_cover_text_hidden_states = None
1859
  if non_cover_text_input_ids is not None:
1860
- print("[preprocess_batch] Inferring non-cover text embeddings...")
1861
  non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids)
1862
 
1863
  return (
@@ -2085,7 +2087,7 @@ class AceStepHandler:
2085
  "cfg_interval_start": cfg_interval_start,
2086
  "cfg_interval_end": cfg_interval_end,
2087
  }
2088
- print("[service_generate] Generating audio...")
2089
  with self._load_model_context("model"):
2090
  outputs = self.model.generate_audio(**generate_kwargs)
2091
  return outputs
@@ -2213,10 +2215,10 @@ class AceStepHandler:
2213
  # Update instruction for cover task
2214
  instruction = "Generate audio semantic tokens based on the given conditions:"
2215
 
2216
- print("[generate_music] Starting generation...")
2217
  if progress:
2218
  progress(0.05, desc="Preparing inputs...")
2219
- print("[generate_music] Preparing inputs...")
2220
 
2221
  # Reset offload cost
2222
  self.current_offload_cost = 0.0
@@ -2242,7 +2244,7 @@ class AceStepHandler:
2242
  # 1. Process reference audio
2243
  refer_audios = None
2244
  if reference_audio is not None:
2245
- print("[generate_music] Processing reference audio...")
2246
  processed_ref_audio = self.process_reference_audio(reference_audio)
2247
  if processed_ref_audio is not None:
2248
  # Convert to the format expected by the service: List[List[torch.Tensor]]
@@ -2254,7 +2256,7 @@ class AceStepHandler:
2254
  # 2. Process source audio
2255
  processed_src_audio = None
2256
  if src_audio is not None:
2257
- print("[generate_music] Processing source audio...")
2258
  processed_src_audio = self.process_src_audio(src_audio)
2259
 
2260
  # 3. Prepare batch data
@@ -2316,15 +2318,15 @@ class AceStepHandler:
2316
  return_intermediate=should_return_intermediate
2317
  )
2318
 
2319
- print("[generate_music] Model generation completed. Decoding latents...")
2320
  pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
2321
  time_costs = outputs["time_costs"]
2322
  time_costs["offload_time_cost"] = self.current_offload_cost
2323
- print(f" - pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}")
2324
- print(f" - time_costs: {time_costs}")
2325
  if progress:
2326
  progress(0.8, desc="Decoding audio...")
2327
- print("[generate_music] Decoding latents with VAE...")
2328
 
2329
  # Decode latents to audio
2330
  start_time = time.time()
@@ -2336,7 +2338,7 @@ class AceStepHandler:
2336
  pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype)
2337
 
2338
  if use_tiled_decode:
2339
- print("[generate_music] Using tiled VAE decode to reduce VRAM usage...")
2340
  pred_wavs = self.tiled_decode(pred_latents_for_decode) # [batch, channels, samples]
2341
  else:
2342
  pred_wavs = self.vae.decode(pred_latents_for_decode).sample
@@ -2350,7 +2352,7 @@ class AceStepHandler:
2350
  # Update offload cost one last time to include VAE offloading
2351
  time_costs["offload_time_cost"] = self.current_offload_cost
2352
 
2353
- print("[generate_music] VAE decode completed. Saving audio files...")
2354
  if progress:
2355
  progress(0.9, desc="Saving audio files...")
2356
 
@@ -2389,7 +2391,7 @@ class AceStepHandler:
2389
  **Steps:** {inference_steps}
2390
  **Files:** {len(saved_files)} audio(s){time_costs_str}"""
2391
  status_message = f"βœ… Generation completed successfully!"
2392
- print(f"[generate_music] Done! Generated {len(saved_files)} audio files.")
2393
 
2394
  # Alignment scores and plots (placeholder for now)
2395
  align_score_1 = ""
 
282
  if init_llm:
283
  full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
284
  if os.path.exists(full_lm_model_path):
285
+ logger.info("loading 5Hz LM tokenizer...")
286
  start_time = time.time()
287
  llm_tokenizer = deepcopy(self.text_tokenizer)
288
  max_audio_length = 2**16 - 1
289
  semantic_tokens = [f"<|audio_code_{i}|>" for i in range(max_audio_length)]
290
  # 217204
291
  llm_tokenizer.add_special_tokens({"additional_special_tokens": semantic_tokens})
292
+ logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
293
  self.llm_tokenizer = llm_tokenizer
294
  if device == "cuda":
295
  status_msg = self._initialize_5hz_lm_cuda(full_lm_model_path)
296
+ logger.info(f"5Hz LM status message: {status_msg}")
297
  # Check if initialization failed (status_msg starts with ❌)
298
  if status_msg.startswith("❌"):
299
  # vllm initialization failed, fallback to PyTorch
 
304
  self.llm.eval()
305
  self.llm_backend = "pt"
306
  self.llm_initialized = True
307
+ logger.info("5Hz LM initialized successfully on CUDA device using Transformers backend")
308
  except Exception as e:
309
  return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
310
  # If vllm initialization succeeded, self.llm_initialized should already be True
 
317
  self.llm.eval()
318
  self.llm_backend = "pt"
319
  self.llm_initialized = True
320
+ logger.info("5Hz LM initialized successfully on non-CUDA device using Transformers backend")
321
  except Exception as e:
322
  return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
323
 
 
467
  """Initialize 5Hz LM model"""
468
  if not torch.cuda.is_available():
469
  self.llm_initialized = False
470
+ logger.error("CUDA is not available. Please check your GPU setup.")
471
  return "❌ CUDA is not available. Please check your GPU setup."
472
  try:
473
  from nanovllm import LLM, SamplingParams
474
  except ImportError:
475
  self.llm_initialized = False
476
+ logger.error("nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install .")
477
  return "❌ nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install ."
478
 
479
  try:
 
491
  else:
492
  self.max_model_len = 2048
493
 
494
+ logger.info(f"Initializing 5Hz LM with model: {model_path}, enforce_eager: False, tensor_parallel_size: 1, max_model_len: {self.max_model_len}, gpu_memory_utilization: {gpu_memory_utilization}")
495
  start_time = time.time()
496
  self.llm = LLM(
497
  model=model_path,
 
500
  max_model_len=self.max_model_len,
501
  gpu_memory_utilization=gpu_memory_utilization,
502
  )
503
+ logger.info(f"5Hz LM initialized successfully in {time.time() - start_time:.2f} seconds")
504
  self.llm.tokenizer = self.llm_tokenizer
505
  self.llm_initialized = True
506
  self.llm_backend = "vllm"
 
525
  tokenize=False,
526
  add_generation_prompt=True,
527
  )
528
+ logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
529
 
530
  sampling_params = SamplingParams(max_tokens=self.max_model_len, temperature=temperature)
531
  outputs = self.llm.generate([formatted_prompt], sampling_params)
 
651
  Tuple of (metadata_dict, audio_codes_string)
652
  """
653
  debug_output_text = output_text.split("</think>")[0]
654
+ logger.debug(f"Debug output text: {debug_output_text}")
655
  metadata = {}
656
  audio_codes = ""
657
 
 
743
 
744
  return audio
745
  except Exception as e:
746
+ logger.error(f"Error processing target audio: {e}")
747
  return None
748
 
749
  def _parse_audio_code_string(self, code_str: str) -> List[int]:
 
884
  return match.group(1).strip()
885
  return caption
886
  except Exception as e:
887
+ logger.error(f"Error extracting caption: {e}")
888
  return caption
889
 
890
  def prepare_seeds(self, actual_batch_size, seed, use_random_seed):
 
1075
  return audio
1076
 
1077
  except Exception as e:
1078
+ logger.error(f"Error processing reference audio: {e}")
1079
  return None
1080
 
1081
  def process_src_audio(self, audio_file) -> Optional[torch.Tensor]:
 
1103
  return audio
1104
 
1105
  except Exception as e:
1106
+ logger.error(f"Error processing target audio: {e}")
1107
  return None
1108
 
1109
  def prepare_batch_data(
 
1180
  target_wavs = torch.zeros(2, frames)
1181
  return target_wavs
1182
  except Exception as e:
1183
+ logger.error(f"Error creating target audio: {e}")
1184
  # Fallback to 30 seconds if error
1185
  return torch.zeros(2, 30 * 48000)
1186
 
 
1410
  code_hint = audio_code_hints[i]
1411
  # Prefer decoding from provided audio codes
1412
  if code_hint:
1413
+ logger.info(f"[generate_music] Decoding audio codes for item {i}...")
1414
  decoded_latents = self._decode_audio_codes_to_latents(code_hint)
1415
  if decoded_latents is not None:
1416
  decoded_latents = decoded_latents.squeeze(0)
 
1427
  target_latent = self.silence_latent[0, :expected_latent_length, :]
1428
  else:
1429
  # Ensure input is in VAE's dtype
1430
+ logger.info(f"[generate_music] Encoding target audio to latents for item {i}...")
1431
  vae_input = current_wav.to(self.device).to(self.vae.dtype)
1432
  target_latent = self.vae.encode(vae_input).latent_dist.sample()
1433
  # Cast back to model dtype
 
1597
  for i in range(batch_size):
1598
  if audio_code_hints[i] is not None:
1599
  # Decode audio codes to 25Hz latents
1600
+ logger.info(f"[generate_music] Decoding audio codes for LM hints for item {i}...")
1601
  hints = self._decode_audio_codes_to_latents(audio_code_hints[i])
1602
  if hints is not None:
1603
  # Pad or crop to match max_latent_length
 
1843
  lyric_attention_mask = batch["lyric_attention_masks"]
1844
  text_inputs = batch["text_inputs"]
1845
 
1846
+ logger.info("[preprocess_batch] Inferring prompt embeddings...")
1847
  with self._load_model_context("text_encoder"):
1848
  text_hidden_states = self.infer_text_embeddings(text_token_idss)
1849
+ logger.info("[preprocess_batch] Inferring lyric embeddings...")
1850
  lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss)
1851
 
1852
  is_covers = batch["is_covers"]
 
1859
  non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None)
1860
  non_cover_text_hidden_states = None
1861
  if non_cover_text_input_ids is not None:
1862
+ logger.info("[preprocess_batch] Inferring non-cover text embeddings...")
1863
  non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids)
1864
 
1865
  return (
 
2087
  "cfg_interval_start": cfg_interval_start,
2088
  "cfg_interval_end": cfg_interval_end,
2089
  }
2090
+ logger.info("[service_generate] Generating audio...")
2091
  with self._load_model_context("model"):
2092
  outputs = self.model.generate_audio(**generate_kwargs)
2093
  return outputs
 
2215
  # Update instruction for cover task
2216
  instruction = "Generate audio semantic tokens based on the given conditions:"
2217
 
2218
+ logger.info("[generate_music] Starting generation...")
2219
  if progress:
2220
  progress(0.05, desc="Preparing inputs...")
2221
+ logger.info("[generate_music] Preparing inputs...")
2222
 
2223
  # Reset offload cost
2224
  self.current_offload_cost = 0.0
 
2244
  # 1. Process reference audio
2245
  refer_audios = None
2246
  if reference_audio is not None:
2247
+ logger.info("[generate_music] Processing reference audio...")
2248
  processed_ref_audio = self.process_reference_audio(reference_audio)
2249
  if processed_ref_audio is not None:
2250
  # Convert to the format expected by the service: List[List[torch.Tensor]]
 
2256
  # 2. Process source audio
2257
  processed_src_audio = None
2258
  if src_audio is not None:
2259
+ logger.info("[generate_music] Processing source audio...")
2260
  processed_src_audio = self.process_src_audio(src_audio)
2261
 
2262
  # 3. Prepare batch data
 
2318
  return_intermediate=should_return_intermediate
2319
  )
2320
 
2321
+ logger.info("[generate_music] Model generation completed. Decoding latents...")
2322
  pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
2323
  time_costs = outputs["time_costs"]
2324
  time_costs["offload_time_cost"] = self.current_offload_cost
2325
+ logger.info(f" - pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}")
2326
+ logger.info(f" - time_costs: {time_costs}")
2327
  if progress:
2328
  progress(0.8, desc="Decoding audio...")
2329
+ logger.info("[generate_music] Decoding latents with VAE...")
2330
 
2331
  # Decode latents to audio
2332
  start_time = time.time()
 
2338
  pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype)
2339
 
2340
  if use_tiled_decode:
2341
+ logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...")
2342
  pred_wavs = self.tiled_decode(pred_latents_for_decode) # [batch, channels, samples]
2343
  else:
2344
  pred_wavs = self.vae.decode(pred_latents_for_decode).sample
 
2352
  # Update offload cost one last time to include VAE offloading
2353
  time_costs["offload_time_cost"] = self.current_offload_cost
2354
 
2355
+ logger.info("[generate_music] VAE decode completed. Saving audio files...")
2356
  if progress:
2357
  progress(0.9, desc="Saving audio files...")
2358
 
 
2391
  **Steps:** {inference_steps}
2392
  **Files:** {len(saved_files)} audio(s){time_costs_str}"""
2393
  status_message = f"βœ… Generation completed successfully!"
2394
+ logger.info(f"[generate_music] Done! Generated {len(saved_files)} audio files.")
2395
 
2396
  # Alignment scores and plots (placeholder for now)
2397
  align_score_1 = ""