keylxiao commited on
Commit
680bf81
·
1 Parent(s): 5ab4485

fix :bug: : fix auto score

Browse files
acestep/gradio_ui/events/results_handlers.py CHANGED
@@ -532,7 +532,37 @@ def generate_with_progress(
532
  score_str = "Done!"
533
  if auto_score:
534
  auto_score_start = time_module.time()
535
- score_str = calculate_score_handler(llm_handler, code_str, captions, lyrics, lm_generated_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  auto_score_end = time_module.time()
537
  total_auto_score_time += (auto_score_end - auto_score_start)
538
  scores_ui_updates[i] = score_str
 
532
  score_str = "Done!"
533
  if auto_score:
534
  auto_score_start = time_module.time()
535
+
536
+ sample_tensor_data = None
537
+ try:
538
+ full_pred = result.extra_outputs.get("pred_latents")
539
+
540
+ if full_pred is not None and i < full_pred.shape[0]:
541
+ sample_tensor_data = {
542
+ "pred_latent": full_pred[i:i + 1],
543
+ "encoder_hidden_states": result.extra_outputs.get("encoder_hidden_states")[
544
+ i:i + 1] if result.extra_outputs.get(
545
+ "encoder_hidden_states") is not None else None,
546
+ "encoder_attention_mask": result.extra_outputs.get("encoder_attention_mask")[
547
+ i:i + 1] if result.extra_outputs.get(
548
+ "encoder_attention_mask") is not None else None,
549
+ "context_latents": result.extra_outputs.get("context_latents")[
550
+ i:i + 1] if result.extra_outputs.get(
551
+ "context_latents") is not None else None,
552
+ "lyric_token_ids": result.extra_outputs.get("lyric_token_idss")[
553
+ i:i + 1] if result.extra_outputs.get(
554
+ "lyric_token_idss") is not None else None,
555
+ }
556
+
557
+ # 简单校验完整性
558
+ if any(v is None for v in sample_tensor_data.values()):
559
+ sample_tensor_data = None
560
+
561
+ except Exception as e:
562
+ print(f"[Auto Score] Failed to prepare tensor data for sample {i}: {e}")
563
+ sample_tensor_data = None
564
+
565
+ score_str = calculate_score_handler(llm_handler, code_str, captions, lyrics, lm_generated_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale, dit_handler, sample_tensor_data, inference_steps)
566
  auto_score_end = time_module.time()
567
  total_auto_score_time += (auto_score_end - auto_score_start)
568
  scores_ui_updates[i] = score_str