Gong Junmin commited on
Commit
77c327b
·
unverified ·
2 Parent(s): 24f370e ba6c5ba

Merge pull request #1 from ace-step/refact_add_inference

Browse files
acestep/api_server.py CHANGED
@@ -44,6 +44,12 @@ from acestep.constants import (
44
  DEFAULT_DIT_INSTRUCTION,
45
  DEFAULT_LM_INSTRUCTION,
46
  )
 
 
 
 
 
 
47
 
48
 
49
  JobStatus = Literal["queued", "running", "succeeded", "failed"]
@@ -387,6 +393,10 @@ def create_app() -> FastAPI:
387
  app.state.executor = executor
388
  app.state.job_store = store
389
  app.state._python_executable = sys.executable
 
 
 
 
390
 
391
  async def _ensure_initialized() -> None:
392
  h: AceStepHandler = app.state.handler
@@ -443,131 +453,10 @@ def create_app() -> FastAPI:
443
  job_store.mark_running(job_id)
444
 
445
  def _blocking_generate() -> Dict[str, Any]:
446
- def _normalize_optional_int(v: Any) -> Optional[int]:
447
- if v is None:
448
- return None
449
- try:
450
- iv = int(v)
451
- except Exception:
452
- return None
453
- return None if iv == 0 else iv
454
-
455
- def _normalize_optional_float(v: Any) -> Optional[float]:
456
- if v is None:
457
- return None
458
- try:
459
- fv = float(v)
460
- except Exception:
461
- return None
462
- # gradio treats 1.0 as disabled for top_p
463
- return None if fv >= 1.0 else fv
464
-
465
- def _maybe_fill_from_metadata(current: GenerateMusicRequest, meta: Dict[str, Any]) -> tuple[Optional[int], str, str, Optional[float]]:
466
- def _parse_first_float(v: Any) -> Optional[float]:
467
- if v is None:
468
- return None
469
- if isinstance(v, (int, float)):
470
- return float(v)
471
- s = str(v).strip()
472
- if not s or s.upper() == "N/A":
473
- return None
474
- try:
475
- return float(s)
476
- except Exception:
477
- pass
478
- m = re.search(r"[-+]?\d*\.?\d+", s)
479
- if not m:
480
- return None
481
- try:
482
- return float(m.group(0))
483
- except Exception:
484
- return None
485
-
486
- def _parse_first_int(v: Any) -> Optional[int]:
487
- fv = _parse_first_float(v)
488
- if fv is None:
489
- return None
490
- try:
491
- return int(round(fv))
492
- except Exception:
493
- return None
494
-
495
- # Fill only when user did not provide values
496
- bpm_val = current.bpm
497
- if bpm_val is None:
498
- m = meta.get("bpm")
499
- parsed = _parse_first_int(m)
500
- if parsed is not None and parsed > 0:
501
- bpm_val = parsed
502
-
503
- key_scale_val = current.key_scale
504
- if not key_scale_val:
505
- m = meta.get("keyscale", meta.get("key_scale", ""))
506
- if m not in (None, "", "N/A"):
507
- key_scale_val = str(m)
508
-
509
- time_sig_val = current.time_signature
510
- if not time_sig_val:
511
- m = meta.get("timesignature", meta.get("time_signature", ""))
512
- if m not in (None, "", "N/A"):
513
- time_sig_val = str(m)
514
-
515
- dur_val = current.audio_duration
516
- if dur_val is None:
517
- m = meta.get("duration", meta.get("audio_duration"))
518
- parsed = _parse_first_float(m)
519
- if parsed is not None:
520
- dur_val = float(parsed)
521
- if dur_val <= 0:
522
- dur_val = None
523
-
524
- # Avoid truncating lyrical songs when LM predicts a very short duration.
525
- # (Users can still force a short duration by explicitly setting `audio_duration`.)
526
- if dur_val is not None and (current.lyrics or "").strip():
527
- min_dur = float(os.getenv("ACESTEP_LM_MIN_DURATION_SECONDS", "30"))
528
- if dur_val < min_dur:
529
- dur_val = None
530
-
531
- return bpm_val, key_scale_val, time_sig_val, dur_val
532
-
533
- def _estimate_duration_from_lyrics(lyrics: str) -> Optional[float]:
534
- lyrics = (lyrics or "").strip()
535
- if not lyrics:
536
- return None
537
-
538
- # Best-effort heuristic: singing rate ~ 2.2 words/sec for English-like lyrics.
539
- # For languages without spaces, fall back to non-space char count.
540
- words = re.findall(r"[A-Za-z0-9']+", lyrics)
541
- if len(words) >= 8:
542
- words_per_sec = float(os.getenv("ACESTEP_LYRICS_WORDS_PER_SEC", "2.2"))
543
- est = len(words) / max(0.5, words_per_sec)
544
- else:
545
- non_space = len(re.sub(r"\s+", "", lyrics))
546
- chars_per_sec = float(os.getenv("ACESTEP_LYRICS_CHARS_PER_SEC", "12"))
547
- est = non_space / max(4.0, chars_per_sec)
548
-
549
- min_dur = float(os.getenv("ACESTEP_LYRICS_MIN_DURATION_SECONDS", "45"))
550
- max_dur = float(os.getenv("ACESTEP_LYRICS_MAX_DURATION_SECONDS", "180"))
551
- return float(min(max(est, min_dur), max_dur))
552
-
553
- def _normalize_metas(meta: Dict[str, Any]) -> Dict[str, Any]:
554
- """Ensure a stable `metas` dict (keys always present)."""
555
- meta = meta or {}
556
- out: Dict[str, Any] = dict(meta)
557
-
558
- # Normalize key aliases
559
- if "keyscale" not in out and "key_scale" in out:
560
- out["keyscale"] = out.get("key_scale")
561
- if "timesignature" not in out and "time_signature" in out:
562
- out["timesignature"] = out.get("time_signature")
563
-
564
- # Ensure required keys exist
565
- for k in ["bpm", "duration", "genres", "keyscale", "timesignature"]:
566
- if out.get(k) in (None, ""):
567
- out[k] = "N/A"
568
- return out
569
-
570
  def _ensure_llm_ready() -> None:
 
571
  with app.state._llm_init_lock:
572
  initialized = getattr(app.state, "_llm_initialized", False)
573
  had_error = getattr(app.state, "_llm_init_error", None)
@@ -597,269 +486,207 @@ def create_app() -> FastAPI:
597
  else:
598
  app.state._llm_initialized = True
599
 
600
- # Optional: generate 5Hz LM codes server-side
601
- audio_code_string = req.audio_code_string
602
- bpm_val = req.bpm
603
- key_scale_val = req.key_scale
604
- time_sig_val = req.time_signature
605
- audio_duration_val = req.audio_duration
606
-
607
- thinking = bool(getattr(req, "thinking", False))
608
-
609
- print(
610
- "[api_server] parsed req: "
611
- f"thinking={thinking}, caption_len={len((req.caption or '').strip())}, lyrics_len={len((req.lyrics or '').strip())}, "
612
- f"bpm={req.bpm}, audio_duration={req.audio_duration}, key_scale={req.key_scale!r}, time_signature={req.time_signature!r}"
613
- )
614
 
615
- # If LM-generated code hints are used, a too-strong cover strength can suppress lyric/vocal conditioning.
616
- # We keep backward compatibility: only auto-adjust when user didn't override (still at default 1.0).
617
- audio_cover_strength_val = float(req.audio_cover_strength)
 
 
618
 
619
- lm_meta: Optional[Dict[str, Any]] = None
 
 
 
 
620
 
621
- sample_mode = bool(getattr(req, "sample_mode", False))
622
- if sample_mode:
 
 
 
 
 
 
 
 
 
 
 
 
623
  _ensure_llm_ready()
624
  if getattr(app.state, "_llm_init_error", None):
625
  raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
626
 
 
 
 
 
 
 
 
 
 
 
627
  sample_metadata, sample_status = llm.understand_audio_from_codes(
628
  audio_codes="NO USER INPUT",
629
- temperature=float(getattr(req, "lm_temperature", _LM_DEFAULT_TEMPERATURE)),
630
- cfg_scale=max(1.0, float(getattr(req, "lm_cfg_scale", _LM_DEFAULT_CFG_SCALE))),
631
- negative_prompt=str(getattr(req, "lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
632
- top_k=_normalize_optional_int(getattr(req, "lm_top_k", None)),
633
- top_p=_normalize_optional_float(getattr(req, "lm_top_p", None)),
634
- repetition_penalty=float(getattr(req, "lm_repetition_penalty", 1.0)),
635
- use_constrained_decoding=bool(getattr(req, "constrained_decoding", True)),
636
- constrained_decoding_debug=bool(getattr(req, "constrained_decoding_debug", False)),
637
  )
638
 
639
  if not sample_metadata or str(sample_status).startswith("❌"):
640
  raise RuntimeError(f"Sample generation failed: {sample_status}")
641
 
642
- req.caption = str(sample_metadata.get("caption", "") or "")
643
- req.lyrics = str(sample_metadata.get("lyrics", "") or "")
644
- req.bpm = _to_int(sample_metadata.get("bpm"), req.bpm)
645
-
646
- sample_keyscale = sample_metadata.get("keyscale", sample_metadata.get("key_scale", ""))
647
- if sample_keyscale:
648
- req.key_scale = str(sample_keyscale)
649
-
650
- sample_timesig = sample_metadata.get("timesignature", sample_metadata.get("time_signature", ""))
651
- if sample_timesig:
652
- req.time_signature = str(sample_timesig)
653
-
654
- sample_duration = _to_float(sample_metadata.get("duration"), None)
655
- if sample_duration is not None and sample_duration > 0:
656
- req.audio_duration = sample_duration
657
-
658
- lm_meta = sample_metadata
659
-
660
- fallback_values: Dict[str, Any] = {}
661
- default_bpm = _to_int(os.getenv("ACESTEP_SAMPLE_DEFAULT_BPM", "120"), 120) or 120
662
- default_duration = _to_float(os.getenv("ACESTEP_SAMPLE_DEFAULT_DURATION_SECONDS", "120"), 120.0) or 120.0
663
- default_key = os.getenv("ACESTEP_SAMPLE_DEFAULT_KEY", "C Major") or "C Major"
664
- default_timesig = os.getenv("ACESTEP_SAMPLE_DEFAULT_TIMESIGNATURE", "4/4") or "4/4"
665
-
666
- if req.bpm is None or req.bpm <= 0:
667
- req.bpm = default_bpm
668
- fallback_values["bpm"] = default_bpm
669
-
670
- if req.audio_duration is None or req.audio_duration <= 0:
671
- req.audio_duration = default_duration
672
- fallback_values["audio_duration"] = default_duration
673
-
674
- if not (req.key_scale or "").strip():
675
- req.key_scale = default_key
676
- fallback_values["key_scale"] = default_key
677
-
678
- if not (req.time_signature or "").strip():
679
- req.time_signature = default_timesig
680
- fallback_values["time_signature"] = default_timesig
681
-
682
- if fallback_values:
683
- print("[api_server] sample mode fallback values:", fallback_values)
684
-
685
- print(
686
- "[api_server] sample mode metadata:",
687
- {
688
- "caption_len": len(req.caption),
689
- "lyrics_len": len(req.lyrics),
690
- "bpm": req.bpm,
691
- "audio_duration": req.audio_duration,
692
- "key_scale": req.key_scale,
693
- "time_signature": req.time_signature,
694
- },
695
- )
696
-
697
- # Determine effective batch size (used for per-sample LM code diversity)
698
- effective_batch_size = req.batch_size
699
- if effective_batch_size is None:
700
- try:
701
- effective_batch_size = int(getattr(h, "batch_size", 1))
702
- except Exception:
703
- effective_batch_size = 1
704
- effective_batch_size = max(1, int(effective_batch_size))
705
-
706
- has_codes = bool(audio_code_string and str(audio_code_string).strip())
707
- need_lm_codes = bool(thinking) and (not has_codes)
708
-
709
- use_constrained_decoding = bool(getattr(req, "constrained_decoding", True))
710
- constrained_decoding_debug = bool(getattr(req, "constrained_decoding_debug", False))
711
- use_cot_caption = bool(getattr(req, "use_cot_caption", True))
712
- use_cot_language = bool(getattr(req, "use_cot_language", True))
713
- is_format_caption = bool(getattr(req, "is_format_caption", False))
714
-
715
- # pass them into constrained decoding so LM injects them directly
716
- # (i.e. does not re-infer / override those fields).
717
- user_metadata: Dict[str, Optional[str]] = {}
718
-
719
- def _set_user_meta(field: str, value: Optional[Any]) -> None:
720
- if value is None:
721
- return
722
- s = str(value).strip()
723
- if not s or s.upper() == "N/A":
724
- return
725
- user_metadata[field] = s
726
-
727
- _set_user_meta("bpm", int(bpm_val) if bpm_val is not None else None)
728
- _set_user_meta("duration", float(audio_duration_val) if audio_duration_val is not None else None)
729
- _set_user_meta("keyscale", key_scale_val if (key_scale_val or "").strip() else None)
730
- _set_user_meta("timesignature", time_sig_val if (time_sig_val or "").strip() else None)
731
-
732
- def _has_meta(field: str) -> bool:
733
- v = user_metadata.get(field)
734
- return bool((v or "").strip())
735
-
736
- need_lm_metas = not (
737
- _has_meta("bpm")
738
- and _has_meta("duration")
739
- and _has_meta("keyscale")
740
- and _has_meta("timesignature")
741
  )
742
 
743
- lm_target_duration: Optional[float] = None
744
- if need_lm_codes:
745
- # If user specified a duration, constrain codes generation length accordingly.
746
- if audio_duration_val is not None and float(audio_duration_val) > 0:
747
- lm_target_duration = float(audio_duration_val)
748
-
749
- print(
750
- "[api_server] LM调用参数: "
751
- f"user_metadata_keys={sorted(user_metadata.keys())}, target_duration={lm_target_duration}, "
752
- f"need_lm_codes={need_lm_codes}, need_lm_metas={need_lm_metas}, "
753
- f"use_constrained_decoding={use_constrained_decoding}, use_cot_caption={use_cot_caption}, "
754
- f"use_cot_language={use_cot_language}, is_format_caption={is_format_caption}"
755
  )
756
 
757
- if need_lm_metas or need_lm_codes:
758
- _ensure_llm_ready()
759
-
760
- if getattr(app.state, "_llm_init_error", None):
761
- # If codes generation is required, fail hard.
762
- if need_lm_codes:
763
- raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
764
- # Otherwise, skip LM best-effort (fallback to default/meta-less behavior)
765
- else:
766
- lm_infer = "llm_dit" if need_lm_codes else "dit"
767
-
768
- def _lm_call() -> tuple[Dict[str, Any], str, str]:
769
- return llm.generate_with_stop_condition(
770
- caption=req.caption,
771
- lyrics=req.lyrics,
772
- infer_type=lm_infer,
773
- temperature=float(req.lm_temperature),
774
- cfg_scale=max(1.0, float(req.lm_cfg_scale)),
775
- negative_prompt=str(req.lm_negative_prompt or "NO USER INPUT"),
776
- top_k=_normalize_optional_int(req.lm_top_k),
777
- top_p=_normalize_optional_float(req.lm_top_p),
778
- repetition_penalty=float(req.lm_repetition_penalty),
779
- target_duration=lm_target_duration,
780
- user_metadata=(user_metadata or None),
781
- use_constrained_decoding=use_constrained_decoding,
782
- constrained_decoding_debug=constrained_decoding_debug,
783
- use_cot_caption=use_cot_caption,
784
- use_cot_language=use_cot_language,
785
- is_format_caption=is_format_caption,
786
- )
787
-
788
- meta, codes, status = _lm_call()
789
- lm_meta = meta
790
-
791
- if need_lm_codes:
792
- if not codes:
793
- raise RuntimeError(f"5Hz LM generation failed: {status}")
794
-
795
- # LM once per job; rely on DiT seeds for batch diversity.
796
- # For convenience, replicate the same codes across the batch.
797
- if effective_batch_size > 1:
798
- audio_code_string = [codes] * effective_batch_size
799
- else:
800
- audio_code_string = codes
801
-
802
- # Fill only missing fields (user-provided values win)
803
- bpm_val, key_scale_val, time_sig_val, audio_duration_val = _maybe_fill_from_metadata(req, meta)
804
-
805
- # If user provided lyrics but LM didn't provide a usable duration, estimate a longer duration.
806
- if audio_duration_val is None and (req.audio_duration is None):
807
- est = _estimate_duration_from_lyrics(req.lyrics)
808
- if est is not None:
809
- audio_duration_val = est
810
-
811
- # Optional: auto-tune LM cover strength (opt-in) to avoid suppressing lyric/vocal conditioning.
812
- if thinking and audio_cover_strength_val >= 0.999 and (req.lyrics or "").strip():
813
- tuned = os.getenv("ACESTEP_LM_COVER_STRENGTH")
814
- if tuned is not None and tuned.strip() != "":
815
- audio_cover_strength_val = float(tuned)
816
-
817
- # Align behavior:
818
- # - thinking=False: metas only (ignore audio codes), keep text2music.
819
- # - thinking=True: metas + audio codes, run in cover mode with LM instruction.
820
- instruction_val = req.instruction
821
- task_type_val = (req.task_type or "").strip() or "text2music"
822
-
823
- if not thinking:
824
- audio_code_string = ""
825
- if task_type_val == "cover":
826
- task_type_val = "text2music"
827
- if (instruction_val or "").strip() in {"", _DEFAULT_LM_INSTRUCTION}:
828
- instruction_val = _DEFAULT_DIT_INSTRUCTION
829
-
830
- if thinking:
831
- task_type_val = "cover"
832
- if (instruction_val or "").strip() in {"", _DEFAULT_DIT_INSTRUCTION}:
833
- instruction_val = _DEFAULT_LM_INSTRUCTION
834
-
835
- if not (audio_code_string and str(audio_code_string).strip()):
836
- # thinking=True requires codes generation.
837
- raise RuntimeError("thinking=true requires non-empty audio codes (LM generation failed).")
838
-
839
- # Response metas MUST reflect the actual values used by DiT.
840
- metas_out = _normalize_metas(lm_meta or {})
841
- if bpm_val is not None and int(bpm_val) > 0:
842
- metas_out["bpm"] = int(bpm_val)
843
- if audio_duration_val is not None and float(audio_duration_val) > 0:
844
- metas_out["duration"] = float(audio_duration_val)
845
- if (key_scale_val or "").strip():
846
- metas_out["keyscale"] = str(key_scale_val)
847
- if (time_sig_val or "").strip():
848
- metas_out["timesignature"] = str(time_sig_val)
849
-
850
- def _ensure_text_meta(field: str, fallback: Optional[str]) -> None:
851
- existing = metas_out.get(field)
852
- if isinstance(existing, str):
853
- stripped = existing.strip()
854
- if stripped and stripped.upper() != "N/A":
855
- return
856
- if fallback is None:
857
- return
858
- if fallback.strip():
859
- metas_out[field] = fallback
860
-
861
- _ensure_text_meta("caption", req.caption)
862
- _ensure_text_meta("lyrics", req.lyrics)
863
 
864
  def _none_if_na_str(v: Any) -> Optional[str]:
865
  if v is None:
@@ -868,44 +695,17 @@ def create_app() -> FastAPI:
868
  if s in {"", "N/A"}:
869
  return None
870
  return s
871
- first, second, paths, gen_info, status_msg, seed_value, *_ = h.generate_music(
872
- captions=req.caption,
873
- lyrics=req.lyrics,
874
- bpm=bpm_val,
875
- key_scale=key_scale_val,
876
- time_signature=time_sig_val,
877
- vocal_language=req.vocal_language,
878
- inference_steps=req.inference_steps,
879
- guidance_scale=req.guidance_scale,
880
- use_random_seed=req.use_random_seed,
881
- seed=("-1" if (req.use_random_seed and int(req.seed) < 0) else str(req.seed)),
882
- reference_audio=req.reference_audio_path,
883
- audio_duration=audio_duration_val,
884
- batch_size=req.batch_size,
885
- src_audio=req.src_audio_path,
886
- audio_code_string=audio_code_string,
887
- repainting_start=req.repainting_start,
888
- repainting_end=req.repainting_end,
889
- instruction=instruction_val,
890
- audio_cover_strength=audio_cover_strength_val,
891
- task_type=task_type_val,
892
- use_adg=req.use_adg,
893
- cfg_interval_start=req.cfg_interval_start,
894
- cfg_interval_end=req.cfg_interval_end,
895
- audio_format=req.audio_format,
896
- use_tiled_decode=req.use_tiled_decode,
897
- progress=None,
898
- )
899
  return {
900
- "first_audio_path": _path_to_audio_url(first) if first else None,
901
- "second_audio_path": _path_to_audio_url(second) if second else None,
902
- "audio_paths": [_path_to_audio_url(p) for p in (paths or [])],
903
- "generation_info": gen_info,
904
- "status_message": status_msg,
905
  "seed_value": seed_value,
906
  "metas": metas_out,
907
- "bpm": int(bpm_val) if bpm_val is not None else None,
908
- "duration": float(audio_duration_val) if audio_duration_val is not None else None,
909
  "genres": _none_if_na_str(metas_out.get("genres")),
910
  "keyscale": _none_if_na_str(metas_out.get("keyscale")),
911
  "timesignature": _none_if_na_str(metas_out.get("timesignature")),
@@ -1010,53 +810,6 @@ def create_app() -> FastAPI:
1010
 
1011
  return default
1012
 
1013
- # Debug: print what keys we actually received (helps explain empty parsed values)
1014
- try:
1015
- top_keys = list(getattr(mapping, "keys", lambda: [])())
1016
- except Exception:
1017
- top_keys = []
1018
- try:
1019
- nested_probe = (
1020
- get("metas", None)
1021
- or get("meta", None)
1022
- or get("metadata", None)
1023
- or get("user_metadata", None)
1024
- or get("userMetadata", None)
1025
- )
1026
- if isinstance(nested_probe, str):
1027
- sp = nested_probe.strip()
1028
- if sp.startswith("{") and sp.endswith("}"):
1029
- try:
1030
- nested_probe = json.loads(sp)
1031
- except Exception:
1032
- nested_probe = None
1033
- nested_keys = list(nested_probe.keys()) if isinstance(nested_probe, dict) else []
1034
- except Exception:
1035
- nested_keys = []
1036
- print(f"[api_server] request keys: top={sorted(top_keys)}, nested={sorted(nested_keys)}")
1037
-
1038
- # Debug: print raw values/types for common meta fields (top-level + common aliases)
1039
- try:
1040
- probe_keys = [
1041
- "thinking",
1042
- "bpm",
1043
- "audio_duration",
1044
- "duration",
1045
- "audioDuration",
1046
- "key_scale",
1047
- "keyscale",
1048
- "keyScale",
1049
- "time_signature",
1050
- "timesignature",
1051
- "timeSignature",
1052
- ]
1053
- raw = {k: get(k, None) for k in probe_keys}
1054
- raw_types = {k: (type(v).__name__ if v is not None else None) for k, v in raw.items()}
1055
- print(f"[api_server] request raw: {raw}")
1056
- print(f"[api_server] request raw types: {raw_types}")
1057
- except Exception:
1058
- pass
1059
-
1060
  normalized_audio_duration = _to_float(_get_any("audio_duration", "duration", "audioDuration"), None)
1061
  normalized_bpm = _to_int(_get_any("bpm"), None)
1062
  normalized_keyscale = str(_get_any("key_scale", "keyscale", "keyScale", default="") or "")
@@ -1066,12 +819,6 @@ def create_app() -> FastAPI:
1066
  if normalized_audio_duration is None:
1067
  normalized_audio_duration = _to_float(_get_any("target_duration", "targetDuration"), None)
1068
 
1069
- print(
1070
- "[api_server] normalized: "
1071
- f"thinking={_to_bool(get('thinking'), False)}, bpm={normalized_bpm}, "
1072
- f"audio_duration={normalized_audio_duration}, key_scale={normalized_keyscale!r}, time_signature={normalized_timesig!r}"
1073
- )
1074
-
1075
  return GenerateMusicRequest(
1076
  caption=str(get("caption", "") or ""),
1077
  lyrics=str(get("lyrics", "") or ""),
@@ -1110,7 +857,6 @@ def create_app() -> FastAPI:
1110
  lm_negative_prompt=str(get("lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
1111
  constrained_decoding=_to_bool(_get_any("constrained_decoding", "constrainedDecoding", "constrained"), True),
1112
  constrained_decoding_debug=_to_bool(_get_any("constrained_decoding_debug", "constrainedDecodingDebug"), False),
1113
- # Accept common aliases, including hyphenated keys from some clients.
1114
  use_cot_caption=_to_bool(_get_any("use_cot_caption", "cot_caption", "cot-caption"), True),
1115
  use_cot_language=_to_bool(_get_any("use_cot_language", "cot_language", "cot-language"), True),
1116
  is_format_caption=_to_bool(_get_any("is_format_caption", "isFormatCaption"), False),
 
44
  DEFAULT_DIT_INSTRUCTION,
45
  DEFAULT_LM_INSTRUCTION,
46
  )
47
+ from acestep.inference import (
48
+ GenerationParams,
49
+ GenerationConfig,
50
+ generate_music,
51
+ )
52
+ from acestep.gradio_ui.events.results_handlers import _build_generation_info
53
 
54
 
55
  JobStatus = Literal["queued", "running", "succeeded", "failed"]
 
393
  app.state.executor = executor
394
  app.state.job_store = store
395
  app.state._python_executable = sys.executable
396
+
397
+ # Temporary directory for saving generated audio files
398
+ app.state.temp_audio_dir = os.path.join(tmp_root, "api_audio")
399
+ os.makedirs(app.state.temp_audio_dir, exist_ok=True)
400
 
401
  async def _ensure_initialized() -> None:
402
  h: AceStepHandler = app.state.handler
 
453
  job_store.mark_running(job_id)
454
 
455
  def _blocking_generate() -> Dict[str, Any]:
456
+ """Generate music using unified inference logic from acestep.inference"""
457
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  def _ensure_llm_ready() -> None:
459
+ """Ensure LLM handler is initialized when needed"""
460
  with app.state._llm_init_lock:
461
  initialized = getattr(app.state, "_llm_initialized", False)
462
  had_error = getattr(app.state, "_llm_init_error", None)
 
486
  else:
487
  app.state._llm_initialized = True
488
 
489
+ def _normalize_metas(meta: Dict[str, Any]) -> Dict[str, Any]:
490
+ """Ensure a stable `metas` dict (keys always present)."""
491
+ meta = meta or {}
492
+ out: Dict[str, Any] = dict(meta)
 
 
 
 
 
 
 
 
 
 
493
 
494
+ # Normalize key aliases
495
+ if "keyscale" not in out and "key_scale" in out:
496
+ out["keyscale"] = out.get("key_scale")
497
+ if "timesignature" not in out and "time_signature" in out:
498
+ out["timesignature"] = out.get("time_signature")
499
 
500
+ # Ensure required keys exist
501
+ for k in ["bpm", "duration", "genres", "keyscale", "timesignature"]:
502
+ if out.get(k) in (None, ""):
503
+ out[k] = "N/A"
504
+ return out
505
 
506
+ # Normalize LM sampling parameters
507
+ lm_top_k = req.lm_top_k if req.lm_top_k and req.lm_top_k > 0 else 0
508
+ lm_top_p = req.lm_top_p if req.lm_top_p and req.lm_top_p < 1.0 else 0.9
509
+
510
+ # Determine if LLM is needed
511
+ thinking = bool(req.thinking)
512
+ sample_mode = bool(req.sample_mode)
513
+ need_llm = thinking or sample_mode
514
+
515
+ print(f"[api_server] Request params: req.thinking={req.thinking}, req.sample_mode={req.sample_mode}")
516
+ print(f"[api_server] Determined: thinking={thinking}, sample_mode={sample_mode}, need_llm={need_llm}")
517
+
518
+ # Ensure LLM is ready if needed
519
+ if need_llm:
520
  _ensure_llm_ready()
521
  if getattr(app.state, "_llm_init_error", None):
522
  raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
523
 
524
+ # Handle sample mode: generate random caption/lyrics first
525
+ caption = req.caption
526
+ lyrics = req.lyrics
527
+ bpm = req.bpm
528
+ key_scale = req.key_scale
529
+ time_signature = req.time_signature
530
+ audio_duration = req.audio_duration
531
+
532
+ if sample_mode:
533
+ print("[api_server] Sample mode: generating random caption/lyrics via LM")
534
  sample_metadata, sample_status = llm.understand_audio_from_codes(
535
  audio_codes="NO USER INPUT",
536
+ temperature=req.lm_temperature,
537
+ cfg_scale=max(1.0, req.lm_cfg_scale),
538
+ negative_prompt=req.lm_negative_prompt,
539
+ top_k=lm_top_k if lm_top_k > 0 else None,
540
+ top_p=lm_top_p if lm_top_p < 1.0 else None,
541
+ repetition_penalty=req.lm_repetition_penalty,
542
+ use_constrained_decoding=req.constrained_decoding,
543
+ constrained_decoding_debug=req.constrained_decoding_debug,
544
  )
545
 
546
  if not sample_metadata or str(sample_status).startswith("❌"):
547
  raise RuntimeError(f"Sample generation failed: {sample_status}")
548
 
549
+ # Use generated values with fallback defaults
550
+ caption = sample_metadata.get("caption", "")
551
+ lyrics = sample_metadata.get("lyrics", "")
552
+ bpm = _to_int(sample_metadata.get("bpm"), None) or _to_int(os.getenv("ACESTEP_SAMPLE_DEFAULT_BPM", "120"), 120)
553
+ key_scale = sample_metadata.get("keyscale", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_KEY", "C Major")
554
+ time_signature = sample_metadata.get("timesignature", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_TIMESIGNATURE", "4/4")
555
+ audio_duration = _to_float(sample_metadata.get("duration"), None) or _to_float(os.getenv("ACESTEP_SAMPLE_DEFAULT_DURATION_SECONDS", "120"), 120.0)
556
+
557
+ print(f"[api_server] Sample generated: caption_len={len(caption)}, lyrics_len={len(lyrics)}, bpm={bpm}, duration={audio_duration}")
558
+
559
+ print(f"[api_server] Before GenerationParams: thinking={thinking}, sample_mode={sample_mode}")
560
+ print(f"[api_server] Caption/Lyrics to use: caption_len={len(caption)}, lyrics_len={len(lyrics)}")
561
+
562
+ # Build GenerationParams using unified interface
563
+ # Note: thinking controls LM code generation, sample_mode only affects CoT metas
564
+ params = GenerationParams(
565
+ task_type=req.task_type,
566
+ instruction=req.instruction,
567
+ reference_audio=req.reference_audio_path,
568
+ src_audio=req.src_audio_path,
569
+ audio_codes=req.audio_code_string,
570
+ caption=caption,
571
+ lyrics=lyrics,
572
+ instrumental=False,
573
+ vocal_language=req.vocal_language,
574
+ bpm=bpm,
575
+ keyscale=key_scale,
576
+ timesignature=time_signature,
577
+ duration=audio_duration if audio_duration else -1.0,
578
+ inference_steps=req.inference_steps,
579
+ seed=req.seed,
580
+ guidance_scale=req.guidance_scale,
581
+ use_adg=req.use_adg,
582
+ cfg_interval_start=req.cfg_interval_start,
583
+ cfg_interval_end=req.cfg_interval_end,
584
+ repainting_start=req.repainting_start,
585
+ repainting_end=req.repainting_end if req.repainting_end else -1,
586
+ audio_cover_strength=req.audio_cover_strength,
587
+ # LM parameters
588
+ thinking=thinking, # Use LM for code generation when thinking=True
589
+ lm_temperature=req.lm_temperature,
590
+ lm_cfg_scale=req.lm_cfg_scale,
591
+ lm_top_k=lm_top_k,
592
+ lm_top_p=lm_top_p,
593
+ lm_negative_prompt=req.lm_negative_prompt,
594
+ use_cot_metas=not sample_mode, # Sample mode already generated metas, don't regenerate
595
+ use_cot_caption=req.use_cot_caption,
596
+ use_cot_language=req.use_cot_language,
597
+ use_constrained_decoding=req.constrained_decoding,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598
  )
599
 
600
+ # Build GenerationConfig - default to 2 audios like gradio_ui
601
+ batch_size = req.batch_size if req.batch_size is not None else 2
602
+ config = GenerationConfig(
603
+ batch_size=batch_size,
604
+ use_random_seed=req.use_random_seed,
605
+ seeds=None, # Let unified logic handle seed generation
606
+ audio_format=req.audio_format,
607
+ constrained_decoding_debug=req.constrained_decoding_debug,
 
 
 
 
608
  )
609
 
610
+ # Check LLM initialization status
611
+ llm_is_initialized = getattr(app.state, "_llm_initialized", False)
612
+ llm_to_pass = llm if llm_is_initialized else None
613
+
614
+ print(f"[api_server] Generating music with unified interface:")
615
+ print(f" - thinking={params.thinking}")
616
+ print(f" - batch_size={batch_size}")
617
+ print(f" - llm_initialized={llm_is_initialized}")
618
+ print(f" - llm_handler={'Available' if llm_to_pass else 'None'}")
619
+
620
+ # Generate music using unified interface
621
+ result = generate_music(
622
+ dit_handler=h,
623
+ llm_handler=llm_to_pass,
624
+ params=params,
625
+ config=config,
626
+ save_dir=app.state.temp_audio_dir,
627
+ progress=None,
628
+ )
629
+
630
+ print(f"[api_server] Generation completed. Success={result.success}, Audios={len(result.audios)}")
631
+ print(f"[api_server] Time costs keys: {list(result.extra_outputs.get('time_costs', {}).keys())}")
632
+
633
+ if not result.success:
634
+ raise RuntimeError(f"Music generation failed: {result.error or result.status_message}")
635
+
636
+ # Extract results
637
+ audio_paths = [audio["path"] for audio in result.audios if audio.get("path")]
638
+ first_audio = audio_paths[0] if len(audio_paths) > 0 else None
639
+ second_audio = audio_paths[1] if len(audio_paths) > 1 else None
640
+
641
+ # Get metadata from LM or CoT results
642
+ lm_metadata = result.extra_outputs.get("lm_metadata", {})
643
+ metas_out = _normalize_metas(lm_metadata)
644
+
645
+ # Update metas with actual values used
646
+ if params.cot_bpm:
647
+ metas_out["bpm"] = params.cot_bpm
648
+ elif bpm:
649
+ metas_out["bpm"] = bpm
650
+
651
+ if params.cot_duration:
652
+ metas_out["duration"] = params.cot_duration
653
+ elif audio_duration:
654
+ metas_out["duration"] = audio_duration
655
+
656
+ if params.cot_keyscale:
657
+ metas_out["keyscale"] = params.cot_keyscale
658
+ elif key_scale:
659
+ metas_out["keyscale"] = key_scale
660
+
661
+ if params.cot_timesignature:
662
+ metas_out["timesignature"] = params.cot_timesignature
663
+ elif time_signature:
664
+ metas_out["timesignature"] = time_signature
665
+
666
+ # Ensure caption and lyrics are in metas
667
+ if caption:
668
+ metas_out["caption"] = caption
669
+ if lyrics:
670
+ metas_out["lyrics"] = lyrics
671
+
672
+ # Extract seed values for response (comma-separated for multiple audios)
673
+ seed_values = []
674
+ for audio in result.audios:
675
+ audio_params = audio.get("params", {})
676
+ seed = audio_params.get("seed")
677
+ if seed is not None:
678
+ seed_values.append(str(seed))
679
+ seed_value = ",".join(seed_values) if seed_values else ""
680
+
681
+ # Build generation_info using the helper function (like gradio_ui)
682
+ time_costs = result.extra_outputs.get("time_costs", {})
683
+ generation_info = _build_generation_info(
684
+ lm_metadata=lm_metadata,
685
+ time_costs=time_costs,
686
+ seed_value=seed_value,
687
+ inference_steps=req.inference_steps,
688
+ num_audios=len(result.audios),
689
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
690
 
691
  def _none_if_na_str(v: Any) -> Optional[str]:
692
  if v is None:
 
695
  if s in {"", "N/A"}:
696
  return None
697
  return s
698
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
699
  return {
700
+ "first_audio_path": _path_to_audio_url(first_audio) if first_audio else None,
701
+ "second_audio_path": _path_to_audio_url(second_audio) if second_audio else None,
702
+ "audio_paths": [_path_to_audio_url(p) for p in audio_paths],
703
+ "generation_info": generation_info,
704
+ "status_message": result.status_message,
705
  "seed_value": seed_value,
706
  "metas": metas_out,
707
+ "bpm": metas_out.get("bpm") if isinstance(metas_out.get("bpm"), int) else None,
708
+ "duration": metas_out.get("duration") if isinstance(metas_out.get("duration"), (int, float)) else None,
709
  "genres": _none_if_na_str(metas_out.get("genres")),
710
  "keyscale": _none_if_na_str(metas_out.get("keyscale")),
711
  "timesignature": _none_if_na_str(metas_out.get("timesignature")),
 
810
 
811
  return default
812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813
  normalized_audio_duration = _to_float(_get_any("audio_duration", "duration", "audioDuration"), None)
814
  normalized_bpm = _to_int(_get_any("bpm"), None)
815
  normalized_keyscale = str(_get_any("key_scale", "keyscale", "keyScale", default="") or "")
 
819
  if normalized_audio_duration is None:
820
  normalized_audio_duration = _to_float(_get_any("target_duration", "targetDuration"), None)
821
 
 
 
 
 
 
 
822
  return GenerateMusicRequest(
823
  caption=str(get("caption", "") or ""),
824
  lyrics=str(get("lyrics", "") or ""),
 
857
  lm_negative_prompt=str(get("lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
858
  constrained_decoding=_to_bool(_get_any("constrained_decoding", "constrainedDecoding", "constrained"), True),
859
  constrained_decoding_debug=_to_bool(_get_any("constrained_decoding_debug", "constrainedDecodingDebug"), False),
 
860
  use_cot_caption=_to_bool(_get_any("use_cot_caption", "cot_caption", "cot-caption"), True),
861
  use_cot_language=_to_bool(_get_any("use_cot_language", "cot_language", "cot-language"), True),
862
  is_format_caption=_to_bool(_get_any("is_format_caption", "isFormatCaption"), False),
acestep/audio_utils.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio saving and transcoding utility module
3
+
4
+ Independent audio file operations outside of handler, supporting:
5
+ - Save audio tensor/numpy to files (default FLAC format, fast)
6
+ - Format conversion (FLAC/WAV/MP3)
7
+ - Batch processing
8
+ """
9
+
10
+ import os
11
+ import hashlib
12
+ import json
13
+ from pathlib import Path
14
+ from typing import Union, Optional, List, Tuple
15
+ import torch
16
+ import numpy as np
17
+ import torchaudio
18
+ from loguru import logger
19
+
20
+
21
+ class AudioSaver:
22
+ """Audio saving and transcoding utility class"""
23
+
24
+ def __init__(self, default_format: str = "flac"):
25
+ """
26
+ Initialize audio saver
27
+
28
+ Args:
29
+ default_format: Default save format ('flac', 'wav', 'mp3')
30
+ """
31
+ self.default_format = default_format.lower()
32
+ if self.default_format not in ["flac", "wav", "mp3"]:
33
+ logger.warning(f"Unsupported format {default_format}, using 'flac'")
34
+ self.default_format = "flac"
35
+
36
+ def save_audio(
37
+ self,
38
+ audio_data: Union[torch.Tensor, np.ndarray],
39
+ output_path: Union[str, Path],
40
+ sample_rate: int = 48000,
41
+ format: Optional[str] = None,
42
+ channels_first: bool = True,
43
+ ) -> str:
44
+ """
45
+ Save audio data to file
46
+
47
+ Args:
48
+ audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
49
+ output_path: Output file path (extension can be omitted)
50
+ sample_rate: Sample rate
51
+ format: Audio format ('flac', 'wav', 'mp3'), defaults to default_format
52
+ channels_first: If True, tensor format is [channels, samples], else [samples, channels]
53
+
54
+ Returns:
55
+ Actual saved file path
56
+ """
57
+ format = (format or self.default_format).lower()
58
+ if format not in ["flac", "wav", "mp3"]:
59
+ logger.warning(f"Unsupported format {format}, using {self.default_format}")
60
+ format = self.default_format
61
+
62
+ # Ensure output path has correct extension
63
+ output_path = Path(output_path)
64
+ if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
65
+ output_path = output_path.with_suffix(f'.{format}')
66
+
67
+ # Convert to torch tensor
68
+ if isinstance(audio_data, np.ndarray):
69
+ if channels_first:
70
+ # numpy [samples, channels] -> tensor [channels, samples]
71
+ audio_tensor = torch.from_numpy(audio_data.T).float()
72
+ else:
73
+ # numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples]
74
+ audio_tensor = torch.from_numpy(audio_data).float()
75
+ if audio_tensor.dim() == 2 and audio_tensor.shape[0] < audio_tensor.shape[1]:
76
+ audio_tensor = audio_tensor.T
77
+ else:
78
+ # torch tensor
79
+ audio_tensor = audio_data.cpu().float()
80
+ if not channels_first and audio_tensor.dim() == 2:
81
+ # [samples, channels] -> [channels, samples]
82
+ if audio_tensor.shape[0] > audio_tensor.shape[1]:
83
+ audio_tensor = audio_tensor.T
84
+
85
+ # Ensure memory is contiguous
86
+ audio_tensor = audio_tensor.contiguous()
87
+
88
+ # Select backend and save
89
+ try:
90
+ if format == "mp3":
91
+ # MP3 uses ffmpeg backend
92
+ torchaudio.save(
93
+ str(output_path),
94
+ audio_tensor,
95
+ sample_rate,
96
+ channels_first=True,
97
+ backend='ffmpeg',
98
+ )
99
+ elif format in ["flac", "wav"]:
100
+ # FLAC and WAV use soundfile backend (fastest)
101
+ torchaudio.save(
102
+ str(output_path),
103
+ audio_tensor,
104
+ sample_rate,
105
+ channels_first=True,
106
+ backend='soundfile',
107
+ )
108
+ else:
109
+ # Other formats use default backend
110
+ torchaudio.save(
111
+ str(output_path),
112
+ audio_tensor,
113
+ sample_rate,
114
+ channels_first=True,
115
+ )
116
+
117
+ logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
118
+ return str(output_path)
119
+
120
+ except Exception as e:
121
+ logger.error(f"[AudioSaver] Failed to save audio: {e}")
122
+ raise
123
+
124
+ def convert_audio(
125
+ self,
126
+ input_path: Union[str, Path],
127
+ output_path: Union[str, Path],
128
+ output_format: str,
129
+ remove_input: bool = False,
130
+ ) -> str:
131
+ """
132
+ Convert audio format
133
+
134
+ Args:
135
+ input_path: Input audio file path
136
+ output_path: Output audio file path
137
+ output_format: Target format ('flac', 'wav', 'mp3')
138
+ remove_input: Whether to delete input file
139
+
140
+ Returns:
141
+ Output file path
142
+ """
143
+ input_path = Path(input_path)
144
+ output_path = Path(output_path)
145
+
146
+ if not input_path.exists():
147
+ raise FileNotFoundError(f"Input file not found: {input_path}")
148
+
149
+ # Load audio
150
+ audio_tensor, sample_rate = torchaudio.load(str(input_path))
151
+
152
+ # Save as new format
153
+ output_path = self.save_audio(
154
+ audio_tensor,
155
+ output_path,
156
+ sample_rate=sample_rate,
157
+ format=output_format,
158
+ channels_first=True
159
+ )
160
+
161
+ # Delete input file if needed
162
+ if remove_input:
163
+ input_path.unlink()
164
+ logger.debug(f"[AudioSaver] Removed input file: {input_path}")
165
+
166
+ return output_path
167
+
168
+ def save_batch(
169
+ self,
170
+ audio_batch: Union[List[torch.Tensor], torch.Tensor],
171
+ output_dir: Union[str, Path],
172
+ file_prefix: str = "audio",
173
+ sample_rate: int = 48000,
174
+ format: Optional[str] = None,
175
+ channels_first: bool = True,
176
+ ) -> List[str]:
177
+ """
178
+ Save audio batch
179
+
180
+ Args:
181
+ audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
182
+ output_dir: Output directory
183
+ file_prefix: File prefix
184
+ sample_rate: Sample rate
185
+ format: Audio format
186
+ channels_first: Tensor format flag
187
+
188
+ Returns:
189
+ List of saved file paths
190
+ """
191
+ output_dir = Path(output_dir)
192
+ output_dir.mkdir(parents=True, exist_ok=True)
193
+
194
+ # Process batch
195
+ if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
196
+ # [batch, channels, samples]
197
+ audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
198
+ elif isinstance(audio_batch, list):
199
+ audio_list = audio_batch
200
+ else:
201
+ audio_list = [audio_batch]
202
+
203
+ saved_paths = []
204
+ for i, audio in enumerate(audio_list):
205
+ output_path = output_dir / f"{file_prefix}_{i:04d}"
206
+ saved_path = self.save_audio(
207
+ audio,
208
+ output_path,
209
+ sample_rate=sample_rate,
210
+ format=format,
211
+ channels_first=channels_first
212
+ )
213
+ saved_paths.append(saved_path)
214
+
215
+ return saved_paths
216
+
217
+
218
+ def get_audio_file_hash(audio_file) -> str:
219
+ """
220
+ Get hash identifier for an audio file.
221
+
222
+ Args:
223
+ audio_file: Path to audio file (str) or file-like object
224
+
225
+ Returns:
226
+ Hash string or empty string
227
+ """
228
+ if audio_file is None:
229
+ return ""
230
+
231
+ try:
232
+ if isinstance(audio_file, str):
233
+ if os.path.exists(audio_file):
234
+ with open(audio_file, 'rb') as f:
235
+ return hashlib.md5(f.read()).hexdigest()
236
+ return hashlib.md5(audio_file.encode('utf-8')).hexdigest()
237
+ elif hasattr(audio_file, 'name'):
238
+ return hashlib.md5(str(audio_file.name).encode('utf-8')).hexdigest()
239
+ return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
240
+ except Exception:
241
+ return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
242
+
243
+
244
+ def generate_uuid_from_params(params_dict) -> str:
245
+ """
246
+ Generate deterministic UUID from generation parameters.
247
+ Same parameters will always generate the same UUID.
248
+
249
+ Args:
250
+ params_dict: Dictionary of parameters
251
+
252
+ Returns:
253
+ UUID string
254
+ """
255
+
256
+ params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
257
+ hash_obj = hashlib.sha256(params_json.encode('utf-8'))
258
+ hash_hex = hash_obj.hexdigest()
259
+ uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
260
+ return uuid_str
261
+
262
+
263
+ def generate_uuid_from_audio_data(
264
+ audio_data: Union[torch.Tensor, np.ndarray],
265
+ seed: Optional[int] = None
266
+ ) -> str:
267
+ """
268
+ Generate UUID from audio data (for caching/deduplication)
269
+
270
+ Args:
271
+ audio_data: Audio data
272
+ seed: Optional seed value
273
+
274
+ Returns:
275
+ UUID string
276
+ """
277
+ if isinstance(audio_data, torch.Tensor):
278
+ # Convert to numpy and calculate hash
279
+ audio_np = audio_data.cpu().numpy()
280
+ else:
281
+ audio_np = audio_data
282
+
283
+ # Calculate data hash
284
+ data_hash = hashlib.md5(audio_np.tobytes()).hexdigest()
285
+
286
+ if seed is not None:
287
+ combined = f"{data_hash}_{seed}"
288
+ return hashlib.md5(combined.encode()).hexdigest()
289
+
290
+ return data_hash
291
+
292
+
293
+ # Global default instance
294
+ _default_saver = AudioSaver(default_format="flac")
295
+
296
+
297
+ def save_audio(
298
+ audio_data: Union[torch.Tensor, np.ndarray],
299
+ output_path: Union[str, Path],
300
+ sample_rate: int = 48000,
301
+ format: Optional[str] = None,
302
+ channels_first: bool = True,
303
+ ) -> str:
304
+ """
305
+ Convenience function: save audio (using default configuration)
306
+
307
+ Args:
308
+ audio_data: Audio data
309
+ output_path: Output path
310
+ sample_rate: Sample rate
311
+ format: Format (default flac)
312
+ channels_first: Tensor format flag
313
+
314
+ Returns:
315
+ Saved file path
316
+ """
317
+ return _default_saver.save_audio(
318
+ audio_data, output_path, sample_rate, format, channels_first
319
+ )
320
+
acestep/constrained_logits_processor.py CHANGED
@@ -571,6 +571,33 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
571
  if self.debug:
572
  logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens")
573
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
575
  """
576
  Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
@@ -1484,10 +1511,10 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1484
  if self.debug:
1485
  logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS")
1486
  else:
1487
- # Force EOS token when target codes count is reached
1488
- mask = torch.full_like(scores, float('-inf'))
1489
- mask[:, self.eos_token_id] = 0
1490
- scores = scores + mask
1491
  if self.debug:
1492
  logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS")
1493
  return self._apply_temperature_scaling(scores)
@@ -1609,20 +1636,15 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1609
  input_ids: torch.LongTensor,
1610
  scores: torch.FloatTensor,
1611
  ) -> torch.FloatTensor:
1612
- """Process a single sequence and return modified scores."""
1613
 
1614
  # Check if we have tokens in queue for user-provided field
1615
  # If so, inject the next token directly
1616
  if self.user_field_token_queue:
1617
- mask = torch.full_like(scores, float('-inf'))
1618
  next_token = self.user_field_token_queue[0]
1619
- mask[0, next_token] = 0
1620
- scores = scores + mask
1621
  return scores
1622
 
1623
- # Create mask (all -inf initially)
1624
- mask = torch.full_like(scores, float('-inf'))
1625
-
1626
  if self.state in self.fixed_strings:
1627
  # Fixed string state: force specific tokens
1628
  fixed_str = self.fixed_strings[self.state]
@@ -1633,28 +1655,18 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1633
  # This happens when we're about to complete the </think> tag
1634
  if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
1635
  # Check if the next token would complete the fixed string
1636
- # We check if position_in_state + length of next token would complete it
1637
- # Since we don't know which token will be selected, we check if we're close to completion
1638
- # Actually, a better approach: check if this is the last character(s) of the fixed string
1639
  remaining_chars = len(fixed_str) - self.position_in_state
1640
  # If remaining is small (<= 10 chars, which is typically 1-2 tokens), force EOS
1641
  if remaining_chars <= 10:
1642
  # Force EOS token to stop generation
1643
  if self.eos_token_id is not None:
1644
- mask[0, self.eos_token_id] = 0
1645
- scores = scores + mask
1646
  if self.debug:
1647
  logger.debug(f"stop_at_reasoning=True: forcing EOS near end of </think> tag (remaining: {remaining_chars} chars)")
1648
  return scores
1649
 
1650
- for t in allowed:
1651
- mask[0, t] = 0
1652
- # Apply mask
1653
- scores = scores + mask
1654
-
1655
- # Update position tracking
1656
- # We need to check if the selected token completes the fixed string
1657
- # This will be done in update_state() after token selection
1658
  else:
1659
  # Position exceeds string, move to next state
1660
  # If stop_at_reasoning is True and we're transitioning from THINK_END_TAG,
@@ -1662,8 +1674,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1662
  if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
1663
  # Force EOS token to stop generation
1664
  if self.eos_token_id is not None:
1665
- mask[0, self.eos_token_id] = 0
1666
- scores = scores + mask
1667
  if self.debug:
1668
  logger.debug(f"stop_at_reasoning=True: forcing EOS after completing </think> tag")
1669
  return scores
@@ -1676,7 +1687,9 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1676
  if self.debug:
1677
  logger.warning(f"State transition from {old_state.name} to {self.state.name} still in fixed_strings, avoiding recursion")
1678
  return scores
1679
- return self._process_single_sequence(input_ids, torch.zeros_like(scores))
 
 
1680
 
1681
  elif self.state == FSMState.BPM_VALUE:
1682
  # Check if field is user-provided and we haven't started injecting yet
@@ -1690,22 +1703,18 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1690
  self.user_field_token_queue = value_tokens
1691
  self.current_user_field = "bpm"
1692
  # Inject first token
1693
- mask[0, value_tokens[0]] = 0
1694
- scores = scores + mask
1695
  return scores
1696
 
1697
  # Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "120")
1698
  allowed = self._get_allowed_numeric_tokens(self.bpm_prefix_tree)
1699
- for t in allowed:
1700
- mask[0, t] = 0
1701
 
1702
  # Also allow newline if current token sequence prefix allows it
1703
- # Check if current token sequence is in prefix tree and allows newline
1704
  token_prefix = tuple(self.accumulated_token_ids)
1705
  if token_prefix in self.bpm_prefix_tree and self.newline_token in self.bpm_prefix_tree[token_prefix]:
1706
- mask[0, self.newline_token] = 0
1707
 
1708
- scores = scores + mask
1709
 
1710
  elif self.state == FSMState.CAPTION_VALUE:
1711
  # Caption field generation with YAML format support:
@@ -1724,8 +1733,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1724
  self.user_field_token_queue = value_tokens
1725
  self.current_user_field = "caption"
1726
  # Inject first token
1727
- mask[0, value_tokens[0]] = 0
1728
- scores = scores + mask
1729
  return scores
1730
 
1731
  # Check if we should transition after a newline (non-indented line = new field)
@@ -1757,7 +1765,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1757
  # The field name detection will happen in update_state()
1758
  return scores
1759
 
1760
- # Block backticks (code blocks)
1761
  if self.backtick_token is not None:
1762
  scores[0, self.backtick_token] = float('-inf')
1763
 
@@ -1773,8 +1781,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1773
  if self.caption_token_count >= 512:
1774
  # Force end by only allowing newline
1775
  if self.newline_token is not None:
1776
- mask[0, self.newline_token] = 0
1777
- scores = scores + mask
1778
  return scores
1779
 
1780
  # Allow natural generation (with blocked audio codes and backticks)
@@ -1791,8 +1798,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1791
  self.user_field_token_queue = value_tokens
1792
  self.current_user_field = "duration"
1793
  # Inject first token
1794
- mask[0, value_tokens[0]] = 0
1795
- scores = scores + mask
1796
  return scores
1797
 
1798
  # If target_duration is set, force generate that exact value
@@ -1804,26 +1810,22 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1804
  # Force the next digit
1805
  next_digit = int(target_str[current_pos])
1806
  if next_digit in self.digit_tokens:
1807
- mask[0, self.digit_tokens[next_digit]] = 0
1808
  else:
1809
  # All digits generated, force newline
1810
  if self.newline_token:
1811
- mask[0, self.newline_token] = 0
1812
-
1813
- scores = scores + mask
1814
  else:
1815
  # Normal duration generation with range constraint
1816
  # Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "60", "120")
1817
  allowed = self._get_allowed_numeric_tokens(self.duration_prefix_tree)
1818
- for t in allowed:
1819
- mask[0, t] = 0
1820
 
1821
  # Also allow newline if current token sequence prefix allows it
1822
  token_prefix = tuple(self.accumulated_token_ids)
1823
  if token_prefix in self.duration_prefix_tree and self.newline_token in self.duration_prefix_tree[token_prefix]:
1824
- mask[0, self.newline_token] = 0
1825
 
1826
- scores = scores + mask
1827
 
1828
  elif self.state == FSMState.GENRES_VALUE:
1829
  # Check if field is user-provided and we haven't started injecting yet
@@ -1836,8 +1838,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1836
  self.user_field_token_queue = value_tokens
1837
  self.current_user_field = "genres"
1838
  # Inject first token
1839
- mask[0, value_tokens[0]] = 0
1840
- scores = scores + mask
1841
  return scores
1842
 
1843
  # Try to hot-reload genres vocab if file has changed
@@ -1848,24 +1849,20 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1848
 
1849
  if allowed:
1850
  # Use vocabulary-constrained decoding
1851
- for t in allowed:
1852
- mask[0, t] = 0
1853
- scores = scores + mask
1854
  elif self.genres_vocab:
1855
  # Vocab is loaded but no valid continuation found
1856
  # Force newline to end the field
1857
  if self.newline_token:
1858
- mask[0, self.newline_token] = 0
1859
  if self.debug:
1860
  logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
1861
- scores = scores + mask
1862
  else:
1863
  # Fallback: no vocab loaded, use probability-based ending
1864
  if self._should_end_text_field(scores):
1865
  if self.newline_token:
1866
- mask[0, self.newline_token] = 0
1867
  self._transition_to_next_state()
1868
- scores = scores + mask
1869
  else:
1870
  # Allow any token except newline if we don't have content yet
1871
  if not self.accumulated_value.strip():
@@ -1884,8 +1881,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1884
  self.user_field_token_queue = value_tokens
1885
  self.current_user_field = "keyscale"
1886
  # Inject first token
1887
- mask[0, value_tokens[0]] = 0
1888
- scores = scores + mask
1889
  return scores
1890
 
1891
  # Check if current token sequence is complete (allows newline)
@@ -1893,21 +1889,17 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1893
  if token_prefix in self.keyscale_prefix_tree and self.newline_token in self.keyscale_prefix_tree[token_prefix]:
1894
  # Complete keyscale, allow newline
1895
  if self.newline_token:
1896
- mask[0, self.newline_token] = 0
1897
- scores = scores + mask
1898
  else:
1899
  # Not complete, allow valid continuation tokens
1900
  allowed = self._get_allowed_keyscale_tokens()
1901
  if allowed:
1902
- for t in allowed:
1903
- mask[0, t] = 0
1904
- scores = scores + mask
1905
  else:
1906
  # No valid tokens found - force newline to end field
1907
  # This handles edge cases where keyscale format is unexpected
1908
  if self.newline_token:
1909
- mask[0, self.newline_token] = 0
1910
- scores = scores + mask
1911
 
1912
  elif self.state == FSMState.LANGUAGE_VALUE:
1913
  # Language field: Use top-1 probability language (greedy selection)
@@ -1925,8 +1917,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1925
  self.user_field_token_queue = value_tokens
1926
  self.current_user_field = "language"
1927
  # Inject first token
1928
- mask[0, value_tokens[0]] = 0
1929
- scores = scores + mask
1930
  return scores
1931
 
1932
  # If we haven't started generating language yet (empty accumulated_token_ids),
@@ -1938,19 +1929,17 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1938
  candidate_tokens = list(self.language_prefix_tree[empty_prefix])
1939
 
1940
  if candidate_tokens:
1941
- # Find the token with highest probability (top-1)
1942
- # Create a mask that blocks all tokens except candidates
1943
- temp_mask = torch.full_like(scores, float('-inf'))
1944
- for t in candidate_tokens:
1945
- temp_mask[0, t] = 0
1946
- temp_scores = scores + temp_mask
1947
 
1948
  # Get the highest probability token among candidates
1949
- top_token_id = torch.argmax(temp_scores[0]).item()
 
1950
 
1951
- # Only allow this top-1 token, block all others (including other language tokens)
1952
- mask[0, top_token_id] = 0
1953
- scores = scores + mask
1954
 
1955
  if self.debug:
1956
  top_token_text = self.tokenizer.decode([top_token_id])
@@ -1958,13 +1947,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1958
  else:
1959
  # No valid first tokens found - force newline
1960
  if self.newline_token:
1961
- mask[0, self.newline_token] = 0
1962
- scores = scores + mask
1963
  else:
1964
  # Empty prefix not in tree - force newline
1965
  if self.newline_token:
1966
- mask[0, self.newline_token] = 0
1967
- scores = scores + mask
1968
  else:
1969
  # We've started generating a language, continue with prefix tree constraints
1970
  # Check if current token sequence is complete (allows newline)
@@ -1972,20 +1959,16 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1972
  if token_prefix in self.language_prefix_tree and self.newline_token in self.language_prefix_tree[token_prefix]:
1973
  # Complete language, allow newline
1974
  if self.newline_token:
1975
- mask[0, self.newline_token] = 0
1976
- scores = scores + mask
1977
  else:
1978
  # Not complete, allow valid continuation tokens
1979
  allowed = self._get_allowed_language_tokens()
1980
  if allowed:
1981
- for t in allowed:
1982
- mask[0, t] = 0
1983
- scores = scores + mask
1984
  else:
1985
  # No valid tokens found - force newline to end field
1986
  if self.newline_token:
1987
- mask[0, self.newline_token] = 0
1988
- scores = scores + mask
1989
 
1990
  elif self.state == FSMState.TIMESIG_VALUE:
1991
  # Check if field is user-provided and we haven't started injecting yet
@@ -1998,8 +1981,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1998
  self.user_field_token_queue = value_tokens
1999
  self.current_user_field = "timesignature"
2000
  # Inject first token
2001
- mask[0, value_tokens[0]] = 0
2002
- scores = scores + mask
2003
  return scores
2004
 
2005
  # Check if current token sequence is complete (allows newline)
@@ -2007,14 +1989,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
2007
  if token_prefix in self.timesig_prefix_tree and self.newline_token in self.timesig_prefix_tree[token_prefix]:
2008
  # Complete value, allow newline
2009
  if self.newline_token:
2010
- mask[0, self.newline_token] = 0
2011
- scores = scores + mask
2012
  else:
2013
  # Not complete, allow valid continuation tokens
2014
  allowed = self._get_allowed_timesig_tokens()
2015
- for t in allowed:
2016
- mask[0, t] = 0
2017
- scores = scores + mask
2018
 
2019
  return scores
2020
 
 
571
  if self.debug:
572
  logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens")
573
 
574
+ def _apply_whitelist_inplace(self, scores: torch.Tensor, allowed_tokens: List[int]) -> None:
575
+ """
576
+ Apply whitelist constraint inplace: only allow specified tokens, block all others.
577
+
578
+ This is more efficient than creating a mask tensor because:
579
+ 1. No memory allocation for mask
580
+ 2. No tensor addition operation
581
+
582
+ Args:
583
+ scores: [1, vocab_size] scores tensor to modify inplace
584
+ allowed_tokens: List of token IDs to allow (all others will be set to -inf)
585
+ """
586
+ if not allowed_tokens:
587
+ # No tokens allowed, set all to -inf
588
+ scores.fill_(float('-inf'))
589
+ return
590
+
591
+ # Save the original values of allowed tokens
592
+ allowed_indices = torch.tensor(allowed_tokens, device=scores.device, dtype=torch.long)
593
+ saved_values = scores[0, allowed_indices].clone()
594
+
595
+ # Set all scores to -inf
596
+ scores.fill_(float('-inf'))
597
+
598
+ # Restore allowed token values
599
+ scores[0, allowed_indices] = saved_values
600
+
601
  def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
602
  """
603
  Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
 
1511
  if self.debug:
1512
  logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS")
1513
  else:
1514
+ # Force EOS token when target codes count is reached - inplace
1515
+ eos_scores = scores[:, self.eos_token_id].clone()
1516
+ scores.fill_(float('-inf'))
1517
+ scores[:, self.eos_token_id] = eos_scores
1518
  if self.debug:
1519
  logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS")
1520
  return self._apply_temperature_scaling(scores)
 
1636
  input_ids: torch.LongTensor,
1637
  scores: torch.FloatTensor,
1638
  ) -> torch.FloatTensor:
1639
+ """Process a single sequence and return modified scores (inplace when possible)."""
1640
 
1641
  # Check if we have tokens in queue for user-provided field
1642
  # If so, inject the next token directly
1643
  if self.user_field_token_queue:
 
1644
  next_token = self.user_field_token_queue[0]
1645
+ self._apply_whitelist_inplace(scores, [next_token])
 
1646
  return scores
1647
 
 
 
 
1648
  if self.state in self.fixed_strings:
1649
  # Fixed string state: force specific tokens
1650
  fixed_str = self.fixed_strings[self.state]
 
1655
  # This happens when we're about to complete the </think> tag
1656
  if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
1657
  # Check if the next token would complete the fixed string
 
 
 
1658
  remaining_chars = len(fixed_str) - self.position_in_state
1659
  # If remaining is small (<= 10 chars, which is typically 1-2 tokens), force EOS
1660
  if remaining_chars <= 10:
1661
  # Force EOS token to stop generation
1662
  if self.eos_token_id is not None:
1663
+ self._apply_whitelist_inplace(scores, [self.eos_token_id])
 
1664
  if self.debug:
1665
  logger.debug(f"stop_at_reasoning=True: forcing EOS near end of </think> tag (remaining: {remaining_chars} chars)")
1666
  return scores
1667
 
1668
+ # Apply whitelist constraint inplace
1669
+ self._apply_whitelist_inplace(scores, allowed)
 
 
 
 
 
 
1670
  else:
1671
  # Position exceeds string, move to next state
1672
  # If stop_at_reasoning is True and we're transitioning from THINK_END_TAG,
 
1674
  if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
1675
  # Force EOS token to stop generation
1676
  if self.eos_token_id is not None:
1677
+ self._apply_whitelist_inplace(scores, [self.eos_token_id])
 
1678
  if self.debug:
1679
  logger.debug(f"stop_at_reasoning=True: forcing EOS after completing </think> tag")
1680
  return scores
 
1687
  if self.debug:
1688
  logger.warning(f"State transition from {old_state.name} to {self.state.name} still in fixed_strings, avoiding recursion")
1689
  return scores
1690
+ # For recursion, reset scores to zero (no constraints from previous state)
1691
+ scores.zero_()
1692
+ return self._process_single_sequence(input_ids, scores)
1693
 
1694
  elif self.state == FSMState.BPM_VALUE:
1695
  # Check if field is user-provided and we haven't started injecting yet
 
1703
  self.user_field_token_queue = value_tokens
1704
  self.current_user_field = "bpm"
1705
  # Inject first token
1706
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
 
1707
  return scores
1708
 
1709
  # Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "120")
1710
  allowed = self._get_allowed_numeric_tokens(self.bpm_prefix_tree)
 
 
1711
 
1712
  # Also allow newline if current token sequence prefix allows it
 
1713
  token_prefix = tuple(self.accumulated_token_ids)
1714
  if token_prefix in self.bpm_prefix_tree and self.newline_token in self.bpm_prefix_tree[token_prefix]:
1715
+ allowed = allowed + [self.newline_token]
1716
 
1717
+ self._apply_whitelist_inplace(scores, allowed)
1718
 
1719
  elif self.state == FSMState.CAPTION_VALUE:
1720
  # Caption field generation with YAML format support:
 
1733
  self.user_field_token_queue = value_tokens
1734
  self.current_user_field = "caption"
1735
  # Inject first token
1736
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
 
1737
  return scores
1738
 
1739
  # Check if we should transition after a newline (non-indented line = new field)
 
1765
  # The field name detection will happen in update_state()
1766
  return scores
1767
 
1768
+ # Block backticks (code blocks) - inplace
1769
  if self.backtick_token is not None:
1770
  scores[0, self.backtick_token] = float('-inf')
1771
 
 
1781
  if self.caption_token_count >= 512:
1782
  # Force end by only allowing newline
1783
  if self.newline_token is not None:
1784
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
1785
  return scores
1786
 
1787
  # Allow natural generation (with blocked audio codes and backticks)
 
1798
  self.user_field_token_queue = value_tokens
1799
  self.current_user_field = "duration"
1800
  # Inject first token
1801
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
 
1802
  return scores
1803
 
1804
  # If target_duration is set, force generate that exact value
 
1810
  # Force the next digit
1811
  next_digit = int(target_str[current_pos])
1812
  if next_digit in self.digit_tokens:
1813
+ self._apply_whitelist_inplace(scores, [self.digit_tokens[next_digit]])
1814
  else:
1815
  # All digits generated, force newline
1816
  if self.newline_token:
1817
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
 
1818
  else:
1819
  # Normal duration generation with range constraint
1820
  # Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "60", "120")
1821
  allowed = self._get_allowed_numeric_tokens(self.duration_prefix_tree)
 
 
1822
 
1823
  # Also allow newline if current token sequence prefix allows it
1824
  token_prefix = tuple(self.accumulated_token_ids)
1825
  if token_prefix in self.duration_prefix_tree and self.newline_token in self.duration_prefix_tree[token_prefix]:
1826
+ allowed = allowed + [self.newline_token]
1827
 
1828
+ self._apply_whitelist_inplace(scores, allowed)
1829
 
1830
  elif self.state == FSMState.GENRES_VALUE:
1831
  # Check if field is user-provided and we haven't started injecting yet
 
1838
  self.user_field_token_queue = value_tokens
1839
  self.current_user_field = "genres"
1840
  # Inject first token
1841
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
 
1842
  return scores
1843
 
1844
  # Try to hot-reload genres vocab if file has changed
 
1849
 
1850
  if allowed:
1851
  # Use vocabulary-constrained decoding
1852
+ self._apply_whitelist_inplace(scores, allowed)
 
 
1853
  elif self.genres_vocab:
1854
  # Vocab is loaded but no valid continuation found
1855
  # Force newline to end the field
1856
  if self.newline_token:
 
1857
  if self.debug:
1858
  logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
1859
+ self._apply_whitelist_inplace(scores, [self.newline_token])
1860
  else:
1861
  # Fallback: no vocab loaded, use probability-based ending
1862
  if self._should_end_text_field(scores):
1863
  if self.newline_token:
1864
+ self._apply_whitelist_inplace(scores, [self.newline_token])
1865
  self._transition_to_next_state()
 
1866
  else:
1867
  # Allow any token except newline if we don't have content yet
1868
  if not self.accumulated_value.strip():
 
1881
  self.user_field_token_queue = value_tokens
1882
  self.current_user_field = "keyscale"
1883
  # Inject first token
1884
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
 
1885
  return scores
1886
 
1887
  # Check if current token sequence is complete (allows newline)
 
1889
  if token_prefix in self.keyscale_prefix_tree and self.newline_token in self.keyscale_prefix_tree[token_prefix]:
1890
  # Complete keyscale, allow newline
1891
  if self.newline_token:
1892
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
1893
  else:
1894
  # Not complete, allow valid continuation tokens
1895
  allowed = self._get_allowed_keyscale_tokens()
1896
  if allowed:
1897
+ self._apply_whitelist_inplace(scores, allowed)
 
 
1898
  else:
1899
  # No valid tokens found - force newline to end field
1900
  # This handles edge cases where keyscale format is unexpected
1901
  if self.newline_token:
1902
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
1903
 
1904
  elif self.state == FSMState.LANGUAGE_VALUE:
1905
  # Language field: Use top-1 probability language (greedy selection)
 
1917
  self.user_field_token_queue = value_tokens
1918
  self.current_user_field = "language"
1919
  # Inject first token
1920
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
 
1921
  return scores
1922
 
1923
  # If we haven't started generating language yet (empty accumulated_token_ids),
 
1929
  candidate_tokens = list(self.language_prefix_tree[empty_prefix])
1930
 
1931
  if candidate_tokens:
1932
+ # Find the token with highest probability (top-1) among candidates
1933
+ # Use tensor indexing to get scores of candidate tokens directly
1934
+ candidate_indices = torch.tensor(candidate_tokens, device=scores.device, dtype=torch.long)
1935
+ candidate_scores = scores[0, candidate_indices]
 
 
1936
 
1937
  # Get the highest probability token among candidates
1938
+ best_idx = torch.argmax(candidate_scores).item()
1939
+ top_token_id = candidate_tokens[best_idx]
1940
 
1941
+ # Only allow this top-1 token, block all others
1942
+ self._apply_whitelist_inplace(scores, [top_token_id])
 
1943
 
1944
  if self.debug:
1945
  top_token_text = self.tokenizer.decode([top_token_id])
 
1947
  else:
1948
  # No valid first tokens found - force newline
1949
  if self.newline_token:
1950
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
1951
  else:
1952
  # Empty prefix not in tree - force newline
1953
  if self.newline_token:
1954
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
1955
  else:
1956
  # We've started generating a language, continue with prefix tree constraints
1957
  # Check if current token sequence is complete (allows newline)
 
1959
  if token_prefix in self.language_prefix_tree and self.newline_token in self.language_prefix_tree[token_prefix]:
1960
  # Complete language, allow newline
1961
  if self.newline_token:
1962
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
1963
  else:
1964
  # Not complete, allow valid continuation tokens
1965
  allowed = self._get_allowed_language_tokens()
1966
  if allowed:
1967
+ self._apply_whitelist_inplace(scores, allowed)
 
 
1968
  else:
1969
  # No valid tokens found - force newline to end field
1970
  if self.newline_token:
1971
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
1972
 
1973
  elif self.state == FSMState.TIMESIG_VALUE:
1974
  # Check if field is user-provided and we haven't started injecting yet
 
1981
  self.user_field_token_queue = value_tokens
1982
  self.current_user_field = "timesignature"
1983
  # Inject first token
1984
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
 
1985
  return scores
1986
 
1987
  # Check if current token sequence is complete (allows newline)
 
1989
  if token_prefix in self.timesig_prefix_tree and self.newline_token in self.timesig_prefix_tree[token_prefix]:
1990
  # Complete value, allow newline
1991
  if self.newline_token:
1992
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
1993
  else:
1994
  # Not complete, allow valid continuation tokens
1995
  allowed = self._get_allowed_timesig_tokens()
1996
+ self._apply_whitelist_inplace(scores, allowed)
 
 
1997
 
1998
  return scores
1999
 
acestep/gradio_ui/event.py DELETED
The diff for this file is too large to render. See raw diff
 
acestep/gradio_ui/events/__init__.py CHANGED
@@ -254,48 +254,84 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
254
  ]
255
  )
256
 
257
- # Save buttons for audio 1 and 2
258
- for btn_idx, btn_key in [(1, "save_btn_1"), (2, "save_btn_2")]:
259
- results_section[btn_key].click(
260
- fn=res_h.save_audio_and_metadata,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  inputs=[
262
  results_section[f"generated_audio_{btn_idx}"],
263
- generation_section["task_type"],
264
- generation_section["captions"],
265
- generation_section["lyrics"],
266
- generation_section["vocal_language"],
267
- generation_section["bpm"],
268
- generation_section["key_scale"],
269
- generation_section["time_signature"],
270
- generation_section["audio_duration"],
271
- generation_section["batch_size_input"],
272
- generation_section["inference_steps"],
273
- generation_section["guidance_scale"],
274
- generation_section["seed"],
275
- generation_section["random_seed_checkbox"],
276
- generation_section["use_adg"],
277
- generation_section["cfg_interval_start"],
278
- generation_section["cfg_interval_end"],
279
- generation_section["audio_format"],
280
- generation_section["lm_temperature"],
281
- generation_section["lm_cfg_scale"],
282
- generation_section["lm_top_k"],
283
- generation_section["lm_top_p"],
284
- generation_section["lm_negative_prompt"],
285
- generation_section["use_cot_caption"],
286
- generation_section["use_cot_language"],
287
- generation_section["audio_cover_strength"],
288
- generation_section["think_checkbox"],
289
- generation_section["text2music_audio_code_string"],
290
- generation_section["repainting_start"],
291
- generation_section["repainting_end"],
292
- generation_section["track_name"],
293
- generation_section["complete_track_classes"],
294
- results_section["lm_metadata_state"],
295
  ],
296
- outputs=[gr.File(label="Download Package", visible=False)]
297
- )
298
-
299
  # ========== Send to SRC Handlers ==========
300
  for btn_idx in range(1, 9):
301
  results_section[f"send_to_src_btn_{btn_idx}"].click(
@@ -331,10 +367,11 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
331
  ],
332
  outputs=[results_section[f"score_display_{btn_idx}"], results_section["batch_queue"]]
333
  )
334
-
 
335
  # ========== Generation Handler ==========
336
  generation_section["generate_btn"].click(
337
- fn=lambda *args: res_h.generate_with_batch_management(dit_handler, llm_handler, *args),
338
  inputs=[
339
  generation_section["captions"],
340
  generation_section["lyrics"],
 
254
  ]
255
  )
256
 
257
+ # Save buttons for all 8 audio outputs
258
+ download_existing_js = """(current_audio, batch_files) => {
259
+ // Debug: print what the input actually is
260
+ console.log("👉 [Debug] Current Audio Input:", current_audio);
261
+
262
+ // 1. Safety check
263
+ if (!current_audio) {
264
+ console.warn("⚠️ No audio selected or audio is empty.");
265
+ return;
266
+ }
267
+ if (!batch_files || !Array.isArray(batch_files)) {
268
+ console.warn("⚠️ Batch file list is empty/not ready.");
269
+ return;
270
+ }
271
+
272
+ // 2. Smartly extract path string
273
+ let pathString = "";
274
+
275
+ if (typeof current_audio === "string") {
276
+ // Case A: direct path string received
277
+ pathString = current_audio;
278
+ } else if (typeof current_audio === "object") {
279
+ // Case B: an object is received, try common properties
280
+ // Gradio file objects usually have path, url, or name
281
+ pathString = current_audio.path || current_audio.name || current_audio.url || "";
282
+ }
283
+
284
+ if (!pathString) {
285
+ console.error("❌ Error: Could not extract a valid path string from input.", current_audio);
286
+ return;
287
+ }
288
+
289
+ // 3. Extract Key (UUID)
290
+ // Path could be /tmp/.../uuid.mp3 or url like /file=.../uuid.mp3
291
+ let filename = pathString.split(/[\\\\/]/).pop(); // get the filename
292
+ let key = filename.split('.')[0]; // get UUID without extension
293
+
294
+ console.log(`🔑 Key extracted: ${key}`);
295
+
296
+ // 4. Find matching file(s) in the list
297
+ let targets = batch_files.filter(f => {
298
+ // Also extract names from batch_files objects
299
+ // f usually contains name (backend path) and orig_name (download name)
300
+ const fPath = f.name || f.path || "";
301
+ return fPath.includes(key);
302
+ });
303
+
304
+ if (targets.length === 0) {
305
+ console.warn("❌ No matching files found in batch list for key:", key);
306
+ alert("Batch list does not contain this file yet. Please wait for generation to finish.");
307
+ return;
308
+ }
309
+
310
+ // 5. Trigger download(s)
311
+ console.log(`🎯 Found ${targets.length} files to download.`);
312
+ targets.forEach((f, index) => {
313
+ setTimeout(() => {
314
+ const a = document.createElement('a');
315
+ // Prefer url (frontend-accessible link), otherwise try data
316
+ a.href = f.url || f.data;
317
+ a.download = f.orig_name || "download";
318
+ a.style.display = 'none';
319
+ document.body.appendChild(a);
320
+ a.click();
321
+ document.body.removeChild(a);
322
+ }, index * 1000); // 300ms interval to avoid browser blocking
323
+ });
324
+ }
325
+ """
326
+ for btn_idx in range(1, 9):
327
+ results_section[f"save_btn_{btn_idx}"].click(
328
+ fn=None,
329
  inputs=[
330
  results_section[f"generated_audio_{btn_idx}"],
331
+ results_section["generated_audio_batch"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  ],
333
+ js=download_existing_js # Run the above JS
334
+ )
 
335
  # ========== Send to SRC Handlers ==========
336
  for btn_idx in range(1, 9):
337
  results_section[f"send_to_src_btn_{btn_idx}"].click(
 
367
  ],
368
  outputs=[results_section[f"score_display_{btn_idx}"], results_section["batch_queue"]]
369
  )
370
+ def generation_wrapper(*args):
371
+ yield from res_h.generate_with_batch_management(dit_handler, llm_handler, *args)
372
  # ========== Generation Handler ==========
373
  generation_section["generate_btn"].click(
374
+ fn=generation_wrapper,
375
  inputs=[
376
  generation_section["captions"],
377
  generation_section["lyrics"],
acestep/gradio_ui/events/results_handlers.py CHANGED
@@ -10,9 +10,123 @@ import tempfile
10
  import shutil
11
  import zipfile
12
  import time as time_module
 
13
  import gradio as gr
14
  from loguru import logger
15
  from acestep.gradio_ui.i18n import t
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  def store_batch_in_queue(
@@ -66,99 +180,6 @@ def update_navigation_buttons(current_batch, total_batches):
66
  can_go_next = current_batch < total_batches - 1
67
  return can_go_previous, can_go_next
68
 
69
-
70
- def save_audio_and_metadata(
71
- audio_path, task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature, audio_duration,
72
- batch_size_input, inference_steps, guidance_scale, seed, random_seed_checkbox,
73
- use_adg, cfg_interval_start, cfg_interval_end, audio_format,
74
- lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
75
- use_cot_caption, use_cot_language, audio_cover_strength,
76
- think_checkbox, text2music_audio_code_string, repainting_start, repainting_end,
77
- track_name, complete_track_classes, lm_metadata
78
- ):
79
- """Save audio file and its metadata as a zip package"""
80
- if audio_path is None:
81
- gr.Warning(t("messages.no_audio_to_save"))
82
- return None
83
-
84
- try:
85
- # Create metadata dictionary
86
- metadata = {
87
- "saved_at": datetime.datetime.now().isoformat(),
88
- "task_type": task_type,
89
- "caption": captions or "",
90
- "lyrics": lyrics or "",
91
- "vocal_language": vocal_language,
92
- "bpm": bpm if bpm is not None else None,
93
- "keyscale": key_scale or "",
94
- "timesignature": time_signature or "",
95
- "duration": audio_duration if audio_duration is not None else -1,
96
- "batch_size": batch_size_input,
97
- "inference_steps": inference_steps,
98
- "guidance_scale": guidance_scale,
99
- "seed": seed,
100
- "random_seed": False, # Disable random seed for reproducibility
101
- "use_adg": use_adg,
102
- "cfg_interval_start": cfg_interval_start,
103
- "cfg_interval_end": cfg_interval_end,
104
- "audio_format": audio_format,
105
- "lm_temperature": lm_temperature,
106
- "lm_cfg_scale": lm_cfg_scale,
107
- "lm_top_k": lm_top_k,
108
- "lm_top_p": lm_top_p,
109
- "lm_negative_prompt": lm_negative_prompt,
110
- "use_cot_caption": use_cot_caption,
111
- "use_cot_language": use_cot_language,
112
- "audio_cover_strength": audio_cover_strength,
113
- "think": think_checkbox,
114
- "audio_codes": text2music_audio_code_string or "",
115
- "repainting_start": repainting_start,
116
- "repainting_end": repainting_end,
117
- "track_name": track_name,
118
- "complete_track_classes": complete_track_classes or [],
119
- }
120
-
121
- # Add LM-generated metadata if available
122
- if lm_metadata:
123
- metadata["lm_generated_metadata"] = lm_metadata
124
-
125
- # Generate timestamp and base name
126
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
127
-
128
- # Extract audio filename extension
129
- audio_ext = os.path.splitext(audio_path)[1]
130
-
131
- # Create temporary directory for packaging
132
- temp_dir = tempfile.mkdtemp()
133
-
134
- # Save JSON metadata
135
- json_path = os.path.join(temp_dir, f"metadata_{timestamp}.json")
136
- with open(json_path, 'w', encoding='utf-8') as f:
137
- json.dump(metadata, f, indent=2, ensure_ascii=False)
138
-
139
- # Copy audio file
140
- audio_copy_path = os.path.join(temp_dir, f"audio_{timestamp}{audio_ext}")
141
- shutil.copy2(audio_path, audio_copy_path)
142
-
143
- # Create zip file
144
- zip_path = os.path.join(tempfile.gettempdir(), f"music_package_{timestamp}.zip")
145
- with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
146
- zipf.write(audio_copy_path, os.path.basename(audio_copy_path))
147
- zipf.write(json_path, os.path.basename(json_path))
148
-
149
- # Clean up temp directory
150
- shutil.rmtree(temp_dir)
151
-
152
- gr.Info(t("messages.save_success", filename=os.path.basename(zip_path)))
153
- return zip_path
154
-
155
- except Exception as e:
156
- gr.Warning(t("messages.save_failed", error=str(e)))
157
- import traceback
158
- traceback.print_exc()
159
- return None
160
-
161
-
162
  def send_audio_to_src_with_metadata(audio_file, lm_metadata):
163
  """Send generated audio file to src_audio input and populate metadata fields
164
 
@@ -254,366 +275,209 @@ def generate_with_progress(
254
  auto_score,
255
  score_scale,
256
  lm_batch_chunk_size,
257
- progress=gr.Progress(track_tqdm=True)
258
  ):
259
  """Generate audio with progress tracking"""
260
- # If think is enabled (llm_dit mode) and use_cot_metas is True, generate audio codes using LM first
261
- audio_code_string_to_use = text2music_audio_code_string
262
- lm_generated_metadata = None # Store LM-generated metadata for display
263
- lm_generated_audio_codes = None # Store LM-generated audio codes for display
264
- lm_generated_audio_codes_list = [] # Store list of audio codes for batch processing
265
-
266
- # Determine if we should use batch LM generation
267
- should_use_lm_batch = (
268
- think_checkbox and
269
- llm_handler.llm_initialized and
270
- use_cot_metas and
271
- allow_lm_batch and
272
- batch_size_input >= 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  )
274
 
275
- if think_checkbox and llm_handler.llm_initialized and use_cot_metas:
276
- # Convert top_k: 0 means None (disabled)
277
- top_k_value = None if lm_top_k == 0 else int(lm_top_k)
278
- # Convert top_p: 1.0 means None (disabled)
279
- top_p_value = None if lm_top_p >= 1.0 else lm_top_p
280
-
281
- # Build user_metadata from user-provided values (only include non-empty values)
282
- user_metadata = {}
283
- # Handle bpm: gr.Number can be None, int, float, or string
284
- if bpm is not None:
285
- try:
286
- bpm_value = float(bpm)
287
- if bpm_value > 0:
288
- user_metadata['bpm'] = str(int(bpm_value))
289
- except (ValueError, TypeError):
290
- # If bpm is not a valid number, skip it
291
- pass
292
- if key_scale and key_scale.strip():
293
- key_scale_clean = key_scale.strip()
294
- if key_scale_clean.lower() not in ["n/a", ""]:
295
- user_metadata['keyscale'] = key_scale_clean
296
- if time_signature and time_signature.strip():
297
- time_sig_clean = time_signature.strip()
298
- if time_sig_clean.lower() not in ["n/a", ""]:
299
- user_metadata['timesignature'] = time_sig_clean
300
- if audio_duration is not None:
301
- try:
302
- duration_value = float(audio_duration)
303
- if duration_value > 0:
304
- user_metadata['duration'] = str(int(duration_value))
305
- except (ValueError, TypeError):
306
- # If audio_duration is not a valid number, skip it
307
- pass
308
-
309
- # Only pass user_metadata if user provided any values, otherwise let LM generate
310
- user_metadata_to_pass = user_metadata if user_metadata else None
311
-
312
- if should_use_lm_batch:
313
- # BATCH LM GENERATION
314
- logger.info(f"Using LM batch generation for {batch_size_input} items...")
315
-
316
- # Prepare seeds for batch items
317
- actual_seed_list, _ = dit_handler.prepare_seeds(batch_size_input, seed, random_seed_checkbox)
318
-
319
- # Split batch into chunks (GPU memory constraint)
320
- max_inference_batch_size = int(lm_batch_chunk_size)
321
- num_chunks = math.ceil(batch_size_input / max_inference_batch_size)
322
-
323
- all_metadata_list = []
324
- all_audio_codes_list = []
325
-
326
- for chunk_idx in range(num_chunks):
327
- chunk_start = chunk_idx * max_inference_batch_size
328
- chunk_end = min(chunk_start + max_inference_batch_size, batch_size_input)
329
- chunk_size = chunk_end - chunk_start
330
- chunk_seeds = actual_seed_list[chunk_start:chunk_end]
331
-
332
- logger.info(f"Generating LM batch chunk {chunk_idx+1}/{num_chunks} (size: {chunk_size}, seeds: {chunk_seeds})...")
333
-
334
- # Generate batch
335
- metadata_list, audio_codes_list, status = llm_handler.generate_with_stop_condition_batch(
336
- caption=captions or "",
337
- lyrics=lyrics or "",
338
- batch_size=chunk_size,
339
- infer_type="llm_dit",
340
- temperature=lm_temperature,
341
- cfg_scale=lm_cfg_scale,
342
- negative_prompt=lm_negative_prompt,
343
- top_k=top_k_value,
344
- top_p=top_p_value,
345
- user_metadata=user_metadata_to_pass,
346
- use_cot_caption=use_cot_caption,
347
- use_cot_language=use_cot_language,
348
- is_format_caption=is_format_caption,
349
- constrained_decoding_debug=constrained_decoding_debug,
350
- seeds=chunk_seeds,
351
- )
352
-
353
- all_metadata_list.extend(metadata_list)
354
- all_audio_codes_list.extend(audio_codes_list)
355
-
356
- # Use first metadata as representative (all are same)
357
- lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
358
 
359
- # Store audio codes list for later use
360
- lm_generated_audio_codes_list = all_audio_codes_list
361
 
362
- # Prepare audio codes for DiT (list of codes, one per batch item)
363
- audio_code_string_to_use = all_audio_codes_list
 
 
 
 
 
 
 
364
 
365
- # Update metadata fields from LM if not provided by user
366
- if lm_generated_metadata:
367
- if bpm is None and lm_generated_metadata.get('bpm'):
368
- bpm_value = lm_generated_metadata.get('bpm')
369
- if bpm_value != "N/A" and bpm_value != "":
370
- try:
371
- bpm = int(bpm_value)
372
- except:
373
- pass
374
- if not key_scale and lm_generated_metadata.get('keyscale'):
375
- key_scale_value = lm_generated_metadata.get('keyscale', lm_generated_metadata.get('key_scale', ""))
376
- if key_scale_value != "N/A":
377
- key_scale = key_scale_value
378
- if not time_signature and lm_generated_metadata.get('timesignature'):
379
- time_signature_value = lm_generated_metadata.get('timesignature', lm_generated_metadata.get('time_signature', ""))
380
- if time_signature_value != "N/A":
381
- time_signature = time_signature_value
382
- if audio_duration is None or audio_duration <= 0:
383
- audio_duration_value = lm_generated_metadata.get('duration', -1)
384
- if audio_duration_value != "N/A" and audio_duration_value != "":
385
- try:
386
- audio_duration = float(audio_duration_value)
387
- except:
388
- pass
389
- else:
390
- # SEQUENTIAL LM GENERATION (current behavior, when allow_lm_batch is False)
391
- # Phase 1: Generate CoT metadata
392
- phase1_start = time_module.time()
393
- metadata, _, status = llm_handler.generate_with_stop_condition(
394
- caption=captions or "",
395
- lyrics=lyrics or "",
396
- infer_type="dit", # Only generate metadata in Phase 1
397
- temperature=lm_temperature,
398
- cfg_scale=lm_cfg_scale,
399
- negative_prompt=lm_negative_prompt,
400
- top_k=top_k_value,
401
- top_p=top_p_value,
402
- user_metadata=user_metadata_to_pass,
403
- use_cot_caption=use_cot_caption,
404
- use_cot_language=use_cot_language,
405
- is_format_caption=is_format_caption,
406
- constrained_decoding_debug=constrained_decoding_debug,
407
- )
408
- lm_phase1_time = time_module.time() - phase1_start
409
- logger.info(f"LM Phase 1 (CoT) completed in {lm_phase1_time:.2f}s")
410
-
411
- # Phase 2: Generate audio codes
412
- phase2_start = time_module.time()
413
- metadata, audio_codes, status = llm_handler.generate_with_stop_condition(
414
- caption=captions or "",
415
- lyrics=lyrics or "",
416
- infer_type="llm_dit", # Generate both metadata and codes
417
- temperature=lm_temperature,
418
- cfg_scale=lm_cfg_scale,
419
- negative_prompt=lm_negative_prompt,
420
- top_k=top_k_value,
421
- top_p=top_p_value,
422
- user_metadata=user_metadata_to_pass,
423
- use_cot_caption=use_cot_caption,
424
- use_cot_language=use_cot_language,
425
- is_format_caption=is_format_caption,
426
- constrained_decoding_debug=constrained_decoding_debug,
427
  )
428
- lm_phase2_time = time_module.time() - phase2_start
429
- logger.info(f"LM Phase 2 (Codes) completed in {lm_phase2_time:.2f}s")
430
-
431
- # Store LM-generated metadata and audio codes for display
432
- lm_generated_metadata = metadata
433
- if audio_codes:
434
- audio_code_string_to_use = audio_codes
435
- lm_generated_audio_codes = audio_codes
436
- # Update metadata fields only if they are empty/None (user didn't provide them)
437
- if bpm is None and metadata.get('bpm'):
438
- bpm_value = metadata.get('bpm')
439
- if bpm_value != "N/A" and bpm_value != "":
440
- try:
441
- bpm = int(bpm_value)
442
- except:
443
- pass
444
- if not key_scale and metadata.get('keyscale'):
445
- key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
446
- if key_scale_value != "N/A":
447
- key_scale = key_scale_value
448
- if not time_signature and metadata.get('timesignature'):
449
- time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
450
- if time_signature_value != "N/A":
451
- time_signature = time_signature_value
452
- if audio_duration is None or audio_duration <= 0:
453
- audio_duration_value = metadata.get('duration', -1)
454
- if audio_duration_value != "N/A" and audio_duration_value != "":
455
- try:
456
- audio_duration = float(audio_duration_value)
457
- except:
458
- pass
459
-
460
- # Call generate_music and get results
461
- result = dit_handler.generate_music(
462
- captions=captions, lyrics=lyrics, bpm=bpm, key_scale=key_scale,
463
- time_signature=time_signature, vocal_language=vocal_language,
464
- inference_steps=inference_steps, guidance_scale=guidance_scale,
465
- use_random_seed=random_seed_checkbox, seed=seed,
466
- reference_audio=reference_audio, audio_duration=audio_duration,
467
- batch_size=batch_size_input, src_audio=src_audio,
468
- audio_code_string=audio_code_string_to_use,
469
- repainting_start=repainting_start, repainting_end=repainting_end,
470
- instruction=instruction_display_gen, audio_cover_strength=audio_cover_strength,
471
- task_type=task_type, use_adg=use_adg,
472
- cfg_interval_start=cfg_interval_start, cfg_interval_end=cfg_interval_end,
473
- audio_format=audio_format, lm_temperature=lm_temperature,
474
- progress=progress
475
- )
476
-
477
- # Extract results
478
- first_audio, second_audio, all_audio_paths, generation_info, status_message, seed_value_for_ui, \
479
- align_score_1, align_text_1, align_plot_1, align_score_2, align_text_2, align_plot_2 = result
480
-
481
- # Extract LM timing from status if available and prepend to generation_info
482
- if status:
483
- import re
484
- # Try to extract timing info from status using regex
485
- # Expected format: "Phase1: X.XXs" and "Phase2: X.XXs"
486
- phase1_match = re.search(r'Phase1:\s*([\d.]+)s', status)
487
- phase2_match = re.search(r'Phase2:\s*([\d.]+)s', status)
488
-
489
- if phase1_match or phase2_match:
490
- lm_timing_section = "\n\n**🤖 LM Timing:**\n"
491
- lm_total = 0.0
492
- if phase1_match:
493
- phase1_time = float(phase1_match.group(1))
494
- lm_timing_section += f" - Phase 1 (CoT Metadata): {phase1_time:.2f}s\n"
495
- lm_total += phase1_time
496
- if phase2_match:
497
- phase2_time = float(phase2_match.group(1))
498
- lm_timing_section += f" - Phase 2 (Audio Codes): {phase2_time:.2f}s\n"
499
- lm_total += phase2_time
500
- if lm_total > 0:
501
- lm_timing_section += f" - Total LM Time: {lm_total:.2f}s\n"
502
- generation_info = lm_timing_section + "\n" + generation_info
503
-
504
- # Append LM-generated metadata to generation_info if available
505
- if lm_generated_metadata:
506
- metadata_lines = []
507
- if lm_generated_metadata.get('bpm'):
508
- metadata_lines.append(f"- **BPM:** {lm_generated_metadata['bpm']}")
509
- if lm_generated_metadata.get('caption'):
510
- metadata_lines.append(f"- **User Query Rewritten Caption:** {lm_generated_metadata['caption']}")
511
- if lm_generated_metadata.get('duration'):
512
- metadata_lines.append(f"- **Duration:** {lm_generated_metadata['duration']} seconds")
513
- if lm_generated_metadata.get('keyscale'):
514
- metadata_lines.append(f"- **KeyScale:** {lm_generated_metadata['keyscale']}")
515
- if lm_generated_metadata.get('language'):
516
- metadata_lines.append(f"- **Language:** {lm_generated_metadata['language']}")
517
- if lm_generated_metadata.get('timesignature'):
518
- metadata_lines.append(f"- **Time Signature:** {lm_generated_metadata['timesignature']}")
519
-
520
- if metadata_lines:
521
- metadata_section = "\n\n**🤖 LM-Generated Metadata:**\n" + "\n\n".join(metadata_lines)
522
- generation_info = metadata_section + "\n\n" + generation_info
523
-
524
- # Update audio codes in UI if LM generated them
525
- codes_outputs = [""] * 8 # Codes for 8 components
526
- if should_use_lm_batch and lm_generated_audio_codes_list:
527
- # Batch mode: update individual codes inputs
528
- for idx in range(min(len(lm_generated_audio_codes_list), 8)):
529
- codes_outputs[idx] = lm_generated_audio_codes_list[idx]
530
- # For single codes input, show first one
531
- updated_audio_codes = lm_generated_audio_codes_list[0] if lm_generated_audio_codes_list else text2music_audio_code_string
532
- else:
533
- # Single mode: update main codes input
534
- updated_audio_codes = lm_generated_audio_codes if lm_generated_audio_codes else text2music_audio_code_string
535
-
536
- # AUTO-SCORING
537
- score_displays = [""] * 8 # Scores for 8 components
538
- if auto_score and all_audio_paths:
539
- logger.info(f"Auto-scoring enabled, calculating quality scores for {batch_size_input} generated audios...")
540
-
541
- # Determine which audio codes to use for scoring
542
- if should_use_lm_batch and lm_generated_audio_codes_list:
543
- codes_list = lm_generated_audio_codes_list
544
- elif audio_code_string_to_use and isinstance(audio_code_string_to_use, list):
545
- codes_list = audio_code_string_to_use
546
  else:
547
- # Single code string, replicate for all audios
548
- codes_list = [audio_code_string_to_use] * len(all_audio_paths)
549
-
550
- # Calculate scores only for actually generated audios (up to batch_size_input)
551
- # Don't score beyond the actual batch size to avoid duplicates
552
- actual_audios_to_score = min(len(all_audio_paths), int(batch_size_input))
553
- for idx in range(actual_audios_to_score):
554
- if idx < len(codes_list) and codes_list[idx]:
555
- try:
556
- score_display = calculate_score_handler(
557
- llm_handler,
558
- codes_list[idx],
559
- captions,
560
- lyrics,
561
- lm_generated_metadata,
562
- bpm, key_scale, time_signature, audio_duration, vocal_language,
563
- score_scale
564
- )
565
- score_displays[idx] = score_display
566
- logger.info(f"Auto-scored audio {idx+1}")
567
- except Exception as e:
568
- logger.error(f"Auto-scoring failed for audio {idx+1}: {e}")
569
- score_displays[idx] = f"❌ Auto-scoring failed: {str(e)}"
570
-
571
- # Prepare audio outputs (up to 8)
572
- audio_outputs = [None] * 8
573
- for idx in range(min(len(all_audio_paths), 8)):
574
- audio_outputs[idx] = all_audio_paths[idx]
575
 
576
- return (
577
- audio_outputs[0], # generated_audio_1
578
- audio_outputs[1], # generated_audio_2
579
- audio_outputs[2], # generated_audio_3
580
- audio_outputs[3], # generated_audio_4
581
- audio_outputs[4], # generated_audio_5
582
- audio_outputs[5], # generated_audio_6
583
- audio_outputs[6], # generated_audio_7
584
- audio_outputs[7], # generated_audio_8
585
- all_audio_paths, # generated_audio_batch
586
  generation_info,
587
- status_message,
588
  seed_value_for_ui,
589
- align_score_1,
590
- align_text_1,
591
- align_plot_1,
592
- align_score_2,
593
- align_text_2,
594
- align_plot_2,
595
- score_displays[0], # score_display_1
596
- score_displays[1], # score_display_2
597
- score_displays[2], # score_display_3
598
- score_displays[3], # score_display_4
599
- score_displays[4], # score_display_5
600
- score_displays[5], # score_display_6
601
- score_displays[6], # score_display_7
602
- score_displays[7], # score_display_8
603
- updated_audio_codes, # Update main audio codes in UI
604
- codes_outputs[0], # text2music_audio_code_string_1
605
- codes_outputs[1], # text2music_audio_code_string_2
606
- codes_outputs[2], # text2music_audio_code_string_3
607
- codes_outputs[3], # text2music_audio_code_string_4
608
- codes_outputs[4], # text2music_audio_code_string_5
609
- codes_outputs[5], # text2music_audio_code_string_6
610
- codes_outputs[6], # text2music_audio_code_string_7
611
- codes_outputs[7], # text2music_audio_code_string_8
612
- lm_generated_metadata, # Store metadata for "Send to src audio" buttons
613
- is_format_caption, # Keep is_format_caption unchanged
614
  )
615
 
616
 
 
617
  def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale):
618
  """
619
  Calculate PMI-based quality score for generated audio.
@@ -756,7 +620,9 @@ def calculate_score_handler_with_selection(llm_handler, sample_idx, score_scale,
756
  if stored_allow_lm_batch and isinstance(stored_codes, list):
757
  # Batch mode: use specific sample's codes
758
  if 0 <= sample_idx - 1 < len(stored_codes):
759
- audio_codes_str = stored_codes[sample_idx - 1]
 
 
760
  else:
761
  # Single mode: all samples use same codes
762
  audio_codes_str = stored_codes if isinstance(stored_codes, str) else ""
@@ -868,7 +734,7 @@ def generate_with_batch_management(
868
  Wrapper for generate_with_progress that adds batch queue management
869
  """
870
  # Call the original generation function
871
- result = generate_with_progress(
872
  dit_handler, llm_handler,
873
  captions, lyrics, bpm, key_scale, time_signature, vocal_language,
874
  inference_steps, guidance_scale, random_seed_checkbox, seed,
@@ -885,23 +751,41 @@ def generate_with_batch_management(
885
  lm_batch_chunk_size,
886
  progress
887
  )
888
-
889
- # Extract results from generation
890
- all_audio_paths = result[8] # generated_audio_batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
891
  generation_info = result[9]
892
  seed_value_for_ui = result[11]
893
- lm_generated_metadata = result[34] # Index 34 is lm_metadata_state
894
 
895
  # Extract codes
896
  generated_codes_single = result[26]
897
  generated_codes_batch = [result[27], result[28], result[29], result[30], result[31], result[32], result[33], result[34]]
898
-
899
  # Determine which codes to store based on mode
900
  if allow_lm_batch and batch_size_input >= 2:
901
  codes_to_store = generated_codes_batch[:int(batch_size_input)]
902
  else:
903
  codes_to_store = generated_codes_single
904
-
905
  # Save parameters for history
906
  saved_params = {
907
  "captions": captions,
@@ -947,6 +831,7 @@ def generate_with_batch_management(
947
  }
948
 
949
  # Next batch parameters (with cleared codes & random seed)
 
950
  next_params = saved_params.copy()
951
  next_params["text2music_audio_code_string"] = ""
952
  next_params["random_seed_checkbox"] = True
@@ -979,9 +864,10 @@ def generate_with_batch_management(
979
  next_batch_status_text = ""
980
  if autogen_checkbox:
981
  next_batch_status_text = t("messages.autogen_enabled")
982
-
983
- # Return original results plus batch management state updates
984
- return result + (
 
985
  current_batch_index,
986
  total_batches,
987
  batch_queue,
@@ -1097,7 +983,8 @@ def generate_next_batch_background(
1097
  params.setdefault("complete_track_classes", [])
1098
 
1099
  # Call generate_with_progress with the saved parameters
1100
- result = generate_with_progress(
 
1101
  dit_handler,
1102
  llm_handler,
1103
  captions=params.get("captions"),
@@ -1142,15 +1029,20 @@ def generate_next_batch_background(
1142
  progress=progress
1143
  )
1144
 
1145
- # Extract results
1146
- all_audio_paths = result[8] # generated_audio_batch
1147
- generation_info = result[9]
1148
- seed_value_for_ui = result[11]
1149
- lm_generated_metadata = result[34] # Index 34 is lm_metadata_state
 
 
 
 
 
1150
 
1151
  # Extract codes
1152
- generated_codes_single = result[26]
1153
- generated_codes_batch = [result[27], result[28], result[29], result[30], result[31], result[32], result[33], result[34]]
1154
 
1155
  # Determine which codes to store
1156
  batch_size = params.get("batch_size_input", 2)
@@ -1240,8 +1132,9 @@ def navigate_to_previous_batch(current_batch_index, batch_queue):
1240
 
1241
  # Prepare audio outputs (up to 8)
1242
  audio_outputs = [None] * 8
1243
- for idx in range(min(len(audio_paths), 8)):
1244
- audio_outputs[idx] = audio_paths[idx]
 
1245
 
1246
  # Update batch indicator
1247
  total_batches = len(batch_queue)
@@ -1286,8 +1179,9 @@ def navigate_to_next_batch(autogen_enabled, current_batch_index, total_batches,
1286
 
1287
  # Prepare audio outputs (up to 8)
1288
  audio_outputs = [None] * 8
1289
- for idx in range(min(len(audio_paths), 8)):
1290
- audio_outputs[idx] = audio_paths[idx]
 
1291
 
1292
  # Update batch indicator
1293
  batch_indicator_text = update_batch_indicator(new_batch_index, total_batches)
 
10
  import shutil
11
  import zipfile
12
  import time as time_module
13
+ from typing import Dict, Any, Optional
14
  import gradio as gr
15
  from loguru import logger
16
  from acestep.gradio_ui.i18n import t
17
+ from acestep.inference import generate_music, GenerationParams, GenerationConfig
18
+ from acestep.audio_utils import save_audio
19
+
20
+
21
+ def _build_generation_info(
22
+ lm_metadata: Optional[Dict[str, Any]],
23
+ time_costs: Dict[str, float],
24
+ seed_value: str,
25
+ inference_steps: int,
26
+ num_audios: int,
27
+ ) -> str:
28
+ """Build generation info string from result data.
29
+
30
+ Args:
31
+ lm_metadata: LM-generated metadata dictionary
32
+ time_costs: Unified time costs dictionary
33
+ seed_value: Seed value string
34
+ inference_steps: Number of inference steps
35
+ num_audios: Number of generated audios
36
+
37
+ Returns:
38
+ Formatted generation info string
39
+ """
40
+ info_parts = []
41
+
42
+ # Part 1: LM-generated metadata (if available)
43
+ if lm_metadata:
44
+ metadata_lines = []
45
+ if lm_metadata.get('bpm'):
46
+ metadata_lines.append(f"- **BPM:** {lm_metadata['bpm']}")
47
+ if lm_metadata.get('caption'):
48
+ metadata_lines.append(f"- **Refined Caption:** {lm_metadata['caption']}")
49
+ if lm_metadata.get('lyrics'):
50
+ metadata_lines.append(f"- **Refined Lyrics:** {lm_metadata['lyrics']}")
51
+ if lm_metadata.get('duration'):
52
+ metadata_lines.append(f"- **Duration:** {lm_metadata['duration']} seconds")
53
+ if lm_metadata.get('keyscale'):
54
+ metadata_lines.append(f"- **Key Scale:** {lm_metadata['keyscale']}")
55
+ if lm_metadata.get('language'):
56
+ metadata_lines.append(f"- **Language:** {lm_metadata['language']}")
57
+ if lm_metadata.get('timesignature'):
58
+ metadata_lines.append(f"- **Time Signature:** {lm_metadata['timesignature']}")
59
+
60
+ if metadata_lines:
61
+ metadata_section = "**🤖 LM-Generated Metadata:**\n" + "\n".join(metadata_lines)
62
+ info_parts.append(metadata_section)
63
+
64
+ # Part 2: Time costs (formatted and beautified)
65
+ if time_costs:
66
+ time_lines = []
67
+
68
+ # LM time costs
69
+ lm_phase1 = time_costs.get('lm_phase1_time', 0.0)
70
+ lm_phase2 = time_costs.get('lm_phase2_time', 0.0)
71
+ lm_total = time_costs.get('lm_total_time', 0.0)
72
+
73
+ if lm_total > 0:
74
+ time_lines.append("**🧠 LM Time:**")
75
+ if lm_phase1 > 0:
76
+ time_lines.append(f" - Phase 1 (CoT): {lm_phase1:.2f}s")
77
+ if lm_phase2 > 0:
78
+ time_lines.append(f" - Phase 2 (Codes): {lm_phase2:.2f}s")
79
+ time_lines.append(f" - Total: {lm_total:.2f}s")
80
+
81
+ # DiT time costs
82
+ dit_encoder = time_costs.get('dit_encoder_time_cost', 0.0)
83
+ dit_model = time_costs.get('dit_model_time_cost', 0.0)
84
+ dit_vae_decode = time_costs.get('dit_vae_decode_time_cost', 0.0)
85
+ dit_offload = time_costs.get('dit_offload_time_cost', 0.0)
86
+ dit_total = time_costs.get('dit_total_time_cost', 0.0)
87
+ if dit_total > 0:
88
+ time_lines.append("\n**🎵 DiT Time:**")
89
+ if dit_encoder > 0:
90
+ time_lines.append(f" - Encoder: {dit_encoder:.2f}s")
91
+ if dit_model > 0:
92
+ time_lines.append(f" - Model: {dit_model:.2f}s")
93
+ if dit_vae_decode > 0:
94
+ time_lines.append(f" - VAE Decode: {dit_vae_decode:.2f}s")
95
+ if dit_offload > 0:
96
+ time_lines.append(f" - Offload: {dit_offload:.2f}s")
97
+ time_lines.append(f" - Total: {dit_total:.2f}s")
98
+
99
+ # Post-processing time costs
100
+ audio_conversion_time = time_costs.get('audio_conversion_time', 0.0)
101
+ auto_score_time = time_costs.get('auto_score_time', 0.0)
102
+
103
+ if audio_conversion_time > 0 or auto_score_time > 0:
104
+ time_lines.append("\n**🔧 Post-processing Time:**")
105
+ if audio_conversion_time > 0:
106
+ time_lines.append(f" - Audio Conversion: {audio_conversion_time:.2f}s")
107
+ if auto_score_time > 0:
108
+ time_lines.append(f" - Auto Score: {auto_score_time:.2f}s")
109
+
110
+ # Pipeline total
111
+ pipeline_total = time_costs.get('pipeline_total_time', 0.0)
112
+ if pipeline_total > 0:
113
+ time_lines.append(f"\n**⏱️ Pipeline Total: {pipeline_total:.2f}s**")
114
+
115
+ if time_lines:
116
+ time_section = "\n".join(time_lines)
117
+ info_parts.append(time_section)
118
+
119
+ # Part 3: Generation summary
120
+ summary_lines = [
121
+ "**🎵 Generation Complete**",
122
+ f" - **Seeds:** {seed_value}",
123
+ f" - **Steps:** {inference_steps}",
124
+ f" - **Audio Count:** {num_audios} audio(s)",
125
+ ]
126
+ info_parts.append("\n".join(summary_lines))
127
+
128
+ # Combine all parts
129
+ return "\n\n".join(info_parts)
130
 
131
 
132
  def store_batch_in_queue(
 
180
  can_go_next = current_batch < total_batches - 1
181
  return can_go_previous, can_go_next
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  def send_audio_to_src_with_metadata(audio_file, lm_metadata):
184
  """Send generated audio file to src_audio input and populate metadata fields
185
 
 
275
  auto_score,
276
  score_scale,
277
  lm_batch_chunk_size,
278
+ progress=gr.Progress(track_tqdm=True),
279
  ):
280
  """Generate audio with progress tracking"""
281
+
282
+ # step 1: prepare inputs
283
+ # generate_music, GenerationParams, GenerationConfig
284
+ gen_params = GenerationParams(
285
+ task_type=task_type,
286
+ instruction=instruction_display_gen,
287
+ reference_audio=reference_audio,
288
+ src_audio=src_audio,
289
+ audio_codes=text2music_audio_code_string if not think_checkbox else "",
290
+ caption=captions or "",
291
+ lyrics=lyrics or "",
292
+ instrumental=False,
293
+ vocal_language=vocal_language,
294
+ bpm=bpm,
295
+ keyscale=key_scale,
296
+ timesignature=time_signature,
297
+ duration=audio_duration,
298
+ inference_steps=inference_steps,
299
+ guidance_scale=guidance_scale,
300
+ use_adg=use_adg,
301
+ cfg_interval_start=cfg_interval_start,
302
+ cfg_interval_end=cfg_interval_end,
303
+ repainting_start=repainting_start,
304
+ repainting_end=repainting_end,
305
+ audio_cover_strength=audio_cover_strength,
306
+ thinking=think_checkbox,
307
+ lm_temperature=lm_temperature,
308
+ lm_cfg_scale=lm_cfg_scale,
309
+ lm_top_k=lm_top_k,
310
+ lm_top_p=lm_top_p,
311
+ lm_negative_prompt=lm_negative_prompt,
312
+ use_cot_metas=use_cot_metas,
313
+ use_cot_caption=use_cot_caption,
314
+ use_cot_language=use_cot_language,
315
+ use_constrained_decoding=True,
316
+ )
317
+ # seed string to list
318
+ if isinstance(seed, str) and seed.strip():
319
+ if "," in seed:
320
+ seed_list = [int(s.strip()) for s in seed.split(",")]
321
+ else:
322
+ seed_list = [int(seed.strip())]
323
+ else:
324
+ seed_list = None
325
+ gen_config = GenerationConfig(
326
+ batch_size=batch_size_input,
327
+ allow_lm_batch=allow_lm_batch,
328
+ use_random_seed=random_seed_checkbox,
329
+ seeds=seed_list,
330
+ lm_batch_chunk_size=lm_batch_chunk_size,
331
+ constrained_decoding_debug=constrained_decoding_debug,
332
+ audio_format=audio_format,
333
+ )
334
+ result = generate_music(
335
+ dit_handler,
336
+ llm_handler,
337
+ params=gen_params,
338
+ config=gen_config,
339
+ progress=progress,
340
  )
341
 
342
+ audio_outputs = [None] * 8
343
+ all_audio_paths = []
344
+ final_codes_list = [""] * 8
345
+ final_scores_list = [""] * 8
346
+
347
+ # Build generation_info from result data
348
+ status_message = result.status_message
349
+ seed_value_for_ui = result.extra_outputs.get("seed_value", "")
350
+ lm_generated_metadata = result.extra_outputs.get("lm_metadata", {})
351
+ time_costs = result.extra_outputs.get("time_costs", {}).copy()
352
+
353
+ # Initialize post-processing timing
354
+ audio_conversion_start_time = time_module.time()
355
+ total_auto_score_time = 0.0
356
+
357
+ align_score_1 = ""
358
+ align_text_1 = ""
359
+ align_plot_1 = None
360
+ align_score_2 = ""
361
+ align_text_2 = ""
362
+ align_plot_2 = None
363
+ updated_audio_codes = text2music_audio_code_string if not think_checkbox else ""
364
+
365
+ # Build initial generation_info (will be updated with post-processing times at the end)
366
+ generation_info = _build_generation_info(
367
+ lm_metadata=lm_generated_metadata,
368
+ time_costs=time_costs,
369
+ seed_value=seed_value_for_ui,
370
+ inference_steps=inference_steps,
371
+ num_audios=len(result.audios) if result.success else 0,
372
+ )
373
+
374
+ if not result.success:
375
+ yield (None,) * 8 + (None, generation_info, result.status_message) + (gr.skip(),) * 26
376
+ return
377
+
378
+ audios = result.audios
379
+ progress(0.99, "Converting audio to mp3...")
380
+ for i in range(8):
381
+ if i < len(audios):
382
+ key = audios[i]["key"]
383
+ audio_tensor = audios[i]["tensor"]
384
+ sample_rate = audios[i]["sample_rate"]
385
+ audio_params = audios[i]["params"]
386
+ temp_dir = tempfile.mkdtemp(f"acestep_gradio_results/")
387
+ os.makedirs(temp_dir, exist_ok=True)
388
+ json_path = os.path.join(temp_dir, f"{key}.json")
389
+ audio_path = os.path.join(temp_dir, f"{key}.{audio_format}")
390
+ save_audio(audio_data=audio_tensor, output_path=audio_path, sample_rate=sample_rate, format=audio_format, channels_first=True)
391
+ with open(json_path, 'w', encoding='utf-8') as f:
392
+ json.dump(audio_params, f, indent=2, ensure_ascii=False)
393
+ audio_outputs[i] = audio_path
394
+ all_audio_paths.append(audio_path)
395
+ all_audio_paths.append(json_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
+ code_str = audio_params.get("audio_codes", "")
398
+ final_codes_list[i] = code_str
399
 
400
+ scores_ui_updates = [gr.skip()] * 8
401
+ score_str = "Done!"
402
+ if auto_score:
403
+ auto_score_start = time_module.time()
404
+ score_str = calculate_score_handler(llm_handler, code_str, captions, lyrics, lm_generated_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale)
405
+ auto_score_end = time_module.time()
406
+ total_auto_score_time += (auto_score_end - auto_score_start)
407
+ scores_ui_updates[i] = score_str
408
+ final_scores_list[i] = score_str
409
 
410
+ status_message = f"Encoding & Ready: {i+1}/{len(audios)}"
411
+ current_audio_updates = [gr.skip()] * 8
412
+ current_audio_updates[i] = audio_path
413
+
414
+ audio_codes_ui_updates = [gr.skip()] * 8
415
+ audio_codes_ui_updates[i] = code_str
416
+ yield (
417
+ current_audio_updates[0], current_audio_updates[1], current_audio_updates[2], current_audio_updates[3],
418
+ current_audio_updates[4], current_audio_updates[5], current_audio_updates[6], current_audio_updates[7],
419
+ all_audio_paths, # Real-time update of Batch File list
420
+ generation_info,
421
+ status_message,
422
+ seed_value_for_ui,
423
+ # Align plot placeholders (assume no need to update in real time)
424
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
425
+ # Scores
426
+ scores_ui_updates[0], scores_ui_updates[1], scores_ui_updates[2], scores_ui_updates[3], scores_ui_updates[4], scores_ui_updates[5], scores_ui_updates[6], scores_ui_updates[7],
427
+ updated_audio_codes,
428
+ # Codes
429
+ audio_codes_ui_updates[0], audio_codes_ui_updates[1], audio_codes_ui_updates[2], audio_codes_ui_updates[3],
430
+ audio_codes_ui_updates[4], audio_codes_ui_updates[5], audio_codes_ui_updates[6], audio_codes_ui_updates[7],
431
+ lm_generated_metadata,
432
+ is_format_caption,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  else:
435
+ # If i exceeds the generated count (e.g., batch=2, i=2..7), do not yield
436
+ pass
437
+ time_module.sleep(0.1)
438
+
439
+ # Record audio conversion time
440
+ audio_conversion_end_time = time_module.time()
441
+ audio_conversion_time = audio_conversion_end_time - audio_conversion_start_time
442
+
443
+ # Add post-processing times to time_costs
444
+ if audio_conversion_time > 0:
445
+ time_costs['audio_conversion_time'] = audio_conversion_time
446
+ if total_auto_score_time > 0:
447
+ time_costs['auto_score_time'] = total_auto_score_time
448
+
449
+ # Update pipeline total time to include post-processing
450
+ if 'pipeline_total_time' in time_costs:
451
+ time_costs['pipeline_total_time'] += audio_conversion_time + total_auto_score_time
452
+
453
+ # Rebuild generation_info with complete timing information
454
+ generation_info = _build_generation_info(
455
+ lm_metadata=lm_generated_metadata,
456
+ time_costs=time_costs,
457
+ seed_value=seed_value_for_ui,
458
+ inference_steps=inference_steps,
459
+ num_audios=len(result.audios),
460
+ )
 
 
461
 
462
+ yield (
463
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(), # Audio 1-4: SKIP
464
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(), # Audio 5-8: SKIP
465
+ all_audio_paths,
 
 
 
 
 
 
466
  generation_info,
467
+ "Generation Complete",
468
  seed_value_for_ui,
469
+ align_score_1, align_text_1, align_plot_1, align_score_2, align_text_2, align_plot_2,
470
+ final_scores_list[0], final_scores_list[1], final_scores_list[2], final_scores_list[3],
471
+ final_scores_list[4], final_scores_list[5], final_scores_list[6], final_scores_list[7],
472
+ updated_audio_codes,
473
+ final_codes_list[0], final_codes_list[1], final_codes_list[2], final_codes_list[3],
474
+ final_codes_list[4], final_codes_list[5], final_codes_list[6], final_codes_list[7],
475
+ lm_generated_metadata,
476
+ is_format_caption,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  )
478
 
479
 
480
+
481
  def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale):
482
  """
483
  Calculate PMI-based quality score for generated audio.
 
620
  if stored_allow_lm_batch and isinstance(stored_codes, list):
621
  # Batch mode: use specific sample's codes
622
  if 0 <= sample_idx - 1 < len(stored_codes):
623
+ code_item = stored_codes[sample_idx - 1]
624
+ # Ensure it's a string (handle cases where dict was mistakenly stored)
625
+ audio_codes_str = code_item if isinstance(code_item, str) else ""
626
  else:
627
  # Single mode: all samples use same codes
628
  audio_codes_str = stored_codes if isinstance(stored_codes, str) else ""
 
734
  Wrapper for generate_with_progress that adds batch queue management
735
  """
736
  # Call the original generation function
737
+ generator = generate_with_progress(
738
  dit_handler, llm_handler,
739
  captions, lyrics, bpm, key_scale, time_signature, vocal_language,
740
  inference_steps, guidance_scale, random_seed_checkbox, seed,
 
751
  lm_batch_chunk_size,
752
  progress
753
  )
754
+ final_result_from_inner = None
755
+ for partial_result in generator:
756
+ final_result_from_inner = partial_result
757
+ # current_batch_index, total_batches, batch_queue, next_params,
758
+ # batch_indicator_text, prev_btn, next_btn, next_status, restore_btn
759
+ yield partial_result + (
760
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(),
761
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
762
+ )
763
+ result = final_result_from_inner
764
+ all_audio_paths = result[8]
765
+
766
+ if all_audio_paths is None:
767
+
768
+ yield result + (
769
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(),
770
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
771
+ )
772
+ return
773
+
774
+ # Extract results from generation (使用 result 下标访问)
775
  generation_info = result[9]
776
  seed_value_for_ui = result[11]
777
+ lm_generated_metadata = result[35] # Fixed: lm_metadata is at index 35, not 34
778
 
779
  # Extract codes
780
  generated_codes_single = result[26]
781
  generated_codes_batch = [result[27], result[28], result[29], result[30], result[31], result[32], result[33], result[34]]
782
+
783
  # Determine which codes to store based on mode
784
  if allow_lm_batch and batch_size_input >= 2:
785
  codes_to_store = generated_codes_batch[:int(batch_size_input)]
786
  else:
787
  codes_to_store = generated_codes_single
788
+
789
  # Save parameters for history
790
  saved_params = {
791
  "captions": captions,
 
831
  }
832
 
833
  # Next batch parameters (with cleared codes & random seed)
834
+ # Next batch parameters
835
  next_params = saved_params.copy()
836
  next_params["text2music_audio_code_string"] = ""
837
  next_params["random_seed_checkbox"] = True
 
864
  next_batch_status_text = ""
865
  if autogen_checkbox:
866
  next_batch_status_text = t("messages.autogen_enabled")
867
+
868
+ # 4. Yield final result (includes Batch UI updates)
869
+ # The result here is already a tuple structure
870
+ yield result + (
871
  current_batch_index,
872
  total_batches,
873
  batch_queue,
 
983
  params.setdefault("complete_track_classes", [])
984
 
985
  # Call generate_with_progress with the saved parameters
986
+ # Note: generate_with_progress is a generator, need to iterate through it
987
+ generator = generate_with_progress(
988
  dit_handler,
989
  llm_handler,
990
  captions=params.get("captions"),
 
1029
  progress=progress
1030
  )
1031
 
1032
+ # Consume generator to get final result (similar to generate_with_batch_management)
1033
+ final_result = None
1034
+ for partial_result in generator:
1035
+ final_result = partial_result
1036
+
1037
+ # Extract results from final_result
1038
+ all_audio_paths = final_result[8] # generated_audio_batch
1039
+ generation_info = final_result[9]
1040
+ seed_value_for_ui = final_result[11]
1041
+ lm_generated_metadata = final_result[35] # Fixed: lm_metadata is at index 35, not 34
1042
 
1043
  # Extract codes
1044
+ generated_codes_single = final_result[26]
1045
+ generated_codes_batch = [final_result[27], final_result[28], final_result[29], final_result[30], final_result[31], final_result[32], final_result[33], final_result[34]]
1046
 
1047
  # Determine which codes to store
1048
  batch_size = params.get("batch_size_input", 2)
 
1132
 
1133
  # Prepare audio outputs (up to 8)
1134
  audio_outputs = [None] * 8
1135
+ real_audio_paths = [p for p in audio_paths if not p.lower().endswith('.json')]
1136
+ for idx in range(min(len(real_audio_paths), 8)):
1137
+ audio_outputs[idx] = real_audio_paths[idx]
1138
 
1139
  # Update batch indicator
1140
  total_batches = len(batch_queue)
 
1179
 
1180
  # Prepare audio outputs (up to 8)
1181
  audio_outputs = [None] * 8
1182
+ real_audio_paths = [p for p in audio_paths if not p.lower().endswith('.json')]
1183
+ for idx in range(min(len(real_audio_paths), 8)):
1184
+ audio_outputs[idx] = real_audio_paths[idx]
1185
 
1186
  # Update batch indicator
1187
  batch_indicator_text = update_batch_indicator(new_batch_index, total_batches)
acestep/gradio_ui/interfaces/result.py CHANGED
@@ -28,7 +28,8 @@ def create_results_section(dit_handler) -> dict:
28
  generated_audio_1 = gr.Audio(
29
  label=t("results.generated_music", n=1),
30
  type="filepath",
31
- interactive=False
 
32
  )
33
  with gr.Row(equal_height=True):
34
  send_to_src_btn_1 = gr.Button(
@@ -58,7 +59,8 @@ def create_results_section(dit_handler) -> dict:
58
  generated_audio_2 = gr.Audio(
59
  label=t("results.generated_music", n=2),
60
  type="filepath",
61
- interactive=False
 
62
  )
63
  with gr.Row(equal_height=True):
64
  send_to_src_btn_2 = gr.Button(
@@ -88,7 +90,8 @@ def create_results_section(dit_handler) -> dict:
88
  generated_audio_3 = gr.Audio(
89
  label=t("results.generated_music", n=3),
90
  type="filepath",
91
- interactive=False
 
92
  )
93
  with gr.Row(equal_height=True):
94
  send_to_src_btn_3 = gr.Button(
@@ -118,7 +121,8 @@ def create_results_section(dit_handler) -> dict:
118
  generated_audio_4 = gr.Audio(
119
  label=t("results.generated_music", n=4),
120
  type="filepath",
121
- interactive=False
 
122
  )
123
  with gr.Row(equal_height=True):
124
  send_to_src_btn_4 = gr.Button(
@@ -151,7 +155,8 @@ def create_results_section(dit_handler) -> dict:
151
  generated_audio_5 = gr.Audio(
152
  label=t("results.generated_music", n=5),
153
  type="filepath",
154
- interactive=False
 
155
  )
156
  with gr.Row(equal_height=True):
157
  send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
@@ -166,7 +171,8 @@ def create_results_section(dit_handler) -> dict:
166
  generated_audio_6 = gr.Audio(
167
  label=t("results.generated_music", n=6),
168
  type="filepath",
169
- interactive=False
 
170
  )
171
  with gr.Row(equal_height=True):
172
  send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
@@ -181,7 +187,8 @@ def create_results_section(dit_handler) -> dict:
181
  generated_audio_7 = gr.Audio(
182
  label=t("results.generated_music", n=7),
183
  type="filepath",
184
- interactive=False
 
185
  )
186
  with gr.Row(equal_height=True):
187
  send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
@@ -196,7 +203,8 @@ def create_results_section(dit_handler) -> dict:
196
  generated_audio_8 = gr.Audio(
197
  label=t("results.generated_music", n=8),
198
  type="filepath",
199
- interactive=False
 
200
  )
201
  with gr.Row(equal_height=True):
202
  send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
 
28
  generated_audio_1 = gr.Audio(
29
  label=t("results.generated_music", n=1),
30
  type="filepath",
31
+ interactive=False,
32
+ show_download_button=False
33
  )
34
  with gr.Row(equal_height=True):
35
  send_to_src_btn_1 = gr.Button(
 
59
  generated_audio_2 = gr.Audio(
60
  label=t("results.generated_music", n=2),
61
  type="filepath",
62
+ interactive=False,
63
+ show_download_button=False
64
  )
65
  with gr.Row(equal_height=True):
66
  send_to_src_btn_2 = gr.Button(
 
90
  generated_audio_3 = gr.Audio(
91
  label=t("results.generated_music", n=3),
92
  type="filepath",
93
+ interactive=False,
94
+ show_download_button=False
95
  )
96
  with gr.Row(equal_height=True):
97
  send_to_src_btn_3 = gr.Button(
 
121
  generated_audio_4 = gr.Audio(
122
  label=t("results.generated_music", n=4),
123
  type="filepath",
124
+ interactive=False,
125
+ show_download_button=False
126
  )
127
  with gr.Row(equal_height=True):
128
  send_to_src_btn_4 = gr.Button(
 
155
  generated_audio_5 = gr.Audio(
156
  label=t("results.generated_music", n=5),
157
  type="filepath",
158
+ interactive=False,
159
+ show_download_button=False
160
  )
161
  with gr.Row(equal_height=True):
162
  send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
 
171
  generated_audio_6 = gr.Audio(
172
  label=t("results.generated_music", n=6),
173
  type="filepath",
174
+ interactive=False,
175
+ show_download_button=False
176
  )
177
  with gr.Row(equal_height=True):
178
  send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
 
187
  generated_audio_7 = gr.Audio(
188
  label=t("results.generated_music", n=7),
189
  type="filepath",
190
+ interactive=False,
191
+ show_download_button=False
192
  )
193
  with gr.Row(equal_height=True):
194
  send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
 
203
  generated_audio_8 = gr.Audio(
204
  label=t("results.generated_music", n=8),
205
  type="filepath",
206
+ interactive=False,
207
+ show_download_button=False
208
  )
209
  with gr.Row(equal_height=True):
210
  send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
acestep/handler.py CHANGED
@@ -10,6 +10,8 @@ import traceback
10
  import re
11
  import random
12
  import uuid
 
 
13
  from contextlib import contextmanager
14
  from typing import Optional, Dict, Any, Tuple, List, Union
15
 
@@ -37,16 +39,12 @@ warnings.filterwarnings("ignore")
37
  class AceStepHandler:
38
  """ACE-Step Business Logic Handler"""
39
 
40
- def __init__(self, save_root = None):
41
  self.model = None
42
  self.config = None
43
  self.device = "cpu"
44
  self.dtype = torch.float32 # Will be set based on device in initialize_service
45
- if save_root is None:
46
- self.temp_dir = tempfile.mkdtemp()
47
- else:
48
- self.temp_dir = save_root
49
-
50
  # VAE for audio encoding/decoding
51
  self.vae = None
52
 
@@ -81,8 +79,7 @@ class AceStepHandler:
81
  def get_available_checkpoints(self) -> str:
82
  """Return project root directory path"""
83
  # Get project root (handler.py is in acestep/, so go up two levels to project root)
84
- current_file = os.path.abspath(__file__)
85
- project_root = os.path.dirname(os.path.dirname(current_file))
86
  # default checkpoints
87
  checkpoint_dir = os.path.join(project_root, "checkpoints")
88
  if os.path.exists(checkpoint_dir):
@@ -93,8 +90,7 @@ class AceStepHandler:
93
  def get_available_acestep_v15_models(self) -> List[str]:
94
  """Scan and return all model directory names starting with 'acestep-v15-'"""
95
  # Get project root
96
- current_file = os.path.abspath(__file__)
97
- project_root = os.path.dirname(os.path.dirname(current_file))
98
  checkpoint_dir = os.path.join(project_root, "checkpoints")
99
 
100
  models = []
@@ -171,8 +167,7 @@ class AceStepHandler:
171
 
172
 
173
  # Auto-detect project root (independent of passed project_root parameter)
174
- current_file = os.path.abspath(__file__)
175
- actual_project_root = os.path.dirname(os.path.dirname(current_file))
176
  checkpoint_dir = os.path.join(actual_project_root, "checkpoints")
177
 
178
  # 1. Load main model
@@ -187,7 +182,7 @@ class AceStepHandler:
187
  attn_implementation = "sdpa"
188
 
189
  try:
190
- logger.info(f"Attempting to load model with attention implementation: {attn_implementation}")
191
  self.model = AutoModel.from_pretrained(
192
  acestep_v15_checkpoint_path,
193
  trust_remote_code=True,
@@ -195,9 +190,9 @@ class AceStepHandler:
195
  dtype="bfloat16"
196
  )
197
  except Exception as e:
198
- logger.warning(f"Failed to load model with {attn_implementation}: {e}")
199
  if attn_implementation == "sdpa":
200
- logger.info("Falling back to eager attention")
201
  attn_implementation = "eager"
202
  self.model = AutoModel.from_pretrained(
203
  acestep_v15_checkpoint_path,
@@ -215,7 +210,7 @@ class AceStepHandler:
215
  else:
216
  # If offload_to_cpu is True, check if we should keep DiT on GPU
217
  if not self.offload_dit_to_cpu:
218
- logger.info(f"Keeping main model on {device} (persistent)")
219
  self.model = self.model.to(device).to(self.dtype)
220
  else:
221
  self.model = self.model.to("cpu").to(self.dtype)
@@ -239,7 +234,7 @@ class AceStepHandler:
239
  raise ValueError(f"Unsupported quantization type: {self.quantization}")
240
 
241
  quantize_(self.model, quant_config)
242
- logger.info(f"DiT quantized with: {self.quantization}")
243
 
244
 
245
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
@@ -260,7 +255,7 @@ class AceStepHandler:
260
  if os.path.exists(vae_checkpoint_path):
261
  self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
262
  # Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
263
- vae_dtype = torch.bfloat16 if device in ["cuda", "xpu"] else self.dtype
264
  if not self.offload_to_cpu:
265
  self.vae = self.vae.to(device).to(vae_dtype)
266
  else:
@@ -302,6 +297,7 @@ class AceStepHandler:
302
 
303
  except Exception as e:
304
  error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
 
305
  return error_msg, False
306
 
307
  @contextmanager
@@ -326,7 +322,7 @@ class AceStepHandler:
326
  try:
327
  param = next(model.parameters())
328
  if param.device.type == "cpu":
329
- logger.info(f"Moving {model_name} to {self.device} (persistent)")
330
  model.to(self.device).to(self.dtype)
331
  if hasattr(self, "silence_latent"):
332
  self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
@@ -341,10 +337,10 @@ class AceStepHandler:
341
  return
342
 
343
  # Load to GPU
344
- logger.info(f"Loading {model_name} to {self.device}")
345
  start_time = time.time()
346
  if model_name == "vae":
347
- vae_dtype = torch.bfloat16 if self.device in ["cuda", "xpu"] else self.dtype
348
  model.to(self.device).to(vae_dtype)
349
  else:
350
  model.to(self.device).to(self.dtype)
@@ -354,13 +350,13 @@ class AceStepHandler:
354
 
355
  load_time = time.time() - start_time
356
  self.current_offload_cost += load_time
357
- logger.info(f"Loaded {model_name} to {self.device} in {load_time:.4f}s")
358
 
359
  try:
360
  yield
361
  finally:
362
  # Offload to CPU
363
- logger.info(f"Offloading {model_name} to CPU")
364
  start_time = time.time()
365
  model.to("cpu")
366
 
@@ -370,7 +366,7 @@ class AceStepHandler:
370
  torch.cuda.empty_cache()
371
  offload_time = time.time() - start_time
372
  self.current_offload_cost += offload_time
373
- logger.info(f"Offloaded {model_name} to CPU in {offload_time:.4f}s")
374
 
375
  def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
376
  """Process target audio"""
@@ -386,23 +382,12 @@ class AceStepHandler:
386
  else:
387
  audio = torch.from_numpy(audio_np.T)
388
 
389
- if audio.shape[0] == 1:
390
- audio = torch.cat([audio, audio], dim=0)
391
-
392
- audio = audio[:2]
393
-
394
- # Resample if needed
395
- if sr != 48000:
396
- import torch.nn.functional as F
397
- ratio = 48000 / sr
398
- new_length = int(audio.shape[-1] * ratio)
399
- audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0)
400
-
401
- audio = torch.clamp(audio, -1.0, 1.0)
402
 
403
  return audio
404
  except Exception as e:
405
- logger.error(f"Error processing target audio: {e}")
406
  return None
407
 
408
  def _parse_audio_code_string(self, code_str: str) -> List[int]:
@@ -411,7 +396,8 @@ class AceStepHandler:
411
  return []
412
  try:
413
  return [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
414
- except Exception:
 
415
  return []
416
 
417
  def _decode_audio_codes_to_latents(self, code_str: str) -> Optional[torch.Tensor]:
@@ -538,9 +524,7 @@ class AceStepHandler:
538
  )
539
  """
540
  # Align instruction formatting with _prepare_batch
541
- final_instruction = instruction or DEFAULT_DIT_INSTRUCTION
542
- if not final_instruction.endswith(":"):
543
- final_instruction = final_instruction + ":"
544
 
545
  # Extract caption and language from metas if available (from LM CoT output)
546
  # Fallback to user-provided values if not in metas
@@ -571,7 +555,7 @@ class AceStepHandler:
571
 
572
  parsed_meta = self._parse_metas([metas])[0]
573
  caption_input = SFT_GEN_PROMPT.format(final_instruction, actual_caption, parsed_meta)
574
- lyrics_input = f"# Languages\n{actual_language}\n\n# Lyric\n{lyrics}<|endoftext|>"
575
  return caption_input, lyrics_input
576
 
577
  def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -614,7 +598,7 @@ class AceStepHandler:
614
  return match.group(1).strip()
615
  return caption
616
  except Exception as e:
617
- logger.error(f"Error extracting caption: {e}")
618
  return caption
619
 
620
  def prepare_seeds(self, actual_batch_size, seed, use_random_seed):
@@ -638,7 +622,8 @@ class AceStepHandler:
638
  else:
639
  try:
640
  seed_list.append(int(float(s)))
641
- except (ValueError, TypeError):
 
642
  seed_list.append(-1)
643
  elif seed is None or (isinstance(seed, (int, float)) and seed < 0):
644
  # If seed is None or negative, use -1 for all items
@@ -679,7 +664,176 @@ class AceStepHandler:
679
  return actual_seed_list, seed_value_for_ui
680
 
681
  def prepare_metadata(self, bpm, key_scale, time_signature):
682
- # Build metadata dict - use "N/A" as default for empty fields
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  metadata_dict = {}
684
  if bpm:
685
  metadata_dict["bpm"] = bpm
@@ -695,10 +849,12 @@ class AceStepHandler:
695
  metadata_dict["timesignature"] = time_signature
696
  else:
697
  metadata_dict["timesignature"] = "N/A"
 
 
 
 
 
698
  return metadata_dict
699
-
700
- def is_silence(self, audio):
701
- return torch.all(audio.abs() < 1e-6)
702
 
703
  def generate_instruction(
704
  self,
@@ -745,23 +901,12 @@ class AceStepHandler:
745
  # Load audio file
746
  audio, sr = torchaudio.load(audio_file)
747
 
748
- logger.info(f"Reference audio shape: {audio.shape}")
749
- logger.info(f"Reference audio sample rate: {sr}")
750
- logger.info(f"Reference audio duration: {audio.shape[-1] / 48000.0} seconds")
751
-
752
- # Convert to stereo (duplicate channel if mono)
753
- if audio.shape[0] == 1:
754
- audio = torch.cat([audio, audio], dim=0)
755
-
756
- # Keep only first 2 channels
757
- audio = audio[:2]
758
 
759
- # Resample to 48kHz if needed
760
- if sr != 48000:
761
- audio = torchaudio.transforms.Resample(sr, 48000)(audio)
762
-
763
- # Clamp values to [-1.0, 1.0]
764
- audio = torch.clamp(audio, -1.0, 1.0)
765
 
766
  is_silence = self.is_silence(audio)
767
  if is_silence:
@@ -800,7 +945,7 @@ class AceStepHandler:
800
  return audio
801
 
802
  except Exception as e:
803
- logger.error(f"Error processing reference audio: {e}")
804
  return None
805
 
806
  def process_src_audio(self, audio_file) -> Optional[torch.Tensor]:
@@ -811,24 +956,13 @@ class AceStepHandler:
811
  # Load audio file
812
  audio, sr = torchaudio.load(audio_file)
813
 
814
- # Convert to stereo (duplicate channel if mono)
815
- if audio.shape[0] == 1:
816
- audio = torch.cat([audio, audio], dim=0)
817
-
818
- # Keep only first 2 channels
819
- audio = audio[:2]
820
-
821
- # Resample to 48kHz if needed
822
- if sr != 48000:
823
- audio = torchaudio.transforms.Resample(sr, 48000)(audio)
824
-
825
- # Clamp values to [-1.0, 1.0]
826
- audio = torch.clamp(audio, -1.0, 1.0)
827
 
828
  return audio
829
 
830
  except Exception as e:
831
- logger.error(f"Error processing target audio: {e}")
832
  return None
833
 
834
  def convert_src_audio_to_codes(self, audio_file) -> str:
@@ -856,19 +990,12 @@ class AceStepHandler:
856
  # Encode audio to latents using VAE
857
  with torch.no_grad():
858
  with self._load_model_context("vae"):
859
- # Prepare audio for VAE: [channels, samples] -> [1, channels, samples]
860
- vae_input = processed_audio.unsqueeze(0).to(self.device).to(self.vae.dtype)
861
-
862
  # Check if audio is silence
863
- if self.is_silence(vae_input):
864
  return "❌ Audio file appears to be silent"
865
 
866
- # Encode to latents
867
- latents = self.vae.encode(vae_input).latent_dist.sample()
868
- # Cast back to model dtype
869
- latents = latents.to(self.dtype)
870
- # Transpose: [1, d, T] -> [1, T, d] -> [T, d]
871
- latents = latents.squeeze(0).transpose(0, 1) # [T, d]
872
 
873
  # Create attention mask for latents
874
  attention_mask = torch.ones(latents.shape[0], dtype=torch.bool, device=self.device)
@@ -893,7 +1020,7 @@ class AceStepHandler:
893
 
894
  except Exception as e:
895
  error_msg = f"❌ Error converting audio to codes: {str(e)}\n{traceback.format_exc()}"
896
- logger.error(error_msg)
897
  return error_msg
898
 
899
  def prepare_batch_data(
@@ -922,26 +1049,7 @@ class AceStepHandler:
922
  calculated_duration = audio_duration
923
 
924
  # Build metadata dict - use "N/A" as default for empty fields
925
- metadata_dict = {}
926
- if bpm:
927
- metadata_dict["bpm"] = bpm
928
- else:
929
- metadata_dict["bpm"] = "N/A"
930
-
931
- if key_scale.strip():
932
- metadata_dict["keyscale"] = key_scale
933
- else:
934
- metadata_dict["keyscale"] = "N/A"
935
-
936
- if time_signature.strip() and time_signature != "N/A" and time_signature:
937
- metadata_dict["timesignature"] = time_signature
938
- else:
939
- metadata_dict["timesignature"] = "N/A"
940
-
941
- # Add duration to metadata if available (inference service format: "30 seconds")
942
- if calculated_duration is not None:
943
- metadata_dict["duration"] = f"{int(calculated_duration)} seconds"
944
- # If duration not set, inference service will use default (30 seconds)
945
 
946
  # Format metadata - inference service accepts dict and will convert to string
947
  # Create a copy for each batch item (in case we modify it)
@@ -977,7 +1085,7 @@ class AceStepHandler:
977
  target_wavs = torch.zeros(2, frames)
978
  return target_wavs
979
  except Exception as e:
980
- logger.error(f"Error creating target audio: {e}")
981
  # Fallback to 30 seconds if error
982
  return torch.zeros(2, 30 * 48000)
983
 
@@ -1158,16 +1266,8 @@ class AceStepHandler:
1158
  """
1159
  batch_size = len(captions)
1160
 
1161
- # Ensure audio_code_hints is a list of the correct length
1162
- if audio_code_hints is None:
1163
- audio_code_hints = [None] * batch_size
1164
- elif len(audio_code_hints) != batch_size:
1165
- if len(audio_code_hints) == 1:
1166
- audio_code_hints = audio_code_hints * batch_size
1167
- else:
1168
- audio_code_hints = audio_code_hints[:batch_size]
1169
- while len(audio_code_hints) < batch_size:
1170
- audio_code_hints.append(None)
1171
 
1172
  for ii, refer_audio_list in enumerate(refer_audios):
1173
  if isinstance(refer_audio_list, list):
@@ -1179,17 +1279,6 @@ class AceStepHandler:
1179
  if vocal_languages is None:
1180
  vocal_languages = self._create_fallback_vocal_languages(batch_size)
1181
 
1182
- # Normalize audio_code_hints to batch list
1183
- if audio_code_hints is None:
1184
- audio_code_hints = [None] * batch_size
1185
- elif not isinstance(audio_code_hints, list):
1186
- audio_code_hints = [audio_code_hints] * batch_size
1187
- elif len(audio_code_hints) == 1 and batch_size > 1:
1188
- audio_code_hints = audio_code_hints * batch_size
1189
- else:
1190
- audio_code_hints = (audio_code_hints + [None] * batch_size)[:batch_size]
1191
- audio_code_hints = [hint if isinstance(hint, str) and hint.strip() else None for hint in audio_code_hints]
1192
-
1193
  # Parse metas with fallbacks
1194
  parsed_metas = self._parse_metas(metas)
1195
 
@@ -1223,13 +1312,9 @@ class AceStepHandler:
1223
  expected_latent_length = current_wav.shape[-1] // 1920
1224
  target_latent = self.silence_latent[0, :expected_latent_length, :]
1225
  else:
1226
- # Ensure input is in VAE's dtype
1227
  logger.info(f"[generate_music] Encoding target audio to latents for item {i}...")
1228
- vae_input = current_wav.to(self.device).to(self.vae.dtype)
1229
- target_latent = self.vae.encode(vae_input).latent_dist.sample()
1230
- # Cast back to model dtype
1231
- target_latent = target_latent.to(self.dtype)
1232
- target_latent = target_latent.squeeze(0).transpose(0, 1)
1233
  target_latents_list.append(target_latent)
1234
  latent_lengths.append(target_latent.shape[0])
1235
 
@@ -1268,18 +1353,7 @@ class AceStepHandler:
1268
 
1269
  # Process instructions early so we can use them for task type detection
1270
  # Use custom instructions if provided, otherwise use default
1271
- if instructions is None:
1272
- instructions = [DEFAULT_DIT_INSTRUCTION] * batch_size
1273
-
1274
- # Ensure instructions list has the same length as batch_size
1275
- if len(instructions) != batch_size:
1276
- if len(instructions) == 1:
1277
- instructions = instructions * batch_size
1278
- else:
1279
- # Pad or truncate to match batch_size
1280
- instructions = instructions[:batch_size]
1281
- while len(instructions) < batch_size:
1282
- instructions.append(DEFAULT_DIT_INSTRUCTION)
1283
 
1284
  # Generate chunk_masks and spans based on repainting parameters
1285
  # Also determine if this is a cover task (target audio provided without repainting)
@@ -1428,6 +1502,10 @@ class AceStepHandler:
1428
  else:
1429
  precomputed_lm_hints_25Hz = None
1430
 
 
 
 
 
1431
  # Format text_inputs
1432
  text_inputs = []
1433
  text_token_idss = []
@@ -1437,26 +1515,10 @@ class AceStepHandler:
1437
 
1438
  for i in range(batch_size):
1439
  # Use custom instruction for this batch item
1440
- instruction = instructions[i] if i < len(instructions) else DEFAULT_DIT_INSTRUCTION
1441
- # Ensure instruction ends with ":"
1442
- if not instruction.endswith(":"):
1443
- instruction = instruction + ":"
1444
-
1445
- # Extract caption and language from metas if available (from LM CoT output)
1446
- # Fallback to user-provided values if not in metas
1447
- actual_caption = captions[i]
1448
- actual_language = vocal_languages[i]
1449
-
1450
- # Check if metas contains caption/language from LM CoT
1451
- if i < len(parsed_metas) and parsed_metas[i]:
1452
- meta_dict = parsed_metas[i]
1453
- if isinstance(meta_dict, dict):
1454
- # Extract caption from metas if available
1455
- if 'caption' in meta_dict and meta_dict['caption']:
1456
- actual_caption = str(meta_dict['caption'])
1457
- # Extract language from metas if available
1458
- if 'language' in meta_dict and meta_dict['language']:
1459
- actual_language = str(meta_dict['language'])
1460
 
1461
  # Format text prompt with custom instruction (using LM-generated caption if available)
1462
  text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
@@ -1473,7 +1535,7 @@ class AceStepHandler:
1473
  text_attention_mask = text_inputs_dict.attention_mask[0].bool()
1474
 
1475
  # Format and tokenize lyrics (using LM-generated language if available)
1476
- lyrics_text = f"# Languages\n{actual_language}\n\n# Lyric\n{lyrics[i]}<|endoftext|>"
1477
  lyrics_inputs_dict = self.text_tokenizer(
1478
  lyrics_text,
1479
  padding="longest",
@@ -1495,36 +1557,12 @@ class AceStepHandler:
1495
 
1496
  # Pad tokenized sequences
1497
  max_text_length = max(len(seq) for seq in text_token_idss)
1498
- padded_text_token_idss = torch.stack([
1499
- torch.nn.functional.pad(
1500
- seq, (0, max_text_length - len(seq)), 'constant',
1501
- self.text_tokenizer.pad_token_id
1502
- )
1503
- for seq in text_token_idss
1504
- ])
1505
-
1506
- padded_text_attention_masks = torch.stack([
1507
- torch.nn.functional.pad(
1508
- seq, (0, max_text_length - len(seq)), 'constant', 0
1509
- )
1510
- for seq in text_attention_masks
1511
- ])
1512
 
1513
  max_lyric_length = max(len(seq) for seq in lyric_token_idss)
1514
- padded_lyric_token_idss = torch.stack([
1515
- torch.nn.functional.pad(
1516
- seq, (0, max_lyric_length - len(seq)), 'constant',
1517
- self.text_tokenizer.pad_token_id
1518
- )
1519
- for seq in lyric_token_idss
1520
- ])
1521
-
1522
- padded_lyric_attention_masks = torch.stack([
1523
- torch.nn.functional.pad(
1524
- seq, (0, max_lyric_length - len(seq)), 'constant', 0
1525
- )
1526
- for seq in lyric_attention_masks
1527
- ])
1528
 
1529
  padded_non_cover_text_input_ids = None
1530
  padded_non_cover_text_attention_masks = None
@@ -1533,14 +1571,10 @@ class AceStepHandler:
1533
  non_cover_text_attention_masks = []
1534
  for i in range(batch_size):
1535
  # Use custom instruction for this batch item
1536
- instruction = DEFAULT_DIT_INSTRUCTION
1537
 
1538
  # Extract caption from metas if available (from LM CoT output)
1539
- actual_caption = captions[i]
1540
- if i < len(parsed_metas) and parsed_metas[i]:
1541
- meta_dict = parsed_metas[i]
1542
- if isinstance(meta_dict, dict) and 'caption' in meta_dict and meta_dict['caption']:
1543
- actual_caption = str(meta_dict['caption'])
1544
 
1545
  # Format text prompt with custom instruction (using LM-generated caption if available)
1546
  text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
@@ -1558,19 +1592,8 @@ class AceStepHandler:
1558
  non_cover_text_input_ids.append(text_token_ids)
1559
  non_cover_text_attention_masks.append(non_cover_text_attention_mask)
1560
 
1561
- padded_non_cover_text_input_ids = torch.stack([
1562
- torch.nn.functional.pad(
1563
- seq, (0, max_text_length - len(seq)), 'constant',
1564
- self.text_tokenizer.pad_token_id
1565
- )
1566
- for seq in non_cover_text_input_ids
1567
- ])
1568
- padded_non_cover_text_attention_masks = torch.stack([
1569
- torch.nn.functional.pad(
1570
- seq, (0, max_text_length - len(seq)), 'constant', 0
1571
- )
1572
- for seq in non_cover_text_attention_masks
1573
- ])
1574
 
1575
  if audio_cover_strength < 1.0:
1576
  assert padded_non_cover_text_input_ids is not None, "When audio_cover_strength < 1.0, padded_non_cover_text_input_ids must not be None"
@@ -1804,7 +1827,7 @@ class AceStepHandler:
1804
  if self.config.is_turbo:
1805
  # Limit inference steps to maximum 8
1806
  if infer_steps > 8:
1807
- logger.warning(f"dmd_gan version: infer_steps {infer_steps} exceeds maximum 8, clamping to 8")
1808
  infer_steps = 8
1809
  # CFG parameters are not adjustable for dmd_gan (they will be ignored)
1810
  # Note: guidance_scale, cfg_interval_start, cfg_interval_end are still passed but may be ignored by the model
@@ -1827,30 +1850,12 @@ class AceStepHandler:
1827
  if isinstance(repainting_end, (int, float)):
1828
  repainting_end = [repainting_end]
1829
 
1830
- # Convert instructions to list
1831
- if isinstance(instructions, str):
1832
- instructions = [instructions]
1833
- elif instructions is None:
1834
- instructions = None
1835
-
1836
- # Convert audio_code_hints to list
1837
- if isinstance(audio_code_hints, str):
1838
- audio_code_hints = [audio_code_hints]
1839
- elif audio_code_hints is None:
1840
- audio_code_hints = None
1841
-
1842
  # Get batch size from captions
1843
  batch_size = len(captions)
1844
 
1845
- # Ensure audio_code_hints matches batch size
1846
- if audio_code_hints is not None:
1847
- if len(audio_code_hints) != batch_size:
1848
- if len(audio_code_hints) == 1:
1849
- audio_code_hints = audio_code_hints * batch_size
1850
- else:
1851
- audio_code_hints = audio_code_hints[:batch_size]
1852
- while len(audio_code_hints) < batch_size:
1853
- audio_code_hints.append(None)
1854
 
1855
  # Convert seed to list format
1856
  if seed is None:
@@ -1947,6 +1952,14 @@ class AceStepHandler:
1947
  logger.info("[service_generate] Generating audio...")
1948
  with self._load_model_context("model"):
1949
  outputs = self.model.generate_audio(**generate_kwargs)
 
 
 
 
 
 
 
 
1950
  return outputs
1951
 
1952
  def tiled_decode(self, latents, chunk_size=512, overlap=64):
@@ -2042,25 +2055,33 @@ class AceStepHandler:
2042
  use_adg: bool = False,
2043
  cfg_interval_start: float = 0.0,
2044
  cfg_interval_end: float = 1.0,
2045
- audio_format: str = "mp3",
2046
- lm_temperature: float = 0.6,
2047
  use_tiled_decode: bool = True,
2048
  progress=None
2049
- ) -> Tuple[Optional[str], Optional[str], List[str], str, str, str, str, str, Optional[Any], str, str, Optional[Any]]:
2050
  """
2051
  Main interface for music generation
2052
 
2053
  Returns:
2054
- (first_audio, second_audio, all_audio_paths, generation_info, status_message,
2055
- seed_value_for_ui, align_score_1, align_text_1, align_plot_1,
2056
- align_score_2, align_text_2, align_plot_2)
 
 
 
 
2057
  """
2058
  if progress is None:
2059
  def progress(*args, **kwargs):
2060
  pass
2061
 
2062
  if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
2063
- return None, None, [], "", "❌ Model not fully initialized. Please initialize all components first.", "-1", "", "", None, "", "", None
 
 
 
 
 
 
2064
 
2065
  def _has_audio_codes(v: Union[str, List[str]]) -> bool:
2066
  if isinstance(v, list):
@@ -2079,7 +2100,7 @@ class AceStepHandler:
2079
 
2080
  logger.info("[generate_music] Starting generation...")
2081
  if progress:
2082
- progress(0.05, desc="Preparing inputs...")
2083
  logger.info("[generate_music] Preparing inputs...")
2084
 
2085
  # Reset offload cost
@@ -2101,8 +2122,6 @@ class AceStepHandler:
2101
  repainting_end = None
2102
 
2103
  try:
2104
- progress(0.1, desc="Preparing inputs...")
2105
-
2106
  # 1. Process reference audio
2107
  refer_audios = None
2108
  if reference_audio is not None:
@@ -2154,7 +2173,7 @@ class AceStepHandler:
2154
  can_use_repainting
2155
  )
2156
 
2157
- progress(0.3, desc=f"Generating music (batch size: {actual_batch_size})...")
2158
 
2159
  # Prepare audio_code_hints - use if audio_code_string is provided
2160
  # This works for both text2music (auto-switched to cover) and cover tasks
@@ -2191,8 +2210,8 @@ class AceStepHandler:
2191
  pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
2192
  time_costs = outputs["time_costs"]
2193
  time_costs["offload_time_cost"] = self.current_offload_cost
2194
- logger.info(f" - pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}")
2195
- logger.info(f" - time_costs: {time_costs}")
2196
  if progress:
2197
  progress(0.8, desc="Decoding audio...")
2198
  logger.info("[generate_music] Decoding latents with VAE...")
@@ -2221,75 +2240,66 @@ class AceStepHandler:
2221
  # Update offload cost one last time to include VAE offloading
2222
  time_costs["offload_time_cost"] = self.current_offload_cost
2223
 
2224
- logger.info("[generate_music] VAE decode completed. Saving audio files...")
2225
  if progress:
2226
- progress(0.9, desc="Saving audio files...")
2227
 
2228
- # Save audio files using soundfile (supports wav, flac, mp3 via format param)
2229
- audio_format_lower = audio_format.lower() if audio_format else "wav"
2230
- if audio_format_lower not in ["wav", "flac", "mp3"]:
2231
- audio_format_lower = "wav"
2232
 
2233
- saved_files = []
2234
- saved_uuids = [] # Store UUIDs for each file
2235
  for i in range(actual_batch_size):
2236
- # Generate unique UUID for each audio file
2237
- file_uuid = str(uuid.uuid4())
2238
- audio_file = os.path.join(self.temp_dir, f"{file_uuid}.{audio_format_lower}")
2239
- # Convert to numpy: [channels, samples] -> [samples, channels]
2240
- audio_np = pred_wavs[i].cpu().float().numpy().T
2241
- sf.write(audio_file, audio_np, self.sample_rate)
2242
- saved_files.append(audio_file)
2243
- saved_uuids.append(file_uuid)
2244
-
2245
- # Prepare return values
2246
- first_audio = saved_files[0] if len(saved_files) > 0 else None
2247
- second_audio = saved_files[1] if len(saved_files) > 1 else None
2248
-
2249
- # Format time costs if available
2250
- time_costs_str = ""
2251
- if time_costs:
2252
- if isinstance(time_costs, dict):
2253
- time_costs_str = "\n\n**⏱️ Time Costs:**\n"
2254
- for key, value in time_costs.items():
2255
- # Format key: encoder_time_cost -> Encoder
2256
- formatted_key = key.replace("_time_cost", "").replace("_", " ").title()
2257
- time_costs_str += f" - {formatted_key}: {value:.2f}s\n"
2258
- elif isinstance(time_costs, (int, float)):
2259
- time_costs_str = f"\n\n**⏱️ Time Cost:** {time_costs:.2f}s"
2260
-
2261
- generation_info = f"""**🎵 Generation Complete**
2262
-
2263
- **Seeds:** {seed_value_for_ui}
2264
- **Steps:** {inference_steps}
2265
- **Files:** {len(saved_files)} audio(s){time_costs_str}"""
2266
  status_message = f"✅ Generation completed successfully!"
2267
- logger.info(f"[generate_music] Done! Generated {len(saved_files)} audio files.")
2268
-
2269
- # Alignment scores and plots (placeholder for now)
2270
- align_score_1 = ""
2271
- align_text_1 = ""
2272
- align_plot_1 = None
2273
- align_score_2 = ""
2274
- align_text_2 = ""
2275
- align_plot_2 = None
2276
-
2277
- return (
2278
- first_audio,
2279
- second_audio,
2280
- saved_files,
2281
- generation_info,
2282
- status_message,
2283
- seed_value_for_ui,
2284
- align_score_1,
2285
- align_text_1,
2286
- align_plot_1,
2287
- align_score_2,
2288
- align_text_2,
2289
- align_plot_2,
2290
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
2291
 
2292
  except Exception as e:
2293
  error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
2294
- return None, None, [], "", error_msg, seed_value_for_ui, "", "", None, "", "", None
2295
-
 
 
 
 
 
 
 
10
  import re
11
  import random
12
  import uuid
13
+ import hashlib
14
+ import json
15
  from contextlib import contextmanager
16
  from typing import Optional, Dict, Any, Tuple, List, Union
17
 
 
39
  class AceStepHandler:
40
  """ACE-Step Business Logic Handler"""
41
 
42
+ def __init__(self):
43
  self.model = None
44
  self.config = None
45
  self.device = "cpu"
46
  self.dtype = torch.float32 # Will be set based on device in initialize_service
47
+
 
 
 
 
48
  # VAE for audio encoding/decoding
49
  self.vae = None
50
 
 
79
  def get_available_checkpoints(self) -> str:
80
  """Return project root directory path"""
81
  # Get project root (handler.py is in acestep/, so go up two levels to project root)
82
+ project_root = self._get_project_root()
 
83
  # default checkpoints
84
  checkpoint_dir = os.path.join(project_root, "checkpoints")
85
  if os.path.exists(checkpoint_dir):
 
90
  def get_available_acestep_v15_models(self) -> List[str]:
91
  """Scan and return all model directory names starting with 'acestep-v15-'"""
92
  # Get project root
93
+ project_root = self._get_project_root()
 
94
  checkpoint_dir = os.path.join(project_root, "checkpoints")
95
 
96
  models = []
 
167
 
168
 
169
  # Auto-detect project root (independent of passed project_root parameter)
170
+ actual_project_root = self._get_project_root()
 
171
  checkpoint_dir = os.path.join(actual_project_root, "checkpoints")
172
 
173
  # 1. Load main model
 
182
  attn_implementation = "sdpa"
183
 
184
  try:
185
+ logger.info(f"[initialize_service] Attempting to load model with attention implementation: {attn_implementation}")
186
  self.model = AutoModel.from_pretrained(
187
  acestep_v15_checkpoint_path,
188
  trust_remote_code=True,
 
190
  dtype="bfloat16"
191
  )
192
  except Exception as e:
193
+ logger.warning(f"[initialize_service] Failed to load model with {attn_implementation}: {e}")
194
  if attn_implementation == "sdpa":
195
+ logger.info("[initialize_service] Falling back to eager attention")
196
  attn_implementation = "eager"
197
  self.model = AutoModel.from_pretrained(
198
  acestep_v15_checkpoint_path,
 
210
  else:
211
  # If offload_to_cpu is True, check if we should keep DiT on GPU
212
  if not self.offload_dit_to_cpu:
213
+ logger.info(f"[initialize_service] Keeping main model on {device} (persistent)")
214
  self.model = self.model.to(device).to(self.dtype)
215
  else:
216
  self.model = self.model.to("cpu").to(self.dtype)
 
234
  raise ValueError(f"Unsupported quantization type: {self.quantization}")
235
 
236
  quantize_(self.model, quant_config)
237
+ logger.info(f"[initialize_service] DiT quantized with: {self.quantization}")
238
 
239
 
240
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
 
255
  if os.path.exists(vae_checkpoint_path):
256
  self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
257
  # Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
258
+ vae_dtype = self._get_vae_dtype(device)
259
  if not self.offload_to_cpu:
260
  self.vae = self.vae.to(device).to(vae_dtype)
261
  else:
 
297
 
298
  except Exception as e:
299
  error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
300
+ logger.exception("[initialize_service] Error initializing model")
301
  return error_msg, False
302
 
303
  @contextmanager
 
322
  try:
323
  param = next(model.parameters())
324
  if param.device.type == "cpu":
325
+ logger.info(f"[_load_model_context] Moving {model_name} to {self.device} (persistent)")
326
  model.to(self.device).to(self.dtype)
327
  if hasattr(self, "silence_latent"):
328
  self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
 
337
  return
338
 
339
  # Load to GPU
340
+ logger.info(f"[_load_model_context] Loading {model_name} to {self.device}")
341
  start_time = time.time()
342
  if model_name == "vae":
343
+ vae_dtype = self._get_vae_dtype()
344
  model.to(self.device).to(vae_dtype)
345
  else:
346
  model.to(self.device).to(self.dtype)
 
350
 
351
  load_time = time.time() - start_time
352
  self.current_offload_cost += load_time
353
+ logger.info(f"[_load_model_context] Loaded {model_name} to {self.device} in {load_time:.4f}s")
354
 
355
  try:
356
  yield
357
  finally:
358
  # Offload to CPU
359
+ logger.info(f"[_load_model_context] Offloading {model_name} to CPU")
360
  start_time = time.time()
361
  model.to("cpu")
362
 
 
366
  torch.cuda.empty_cache()
367
  offload_time = time.time() - start_time
368
  self.current_offload_cost += offload_time
369
+ logger.info(f"[_load_model_context] Offloaded {model_name} to CPU in {offload_time:.4f}s")
370
 
371
  def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
372
  """Process target audio"""
 
382
  else:
383
  audio = torch.from_numpy(audio_np.T)
384
 
385
+ # Normalize to stereo 48kHz
386
+ audio = self._normalize_audio_to_stereo_48k(audio, sr)
 
 
 
 
 
 
 
 
 
 
 
387
 
388
  return audio
389
  except Exception as e:
390
+ logger.exception("[process_target_audio] Error processing target audio")
391
  return None
392
 
393
  def _parse_audio_code_string(self, code_str: str) -> List[int]:
 
396
  return []
397
  try:
398
  return [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
399
+ except Exception as e:
400
+ logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
401
  return []
402
 
403
  def _decode_audio_codes_to_latents(self, code_str: str) -> Optional[torch.Tensor]:
 
524
  )
525
  """
526
  # Align instruction formatting with _prepare_batch
527
+ final_instruction = self._format_instruction(instruction or DEFAULT_DIT_INSTRUCTION)
 
 
528
 
529
  # Extract caption and language from metas if available (from LM CoT output)
530
  # Fallback to user-provided values if not in metas
 
555
 
556
  parsed_meta = self._parse_metas([metas])[0]
557
  caption_input = SFT_GEN_PROMPT.format(final_instruction, actual_caption, parsed_meta)
558
+ lyrics_input = self._format_lyrics(lyrics, actual_language)
559
  return caption_input, lyrics_input
560
 
561
  def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
 
598
  return match.group(1).strip()
599
  return caption
600
  except Exception as e:
601
+ logger.exception("[extract_caption_from_sft_format] Error extracting caption")
602
  return caption
603
 
604
  def prepare_seeds(self, actual_batch_size, seed, use_random_seed):
 
622
  else:
623
  try:
624
  seed_list.append(int(float(s)))
625
+ except (ValueError, TypeError) as e:
626
+ logger.debug(f"[prepare_seeds] Failed to parse seed value '{s}': {e}")
627
  seed_list.append(-1)
628
  elif seed is None or (isinstance(seed, (int, float)) and seed < 0):
629
  # If seed is None or negative, use -1 for all items
 
664
  return actual_seed_list, seed_value_for_ui
665
 
666
  def prepare_metadata(self, bpm, key_scale, time_signature):
667
+ """Build metadata dict - use "N/A" as default for empty fields."""
668
+ return self._build_metadata_dict(bpm, key_scale, time_signature)
669
+
670
+ def is_silence(self, audio):
671
+ return torch.all(audio.abs() < 1e-6)
672
+
673
+ def _get_project_root(self) -> str:
674
+ """Get project root directory path."""
675
+ current_file = os.path.abspath(__file__)
676
+ return os.path.dirname(os.path.dirname(current_file))
677
+
678
+ def _get_vae_dtype(self, device: Optional[str] = None) -> torch.dtype:
679
+ """Get VAE dtype based on device."""
680
+ device = device or self.device
681
+ return torch.bfloat16 if device in ["cuda", "xpu"] else self.dtype
682
+
683
+ def _format_instruction(self, instruction: str) -> str:
684
+ """Format instruction to ensure it ends with colon."""
685
+ if not instruction.endswith(":"):
686
+ instruction = instruction + ":"
687
+ return instruction
688
+
689
+ def _normalize_audio_to_stereo_48k(self, audio: torch.Tensor, sr: int) -> torch.Tensor:
690
+ """
691
+ Normalize audio to stereo 48kHz format.
692
+
693
+ Args:
694
+ audio: Audio tensor [channels, samples] or [samples]
695
+ sr: Sample rate
696
+
697
+ Returns:
698
+ Normalized audio tensor [2, samples] at 48kHz
699
+ """
700
+ # Convert to stereo (duplicate channel if mono)
701
+ if audio.shape[0] == 1:
702
+ audio = torch.cat([audio, audio], dim=0)
703
+
704
+ # Keep only first 2 channels
705
+ audio = audio[:2]
706
+
707
+ # Resample to 48kHz if needed
708
+ if sr != 48000:
709
+ audio = torchaudio.transforms.Resample(sr, 48000)(audio)
710
+
711
+ # Clamp values to [-1.0, 1.0]
712
+ audio = torch.clamp(audio, -1.0, 1.0)
713
+
714
+ return audio
715
+
716
+ def _normalize_audio_code_hints(self, audio_code_hints: Optional[Union[str, List[str]]], batch_size: int) -> List[Optional[str]]:
717
+ """Normalize audio_code_hints to list of correct length."""
718
+ if audio_code_hints is None:
719
+ normalized = [None] * batch_size
720
+ elif isinstance(audio_code_hints, str):
721
+ normalized = [audio_code_hints] * batch_size
722
+ elif len(audio_code_hints) == 1 and batch_size > 1:
723
+ normalized = audio_code_hints * batch_size
724
+ elif len(audio_code_hints) != batch_size:
725
+ # Pad or truncate to match batch_size
726
+ normalized = list(audio_code_hints[:batch_size])
727
+ while len(normalized) < batch_size:
728
+ normalized.append(None)
729
+ else:
730
+ normalized = list(audio_code_hints)
731
+
732
+ # Clean up: convert empty strings to None
733
+ normalized = [hint if isinstance(hint, str) and hint.strip() else None for hint in normalized]
734
+ return normalized
735
+
736
+ def _normalize_instructions(self, instructions: Optional[Union[str, List[str]]], batch_size: int, default: Optional[str] = None) -> List[str]:
737
+ """Normalize instructions to list of correct length."""
738
+ if instructions is None:
739
+ default_instruction = default or DEFAULT_DIT_INSTRUCTION
740
+ return [default_instruction] * batch_size
741
+ elif isinstance(instructions, str):
742
+ return [instructions] * batch_size
743
+ elif len(instructions) == 1:
744
+ return instructions * batch_size
745
+ elif len(instructions) != batch_size:
746
+ # Pad or truncate to match batch_size
747
+ normalized = list(instructions[:batch_size])
748
+ default_instruction = default or DEFAULT_DIT_INSTRUCTION
749
+ while len(normalized) < batch_size:
750
+ normalized.append(default_instruction)
751
+ return normalized
752
+ else:
753
+ return list(instructions)
754
+
755
+ def _format_lyrics(self, lyrics: str, language: str) -> str:
756
+ """Format lyrics text with language header."""
757
+ return f"# Languages\n{language}\n\n# Lyric\n{lyrics}<|endoftext|>"
758
+
759
+ def _pad_sequences(self, sequences: List[torch.Tensor], max_length: int, pad_value: int = 0) -> torch.Tensor:
760
+ """Pad sequences to same length."""
761
+ return torch.stack([
762
+ torch.nn.functional.pad(seq, (0, max_length - len(seq)), 'constant', pad_value)
763
+ for seq in sequences
764
+ ])
765
+
766
+ def _extract_caption_and_language(self, metas: List[Union[str, Dict[str, Any]]], captions: List[str], vocal_languages: List[str]) -> Tuple[List[str], List[str]]:
767
+ """Extract caption and language from metas with fallback to provided values."""
768
+ actual_captions = list(captions)
769
+ actual_languages = list(vocal_languages)
770
+
771
+ for i, meta in enumerate(metas):
772
+ if i >= len(actual_captions):
773
+ break
774
+
775
+ meta_dict = None
776
+ if isinstance(meta, str):
777
+ parsed = self._parse_metas([meta])
778
+ if parsed and isinstance(parsed[0], dict):
779
+ meta_dict = parsed[0]
780
+ elif isinstance(meta, dict):
781
+ meta_dict = meta
782
+
783
+ if meta_dict:
784
+ if 'caption' in meta_dict and meta_dict['caption']:
785
+ actual_captions[i] = str(meta_dict['caption'])
786
+ if 'language' in meta_dict and meta_dict['language']:
787
+ actual_languages[i] = str(meta_dict['language'])
788
+
789
+ return actual_captions, actual_languages
790
+
791
+ def _encode_audio_to_latents(self, audio: torch.Tensor) -> torch.Tensor:
792
+ """
793
+ Encode audio to latents using VAE.
794
+
795
+ Args:
796
+ audio: Audio tensor [channels, samples] or [batch, channels, samples]
797
+
798
+ Returns:
799
+ Latents tensor [T, D] or [batch, T, D]
800
+ """
801
+ # Ensure batch dimension
802
+ if audio.dim() == 2:
803
+ audio = audio.unsqueeze(0)
804
+
805
+ # Ensure input is in VAE's dtype
806
+ vae_input = audio.to(self.device).to(self.vae.dtype)
807
+
808
+ # Encode to latents
809
+ with torch.no_grad():
810
+ latents = self.vae.encode(vae_input).latent_dist.sample()
811
+
812
+ # Cast back to model dtype
813
+ latents = latents.to(self.dtype)
814
+
815
+ # Transpose: [batch, d, T] -> [batch, T, d]
816
+ latents = latents.transpose(1, 2)
817
+
818
+ # Remove batch dimension if input didn't have it
819
+ if audio.dim() == 2:
820
+ latents = latents.squeeze(0)
821
+
822
+ return latents
823
+
824
+ def _build_metadata_dict(self, bpm: Optional[Union[int, str]], key_scale: str, time_signature: str, duration: Optional[float] = None) -> Dict[str, Any]:
825
+ """
826
+ Build metadata dictionary with default values.
827
+
828
+ Args:
829
+ bpm: BPM value (optional)
830
+ key_scale: Key/scale string
831
+ time_signature: Time signature string
832
+ duration: Duration in seconds (optional)
833
+
834
+ Returns:
835
+ Metadata dictionary
836
+ """
837
  metadata_dict = {}
838
  if bpm:
839
  metadata_dict["bpm"] = bpm
 
849
  metadata_dict["timesignature"] = time_signature
850
  else:
851
  metadata_dict["timesignature"] = "N/A"
852
+
853
+ # Add duration if provided
854
+ if duration is not None:
855
+ metadata_dict["duration"] = f"{int(duration)} seconds"
856
+
857
  return metadata_dict
 
 
 
858
 
859
  def generate_instruction(
860
  self,
 
901
  # Load audio file
902
  audio, sr = torchaudio.load(audio_file)
903
 
904
+ logger.debug(f"[process_reference_audio] Reference audio shape: {audio.shape}")
905
+ logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
906
+ logger.debug(f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / 48000.0} seconds")
 
 
 
 
 
 
 
907
 
908
+ # Normalize to stereo 48kHz
909
+ audio = self._normalize_audio_to_stereo_48k(audio, sr)
 
 
 
 
910
 
911
  is_silence = self.is_silence(audio)
912
  if is_silence:
 
945
  return audio
946
 
947
  except Exception as e:
948
+ logger.exception("[process_reference_audio] Error processing reference audio")
949
  return None
950
 
951
  def process_src_audio(self, audio_file) -> Optional[torch.Tensor]:
 
956
  # Load audio file
957
  audio, sr = torchaudio.load(audio_file)
958
 
959
+ # Normalize to stereo 48kHz
960
+ audio = self._normalize_audio_to_stereo_48k(audio, sr)
 
 
 
 
 
 
 
 
 
 
 
961
 
962
  return audio
963
 
964
  except Exception as e:
965
+ logger.exception("[process_src_audio] Error processing source audio")
966
  return None
967
 
968
  def convert_src_audio_to_codes(self, audio_file) -> str:
 
990
  # Encode audio to latents using VAE
991
  with torch.no_grad():
992
  with self._load_model_context("vae"):
 
 
 
993
  # Check if audio is silence
994
+ if self.is_silence(processed_audio.unsqueeze(0)):
995
  return "❌ Audio file appears to be silent"
996
 
997
+ # Encode to latents using helper method
998
+ latents = self._encode_audio_to_latents(processed_audio) # [T, d]
 
 
 
 
999
 
1000
  # Create attention mask for latents
1001
  attention_mask = torch.ones(latents.shape[0], dtype=torch.bool, device=self.device)
 
1020
 
1021
  except Exception as e:
1022
  error_msg = f"❌ Error converting audio to codes: {str(e)}\n{traceback.format_exc()}"
1023
+ logger.exception("[convert_src_audio_to_codes] Error converting audio to codes")
1024
  return error_msg
1025
 
1026
  def prepare_batch_data(
 
1049
  calculated_duration = audio_duration
1050
 
1051
  # Build metadata dict - use "N/A" as default for empty fields
1052
+ metadata_dict = self._build_metadata_dict(bpm, key_scale, time_signature, calculated_duration)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1053
 
1054
  # Format metadata - inference service accepts dict and will convert to string
1055
  # Create a copy for each batch item (in case we modify it)
 
1085
  target_wavs = torch.zeros(2, frames)
1086
  return target_wavs
1087
  except Exception as e:
1088
+ logger.exception("[create_target_wavs] Error creating target audio")
1089
  # Fallback to 30 seconds if error
1090
  return torch.zeros(2, 30 * 48000)
1091
 
 
1266
  """
1267
  batch_size = len(captions)
1268
 
1269
+ # Normalize audio_code_hints to batch list
1270
+ audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size)
 
 
 
 
 
 
 
 
1271
 
1272
  for ii, refer_audio_list in enumerate(refer_audios):
1273
  if isinstance(refer_audio_list, list):
 
1279
  if vocal_languages is None:
1280
  vocal_languages = self._create_fallback_vocal_languages(batch_size)
1281
 
 
 
 
 
 
 
 
 
 
 
 
1282
  # Parse metas with fallbacks
1283
  parsed_metas = self._parse_metas(metas)
1284
 
 
1312
  expected_latent_length = current_wav.shape[-1] // 1920
1313
  target_latent = self.silence_latent[0, :expected_latent_length, :]
1314
  else:
1315
+ # Encode using helper method
1316
  logger.info(f"[generate_music] Encoding target audio to latents for item {i}...")
1317
+ target_latent = self._encode_audio_to_latents(current_wav.squeeze(0)) # Remove batch dim for helper
 
 
 
 
1318
  target_latents_list.append(target_latent)
1319
  latent_lengths.append(target_latent.shape[0])
1320
 
 
1353
 
1354
  # Process instructions early so we can use them for task type detection
1355
  # Use custom instructions if provided, otherwise use default
1356
+ instructions = self._normalize_instructions(instructions, batch_size, DEFAULT_DIT_INSTRUCTION)
 
 
 
 
 
 
 
 
 
 
 
1357
 
1358
  # Generate chunk_masks and spans based on repainting parameters
1359
  # Also determine if this is a cover task (target audio provided without repainting)
 
1502
  else:
1503
  precomputed_lm_hints_25Hz = None
1504
 
1505
+ # Extract caption and language from metas if available (from LM CoT output)
1506
+ # Fallback to user-provided values if not in metas
1507
+ actual_captions, actual_languages = self._extract_caption_and_language(parsed_metas, captions, vocal_languages)
1508
+
1509
  # Format text_inputs
1510
  text_inputs = []
1511
  text_token_idss = []
 
1515
 
1516
  for i in range(batch_size):
1517
  # Use custom instruction for this batch item
1518
+ instruction = self._format_instruction(instructions[i] if i < len(instructions) else DEFAULT_DIT_INSTRUCTION)
1519
+
1520
+ actual_caption = actual_captions[i]
1521
+ actual_language = actual_languages[i]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1522
 
1523
  # Format text prompt with custom instruction (using LM-generated caption if available)
1524
  text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
 
1535
  text_attention_mask = text_inputs_dict.attention_mask[0].bool()
1536
 
1537
  # Format and tokenize lyrics (using LM-generated language if available)
1538
+ lyrics_text = self._format_lyrics(lyrics[i], actual_language)
1539
  lyrics_inputs_dict = self.text_tokenizer(
1540
  lyrics_text,
1541
  padding="longest",
 
1557
 
1558
  # Pad tokenized sequences
1559
  max_text_length = max(len(seq) for seq in text_token_idss)
1560
+ padded_text_token_idss = self._pad_sequences(text_token_idss, max_text_length, self.text_tokenizer.pad_token_id)
1561
+ padded_text_attention_masks = self._pad_sequences(text_attention_masks, max_text_length, 0)
 
 
 
 
 
 
 
 
 
 
 
 
1562
 
1563
  max_lyric_length = max(len(seq) for seq in lyric_token_idss)
1564
+ padded_lyric_token_idss = self._pad_sequences(lyric_token_idss, max_lyric_length, self.text_tokenizer.pad_token_id)
1565
+ padded_lyric_attention_masks = self._pad_sequences(lyric_attention_masks, max_lyric_length, 0)
 
 
 
 
 
 
 
 
 
 
 
 
1566
 
1567
  padded_non_cover_text_input_ids = None
1568
  padded_non_cover_text_attention_masks = None
 
1571
  non_cover_text_attention_masks = []
1572
  for i in range(batch_size):
1573
  # Use custom instruction for this batch item
1574
+ instruction = self._format_instruction(DEFAULT_DIT_INSTRUCTION)
1575
 
1576
  # Extract caption from metas if available (from LM CoT output)
1577
+ actual_caption = actual_captions[i]
 
 
 
 
1578
 
1579
  # Format text prompt with custom instruction (using LM-generated caption if available)
1580
  text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
 
1592
  non_cover_text_input_ids.append(text_token_ids)
1593
  non_cover_text_attention_masks.append(non_cover_text_attention_mask)
1594
 
1595
+ padded_non_cover_text_input_ids = self._pad_sequences(non_cover_text_input_ids, max_text_length, self.text_tokenizer.pad_token_id)
1596
+ padded_non_cover_text_attention_masks = self._pad_sequences(non_cover_text_attention_masks, max_text_length, 0)
 
 
 
 
 
 
 
 
 
 
 
1597
 
1598
  if audio_cover_strength < 1.0:
1599
  assert padded_non_cover_text_input_ids is not None, "When audio_cover_strength < 1.0, padded_non_cover_text_input_ids must not be None"
 
1827
  if self.config.is_turbo:
1828
  # Limit inference steps to maximum 8
1829
  if infer_steps > 8:
1830
+ logger.warning(f"[service_generate] dmd_gan version: infer_steps {infer_steps} exceeds maximum 8, clamping to 8")
1831
  infer_steps = 8
1832
  # CFG parameters are not adjustable for dmd_gan (they will be ignored)
1833
  # Note: guidance_scale, cfg_interval_start, cfg_interval_end are still passed but may be ignored by the model
 
1850
  if isinstance(repainting_end, (int, float)):
1851
  repainting_end = [repainting_end]
1852
 
 
 
 
 
 
 
 
 
 
 
 
 
1853
  # Get batch size from captions
1854
  batch_size = len(captions)
1855
 
1856
+ # Normalize instructions and audio_code_hints to match batch size
1857
+ instructions = self._normalize_instructions(instructions, batch_size, DEFAULT_DIT_INSTRUCTION) if instructions is not None else None
1858
+ audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size) if audio_code_hints is not None else None
 
 
 
 
 
 
1859
 
1860
  # Convert seed to list format
1861
  if seed is None:
 
1952
  logger.info("[service_generate] Generating audio...")
1953
  with self._load_model_context("model"):
1954
  outputs = self.model.generate_audio(**generate_kwargs)
1955
+
1956
+ # Add intermediate information to outputs for extra_outputs
1957
+ outputs["src_latents"] = src_latents
1958
+ outputs["target_latents_input"] = target_latents # Input target latents (before generation)
1959
+ outputs["chunk_masks"] = chunk_mask
1960
+ outputs["spans"] = spans
1961
+ outputs["latent_masks"] = batch.get("latent_masks") # Latent masks for valid length
1962
+
1963
  return outputs
1964
 
1965
  def tiled_decode(self, latents, chunk_size=512, overlap=64):
 
2055
  use_adg: bool = False,
2056
  cfg_interval_start: float = 0.0,
2057
  cfg_interval_end: float = 1.0,
 
 
2058
  use_tiled_decode: bool = True,
2059
  progress=None
2060
+ ) -> Dict[str, Any]:
2061
  """
2062
  Main interface for music generation
2063
 
2064
  Returns:
2065
+ Dictionary containing:
2066
+ - audios: List of audio dictionaries with path, key, params
2067
+ - generation_info: Markdown-formatted generation information
2068
+ - status_message: Status message
2069
+ - extra_outputs: Dictionary with latents, masks, time_costs, etc.
2070
+ - success: Whether generation completed successfully
2071
+ - error: Error message if generation failed
2072
  """
2073
  if progress is None:
2074
  def progress(*args, **kwargs):
2075
  pass
2076
 
2077
  if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
2078
+ return {
2079
+ "audios": [],
2080
+ "status_message": "❌ Model not fully initialized. Please initialize all components first.",
2081
+ "extra_outputs": {},
2082
+ "success": False,
2083
+ "error": "Model not fully initialized",
2084
+ }
2085
 
2086
  def _has_audio_codes(v: Union[str, List[str]]) -> bool:
2087
  if isinstance(v, list):
 
2100
 
2101
  logger.info("[generate_music] Starting generation...")
2102
  if progress:
2103
+ progress(0.51, desc="Preparing inputs...")
2104
  logger.info("[generate_music] Preparing inputs...")
2105
 
2106
  # Reset offload cost
 
2122
  repainting_end = None
2123
 
2124
  try:
 
 
2125
  # 1. Process reference audio
2126
  refer_audios = None
2127
  if reference_audio is not None:
 
2173
  can_use_repainting
2174
  )
2175
 
2176
+ progress(0.52, desc=f"Generating music (batch size: {actual_batch_size})...")
2177
 
2178
  # Prepare audio_code_hints - use if audio_code_string is provided
2179
  # This works for both text2music (auto-switched to cover) and cover tasks
 
2210
  pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
2211
  time_costs = outputs["time_costs"]
2212
  time_costs["offload_time_cost"] = self.current_offload_cost
2213
+ logger.debug(f"[generate_music] pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}")
2214
+ logger.debug(f"[generate_music] time_costs: {time_costs}")
2215
  if progress:
2216
  progress(0.8, desc="Decoding audio...")
2217
  logger.info("[generate_music] Decoding latents with VAE...")
 
2240
  # Update offload cost one last time to include VAE offloading
2241
  time_costs["offload_time_cost"] = self.current_offload_cost
2242
 
2243
+ logger.info("[generate_music] VAE decode completed. Preparing audio tensors...")
2244
  if progress:
2245
+ progress(0.99, desc="Preparing audio data...")
2246
 
2247
+ # Prepare audio tensors (no file I/O here, no UUID generation)
2248
+ # pred_wavs is already [batch, channels, samples] format
2249
+ # Move to CPU and convert to float32 for return
2250
+ audio_tensors = []
2251
 
 
 
2252
  for i in range(actual_batch_size):
2253
+ # Extract audio tensor: [channels, samples] format, CPU, float32
2254
+ audio_tensor = pred_wavs[i].cpu().float()
2255
+ audio_tensors.append(audio_tensor)
2256
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2257
  status_message = f"✅ Generation completed successfully!"
2258
+ logger.info(f"[generate_music] Done! Generated {len(audio_tensors)} audio tensors.")
2259
+
2260
+ # Extract intermediate information from outputs
2261
+ src_latents = outputs.get("src_latents") # [batch, T, D]
2262
+ target_latents_input = outputs.get("target_latents_input") # [batch, T, D]
2263
+ chunk_masks = outputs.get("chunk_masks") # [batch, T]
2264
+ spans = outputs.get("spans", []) # List of tuples
2265
+ latent_masks = outputs.get("latent_masks") # [batch, T]
2266
+
2267
+ # Move latents to CPU to save memory (they can be large)
2268
+ extra_outputs = {
2269
+ "pred_latents": pred_latents.cpu() if pred_latents is not None else None,
2270
+ "target_latents": target_latents_input.cpu() if target_latents_input is not None else None,
2271
+ "src_latents": src_latents.cpu() if src_latents is not None else None,
2272
+ "chunk_masks": chunk_masks.cpu() if chunk_masks is not None else None,
2273
+ "latent_masks": latent_masks.cpu() if latent_masks is not None else None,
2274
+ "spans": spans,
2275
+ "time_costs": time_costs,
2276
+ "seed_value": seed_value_for_ui,
2277
+ }
2278
+
2279
+ # Build audios list with tensor data (no file paths, no UUIDs, handled outside)
2280
+ audios = []
2281
+ for idx, audio_tensor in enumerate(audio_tensors):
2282
+ audio_dict = {
2283
+ "tensor": audio_tensor, # torch.Tensor [channels, samples], CPU, float32
2284
+ "sample_rate": self.sample_rate,
2285
+ }
2286
+ audios.append(audio_dict)
2287
+
2288
+ return {
2289
+ "audios": audios,
2290
+ "status_message": status_message,
2291
+ "extra_outputs": extra_outputs,
2292
+ "success": True,
2293
+ "error": None,
2294
+ }
2295
 
2296
  except Exception as e:
2297
  error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
2298
+ logger.exception("[generate_music] Generation failed")
2299
+ return {
2300
+ "audios": [],
2301
+ "status_message": error_msg,
2302
+ "extra_outputs": {},
2303
+ "success": False,
2304
+ "error": str(e),
2305
+ }
acestep/inference.py CHANGED
@@ -7,105 +7,100 @@ backward-compatible Gradio UI support.
7
  """
8
 
9
  import math
 
 
10
  from typing import Optional, Union, List, Dict, Any, Tuple
11
  from dataclasses import dataclass, field, asdict
12
  from loguru import logger
13
- import time as time_module
 
14
 
15
 
16
  @dataclass
17
- class GenerationConfig:
18
- """Configuration for music generation.
19
 
20
  Attributes:
21
  # Text Inputs
22
- caption: Text description of the desired music
23
- lyrics: Lyrics text for vocal music (use "[Instrumental]" for instrumental)
 
24
 
25
  # Music Metadata
26
- bpm: Beats per minute (e.g., 120). None for auto-detection
27
- key_scale: Musical key (e.g., "C Major", "Am"). Empty for auto-detection
28
- time_signature: Time signature (e.g., "4/4", "3/4"). Empty for auto-detection
29
- vocal_language: Language code for vocals (e.g., "en", "zh", "ja")
30
- audio_duration: Duration in seconds. None for auto-detection
31
 
32
  # Generation Parameters
33
- inference_steps: Number of denoising steps (8 for turbo, 32-100 for base)
34
- guidance_scale: Classifier-free guidance scale (higher = more adherence to prompt)
35
- use_random_seed: Whether to use random seed (True) or fixed seed
36
- seed: Random seed for reproducibility (-1 for random)
37
- batch_size: Number of samples to generate (1-8)
38
 
39
  # Advanced DiT Parameters
40
- use_adg: Use Adaptive Dual Guidance (base model only)
41
- cfg_interval_start: CFG application start ratio (0.0-1.0)
42
- cfg_interval_end: CFG application end ratio (0.0-1.0)
43
- audio_format: Output audio format ("mp3", "wav", "flac")
44
 
45
  # Task-Specific Parameters
46
- task_type: Generation task type ("text2music", "cover", "repaint", "lego", "extract", "complete")
47
- reference_audio: Path to reference audio file (for style transfer)
48
- src_audio: Path to source audio file (for audio-to-audio tasks)
49
- audio_code_string: Pre-extracted audio codes (advanced use)
50
- repainting_start: Repainting start time in seconds (for repaint/lego tasks)
51
- repainting_end: Repainting end time in seconds (-1 for end of audio)
52
- audio_cover_strength: Strength of audio cover/codes influence (0.0-1.0)
53
- instruction: Task-specific instruction prompt (auto-generated if empty)
54
 
55
- # 5Hz Language Model Parameters (CoT Reasoning)
56
- use_llm_thinking: Enable LM-based Chain-of-Thought reasoning for metadata/codes
57
- lm_temperature: LM sampling temperature (0.0-2.0, higher = more creative)
58
- lm_cfg_scale: LM classifier-free guidance scale
59
- lm_top_k: LM top-k sampling (0 = disabled)
60
- lm_top_p: LM nucleus sampling (1.0 = disabled)
61
- lm_negative_prompt: Negative prompt for LM guidance
62
- use_cot_metas: Generate metadata using LM CoT
63
- use_cot_caption: Refine caption using LM CoT
64
- use_cot_language: Detect language using LM CoT
65
- is_format_caption: Whether caption is already formatted
66
- constrained_decoding_debug: Enable debug logging for constrained decoding
67
-
68
- # Batch LM Generation
69
- allow_lm_batch: Allow batch LM code generation (faster for batch_size >= 2)
70
- lm_batch_chunk_size: Maximum batch size per LM inference chunk (GPU memory constraint)
71
  """
72
-
 
 
 
 
 
 
 
 
 
 
73
  # Text Inputs
74
  caption: str = ""
75
  lyrics: str = ""
76
-
77
- # Music Metadata
78
- bpm: Optional[int] = None
79
- key_scale: str = ""
80
- time_signature: str = ""
81
  vocal_language: str = "unknown"
82
- audio_duration: Optional[float] = None
83
-
84
- # Generation Parameters
 
 
 
85
  inference_steps: int = 8
86
- guidance_scale: float = 7.0
87
- use_random_seed: bool = True
88
  seed: int = -1
89
- batch_size: int = 1
90
-
91
- # Advanced DiT Parameters
92
  use_adg: bool = False
93
  cfg_interval_start: float = 0.0
94
  cfg_interval_end: float = 1.0
95
- audio_format: str = "mp3"
96
-
97
- # Task-Specific Parameters
98
- task_type: str = "text2music"
99
- reference_audio: Optional[str] = None
100
- src_audio: Optional[str] = None
101
- audio_code_string: Union[str, List[str]] = ""
102
  repainting_start: float = 0.0
103
  repainting_end: float = -1
104
  audio_cover_strength: float = 1.0
105
- instruction: str = ""
106
-
107
  # 5Hz Language Model Parameters
108
- use_llm_thinking: bool = False
109
  lm_temperature: float = 0.85
110
  lm_cfg_scale: float = 2.0
111
  lm_top_k: int = 0
@@ -113,13 +108,50 @@ class GenerationConfig:
113
  lm_negative_prompt: str = "NO USER INPUT"
114
  use_cot_metas: bool = True
115
  use_cot_caption: bool = True
 
116
  use_cot_language: bool = True
117
- is_format_caption: bool = False
118
- constrained_decoding_debug: bool = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- # Batch LM Generation
 
 
 
 
 
 
 
 
 
 
 
 
121
  allow_lm_batch: bool = False
122
- lm_batch_chunk_size: int = 4
 
 
 
 
 
 
 
 
123
 
124
 
125
  @dataclass
@@ -128,801 +160,461 @@ class GenerationResult:
128
 
129
  Attributes:
130
  # Audio Outputs
131
- audio_paths: List of paths to generated audio files
132
- first_audio: Path to first generated audio (backward compatibility)
133
- second_audio: Path to second generated audio (backward compatibility)
134
-
135
- # Generation Information
136
- generation_info: Markdown-formatted generation information
137
  status_message: Status message from generation
138
- seed_value: Actual seed value used for generation
139
-
140
- # LM-Generated Metadata (if applicable)
141
- lm_metadata: Metadata generated by language model (dict or None)
142
-
143
- # Audio-Text Alignment Scores (if available)
144
- align_score_1: First alignment score
145
- align_text_1: First alignment text description
146
- align_plot_1: First alignment plot image
147
- align_score_2: Second alignment score
148
- align_text_2: Second alignment text description
149
- align_plot_2: Second alignment plot image
150
-
151
- # Success Status
152
  success: Whether generation completed successfully
153
  error: Error message if generation failed
154
  """
155
-
156
  # Audio Outputs
157
- audio_paths: List[str] = field(default_factory=list)
158
- first_audio: Optional[str] = None
159
- second_audio: Optional[str] = None
160
-
161
  # Generation Information
162
- generation_info: str = ""
163
  status_message: str = ""
164
- seed_value: str = ""
165
-
166
- # LM-Generated Metadata
167
- lm_metadata: Optional[Dict[str, Any]] = None
168
-
169
- # Audio-Text Alignment Scores
170
- align_score_1: Optional[float] = None
171
- align_text_1: Optional[str] = None
172
- align_plot_1: Optional[Any] = None
173
- align_score_2: Optional[float] = None
174
- align_text_2: Optional[str] = None
175
- align_plot_2: Optional[Any] = None
176
-
177
  # Success Status
178
  success: bool = True
179
  error: Optional[str] = None
180
-
181
  def to_dict(self) -> Dict[str, Any]:
182
  """Convert result to dictionary for JSON serialization."""
183
  return asdict(self)
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  def generate_music(
187
  dit_handler,
188
  llm_handler,
 
189
  config: GenerationConfig,
 
 
190
  ) -> GenerationResult:
191
  """Generate music using ACE-Step model with optional LM reasoning.
192
 
193
- This is the main inference API for music generation. It supports various task types
194
- (text2music, cover, repaint, etc.) and can optionally use a 5Hz Language Model for
195
- Chain-of-Thought reasoning to generate metadata and audio codes.
196
-
197
  Args:
198
  dit_handler: Initialized DiT model handler (AceStepHandler instance)
199
  llm_handler: Initialized LLM handler (LLMHandler instance)
 
200
  config: Generation configuration (GenerationConfig instance)
201
 
202
  Returns:
203
- GenerationResult: Generation result containing audio paths and metadata
204
-
205
- Example:
206
- >>> from acestep.handler import AceStepHandler
207
- >>> from acestep.llm_inference import LLMHandler
208
- >>> from acestep.inference import GenerationConfig, generate_music
209
- >>>
210
- >>> # Initialize handlers
211
- >>> dit_handler = AceStepHandler()
212
- >>> llm_handler = LLMHandler()
213
- >>> dit_handler.initialize_service(...)
214
- >>> llm_handler.initialize(...)
215
- >>>
216
- >>> # Configure generation
217
- >>> config = GenerationConfig(
218
- ... caption="upbeat electronic dance music",
219
- ... bpm=128,
220
- ... audio_duration=30,
221
- ... batch_size=2,
222
- ... )
223
- >>>
224
- >>> # Generate music
225
- >>> result = generate_music(dit_handler, llm_handler, config)
226
- >>> print(f"Generated {len(result.audio_paths)} audio files")
227
- >>> for path in result.audio_paths:
228
- ... print(f"Audio: {path}")
229
  """
230
-
231
  try:
232
  # Phase 1: LM-based metadata and code generation (if enabled)
233
- audio_code_string_to_use = config.audio_code_string
234
  lm_generated_metadata = None
235
- lm_generated_audio_codes = None
236
  lm_generated_audio_codes_list = []
237
-
 
 
 
 
 
238
  # Extract mutable copies of metadata (will be updated by LM if needed)
239
- bpm = config.bpm
240
- key_scale = config.key_scale
241
- time_signature = config.time_signature
242
- audio_duration = config.audio_duration
243
-
244
- # Determine if we should use batch LM generation
245
- should_use_lm_batch = (
246
- config.use_llm_thinking
247
- and llm_handler.llm_initialized
248
- and config.use_cot_metas
249
- and config.allow_lm_batch
250
- and config.batch_size >= 2
251
- )
252
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  # LM-based Chain-of-Thought reasoning
254
- if config.use_llm_thinking and llm_handler.llm_initialized and config.use_cot_metas:
255
- # Convert sampling parameters
256
- top_k_value = None if config.lm_top_k == 0 else int(config.lm_top_k)
257
- top_p_value = None if config.lm_top_p >= 1.0 else config.lm_top_p
258
-
 
 
259
  # Build user_metadata from user-provided values
260
  user_metadata = {}
261
  if bpm is not None:
262
  try:
263
  bpm_value = float(bpm)
264
  if bpm_value > 0:
265
- user_metadata['bpm'] = str(int(bpm_value))
266
  except (ValueError, TypeError):
267
  pass
268
-
269
  if key_scale and key_scale.strip():
270
  key_scale_clean = key_scale.strip()
271
  if key_scale_clean.lower() not in ["n/a", ""]:
272
  user_metadata['keyscale'] = key_scale_clean
273
-
274
  if time_signature and time_signature.strip():
275
  time_sig_clean = time_signature.strip()
276
  if time_sig_clean.lower() not in ["n/a", ""]:
277
  user_metadata['timesignature'] = time_sig_clean
278
-
279
  if audio_duration is not None:
280
  try:
281
  duration_value = float(audio_duration)
282
  if duration_value > 0:
283
- user_metadata['duration'] = str(int(duration_value))
284
  except (ValueError, TypeError):
285
  pass
286
-
287
  user_metadata_to_pass = user_metadata if user_metadata else None
288
-
289
- # Batch LM generation (faster for multiple samples)
290
- if should_use_lm_batch:
291
- actual_seed_list, _ = dit_handler.prepare_seeds(
292
- config.batch_size, config.seed, config.use_random_seed
293
- )
294
-
295
- max_inference_batch_size = int(config.lm_batch_chunk_size)
296
- num_chunks = math.ceil(config.batch_size / max_inference_batch_size)
297
-
298
- all_metadata_list = []
299
- all_audio_codes_list = []
300
-
301
- for chunk_idx in range(num_chunks):
302
- chunk_start = chunk_idx * max_inference_batch_size
303
- chunk_end = min(chunk_start + max_inference_batch_size, config.batch_size)
304
- chunk_size = chunk_end - chunk_start
305
- chunk_seeds = actual_seed_list[chunk_start:chunk_end]
306
-
307
- logger.info(
308
- f"LM batch chunk {chunk_idx+1}/{num_chunks} "
309
- f"(size: {chunk_size}, seeds: {chunk_seeds})"
310
- )
311
-
312
- metadata_list, audio_codes_list, status = llm_handler.generate_with_stop_condition_batch(
313
- caption=config.caption or "",
314
- lyrics=config.lyrics or "",
315
- batch_size=chunk_size,
316
- infer_type="llm_dit",
317
- temperature=config.lm_temperature,
318
- cfg_scale=config.lm_cfg_scale,
319
- negative_prompt=config.lm_negative_prompt,
320
- top_k=top_k_value,
321
- top_p=top_p_value,
322
- user_metadata=user_metadata_to_pass,
323
- use_cot_caption=config.use_cot_caption,
324
- use_cot_language=config.use_cot_language,
325
- is_format_caption=config.is_format_caption,
326
- constrained_decoding_debug=config.constrained_decoding_debug,
327
- seeds=chunk_seeds,
328
- )
329
-
330
- all_metadata_list.extend(metadata_list)
331
- all_audio_codes_list.extend(audio_codes_list)
332
-
333
- lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
334
- lm_generated_audio_codes_list = all_audio_codes_list
335
- audio_code_string_to_use = all_audio_codes_list
336
-
337
- # Update metadata from LM if not provided by user
338
- if lm_generated_metadata:
339
- bpm, key_scale, time_signature, audio_duration = _update_metadata_from_lm(
340
- lm_generated_metadata, bpm, key_scale, time_signature, audio_duration
341
- )
342
-
343
- else:
344
- # Sequential LM generation (current behavior)
345
- # Phase 1: Generate CoT metadata
346
- phase1_start = time_module.time()
347
- metadata, _, status = llm_handler.generate_with_stop_condition(
348
- caption=config.caption or "",
349
- lyrics=config.lyrics or "",
350
- infer_type="dit",
351
- temperature=config.lm_temperature,
352
- cfg_scale=config.lm_cfg_scale,
353
- negative_prompt=config.lm_negative_prompt,
354
- top_k=top_k_value,
355
- top_p=top_p_value,
356
- user_metadata=user_metadata_to_pass,
357
- use_cot_caption=config.use_cot_caption,
358
- use_cot_language=config.use_cot_language,
359
- is_format_caption=config.is_format_caption,
360
- constrained_decoding_debug=config.constrained_decoding_debug,
361
- )
362
- lm_phase1_time = time_module.time() - phase1_start
363
- logger.info(f"LM Phase 1 (CoT) completed in {lm_phase1_time:.2f}s")
364
-
365
- # Phase 2: Generate audio codes
366
- phase2_start = time_module.time()
367
- metadata, audio_codes, status = llm_handler.generate_with_stop_condition(
368
- caption=config.caption or "",
369
- lyrics=config.lyrics or "",
370
- infer_type="llm_dit",
371
- temperature=config.lm_temperature,
372
- cfg_scale=config.lm_cfg_scale,
373
- negative_prompt=config.lm_negative_prompt,
374
  top_k=top_k_value,
375
  top_p=top_p_value,
376
  user_metadata=user_metadata_to_pass,
377
- use_cot_caption=config.use_cot_caption,
378
- use_cot_language=config.use_cot_language,
379
- is_format_caption=config.is_format_caption,
 
380
  constrained_decoding_debug=config.constrained_decoding_debug,
 
 
 
381
  )
382
- lm_phase2_time = time_module.time() - phase2_start
383
- logger.info(f"LM Phase 2 (Codes) completed in {lm_phase2_time:.2f}s")
384
-
385
- lm_generated_metadata = metadata
386
- if audio_codes:
387
- audio_code_string_to_use = audio_codes
388
- lm_generated_audio_codes = audio_codes
389
-
390
- # Update metadata from LM if not provided by user
391
- bpm, key_scale, time_signature, audio_duration = _update_metadata_from_lm(
392
- metadata, bpm, key_scale, time_signature, audio_duration
 
393
  )
394
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  # Phase 2: DiT music generation
 
396
  result = dit_handler.generate_music(
397
- captions=config.caption,
398
- lyrics=config.lyrics,
399
  bpm=bpm,
400
  key_scale=key_scale,
401
  time_signature=time_signature,
402
- vocal_language=config.vocal_language,
403
- inference_steps=config.inference_steps,
404
- guidance_scale=config.guidance_scale,
405
  use_random_seed=config.use_random_seed,
406
- seed=config.seed,
407
- reference_audio=config.reference_audio,
408
  audio_duration=audio_duration,
409
- batch_size=config.batch_size,
410
- src_audio=config.src_audio,
411
  audio_code_string=audio_code_string_to_use,
412
- repainting_start=config.repainting_start,
413
- repainting_end=config.repainting_end,
414
- instruction=config.instruction,
415
- audio_cover_strength=config.audio_cover_strength,
416
- task_type=config.task_type,
417
- use_adg=config.use_adg,
418
- cfg_interval_start=config.cfg_interval_start,
419
- cfg_interval_end=config.cfg_interval_end,
420
- audio_format=config.audio_format,
421
- lm_temperature=config.lm_temperature,
422
  )
423
-
424
- # Extract results
425
- (first_audio, second_audio, all_audio_paths, generation_info, status_message,
426
- seed_value, align_score_1, align_text_1, align_plot_1,
427
- align_score_2, align_text_2, align_plot_2) = result
428
-
429
- # Append LM metadata to generation info
430
- if lm_generated_metadata:
431
- generation_info = _append_lm_metadata_to_info(generation_info, lm_generated_metadata)
432
-
433
- # Create result object
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  return GenerationResult(
435
- audio_paths=all_audio_paths or [],
436
- first_audio=first_audio,
437
- second_audio=second_audio,
438
- generation_info=generation_info,
439
  status_message=status_message,
440
- seed_value=seed_value,
441
- lm_metadata=lm_generated_metadata,
442
- align_score_1=align_score_1,
443
- align_text_1=align_text_1,
444
- align_plot_1=align_plot_1,
445
- align_score_2=align_score_2,
446
- align_text_2=align_text_2,
447
- align_plot_2=align_plot_2,
448
  success=True,
449
  error=None,
450
  )
451
-
452
  except Exception as e:
453
  logger.exception("Music generation failed")
454
  return GenerationResult(
 
 
 
455
  success=False,
456
  error=str(e),
457
- generation_info=f"❌ Generation failed: {str(e)}",
458
- status_message=f"Error: {str(e)}",
459
  )
460
-
461
-
462
- def _update_metadata_from_lm(
463
- metadata: Dict[str, Any],
464
- bpm: Optional[int],
465
- key_scale: str,
466
- time_signature: str,
467
- audio_duration: Optional[float],
468
- ) -> Tuple[Optional[int], str, str, Optional[float]]:
469
- """Update metadata fields from LM output if not provided by user."""
470
-
471
- if bpm is None and metadata.get('bpm'):
472
- bpm_value = metadata.get('bpm')
473
- if bpm_value not in ["N/A", ""]:
474
- try:
475
- bpm = int(bpm_value)
476
- except (ValueError, TypeError):
477
- pass
478
-
479
- if not key_scale and metadata.get('keyscale'):
480
- key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
481
- if key_scale_value != "N/A":
482
- key_scale = key_scale_value
483
-
484
- if not time_signature and metadata.get('timesignature'):
485
- time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
486
- if time_signature_value != "N/A":
487
- time_signature = time_signature_value
488
-
489
- if audio_duration is None or audio_duration <= 0:
490
- audio_duration_value = metadata.get('duration', -1)
491
- if audio_duration_value not in ["N/A", ""]:
492
- try:
493
- audio_duration = float(audio_duration_value)
494
- except (ValueError, TypeError):
495
- pass
496
-
497
- return bpm, key_scale, time_signature, audio_duration
498
-
499
-
500
- def _append_lm_metadata_to_info(generation_info: str, metadata: Dict[str, Any]) -> str:
501
- """Append LM-generated metadata to generation info string."""
502
-
503
- metadata_lines = []
504
- if metadata.get('bpm'):
505
- metadata_lines.append(f"- **BPM:** {metadata['bpm']}")
506
- if metadata.get('caption'):
507
- metadata_lines.append(f"- **Refined Caption:** {metadata['caption']}")
508
- if metadata.get('duration'):
509
- metadata_lines.append(f"- **Duration:** {metadata['duration']} seconds")
510
- if metadata.get('keyscale'):
511
- metadata_lines.append(f"- **Key Scale:** {metadata['keyscale']}")
512
- if metadata.get('language'):
513
- metadata_lines.append(f"- **Language:** {metadata['language']}")
514
- if metadata.get('timesignature'):
515
- metadata_lines.append(f"- **Time Signature:** {metadata['timesignature']}")
516
-
517
- if metadata_lines:
518
- metadata_section = "\n\n**🤖 LM-Generated Metadata:**\n" + "\n\n".join(metadata_lines)
519
- return metadata_section + "\n\n" + generation_info
520
-
521
- return generation_info
522
-
523
-
524
- # ============================================================================
525
- # LEGACY GRADIO UI COMPATIBILITY LAYER
526
- # ============================================================================
527
-
528
- def generate(
529
- dit_handler,
530
- llm_handler,
531
- captions,
532
- lyrics,
533
- bpm,
534
- key_scale,
535
- time_signature,
536
- vocal_language,
537
- inference_steps,
538
- guidance_scale,
539
- random_seed_checkbox,
540
- seed,
541
- reference_audio,
542
- audio_duration,
543
- batch_size_input,
544
- src_audio,
545
- text2music_audio_code_string,
546
- repainting_start,
547
- repainting_end,
548
- instruction_display_gen,
549
- audio_cover_strength,
550
- task_type,
551
- use_adg,
552
- cfg_interval_start,
553
- cfg_interval_end,
554
- audio_format,
555
- lm_temperature,
556
- think_checkbox,
557
- lm_cfg_scale,
558
- lm_top_k,
559
- lm_top_p,
560
- lm_negative_prompt,
561
- use_cot_metas,
562
- use_cot_caption,
563
- use_cot_language,
564
- is_format_caption,
565
- constrained_decoding_debug,
566
- allow_lm_batch,
567
- lm_batch_chunk_size,
568
- ):
569
- """Legacy Gradio UI compatibility wrapper.
570
-
571
- This function maintains backward compatibility with the Gradio UI.
572
- For new integrations, use generate_music() with GenerationConfig instead.
573
-
574
- Returns:
575
- Tuple with 28 elements for Gradio UI component updates
576
- """
577
-
578
- # Convert legacy parameters to new config
579
- config = GenerationConfig(
580
- caption=captions,
581
- lyrics=lyrics,
582
- bpm=bpm,
583
- key_scale=key_scale,
584
- time_signature=time_signature,
585
- vocal_language=vocal_language,
586
- audio_duration=audio_duration,
587
- inference_steps=inference_steps,
588
- guidance_scale=guidance_scale,
589
- use_random_seed=random_seed_checkbox,
590
- seed=seed,
591
- batch_size=batch_size_input,
592
- use_adg=use_adg,
593
- cfg_interval_start=cfg_interval_start,
594
- cfg_interval_end=cfg_interval_end,
595
- audio_format=audio_format,
596
- task_type=task_type,
597
- reference_audio=reference_audio,
598
- src_audio=src_audio,
599
- audio_code_string=text2music_audio_code_string,
600
- repainting_start=repainting_start,
601
- repainting_end=repainting_end,
602
- audio_cover_strength=audio_cover_strength,
603
- instruction=instruction_display_gen,
604
- use_llm_thinking=think_checkbox,
605
- lm_temperature=lm_temperature,
606
- lm_cfg_scale=lm_cfg_scale,
607
- lm_top_k=lm_top_k,
608
- lm_top_p=lm_top_p,
609
- lm_negative_prompt=lm_negative_prompt,
610
- use_cot_metas=use_cot_metas,
611
- use_cot_caption=use_cot_caption,
612
- use_cot_language=use_cot_language,
613
- is_format_caption=is_format_caption,
614
- constrained_decoding_debug=constrained_decoding_debug,
615
- allow_lm_batch=allow_lm_batch,
616
- lm_batch_chunk_size=lm_batch_chunk_size,
617
- )
618
-
619
- # Call new API
620
- result = generate_music(dit_handler, llm_handler, config)
621
-
622
- # Determine which codes to update in UI
623
- if config.allow_lm_batch and result.lm_metadata:
624
- # Batch mode: extract codes from metadata if available
625
- lm_codes_list = result.lm_metadata.get('audio_codes_list', [])
626
- updated_audio_codes = lm_codes_list[0] if lm_codes_list else text2music_audio_code_string
627
- codes_outputs = (lm_codes_list + [""] * 8)[:8]
628
- else:
629
- # Single mode
630
- lm_codes = result.lm_metadata.get('audio_codes', '') if result.lm_metadata else ''
631
- updated_audio_codes = lm_codes if lm_codes else text2music_audio_code_string
632
- codes_outputs = [""] * 8
633
-
634
- # Prepare audio outputs (up to 8)
635
- audio_outputs = (result.audio_paths + [None] * 8)[:8]
636
-
637
- # Return tuple for Gradio UI (28 elements)
638
- return (
639
- audio_outputs[0], # generated_audio_1
640
- audio_outputs[1], # generated_audio_2
641
- audio_outputs[2], # generated_audio_3
642
- audio_outputs[3], # generated_audio_4
643
- audio_outputs[4], # generated_audio_5
644
- audio_outputs[5], # generated_audio_6
645
- audio_outputs[6], # generated_audio_7
646
- audio_outputs[7], # generated_audio_8
647
- result.audio_paths, # generated_audio_batch
648
- result.generation_info,
649
- result.status_message,
650
- result.seed_value,
651
- result.align_score_1,
652
- result.align_text_1,
653
- result.align_plot_1,
654
- result.align_score_2,
655
- result.align_text_2,
656
- result.align_plot_2,
657
- updated_audio_codes, # Update main audio codes in UI
658
- codes_outputs[0], # text2music_audio_code_string_1
659
- codes_outputs[1], # text2music_audio_code_string_2
660
- codes_outputs[2], # text2music_audio_code_string_3
661
- codes_outputs[3], # text2music_audio_code_string_4
662
- codes_outputs[4], # text2music_audio_code_string_5
663
- codes_outputs[5], # text2music_audio_code_string_6
664
- codes_outputs[6], # text2music_audio_code_string_7
665
- codes_outputs[7], # text2music_audio_code_string_8
666
- result.lm_metadata, # Store metadata for "Send to src audio" buttons
667
- is_format_caption, # Keep is_format_caption unchanged
668
- )
669
-
670
-
671
- # ============================================================================
672
- # TESTING & EXAMPLES
673
- # ============================================================================
674
-
675
- if __name__ == "__main__":
676
- """
677
- Test suite for the inference API.
678
- Demonstrates various usage scenarios and validates functionality.
679
-
680
- Usage:
681
- python -m acestep.inference
682
- """
683
-
684
- import os
685
- import json
686
- from acestep.handler import AceStepHandler
687
- from acestep.llm_inference import LLMHandler
688
-
689
- # Initialize paths
690
- project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
691
- checkpoint_dir = os.path.join(project_root, "checkpoints")
692
-
693
- print("=" * 80)
694
- print("ACE-Step Inference API Test Suite")
695
- print("=" * 80)
696
-
697
- # ========================================================================
698
- # Initialize Handlers
699
- # ========================================================================
700
- print("\n[1/3] Initializing handlers...")
701
- dit_handler = AceStepHandler(save_root="./")
702
- llm_handler = LLMHandler()
703
-
704
- try:
705
- # Initialize DiT handler
706
- print(" - Initializing DiT model...")
707
- status_dit, success_dit = dit_handler.initialize_service(
708
- project_root=project_root,
709
- config_path="acestep-v15-turbo-rl",
710
- device="cuda",
711
- )
712
- if not success_dit:
713
- print(f" ❌ DiT initialization failed: {status_dit}")
714
- exit(1)
715
- print(f" ✓ DiT model initialized successfully")
716
-
717
- # Initialize LLM handler
718
- print(" - Initializing 5Hz LM model...")
719
- status_llm, success_llm = llm_handler.initialize(
720
- checkpoint_dir=checkpoint_dir,
721
- lm_model_path="acestep-5Hz-lm-0.6B-v3",
722
- backend="vllm",
723
- device="cuda",
724
- )
725
- if success_llm:
726
- print(f" ✓ LM model initialized successfully")
727
- else:
728
- print(f" ⚠ LM initialization failed (will skip LM tests): {status_llm}")
729
-
730
- except Exception as e:
731
- print(f" ❌ Initialization error: {e}")
732
- exit(1)
733
-
734
- # ========================================================================
735
- # Helper Functions
736
- # ========================================================================
737
- def load_example_config(example_file: str) -> GenerationConfig:
738
- """Load configuration from an example JSON file."""
739
- try:
740
- with open(example_file, 'r', encoding='utf-8') as f:
741
- data = json.load(f)
742
-
743
- # Convert example format to GenerationConfig
744
- # Handle time signature format (example uses "4" instead of "4/4")
745
- time_sig = data.get('timesignature', '')
746
- if time_sig and '/' not in time_sig:
747
- time_sig = f"{time_sig}/4" # Default to /4 if only numerator given
748
-
749
- config = GenerationConfig(
750
- caption=data.get('caption', ''),
751
- lyrics=data.get('lyrics', ''),
752
- bpm=data.get('bpm'),
753
- key_scale=data.get('keyscale', ''),
754
- time_signature=time_sig,
755
- vocal_language=data.get('language', 'unknown'),
756
- audio_duration=data.get('duration'),
757
- use_llm_thinking=data.get('think', False),
758
- batch_size=data.get('batch_size', 1),
759
- inference_steps=data.get('inference_steps', 8),
760
- )
761
- return config
762
-
763
- except Exception as e:
764
- print(f" ⚠ Failed to load example file: {e}")
765
- return None
766
-
767
- # ========================================================================
768
- # Test Cases
769
- # ========================================================================
770
- test_results = []
771
-
772
- def run_test(test_name: str, config: GenerationConfig, expected_outputs: int = 1):
773
- """Run a single test case and collect results."""
774
- print(f"\n{'=' * 80}")
775
- print(f"Test: {test_name}")
776
- print(f"{'=' * 80}")
777
-
778
- # Display configuration
779
- print("\nConfiguration:")
780
- print(f" Task Type: {config.task_type}")
781
- print(f" Caption: {config.caption[:60]}..." if len(config.caption) > 60 else f" Caption: {config.caption}")
782
- if config.lyrics:
783
- print(f" Lyrics: {config.lyrics[:60]}..." if len(config.lyrics) > 60 else f" Lyrics: {config.lyrics}")
784
- if config.bpm:
785
- print(f" BPM: {config.bpm}")
786
- if config.key_scale:
787
- print(f" Key Scale: {config.key_scale}")
788
- if config.time_signature:
789
- print(f" Time Signature: {config.time_signature}")
790
- if config.audio_duration:
791
- print(f" Duration: {config.audio_duration}s")
792
- print(f" Batch Size: {config.batch_size}")
793
- print(f" Inference Steps: {config.inference_steps}")
794
- print(f" Use LLM Thinking: {config.use_llm_thinking}")
795
-
796
- # Run generation
797
- print("\nGenerating...")
798
- import time
799
- start_time = time.time()
800
-
801
- result = generate_music(dit_handler, llm_handler, config)
802
-
803
- elapsed_time = time.time() - start_time
804
-
805
- # Display results
806
- print("\nResults:")
807
- print(f" Success: {'✓' if result.success else '✗'}")
808
-
809
- if result.success:
810
- print(f" Generated Files: {len(result.audio_paths)}")
811
- for i, path in enumerate(result.audio_paths, 1):
812
- if os.path.exists(path):
813
- file_size = os.path.getsize(path) / (1024 * 1024) # MB
814
- print(f" [{i}] {os.path.basename(path)} ({file_size:.2f} MB)")
815
- else:
816
- print(f" [{i}] {os.path.basename(path)} (file not found)")
817
-
818
- print(f" Seed: {result.seed_value}")
819
- print(f" Generation Time: {elapsed_time:.2f}s")
820
-
821
- # Display LM metadata if available
822
- if result.lm_metadata:
823
- print(f"\n LM-Generated Metadata:")
824
- for key, value in result.lm_metadata.items():
825
- if key not in ['audio_codes', 'audio_codes_list']: # Skip large code strings
826
- print(f" {key}: {value}")
827
-
828
- # Validate outputs
829
- if len(result.audio_paths) != expected_outputs:
830
- print(f" ⚠ Warning: Expected {expected_outputs} outputs, got {len(result.audio_paths)}")
831
- success = False
832
- else:
833
- success = True
834
-
835
- else:
836
- print(f" Error: {result.error}")
837
- success = False
838
-
839
- # Store test result
840
- test_results.append({
841
- "test_name": test_name,
842
- "success": success,
843
- "generation_success": result.success,
844
- "num_outputs": len(result.audio_paths) if result.success else 0,
845
- "expected_outputs": expected_outputs,
846
- "elapsed_time": elapsed_time,
847
- "error": result.error if not result.success else None,
848
- })
849
-
850
- return result
851
-
852
- # ========================================================================
853
- # Test: Production Example (from examples directory)
854
- # ========================================================================
855
- print("\n[2/3] Running Test...")
856
-
857
- # Load production example (J-Rock song from examples/text2music/example_05.json)
858
- example_file = os.path.join(project_root, "examples", "text2music", "example_05.json")
859
-
860
- if not os.path.exists(example_file):
861
- print(f"\n ❌ Example file not found: {example_file}")
862
- print(" Please ensure the examples directory exists.")
863
- exit(1)
864
-
865
- print(f" Loading example: {os.path.basename(example_file)}")
866
- config = load_example_config(example_file)
867
-
868
- if not config:
869
- print(" ❌ Failed to load example configuration")
870
- exit(1)
871
-
872
- # Reduce duration for faster testing (original is 200s)
873
- print(f" Original duration: {config.audio_duration}s")
874
- config.audio_duration = 30
875
- config.use_random_seed = False
876
- config.seed = 42
877
- print(f" Test duration: {config.audio_duration}s (reduced for testing)")
878
-
879
- run_test("Production Example (J-Rock Song)", config, expected_outputs=1)
880
-
881
- # ========================================================================
882
- # Test Summary
883
- # ========================================================================
884
- print("\n[3/3] Test Summary")
885
- print("=" * 80)
886
-
887
- if len(test_results) == 0:
888
- print("No tests were run.")
889
- exit(1)
890
-
891
- result = test_results[0]
892
-
893
- print(f"\nTest: {result['test_name']}")
894
- print(f"Status: {'✓ PASS' if result['success'] else '✗ FAIL'}")
895
- print(f"Generation: {'Success' if result['generation_success'] else 'Failed'}")
896
- print(f"Outputs: {result['num_outputs']}/{result['expected_outputs']}")
897
- print(f"Time: {result['elapsed_time']:.2f}s")
898
-
899
- if result["error"]:
900
- print(f"Error: {result['error']}")
901
-
902
- # Save test results to JSON
903
- results_file = os.path.join(project_root, "test_results.json")
904
- try:
905
- with open(results_file, "w") as f:
906
- json.dump({
907
- "test_name": result['test_name'],
908
- "success": result['success'],
909
- "generation_success": result['generation_success'],
910
- "num_outputs": result['num_outputs'],
911
- "expected_outputs": result['expected_outputs'],
912
- "elapsed_time": result['elapsed_time'],
913
- "error": result['error'],
914
- }, f, indent=2)
915
- print(f"\n✓ Test results saved to: {results_file}")
916
- except Exception as e:
917
- print(f"\n⚠ Failed to save test results: {e}")
918
-
919
- # Exit with appropriate code
920
- print("\n" + "=" * 80)
921
- if result['success']:
922
- print("Test passed! ✓")
923
- print("=" * 80)
924
- exit(0)
925
- else:
926
- print("Test failed! ✗")
927
- print("=" * 80)
928
- exit(1)
 
7
  """
8
 
9
  import math
10
+ import os
11
+ import tempfile
12
  from typing import Optional, Union, List, Dict, Any, Tuple
13
  from dataclasses import dataclass, field, asdict
14
  from loguru import logger
15
+
16
+ from acestep.audio_utils import AudioSaver, generate_uuid_from_params
17
 
18
 
19
  @dataclass
20
+ class GenerationParams:
21
+ """Configuration for music generation parameters.
22
 
23
  Attributes:
24
  # Text Inputs
25
+ caption: A short text prompt describing the desired music (main prompt). < 512 characters
26
+ lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters
27
+ instrumental: If True, generate instrumental music regardless of lyrics.
28
 
29
  # Music Metadata
30
+ bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300
31
+ keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor
32
+ timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection.
33
+ vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES
34
+ duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600
35
 
36
  # Generation Parameters
37
+ inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32100 for base model).
38
+ guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model.
39
+ seed: Integer seed for reproducibility. -1 means use random seed each time.
 
 
40
 
41
  # Advanced DiT Parameters
42
+ use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
43
+ cfg_interval_start: Start ratio (0.01.0) to apply CFG.
44
+ cfg_interval_end: End ratio (0.01.0) to apply CFG.
 
45
 
46
  # Task-Specific Parameters
47
+ task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
48
+ reference_audio: Path to a reference audio file for style transfer or cover tasks.
49
+ src_audio: Path to a source audio file for audio-to-audio tasks.
50
+ audio_codes: Audio semantic codes as a string (advanced use, for code-control generation).
51
+ repainting_start: For repaint/lego tasks: start time in seconds for region to repaint.
52
+ repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end).
53
+ audio_cover_strength: Strength of reference audio/codes influence (range 0.01.0). set smaller (0.2) for style transfer tasks.
54
+ instruction: Optional task instruction prompt. If empty, auto-generated by system.
55
 
56
+ # 5Hz Language Model Parameters for CoT reasoning
57
+ thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes.
58
+ lm_temperature: Sampling temperature for the LLM (0.02.0). Higher = more creative/varied results.
59
+ lm_cfg_scale: Classifier-free guidance scale for the LLM.
60
+ lm_top_k: LLM top-k sampling (0 = disabled).
61
+ lm_top_p: LLM top-p nucleus sampling (1.0 = disabled).
62
+ lm_negative_prompt: Negative prompt to use for LLM (for control).
63
+ use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning.
64
+ use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning.
65
+ use_cot_language: Whether to let LLM detect vocal language via CoT.
 
 
 
 
 
 
66
  """
67
+ # Required Inputs
68
+ task_type: str = "text2music"
69
+ instruction: str = "Fill the audio semantic mask based on the given conditions:"
70
+
71
+ # Audio Uploads
72
+ reference_audio: Optional[str] = None
73
+ src_audio: Optional[str] = None
74
+
75
+ # LM Codes Hints
76
+ audio_codes: str = ""
77
+
78
  # Text Inputs
79
  caption: str = ""
80
  lyrics: str = ""
81
+ instrumental: bool = False
82
+
83
+ # Metadata
 
 
84
  vocal_language: str = "unknown"
85
+ bpm: Optional[int] = None
86
+ keyscale: str = ""
87
+ timesignature: str = ""
88
+ duration: float = -1.0
89
+
90
+ # Advanced Settings
91
  inference_steps: int = 8
 
 
92
  seed: int = -1
93
+ guidance_scale: float = 7.0
 
 
94
  use_adg: bool = False
95
  cfg_interval_start: float = 0.0
96
  cfg_interval_end: float = 1.0
97
+
 
 
 
 
 
 
98
  repainting_start: float = 0.0
99
  repainting_end: float = -1
100
  audio_cover_strength: float = 1.0
101
+
 
102
  # 5Hz Language Model Parameters
103
+ thinking: bool = True
104
  lm_temperature: float = 0.85
105
  lm_cfg_scale: float = 2.0
106
  lm_top_k: int = 0
 
108
  lm_negative_prompt: str = "NO USER INPUT"
109
  use_cot_metas: bool = True
110
  use_cot_caption: bool = True
111
+ use_cot_lyrics: bool = False # TODO: not used yet
112
  use_cot_language: bool = True
113
+ use_constrained_decoding: bool = True
114
+
115
+ cot_bpm: Optional[int] = None
116
+ cot_keyscale: str = ""
117
+ cot_timesignature: str = ""
118
+ cot_duration: Optional[float] = None
119
+ cot_vocal_language: str = "unknown"
120
+ cot_caption: str = ""
121
+ cot_lyrics: str = ""
122
+
123
+ def to_dict(self) -> Dict[str, Any]:
124
+ """Convert config to dictionary for JSON serialization."""
125
+ return asdict(self)
126
+
127
+
128
+ @dataclass
129
+ class GenerationConfig:
130
+ """Configuration for music generation.
131
 
132
+ Attributes:
133
+ batch_size: Number of audio samples to generate
134
+ allow_lm_batch: Whether to allow batch processing in LM
135
+ use_random_seed: Whether to use random seed
136
+ seeds: Seed(s) for batch generation. Can be:
137
+ - None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
138
+ - List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
139
+ - int: Single seed value (will be converted to list and padded)
140
+ lm_batch_chunk_size: Batch chunk size for LM processing
141
+ constrained_decoding_debug: Whether to enable constrained decoding debug
142
+ audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
143
+ """
144
+ batch_size: int = 2
145
  allow_lm_batch: bool = False
146
+ use_random_seed: bool = True
147
+ seeds: Optional[List[int]] = None
148
+ lm_batch_chunk_size: int = 8
149
+ constrained_decoding_debug: bool = False
150
+ audio_format: str = "flac" # Default to FLAC for fast saving
151
+
152
+ def to_dict(self) -> Dict[str, Any]:
153
+ """Convert config to dictionary for JSON serialization."""
154
+ return asdict(self)
155
 
156
 
157
  @dataclass
 
160
 
161
  Attributes:
162
  # Audio Outputs
163
+ audios: List of audio dictionaries with paths, keys, params
 
 
 
 
 
164
  status_message: Status message from generation
165
+ extra_outputs: Extra outputs from generation
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  success: Whether generation completed successfully
167
  error: Error message if generation failed
168
  """
169
+
170
  # Audio Outputs
171
+ audios: List[Dict[str, Any]] = field(default_factory=list)
 
 
 
172
  # Generation Information
 
173
  status_message: str = ""
174
+ extra_outputs: Dict[str, Any] = field(default_factory=dict)
 
 
 
 
 
 
 
 
 
 
 
 
175
  # Success Status
176
  success: bool = True
177
  error: Optional[str] = None
178
+
179
  def to_dict(self) -> Dict[str, Any]:
180
  """Convert result to dictionary for JSON serialization."""
181
  return asdict(self)
182
 
183
 
184
+ def _update_metadata_from_lm(
185
+ metadata: Dict[str, Any],
186
+ bpm: Optional[int],
187
+ key_scale: str,
188
+ time_signature: str,
189
+ audio_duration: Optional[float],
190
+ vocal_language: str,
191
+ caption: str,
192
+ lyrics: str,
193
+ ) -> Tuple[Optional[int], str, str, Optional[float]]:
194
+ """Update metadata fields from LM output if not provided by user."""
195
+
196
+ if bpm is None and metadata.get('bpm'):
197
+ bpm_value = metadata.get('bpm')
198
+ if bpm_value not in ["N/A", ""]:
199
+ try:
200
+ bpm = int(bpm_value)
201
+ except (ValueError, TypeError):
202
+ pass
203
+
204
+ if not key_scale and metadata.get('keyscale'):
205
+ key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
206
+ if key_scale_value != "N/A":
207
+ key_scale = key_scale_value
208
+
209
+ if not time_signature and metadata.get('timesignature'):
210
+ time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
211
+ if time_signature_value != "N/A":
212
+ time_signature = time_signature_value
213
+
214
+ if audio_duration is None or audio_duration <= 0:
215
+ audio_duration_value = metadata.get('duration', -1)
216
+ if audio_duration_value not in ["N/A", ""]:
217
+ try:
218
+ audio_duration = float(audio_duration_value)
219
+ except (ValueError, TypeError):
220
+ pass
221
+
222
+ if not vocal_language and metadata.get('vocal_language'):
223
+ vocal_language = metadata.get('vocal_language')
224
+ if not caption and metadata.get('caption'):
225
+ caption = metadata.get('caption')
226
+ if not lyrics and metadata.get('lyrics'):
227
+ lyrics = metadata.get('lyrics')
228
+ return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
229
+
230
+
231
  def generate_music(
232
  dit_handler,
233
  llm_handler,
234
+ params: GenerationParams,
235
  config: GenerationConfig,
236
+ save_dir: Optional[str] = None,
237
+ progress=None,
238
  ) -> GenerationResult:
239
  """Generate music using ACE-Step model with optional LM reasoning.
240
 
 
 
 
 
241
  Args:
242
  dit_handler: Initialized DiT model handler (AceStepHandler instance)
243
  llm_handler: Initialized LLM handler (LLMHandler instance)
244
+ params: Generation parameters (GenerationParams instance)
245
  config: Generation configuration (GenerationConfig instance)
246
 
247
  Returns:
248
+ GenerationResult with generated audio files and metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  """
 
250
  try:
251
  # Phase 1: LM-based metadata and code generation (if enabled)
252
+ audio_code_string_to_use = params.audio_codes
253
  lm_generated_metadata = None
 
254
  lm_generated_audio_codes_list = []
255
+ lm_total_time_costs = {
256
+ "phase1_time": 0.0,
257
+ "phase2_time": 0.0,
258
+ "total_time": 0.0,
259
+ }
260
+
261
  # Extract mutable copies of metadata (will be updated by LM if needed)
262
+ bpm = params.bpm
263
+ key_scale = params.keyscale
264
+ time_signature = params.timesignature
265
+ audio_duration = params.duration
266
+ dit_input_caption = params.caption
267
+ dit_input_vocal_language = params.vocal_language
268
+ dit_input_lyrics = params.lyrics
269
+ # Determine if we need to generate audio codes
270
+ # If user has provided audio_codes, we don't need to generate them
271
+ # Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
272
+ user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
273
+
274
+ # Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
275
+ # For now, we use "llm_dit" if batch mode or if user hasn't provided codes
276
+ # Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
277
+ # Note: This logic can be refined based on specific requirements
278
+ need_audio_codes = not user_provided_audio_codes
279
+
280
+ # Determine if we should use chunk-based LM generation (always use chunks for consistency)
281
+ # Determine actual batch size for chunk processing
282
+ actual_batch_size = config.batch_size if config.batch_size is not None else 1
283
+
284
+ # Prepare seeds for batch generation
285
+ # Use config.seed if provided, otherwise fallback to params.seed
286
+ # Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
287
+ seed_for_generation = ""
288
+ if config.seeds is not None and len(config.seeds) > 0:
289
+ if isinstance(config.seeds, list):
290
+ # Convert List[int] to comma-separated string
291
+ seed_for_generation = ",".join(str(s) for s in config.seeds)
292
+
293
+ # Use dit_handler.prepare_seeds to handle seed list generation and padding
294
+ # This will handle all the logic: padding with random seeds if needed, etc.
295
+ actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
296
+
297
  # LM-based Chain-of-Thought reasoning
298
+ use_lm = params.thinking and llm_handler.llm_initialized
299
+ lm_status = []
300
+ if use_lm:
301
+ # Convert sampling parameters - handle None values safely
302
+ top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
303
+ top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p
304
+
305
  # Build user_metadata from user-provided values
306
  user_metadata = {}
307
  if bpm is not None:
308
  try:
309
  bpm_value = float(bpm)
310
  if bpm_value > 0:
311
+ user_metadata['bpm'] = int(bpm_value)
312
  except (ValueError, TypeError):
313
  pass
314
+
315
  if key_scale and key_scale.strip():
316
  key_scale_clean = key_scale.strip()
317
  if key_scale_clean.lower() not in ["n/a", ""]:
318
  user_metadata['keyscale'] = key_scale_clean
319
+
320
  if time_signature and time_signature.strip():
321
  time_sig_clean = time_signature.strip()
322
  if time_sig_clean.lower() not in ["n/a", ""]:
323
  user_metadata['timesignature'] = time_sig_clean
324
+
325
  if audio_duration is not None:
326
  try:
327
  duration_value = float(audio_duration)
328
  if duration_value > 0:
329
+ user_metadata['duration'] = int(duration_value)
330
  except (ValueError, TypeError):
331
  pass
332
+
333
  user_metadata_to_pass = user_metadata if user_metadata else None
334
+
335
+ # Determine infer_type based on whether we need audio codes
336
+ # - "llm_dit": generates both metas and audio codes (two-phase internally)
337
+ # - "dit": generates only metas (single phase)
338
+ infer_type = "llm_dit" if need_audio_codes else "dit"
339
+
340
+ # Use chunk size from config, or default to batch_size if not set
341
+ max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
342
+ num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
343
+
344
+ all_metadata_list = []
345
+ all_audio_codes_list = []
346
+
347
+ for chunk_idx in range(num_chunks):
348
+ chunk_start = chunk_idx * max_inference_batch_size
349
+ chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
350
+ chunk_size = chunk_end - chunk_start
351
+ chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
352
+
353
+ logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
354
+ f"(size: {chunk_size}, seeds: {chunk_seeds})")
355
+
356
+ # Use the determined infer_type
357
+ # - "llm_dit" will internally run two phases (metas + codes)
358
+ # - "dit" will only run phase 1 (metas only)
359
+ result = llm_handler.generate_with_stop_condition(
360
+ caption=params.caption or "",
361
+ lyrics=params.lyrics or "",
362
+ infer_type=infer_type,
363
+ temperature=params.lm_temperature,
364
+ cfg_scale=params.lm_cfg_scale,
365
+ negative_prompt=params.lm_negative_prompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  top_k=top_k_value,
367
  top_p=top_p_value,
368
  user_metadata=user_metadata_to_pass,
369
+ use_cot_caption=params.use_cot_caption,
370
+ use_cot_language=params.use_cot_language,
371
+ use_cot_metas=params.use_cot_metas,
372
+ use_constrained_decoding=params.use_constrained_decoding,
373
  constrained_decoding_debug=config.constrained_decoding_debug,
374
+ batch_size=chunk_size,
375
+ seeds=chunk_seeds,
376
+ progress=progress,
377
  )
378
+
379
+ # Check if LM generation failed
380
+ if not result.get("success", False):
381
+ error_msg = result.get("error", "Unknown LM error")
382
+ lm_status.append(f"❌ LM Error: {error_msg}")
383
+ # Return early with error
384
+ return GenerationResult(
385
+ audios=[],
386
+ status_message=f"❌ LM generation failed: {error_msg}",
387
+ extra_outputs={},
388
+ success=False,
389
+ error=error_msg,
390
  )
391
+
392
+ # Extract metadata and audio_codes from result dict
393
+ if chunk_size > 1:
394
+ metadata_list = result.get("metadata", [])
395
+ audio_codes_list = result.get("audio_codes", [])
396
+ all_metadata_list.extend(metadata_list)
397
+ all_audio_codes_list.extend(audio_codes_list)
398
+ else:
399
+ metadata = result.get("metadata", {})
400
+ audio_codes = result.get("audio_codes", "")
401
+ all_metadata_list.append(metadata)
402
+ all_audio_codes_list.append(audio_codes)
403
+
404
+ # Collect time costs from LM extra_outputs
405
+ lm_extra = result.get("extra_outputs", {})
406
+ lm_chunk_time_costs = lm_extra.get("time_costs", {})
407
+ if lm_chunk_time_costs:
408
+ # Accumulate time costs from all chunks
409
+ for key in ["phase1_time", "phase2_time", "total_time"]:
410
+ if key in lm_chunk_time_costs:
411
+ lm_total_time_costs[key] += lm_chunk_time_costs[key]
412
+
413
+ time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()])
414
+ lm_status.append(f"✅ LM chunk {chunk_idx+1}: {time_str}")
415
+
416
+ lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
417
+ lm_generated_audio_codes_list = all_audio_codes_list
418
+
419
+ # Set audio_code_string_to_use based on infer_type
420
+ if infer_type == "llm_dit":
421
+ # If batch mode, use list; otherwise use single string
422
+ if actual_batch_size > 1:
423
+ audio_code_string_to_use = all_audio_codes_list
424
+ else:
425
+ audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else ""
426
+ else:
427
+ # For "dit" mode, keep user-provided codes or empty
428
+ audio_code_string_to_use = params.audio_codes
429
+
430
+ # Update metadata from LM if not provided by user
431
+ if lm_generated_metadata:
432
+ bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm(
433
+ metadata=lm_generated_metadata,
434
+ bpm=bpm,
435
+ key_scale=key_scale,
436
+ time_signature=time_signature,
437
+ audio_duration=audio_duration,
438
+ vocal_language=dit_input_vocal_language,
439
+ caption=dit_input_caption,
440
+ lyrics=dit_input_lyrics)
441
+ if not params.bpm:
442
+ params.cot_bpm = bpm
443
+ if not params.keyscale:
444
+ params.cot_keyscale = key_scale
445
+ if not params.timesignature:
446
+ params.cot_timesignature = time_signature
447
+ if not params.duration:
448
+ params.cot_duration = audio_duration
449
+ if not params.vocal_language:
450
+ params.cot_vocal_language = vocal_language
451
+ if not params.caption:
452
+ params.cot_caption = caption
453
+ if not params.lyrics:
454
+ params.cot_lyrics = lyrics
455
+
456
+ # set cot caption and language if needed
457
+ if params.use_cot_caption:
458
+ dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption)
459
+ if params.use_cot_language:
460
+ dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language)
461
+
462
  # Phase 2: DiT music generation
463
+ # Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
464
  result = dit_handler.generate_music(
465
+ captions=dit_input_caption,
466
+ lyrics=dit_input_lyrics,
467
  bpm=bpm,
468
  key_scale=key_scale,
469
  time_signature=time_signature,
470
+ vocal_language=dit_input_vocal_language,
471
+ inference_steps=params.inference_steps,
472
+ guidance_scale=params.guidance_scale,
473
  use_random_seed=config.use_random_seed,
474
+ seed=seed_for_generation, # Use config.seed (or params.seed fallback) instead of params.seed directly
475
+ reference_audio=params.reference_audio,
476
  audio_duration=audio_duration,
477
+ batch_size=config.batch_size if config.batch_size is not None else 1,
478
+ src_audio=params.src_audio,
479
  audio_code_string=audio_code_string_to_use,
480
+ repainting_start=params.repainting_start,
481
+ repainting_end=params.repainting_end,
482
+ instruction=params.instruction,
483
+ audio_cover_strength=params.audio_cover_strength,
484
+ task_type=params.task_type,
485
+ use_adg=params.use_adg,
486
+ cfg_interval_start=params.cfg_interval_start,
487
+ cfg_interval_end=params.cfg_interval_end,
488
+ progress=progress,
 
489
  )
490
+
491
+ # Check if generation failed
492
+ if not result.get("success", False):
493
+ return GenerationResult(
494
+ audios=[],
495
+ status_message=result.get("status_message", ""),
496
+ extra_outputs={},
497
+ success=False,
498
+ error=result.get("error"),
499
+ )
500
+
501
+ # Extract results from dit_handler.generate_music dict
502
+ dit_audios = result.get("audios", [])
503
+ status_message = result.get("status_message", "")
504
+ dit_extra_outputs = result.get("extra_outputs", {})
505
+
506
+ # Use the seed list already prepared above (from config.seed or params.seed fallback)
507
+ # actual_seed_list was computed earlier using dit_handler.prepare_seeds
508
+ seed_list = actual_seed_list
509
+
510
+ # Get base params dictionary
511
+ base_params_dict = params.to_dict()
512
+
513
+ # Save audio files using AudioSaver (format from config)
514
+ audio_format = config.audio_format if config.audio_format else "flac"
515
+ audio_saver = AudioSaver(default_format=audio_format)
516
+
517
+ # Use handler's temp_dir for saving files
518
+ if save_dir is not None:
519
+ os.makedirs(save_dir, exist_ok=True)
520
+
521
+ # Build audios list for GenerationResult with params and save files
522
+ # Audio saving and UUID generation handled here, outside of handler
523
+ audios = []
524
+ for idx, dit_audio in enumerate(dit_audios):
525
+ # Create a copy of params dict for this audio
526
+ audio_params = base_params_dict.copy()
527
+
528
+ # Update audio-specific values
529
+ audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
530
+
531
+ # Add audio codes if batch mode
532
+ if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
533
+ audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
534
+
535
+ # Get audio tensor and metadata
536
+ audio_tensor = dit_audio.get("tensor")
537
+ sample_rate = dit_audio.get("sample_rate", 48000)
538
+
539
+ # Generate UUID for this audio (moved from handler)
540
+ batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
541
+ audio_code_str = lm_generated_audio_codes_list[idx] if (
542
+ lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
543
+ if isinstance(audio_code_str, list):
544
+ audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
545
+
546
+ audio_key = generate_uuid_from_params(audio_params)
547
+
548
+ # Save audio file (handled outside handler)
549
+ audio_path = None
550
+ if audio_tensor is not None and save_dir is not None:
551
+ try:
552
+ audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
553
+ audio_path = audio_saver.save_audio(audio_tensor,
554
+ audio_file,
555
+ sample_rate=sample_rate,
556
+ format=audio_format,
557
+ channels_first=True)
558
+ except Exception as e:
559
+ logger.error(f"[generate_music] Failed to save audio file: {e}")
560
+ audio_path = "" # Fallback to empty path
561
+
562
+ audio_dict = {
563
+ "path": audio_path or "", # File path (saved here, not in handler)
564
+ "tensor": audio_tensor, # Audio tensor [channels, samples], CPU, float32
565
+ "key": audio_key,
566
+ "sample_rate": sample_rate,
567
+ "params": audio_params,
568
+ }
569
+
570
+ audios.append(audio_dict)
571
+
572
+ # Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
573
+ extra_outputs = dit_extra_outputs.copy()
574
+ extra_outputs["lm_metadata"] = lm_generated_metadata
575
+
576
+ # Merge time_costs from both LM and DiT into a unified dictionary
577
+ unified_time_costs = {}
578
+
579
+ # Add LM time costs (if LM was used)
580
+ if use_lm and lm_total_time_costs:
581
+ for key, value in lm_total_time_costs.items():
582
+ unified_time_costs[f"lm_{key}"] = value
583
+
584
+ # Add DiT time costs (if available)
585
+ dit_time_costs = dit_extra_outputs.get("time_costs", {})
586
+ if dit_time_costs:
587
+ for key, value in dit_time_costs.items():
588
+ unified_time_costs[f"dit_{key}"] = value
589
+
590
+ # Calculate total pipeline time
591
+ if unified_time_costs:
592
+ lm_total = unified_time_costs.get("lm_total_time", 0.0)
593
+ dit_total = unified_time_costs.get("dit_total_time_cost", 0.0)
594
+ unified_time_costs["pipeline_total_time"] = lm_total + dit_total
595
+
596
+ # Update extra_outputs with unified time_costs
597
+ extra_outputs["time_costs"] = unified_time_costs
598
+
599
+ if lm_status:
600
+ status_message = "\n".join(lm_status) + "\n" + status_message
601
+ else:
602
+ status_message = status_message
603
+ # Create and return GenerationResult
604
  return GenerationResult(
605
+ audios=audios,
 
 
 
606
  status_message=status_message,
607
+ extra_outputs=extra_outputs,
 
 
 
 
 
 
 
608
  success=True,
609
  error=None,
610
  )
611
+
612
  except Exception as e:
613
  logger.exception("Music generation failed")
614
  return GenerationResult(
615
+ audios=[],
616
+ status_message=f"Error: {str(e)}",
617
+ extra_outputs={},
618
  success=False,
619
  error=str(e),
 
 
620
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acestep/llm_inference.py CHANGED
@@ -5,7 +5,8 @@ Handles all LM-related operations including initialization and generation
5
  import os
6
  import traceback
7
  import time
8
- from typing import Optional, Dict, Any, Tuple, List
 
9
  from contextlib import contextmanager
10
 
11
  import yaml
@@ -85,6 +86,189 @@ class LLMHandler:
85
  except Exception as e:
86
  return 0.9, False
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def initialize(
89
  self,
90
  checkpoint_dir: str,
@@ -126,6 +310,7 @@ class LLMHandler:
126
 
127
  logger.info("loading 5Hz LM tokenizer...")
128
  start_time = time.time()
 
129
  llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True)
130
  logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
131
  self.llm_tokenizer = llm_tokenizer
@@ -150,41 +335,21 @@ class LLMHandler:
150
  # vllm initialization failed, fallback to PyTorch
151
  if not self.llm_initialized:
152
  logger.warning("vllm initialization failed, falling back to PyTorch backend")
153
- try:
154
- self.llm = AutoModelForCausalLM.from_pretrained(full_lm_model_path, trust_remote_code=True)
155
- if not self.offload_to_cpu:
156
- self.llm = self.llm.to(device).to(self.dtype)
157
- else:
158
- self.llm = self.llm.to("cpu").to(self.dtype)
159
- self.llm.eval()
160
- self.llm_backend = "pt"
161
- self.llm_initialized = True
162
- logger.info("5Hz LM initialized successfully using PyTorch backend (fallback)")
163
- status_msg = f"✅ 5Hz LM initialized successfully (PyTorch fallback)\nModel: {full_lm_model_path}\nBackend: PyTorch"
164
- except Exception as e:
165
- return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
166
  # If vllm initialization succeeded, self.llm_initialized should already be True
167
  else:
168
  # Use PyTorch backend (pt)
169
- try:
170
- self.llm = AutoModelForCausalLM.from_pretrained(full_lm_model_path, trust_remote_code=True)
171
- if not self.offload_to_cpu:
172
- self.llm = self.llm.to(device).to(self.dtype)
173
- else:
174
- self.llm = self.llm.to("cpu").to(self.dtype)
175
- self.llm.eval()
176
- self.llm_backend = "pt"
177
- self.llm_initialized = True
178
- logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}")
179
- status_msg = f"✅ 5Hz LM initialized successfully\nModel: {full_lm_model_path}\nBackend: PyTorch\nDevice: {device}"
180
- except Exception as e:
181
- return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
182
 
183
  return status_msg, True
184
 
185
  except Exception as e:
186
- error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
187
- return error_msg, False
188
 
189
  def _initialize_5hz_lm_vllm(self, model_path: str) -> str:
190
  """Initialize 5Hz LM model using vllm backend"""
@@ -230,12 +395,11 @@ class LLMHandler:
230
  return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
231
  except Exception as e:
232
  self.llm_initialized = False
233
- error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
234
- return error_msg
235
 
236
- def _run_vllm_from_formatted(
237
  self,
238
- formatted_prompt: str,
239
  temperature: float,
240
  cfg_scale: float,
241
  negative_prompt: str,
@@ -244,7 +408,7 @@ class LLMHandler:
244
  repetition_penalty: float,
245
  use_constrained_decoding: bool = True,
246
  constrained_decoding_debug: bool = False,
247
- metadata_temperature: Optional[float] = 0.85,
248
  codes_temperature: Optional[float] = None,
249
  target_duration: Optional[float] = None,
250
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
@@ -256,37 +420,40 @@ class LLMHandler:
256
  caption: str = "",
257
  lyrics: str = "",
258
  cot_text: str = "",
259
- ) -> str:
260
- """Shared vllm path: accept prebuilt formatted prompt and return text."""
 
 
 
 
 
261
  from nanovllm import SamplingParams
262
 
 
 
 
 
263
  # Determine effective temperature for sampler
264
- use_phase_temperatures = metadata_temperature is not None or codes_temperature is not None
 
 
265
  effective_sampler_temp = 1.0 if use_phase_temperatures else temperature
266
 
267
- # Use shared constrained processor if enabled
268
- constrained_processor = None
269
- if use_constrained_decoding or use_phase_temperatures:
270
- # Reset processor state for new generation
271
- self.constrained_processor.reset()
272
-
273
- # Use shared processor, just update caption and settings
274
- self.constrained_processor.enabled = use_constrained_decoding
275
- self.constrained_processor.debug = constrained_decoding_debug
276
- self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
277
- self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
278
- self.constrained_processor.set_target_duration(target_duration)
279
- # Always call set_user_metadata to ensure previous settings are cleared if None
280
- self.constrained_processor.set_user_metadata(user_metadata)
281
- self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
282
- # Set skip_caption and skip_language based on flags
283
- self.constrained_processor.set_skip_genres(skip_genres)
284
- self.constrained_processor.set_skip_caption(skip_caption)
285
- self.constrained_processor.set_skip_language(skip_language)
286
- # Set generation phase for phase-aware processing
287
- self.constrained_processor.set_generation_phase(generation_phase)
288
-
289
- constrained_processor = self.constrained_processor
290
 
291
  sampling_params = SamplingParams(
292
  max_tokens=self.max_model_len - 64,
@@ -301,119 +468,25 @@ class LLMHandler:
301
 
302
  if cfg_scale > 1.0:
303
  # Build unconditional prompt based on generation phase
304
- if generation_phase == "codes":
305
- # Codes phase: use empty CoT in unconditional prompt
306
- # formatted_prompt was built with build_formatted_prompt_with_cot(caption, lyrics, cot_text)
307
- # For unconditional, we use empty CoT: build_formatted_prompt_with_cot(caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=...)
308
- formatted_unconditional_prompt = self.build_formatted_prompt_with_cot(
309
- caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=negative_prompt
310
- )
311
- else:
312
- # CoT phase: unconditional prompt
313
- # If negative_prompt is provided, use it as caption; otherwise remove caption and keep only lyrics
314
- formatted_unconditional_prompt = self.build_formatted_prompt(
315
- caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
316
- )
317
-
318
- outputs = self.llm.generate(
319
- [formatted_prompt],
320
- sampling_params,
321
- unconditional_prompts=[formatted_unconditional_prompt],
322
- )
323
- else:
324
- outputs = self.llm.generate([formatted_prompt], sampling_params)
325
-
326
- # Extract text (retain original selection order/logic)
327
- if isinstance(outputs, list) and len(outputs) > 0:
328
- if hasattr(outputs[0], "outputs") and len(outputs[0].outputs) > 0:
329
- output_text = outputs[0].outputs[0].text
330
- elif hasattr(outputs[0], "text"):
331
- output_text = outputs[0].text
332
- elif isinstance(outputs[0], dict) and "text" in outputs[0]:
333
- output_text = outputs[0]["text"]
334
- else:
335
- output_text = str(outputs[0])
336
- else:
337
- output_text = str(outputs)
338
-
339
- return output_text
340
-
341
- def _run_vllm_batch(
342
- self,
343
- formatted_prompts: List[str],
344
- temperature: float,
345
- cfg_scale: float,
346
- negative_prompt: str,
347
- top_k: Optional[int],
348
- top_p: Optional[float],
349
- repetition_penalty: float,
350
- use_constrained_decoding: bool = True,
351
- constrained_decoding_debug: bool = False,
352
- target_duration: Optional[float] = None,
353
- generation_phase: str = "codes",
354
- caption: str = "",
355
- lyrics: str = "",
356
- cot_text: str = "",
357
- seeds: Optional[List[int]] = None,
358
- ) -> List[str]:
359
- """Batch generation using vllm backend"""
360
- from nanovllm import SamplingParams
361
-
362
- batch_size = len(formatted_prompts)
363
-
364
- # Determine effective temperature for sampler
365
- effective_sampler_temp = temperature
366
-
367
- # Use shared constrained processor if enabled
368
- # Note: vllm batch mode uses same processor for all items
369
- constrained_processor = None
370
- if use_constrained_decoding:
371
- # Reset processor state for new generation
372
- self.constrained_processor.reset()
373
-
374
- self.constrained_processor.enabled = use_constrained_decoding
375
- self.constrained_processor.debug = constrained_decoding_debug
376
- self.constrained_processor.metadata_temperature = None
377
- self.constrained_processor.codes_temperature = None
378
- self.constrained_processor.set_target_duration(target_duration)
379
- self.constrained_processor.set_user_metadata(None)
380
- self.constrained_processor.set_stop_at_reasoning(False)
381
- self.constrained_processor.set_skip_genres(True)
382
- self.constrained_processor.set_skip_caption(True)
383
- self.constrained_processor.set_skip_language(True)
384
- self.constrained_processor.set_generation_phase(generation_phase)
385
-
386
- constrained_processor = self.constrained_processor
387
-
388
- # Build sampling params
389
- sampling_params = SamplingParams(
390
- max_tokens=self.max_model_len - 64,
391
- temperature=effective_sampler_temp,
392
- cfg_scale=cfg_scale,
393
- top_k=top_k,
394
- top_p=top_p,
395
- repetition_penalty=repetition_penalty,
396
- logits_processor=constrained_processor,
397
- logits_processor_update_state=constrained_processor.update_state if constrained_processor else None,
398
- )
399
-
400
- # Generate with or without CFG
401
- if cfg_scale > 1.0:
402
- # Build unconditional prompts
403
- formatted_unconditional_prompt = self.build_formatted_prompt_with_cot(
404
- caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=negative_prompt
405
  )
406
  unconditional_prompts = [formatted_unconditional_prompt] * batch_size
407
 
408
  outputs = self.llm.generate(
409
- formatted_prompts,
410
  sampling_params,
411
  unconditional_prompts=unconditional_prompts,
412
  )
413
  else:
414
- outputs = self.llm.generate(formatted_prompts, sampling_params)
415
-
416
- # Extract text from each output
417
  output_texts = []
418
  for output in outputs:
419
  if hasattr(output, "outputs") and len(output.outputs) > 0:
@@ -424,70 +497,11 @@ class LLMHandler:
424
  output_texts.append(output["text"])
425
  else:
426
  output_texts.append(str(output))
427
-
428
- return output_texts
429
 
430
- def _run_pt_batch(
431
- self,
432
- formatted_prompts: List[str],
433
- temperature: float,
434
- cfg_scale: float,
435
- negative_prompt: str,
436
- top_k: Optional[int],
437
- top_p: Optional[float],
438
- repetition_penalty: float,
439
- use_constrained_decoding: bool = True,
440
- constrained_decoding_debug: bool = False,
441
- target_duration: Optional[float] = None,
442
- generation_phase: str = "codes",
443
- caption: str = "",
444
- lyrics: str = "",
445
- cot_text: str = "",
446
- seeds: Optional[List[int]] = None,
447
- ) -> List[str]:
448
- """Batch generation using PyTorch backend"""
449
- import random
450
-
451
- batch_size = len(formatted_prompts)
452
- output_texts = []
453
-
454
- # Generate each item sequentially with different seeds
455
- # (PyTorch backend doesn't support true batching efficiently)
456
- for i, formatted_prompt in enumerate(formatted_prompts):
457
- # Set seed for this item if provided
458
- if seeds and i < len(seeds):
459
- torch.manual_seed(seeds[i])
460
- if torch.cuda.is_available():
461
- torch.cuda.manual_seed_all(seeds[i])
462
-
463
- # Generate using single-item method
464
- output_text = self._run_pt_from_formatted(
465
- formatted_prompt=formatted_prompt,
466
- temperature=temperature,
467
- cfg_scale=cfg_scale,
468
- negative_prompt=negative_prompt,
469
- top_k=top_k,
470
- top_p=top_p,
471
- repetition_penalty=repetition_penalty,
472
- use_constrained_decoding=use_constrained_decoding,
473
- constrained_decoding_debug=constrained_decoding_debug,
474
- target_duration=target_duration,
475
- user_metadata=None,
476
- stop_at_reasoning=False,
477
- skip_genres=True,
478
- skip_caption=True,
479
- skip_language=True,
480
- generation_phase=generation_phase,
481
- caption=caption,
482
- lyrics=lyrics,
483
- cot_text=cot_text,
484
- )
485
-
486
- output_texts.append(output_text)
487
-
488
- return output_texts
489
 
490
- def _run_pt_from_formatted(
491
  self,
492
  formatted_prompt: str,
493
  temperature: float,
@@ -496,20 +510,20 @@ class LLMHandler:
496
  top_k: Optional[int],
497
  top_p: Optional[float],
498
  repetition_penalty: float,
499
- use_constrained_decoding: bool = True,
500
- constrained_decoding_debug: bool = False,
501
- target_duration: Optional[float] = None,
502
- user_metadata: Optional[Dict[str, Optional[str]]] = None,
503
- stop_at_reasoning: bool = False,
504
- skip_genres: bool = True,
505
- skip_caption: bool = False,
506
- skip_language: bool = False,
507
- generation_phase: str = "cot",
508
- caption: str = "",
509
- lyrics: str = "",
510
- cot_text: str = "",
511
  ) -> str:
512
- """Shared PyTorch path: accept prebuilt formatted prompt and return text."""
513
  inputs = self.llm_tokenizer(
514
  formatted_prompt,
515
  return_tensors="pt",
@@ -517,27 +531,19 @@ class LLMHandler:
517
  truncation=True,
518
  )
519
 
520
- # Use shared constrained processor if enabled
521
- constrained_processor = None
522
- if use_constrained_decoding:
523
- # Reset processor state for new generation
524
- self.constrained_processor.reset()
525
-
526
- # Use shared processor, just update caption and settings
527
- self.constrained_processor.enabled = use_constrained_decoding
528
- self.constrained_processor.debug = constrained_decoding_debug
529
- self.constrained_processor.set_target_duration(target_duration)
530
- # Always call set_user_metadata to ensure previous settings are cleared if None
531
- self.constrained_processor.set_user_metadata(user_metadata)
532
- self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
533
- # Set skip_caption and skip_language based on flags
534
- self.constrained_processor.set_skip_genres(skip_genres)
535
- self.constrained_processor.set_skip_caption(skip_caption)
536
- self.constrained_processor.set_skip_language(skip_language)
537
- # Set generation phase for phase-aware processing
538
- self.constrained_processor.set_generation_phase(generation_phase)
539
-
540
- constrained_processor = self.constrained_processor
541
 
542
  with self._load_model_context():
543
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
@@ -546,25 +552,18 @@ class LLMHandler:
546
  max_new_tokens = min(max_new_tokens, self.max_model_len - 64)
547
 
548
  # Build logits processor list (only for CFG and repetition penalty)
549
- logits_processor = LogitsProcessorList()
550
-
551
- # Add repetition penalty if needed
552
- if repetition_penalty != 1.0:
553
- logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
554
 
555
  if cfg_scale > 1.0:
556
  # Build unconditional prompt based on generation phase
557
- if generation_phase == "codes":
558
- # Codes phase: use empty CoT in unconditional prompt
559
- formatted_unconditional_prompt = self.build_formatted_prompt_with_cot(
560
- caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=negative_prompt
561
- )
562
- else:
563
- # CoT phase: unconditional prompt
564
- # If negative_prompt is provided, use it as caption; otherwise remove caption and keep only lyrics
565
- formatted_unconditional_prompt = self.build_formatted_prompt(
566
- caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
567
- )
568
 
569
  # Tokenize both prompts together to ensure same length (with left padding)
570
  # Left padding is important for generation tasks
@@ -657,7 +656,101 @@ class LLMHandler:
657
 
658
  output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
659
  return output_text
660
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
661
  def has_all_metas(self, user_metadata: Optional[Dict[str, Optional[str]]]) -> bool:
662
  """Check if all required metadata are present."""
663
  if user_metadata is None:
@@ -705,10 +798,13 @@ class LLMHandler:
705
  constrained_decoding_debug: bool = False,
706
  target_duration: Optional[float] = None,
707
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
 
708
  use_cot_caption: bool = True,
709
  use_cot_language: bool = True,
710
- is_format_caption: bool = False,
711
- ) -> Tuple[Dict[str, Any], str, str]:
 
 
712
  """Two-phase LM generation: CoT generation followed by audio codes generation.
713
 
714
  - infer_type='dit': Phase 1 only - generate CoT and return metas (no audio codes)
@@ -721,30 +817,67 @@ class LLMHandler:
721
  If specified, constrained decoding will inject these values directly.
722
  use_cot_caption: Whether to generate caption in CoT (default True).
723
  use_cot_language: Whether to generate language in CoT (default True).
724
- """
725
- import time
 
 
726
 
 
 
 
 
 
 
 
 
 
 
 
 
727
  infer_type = (infer_type or "").strip().lower()
728
  if infer_type not in {"dit", "llm_dit"}:
729
- return {}, "", f"invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
730
-
 
 
 
 
 
 
 
 
 
 
 
 
731
  metadata = {}
732
  audio_codes = ""
733
  has_all_metas = self.has_all_metas(user_metadata)
734
-
735
- # Timing variables
736
  phase1_time = 0.0
737
  phase2_time = 0.0
738
 
 
 
 
 
 
 
 
 
 
739
  # ========== PHASE 1: CoT Generation ==========
740
- # Always generate CoT unless all metadata are user-provided
741
- if not has_all_metas or not is_format_caption:
742
- logger.info("Phase 1: Generating CoT metadata...")
 
 
 
 
743
  phase1_start = time.time()
744
 
745
  # Build formatted prompt for CoT phase
746
  formatted_prompt = self.build_formatted_prompt(caption, lyrics, generation_phase="cot")
747
-
748
  logger.info(f"generate_with_stop_condition: formatted_prompt={formatted_prompt}")
749
  # Generate CoT (stop at </think>)
750
  cot_output_text, status = self.generate_from_formatted_prompt(
@@ -774,23 +907,63 @@ class LLMHandler:
774
  phase1_time = time.time() - phase1_start
775
 
776
  if not cot_output_text:
777
- return {}, "", status
 
 
 
 
 
 
778
 
779
  # Parse metadata from CoT output
780
  metadata, _ = self.parse_lm_output(cot_output_text)
781
- logger.info(f"Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
 
 
 
782
  else:
783
  # Use user-provided metadata
784
- logger.info("Phase 1: Using user-provided metadata (skipping generation)")
 
 
 
785
  metadata = {k: v for k, v in user_metadata.items() if v is not None}
786
 
787
  # If infer_type is 'dit', stop here and return only metadata
788
  if infer_type == "dit":
789
- status_msg = f"✅ Generated CoT metadata successfully\nFields: {', '.join(metadata.keys())}\nPhase1: {phase1_time:.2f}s"
790
- return metadata, "", status_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
791
 
792
  # ========== PHASE 2: Audio Codes Generation ==========
793
- logger.info("Phase 2: Generating audio codes...")
 
 
 
794
  phase2_start = time.time()
795
 
796
  # Format metadata as CoT using YAML (matching training format)
@@ -799,221 +972,163 @@ class LLMHandler:
799
  # Build formatted prompt with CoT for codes generation phase
800
  formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
801
  logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}")
802
- # Generate audio codes
803
- codes_output_text, status = self.generate_from_formatted_prompt(
804
- formatted_prompt=formatted_prompt_with_cot,
805
- cfg={
806
- "temperature": temperature,
807
- "cfg_scale": cfg_scale,
808
- "negative_prompt": negative_prompt,
809
- "top_k": top_k,
810
- "top_p": top_p,
811
- "repetition_penalty": repetition_penalty,
812
- "target_duration": target_duration,
813
- "user_metadata": None, # No user metadata injection in Phase 2
814
- "skip_caption": True, # Skip caption since CoT is already included
815
- "skip_language": True, # Skip language since CoT is already included
816
- "generation_phase": "codes",
817
- # Pass context for building unconditional prompt in codes phase
818
- "caption": caption,
819
- "lyrics": lyrics,
820
- "cot_text": cot_text,
821
- },
822
- use_constrained_decoding=use_constrained_decoding,
823
- constrained_decoding_debug=constrained_decoding_debug,
824
- stop_at_reasoning=False, # Generate codes until EOS
825
- )
826
-
827
- if not codes_output_text:
828
- return metadata, "", status
829
-
830
- phase2_time = time.time() - phase2_start
831
-
832
- # Parse audio codes from output (metadata should be same as Phase 1)
833
- _, audio_codes = self.parse_lm_output(codes_output_text)
834
-
835
- codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
836
- logger.info(f"Phase 2 completed in {phase2_time:.2f}s. Generated {codes_count} audio codes")
837
-
838
- status_msg = f"✅ Generated successfully (2-phase)\nPhase 1: CoT metadata\nPhase 2: {codes_count} audio codes\nPhase1: {phase1_time:.2f}s, Phase2: {phase2_time:.2f}s"
839
- return metadata, audio_codes, status_msg
840
-
841
- def generate_with_stop_condition_batch(
842
- self,
843
- caption: str,
844
- lyrics: str,
845
- batch_size: int,
846
- infer_type: str = "llm_dit",
847
- temperature: float = 0.85,
848
- cfg_scale: float = 1.0,
849
- negative_prompt: str = "NO USER INPUT",
850
- top_k: Optional[int] = None,
851
- top_p: Optional[float] = None,
852
- repetition_penalty: float = 1.0,
853
- use_constrained_decoding: bool = True,
854
- constrained_decoding_debug: bool = False,
855
- target_duration: Optional[float] = None,
856
- user_metadata: Optional[Dict[str, Optional[str]]] = None,
857
- use_cot_caption: bool = True,
858
- use_cot_language: bool = True,
859
- is_format_caption: bool = False,
860
- seeds: Optional[List[int]] = None,
861
- ) -> Tuple[List[Dict[str, Any]], List[str], str]:
862
- """
863
- Batch version of generate_with_stop_condition.
864
-
865
- Generates multiple audio codes with same conditions but different seeds (for diversity).
866
-
867
- Args:
868
- caption: Same caption for all items
869
- lyrics: Same lyrics for all items
870
- batch_size: Number of items to generate
871
- seeds: Optional list of seeds for each batch item (for reproducibility)
872
- ... (other args same as generate_with_stop_condition)
873
-
874
- Returns:
875
- Tuple of (metadata_list, audio_codes_list, status_message)
876
- - metadata_list: List of metadata dicts (same metadata for all items)
877
- - audio_codes_list: List of audio code strings (one per item, different due to sampling)
878
- - status_message: Generation status
879
- """
880
- import random
881
- import time
882
 
883
- infer_type = (infer_type or "").strip().lower()
884
- if infer_type not in {"dit", "llm_dit"}:
885
- return [], [], f"❌ invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
886
-
887
- # Generate seeds if not provided
888
- if seeds is None:
889
- seeds = [random.randint(0, 2**32 - 1) for _ in range(batch_size)]
890
- elif len(seeds) < batch_size:
891
- # Pad with random seeds if not enough provided
892
- seeds = list(seeds) + [random.randint(0, 2**32 - 1) for _ in range(batch_size - len(seeds))]
893
- else:
894
- seeds = seeds[:batch_size] # Truncate if too many
895
-
896
- # Timing variables
897
- phase1_time = 0.0
898
- phase2_time = 0.0
899
-
900
- # ========== PHASE 1: CoT Generation (ONCE for all items) ==========
901
- has_all_metas = self.has_all_metas(user_metadata)
902
-
903
- if not has_all_metas or not is_format_caption:
904
- logger.info("Batch Phase 1: Generating CoT metadata (once for all items)...")
905
- phase1_start = time.time()
906
 
907
- # Generate CoT metadata once (same for all batch items)
908
- metadata, _, status = self.generate_with_stop_condition(
909
- caption=caption,
910
- lyrics=lyrics,
911
- infer_type="dit", # Only generate metadata
912
- temperature=temperature,
913
- cfg_scale=cfg_scale,
914
- negative_prompt=negative_prompt,
915
- top_k=top_k,
916
- top_p=top_p,
917
- repetition_penalty=repetition_penalty,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
918
  use_constrained_decoding=use_constrained_decoding,
919
  constrained_decoding_debug=constrained_decoding_debug,
920
- target_duration=target_duration,
921
- user_metadata=user_metadata,
922
- use_cot_caption=use_cot_caption,
923
- use_cot_language=use_cot_language,
924
- is_format_caption=is_format_caption,
925
  )
926
 
927
- phase1_time = time.time() - phase1_start
 
 
 
 
 
 
 
 
 
 
 
 
 
 
928
 
929
- if not metadata:
930
- return [], [], status
931
 
932
- logger.info(f"Batch Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
933
- else:
934
- # Use user-provided metadata
935
- logger.info("Batch Phase 1: Using user-provided metadata (skipping generation)")
936
- metadata = {k: v for k, v in user_metadata.items() if v is not None}
937
-
938
- # If infer_type is 'dit', stop here and return only metadata
939
- if infer_type == "dit":
940
- metadata_list = [metadata.copy() for _ in range(batch_size)]
941
- status_msg = f"✅ Generated CoT metadata successfully (batch mode)\nFields: {', '.join(metadata.keys())}\nPhase1: {phase1_time:.2f}s"
942
- return metadata_list, [""] * batch_size, status_msg
943
-
944
- # ========== PHASE 2: Audio Codes Generation (BATCH) ==========
945
- logger.info(f"Batch Phase 2: Generating audio codes for {batch_size} items...")
946
- phase2_start = time.time()
947
-
948
- # Format metadata as CoT
949
- cot_text = self._format_metadata_as_cot(metadata)
950
-
951
- # Build formatted prompt with CoT
952
- formatted_prompt = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
953
-
954
- # Replicate prompt for batch (all items have same prompt, differ by seeds)
955
- formatted_prompts = [formatted_prompt] * batch_size
956
-
957
- # Call backend-specific batch generation
958
- try:
959
- if self.llm_backend == "vllm":
960
- codes_outputs = self._run_vllm_batch(
961
- formatted_prompts=formatted_prompts,
962
- temperature=temperature,
963
- cfg_scale=cfg_scale,
964
- negative_prompt=negative_prompt,
965
- top_k=top_k,
966
- top_p=top_p,
967
- repetition_penalty=repetition_penalty,
968
- use_constrained_decoding=use_constrained_decoding,
969
- constrained_decoding_debug=constrained_decoding_debug,
970
- target_duration=target_duration,
971
- generation_phase="codes",
972
- caption=caption,
973
- lyrics=lyrics,
974
- cot_text=cot_text,
975
- seeds=seeds,
976
- )
977
- else: # pt backend
978
- codes_outputs = self._run_pt_batch(
979
- formatted_prompts=formatted_prompts,
980
- temperature=temperature,
981
- cfg_scale=cfg_scale,
982
- negative_prompt=negative_prompt,
983
- top_k=top_k,
984
- top_p=top_p,
985
- repetition_penalty=repetition_penalty,
986
- use_constrained_decoding=use_constrained_decoding,
987
- constrained_decoding_debug=constrained_decoding_debug,
988
- target_duration=target_duration,
989
- generation_phase="codes",
990
- caption=caption,
991
- lyrics=lyrics,
992
- cot_text=cot_text,
993
- seeds=seeds,
994
- )
995
- except Exception as e:
996
- error_msg = f"❌ Error in batch codes generation: {str(e)}"
997
- logger.error(error_msg)
998
- return [], [], error_msg
999
-
1000
- # Parse audio codes from each output
1001
- audio_codes_list = []
1002
- metadata_list = []
1003
- for output_text in codes_outputs:
1004
- _, audio_codes = self.parse_lm_output(output_text)
1005
- audio_codes_list.append(audio_codes)
1006
- metadata_list.append(metadata.copy()) # Same metadata for all
1007
-
1008
- phase2_time = time.time() - phase2_start
1009
-
1010
- # Log results
1011
- codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list]
1012
- logger.info(f"Batch Phase 2 completed in {phase2_time:.2f}s. Generated codes: {codes_counts}")
1013
-
1014
- status_msg = f"✅ Batch generation completed ({batch_size} items)\nPhase 1: CoT metadata\nPhase 2: {sum(codes_counts)} total codes ({codes_counts})\nPhase1: {phase1_time:.2f}s, Phase2: {phase2_time:.2f}s"
1015
- return metadata_list, audio_codes_list, status_msg
1016
-
1017
  def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str:
1018
  """
1019
  Build the chat-formatted prompt for 5Hz LM from caption/lyrics.
@@ -1035,7 +1150,7 @@ class LLMHandler:
1035
  if is_negative_prompt:
1036
  # Unconditional prompt for CFG
1037
  # Check if user provided a meaningful negative prompt (not the default)
1038
- has_negative_prompt = negative_prompt and negative_prompt.strip() and negative_prompt.strip() != "NO USER INPUT"
1039
 
1040
  if generation_phase == "cot":
1041
  # CoT phase unconditional prompt
@@ -1086,7 +1201,7 @@ class LLMHandler:
1086
  if is_negative_prompt:
1087
  # Unconditional prompt for codes phase
1088
  # Check if user provided a meaningful negative prompt
1089
- has_negative_prompt = negative_prompt and negative_prompt.strip() and negative_prompt.strip() != "NO USER INPUT"
1090
 
1091
  # Use empty CoT for unconditional
1092
  cot_for_prompt = "<think>\n</think>"
@@ -1369,8 +1484,8 @@ class LLMHandler:
1369
 
1370
  try:
1371
  if self.llm_backend == "vllm":
1372
- output_text = self._run_vllm_from_formatted(
1373
- formatted_prompt=formatted_prompt,
1374
  temperature=temperature,
1375
  cfg_scale=cfg_scale,
1376
  negative_prompt=negative_prompt,
@@ -1393,8 +1508,8 @@ class LLMHandler:
1393
  return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
1394
 
1395
  # PyTorch backend
1396
- output_text = self._run_pt_from_formatted(
1397
- formatted_prompt=formatted_prompt,
1398
  temperature=temperature,
1399
  cfg_scale=cfg_scale,
1400
  negative_prompt=negative_prompt,
@@ -1459,26 +1574,12 @@ class LLMHandler:
1459
  eos_token_id = pad_token_id
1460
 
1461
  # Build logits processor for repetition penalty
1462
- logits_processor = LogitsProcessorList()
1463
- if repetition_penalty != 1.0:
1464
- logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
1465
 
1466
  with torch.no_grad():
1467
  for step in range(max_new_tokens):
1468
  # Forward pass
1469
- if past_key_values is None:
1470
- outputs = model(
1471
- input_ids=generated_ids,
1472
- **model_kwargs,
1473
- use_cache=use_cache,
1474
- )
1475
- else:
1476
- outputs = model(
1477
- input_ids=generated_ids[:, -1:],
1478
- past_key_values=past_key_values,
1479
- **model_kwargs,
1480
- use_cache=use_cache,
1481
- )
1482
 
1483
  # Get logits for the last position
1484
  next_token_logits = outputs.logits[:, -1, :] # [batch_size, vocab_size]
@@ -1491,41 +1592,18 @@ class LLMHandler:
1491
  for processor in logits_processor:
1492
  next_token_logits = processor(generated_ids, next_token_logits)
1493
 
1494
- # Apply top-k filtering
1495
- if top_k is not None and top_k > 0:
1496
- indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
1497
- next_token_logits[indices_to_remove] = float('-inf')
1498
-
1499
- # Apply top-p filtering
1500
- if top_p is not None and 0.0 < top_p < 1.0:
1501
- sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
1502
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
1503
- sorted_indices_to_remove = cumulative_probs > top_p
1504
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
1505
- sorted_indices_to_remove[..., 0] = 0
1506
- indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
1507
- next_token_logits[indices_to_remove] = float('-inf')
1508
 
1509
  # Apply temperature and sample
1510
- if temperature > 0:
1511
- next_token_logits = next_token_logits / temperature
1512
- probs = torch.softmax(next_token_logits, dim=-1)
1513
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1514
- else:
1515
- next_tokens = torch.argmax(next_token_logits, dim=-1)
1516
 
1517
  # Update constrained processor state
1518
- if constrained_processor is not None:
1519
- for b in range(next_tokens.shape[0]):
1520
- constrained_processor.update_state(next_tokens[b].item())
1521
 
1522
  # Check for EOS token
1523
- should_stop = False
1524
- if torch.any(next_tokens == eos_token_id):
1525
- should_stop = True
1526
- elif pad_token_id is not None and pad_token_id != eos_token_id:
1527
- if torch.any(next_tokens == pad_token_id):
1528
- should_stop = True
1529
 
1530
  # Append token to sequence
1531
  next_tokens_unsqueezed = next_tokens.unsqueeze(1)
@@ -1601,28 +1679,12 @@ class LLMHandler:
1601
  eos_token_id = pad_token_id
1602
 
1603
  # Build logits processor for non-CFG operations (repetition penalty, top_k, top_p)
1604
- logits_processor = LogitsProcessorList()
1605
- if repetition_penalty != 1.0:
1606
- logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
1607
 
1608
  with torch.no_grad():
1609
  for step in range(max_new_tokens):
1610
  # Forward pass for the entire batch (conditional + unconditional)
1611
- if past_key_values is None:
1612
- # First step: full forward pass
1613
- outputs = model(
1614
- input_ids=generated_ids,
1615
- **model_kwargs,
1616
- use_cache=use_cache,
1617
- )
1618
- else:
1619
- # Subsequent steps: only forward the last token (utilizing KV cache)
1620
- outputs = model(
1621
- input_ids=generated_ids[:, -1:],
1622
- past_key_values=past_key_values,
1623
- **model_kwargs,
1624
- use_cache=use_cache,
1625
- )
1626
 
1627
  # Get logits for the last position
1628
  next_token_logits = outputs.logits[:, -1, :] # [batch_size*2, vocab_size]
@@ -1645,45 +1707,20 @@ class LLMHandler:
1645
  for processor in logits_processor:
1646
  cfg_logits = processor(current_input_ids, cfg_logits)
1647
 
1648
- # Apply top-k filtering
1649
- if top_k is not None and top_k > 0:
1650
- indices_to_remove = cfg_logits < torch.topk(cfg_logits, top_k)[0][..., -1, None]
1651
- cfg_logits[indices_to_remove] = float('-inf')
1652
-
1653
- # Apply top-p (nucleus) filtering
1654
- if top_p is not None and 0.0 < top_p < 1.0:
1655
- sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
1656
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
1657
- # Remove tokens with cumulative probability above the threshold
1658
- sorted_indices_to_remove = cumulative_probs > top_p
1659
- # Shift the indices to the right to keep also the first token above the threshold
1660
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
1661
- sorted_indices_to_remove[..., 0] = 0
1662
- indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
1663
- cfg_logits[indices_to_remove] = float('-inf')
1664
 
1665
  # Apply temperature and sample
1666
- if temperature > 0:
1667
- cfg_logits = cfg_logits / temperature
1668
- probs = torch.softmax(cfg_logits, dim=-1)
1669
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1670
- else:
1671
- next_tokens = torch.argmax(cfg_logits, dim=-1)
1672
 
1673
  # Update constrained processor state AFTER sampling
1674
- if constrained_processor is not None:
1675
- for b in range(next_tokens.shape[0]):
1676
- constrained_processor.update_state(next_tokens[b].item())
1677
 
1678
  # Check for EOS token in conditional sequences BEFORE unsqueezing
1679
  # Stop if any conditional sequence generates EOS token
1680
  # next_tokens shape: [batch_size] (only conditional tokens)
1681
- should_stop = False
1682
- if torch.any(next_tokens == eos_token_id):
1683
- should_stop = True
1684
- elif pad_token_id is not None and pad_token_id != eos_token_id:
1685
- if torch.any(next_tokens == pad_token_id):
1686
- should_stop = True
1687
 
1688
  # Apply the same sampled tokens to both conditional and unconditional sequences
1689
  next_tokens_unsqueezed = next_tokens.unsqueeze(1)
 
5
  import os
6
  import traceback
7
  import time
8
+ import random
9
+ from typing import Optional, Dict, Any, Tuple, List, Union
10
  from contextlib import contextmanager
11
 
12
  import yaml
 
86
  except Exception as e:
87
  return 0.9, False
88
 
89
+ def _has_meaningful_negative_prompt(self, negative_prompt: str) -> bool:
90
+ """Check if negative prompt is meaningful (not default/empty)"""
91
+ return negative_prompt and negative_prompt.strip() and negative_prompt.strip() != "NO USER INPUT"
92
+
93
+ def _build_logits_processor(self, repetition_penalty: float) -> LogitsProcessorList:
94
+ """Build logits processor list with repetition penalty if needed"""
95
+ logits_processor = LogitsProcessorList()
96
+ if repetition_penalty != 1.0:
97
+ logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
98
+ return logits_processor
99
+
100
+ def _setup_constrained_processor(
101
+ self,
102
+ use_constrained_decoding: bool,
103
+ constrained_decoding_debug: bool,
104
+ target_duration: Optional[float],
105
+ user_metadata: Optional[Dict[str, Optional[str]]],
106
+ stop_at_reasoning: bool,
107
+ skip_genres: bool,
108
+ skip_caption: bool,
109
+ skip_language: bool,
110
+ generation_phase: str,
111
+ is_batch: bool = False,
112
+ metadata_temperature: Optional[float] = None,
113
+ codes_temperature: Optional[float] = None,
114
+ ) -> Optional[MetadataConstrainedLogitsProcessor]:
115
+ """Setup and configure constrained processor for generation"""
116
+ use_phase_temperatures = not is_batch and (metadata_temperature is not None or codes_temperature is not None)
117
+
118
+ if not use_constrained_decoding and not use_phase_temperatures:
119
+ return None
120
+
121
+ # Reset processor state for new generation
122
+ self.constrained_processor.reset()
123
+
124
+ # Use shared processor, just update settings
125
+ self.constrained_processor.enabled = use_constrained_decoding
126
+ self.constrained_processor.debug = constrained_decoding_debug
127
+
128
+ # Phase temperatures only supported in single mode
129
+ if use_phase_temperatures:
130
+ self.constrained_processor.metadata_temperature = metadata_temperature
131
+ self.constrained_processor.codes_temperature = codes_temperature
132
+ else:
133
+ self.constrained_processor.metadata_temperature = None
134
+ self.constrained_processor.codes_temperature = None
135
+
136
+ self.constrained_processor.set_target_duration(target_duration)
137
+
138
+ # Batch mode uses default/disabled settings for these options
139
+ if is_batch:
140
+ self.constrained_processor.set_user_metadata(None)
141
+ self.constrained_processor.set_stop_at_reasoning(False)
142
+ self.constrained_processor.set_skip_genres(True)
143
+ self.constrained_processor.set_skip_caption(True)
144
+ self.constrained_processor.set_skip_language(True)
145
+ else:
146
+ # Single mode uses provided settings
147
+ self.constrained_processor.set_user_metadata(user_metadata)
148
+ self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
149
+ self.constrained_processor.set_skip_genres(skip_genres)
150
+ self.constrained_processor.set_skip_caption(skip_caption)
151
+ self.constrained_processor.set_skip_language(skip_language)
152
+
153
+ # Set generation phase for phase-aware processing
154
+ self.constrained_processor.set_generation_phase(generation_phase)
155
+
156
+ return self.constrained_processor
157
+
158
+ def _build_unconditional_prompt(
159
+ self,
160
+ caption: str,
161
+ lyrics: str,
162
+ cot_text: str,
163
+ negative_prompt: str,
164
+ generation_phase: str,
165
+ is_batch: bool = False,
166
+ ) -> str:
167
+ """Build unconditional prompt for CFG based on generation phase and batch mode"""
168
+ if is_batch or generation_phase == "codes":
169
+ # Codes phase or batch mode: use empty CoT in unconditional prompt
170
+ return self.build_formatted_prompt_with_cot(
171
+ caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=negative_prompt
172
+ )
173
+ else:
174
+ # CoT phase (single mode only): unconditional prompt
175
+ # If negative_prompt is provided, use it as caption; otherwise remove caption and keep only lyrics
176
+ return self.build_formatted_prompt(
177
+ caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
178
+ )
179
+
180
+ def _load_pytorch_model(self, model_path: str, device: str) -> Tuple[bool, str]:
181
+ """Load PyTorch model from path and return (success, status_message)"""
182
+ try:
183
+ self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
184
+ if not self.offload_to_cpu:
185
+ self.llm = self.llm.to(device).to(self.dtype)
186
+ else:
187
+ self.llm = self.llm.to("cpu").to(self.dtype)
188
+ self.llm.eval()
189
+ self.llm_backend = "pt"
190
+ self.llm_initialized = True
191
+ logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}")
192
+ status_msg = f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nBackend: PyTorch\nDevice: {device}"
193
+ return True, status_msg
194
+ except Exception as e:
195
+ return False, f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
196
+
197
+ def _apply_top_k_filter(self, logits: torch.Tensor, top_k: Optional[int]) -> torch.Tensor:
198
+ """Apply top-k filtering to logits"""
199
+ if top_k is not None and top_k > 0:
200
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
201
+ logits[indices_to_remove] = float('-inf')
202
+ return logits
203
+
204
+ def _apply_top_p_filter(self, logits: torch.Tensor, top_p: Optional[float]) -> torch.Tensor:
205
+ """Apply top-p (nucleus) filtering to logits"""
206
+ if top_p is not None and 0.0 < top_p < 1.0:
207
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
208
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
209
+ sorted_indices_to_remove = cumulative_probs > top_p
210
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
211
+ sorted_indices_to_remove[..., 0] = 0
212
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
213
+ logits[indices_to_remove] = float('-inf')
214
+ return logits
215
+
216
+ def _sample_tokens(self, logits: torch.Tensor, temperature: float) -> torch.Tensor:
217
+ """Sample tokens from logits with temperature"""
218
+ if temperature > 0:
219
+ logits = logits / temperature
220
+ probs = torch.softmax(logits, dim=-1)
221
+ return torch.multinomial(probs, num_samples=1).squeeze(1)
222
+ else:
223
+ return torch.argmax(logits, dim=-1)
224
+
225
+ def _check_eos_token(self, tokens: torch.Tensor, eos_token_id: int, pad_token_id: Optional[int]) -> bool:
226
+ """Check if any token in the batch is EOS or pad token"""
227
+ if torch.any(tokens == eos_token_id):
228
+ return True
229
+ if pad_token_id is not None and pad_token_id != eos_token_id:
230
+ if torch.any(tokens == pad_token_id):
231
+ return True
232
+ return False
233
+
234
+ def _update_constrained_processor_state(self, constrained_processor: Optional[MetadataConstrainedLogitsProcessor], tokens: torch.Tensor):
235
+ """Update constrained processor state with generated tokens"""
236
+ if constrained_processor is not None:
237
+ for b in range(tokens.shape[0]):
238
+ constrained_processor.update_state(tokens[b].item())
239
+
240
+ def _forward_pass(
241
+ self,
242
+ model: Any,
243
+ generated_ids: torch.Tensor,
244
+ model_kwargs: Dict[str, Any],
245
+ past_key_values: Optional[Any],
246
+ use_cache: bool,
247
+ ) -> Any:
248
+ """Perform forward pass with KV cache support"""
249
+ if past_key_values is None:
250
+ outputs = model(
251
+ input_ids=generated_ids,
252
+ **model_kwargs,
253
+ use_cache=use_cache,
254
+ )
255
+ else:
256
+ outputs = model(
257
+ input_ids=generated_ids[:, -1:],
258
+ past_key_values=past_key_values,
259
+ **model_kwargs,
260
+ use_cache=use_cache,
261
+ )
262
+ return outputs
263
+
264
+ def _normalize_batch_input(self, formatted_prompts: Union[str, List[str]]) -> Tuple[List[str], bool]:
265
+ """Normalize batch input: convert single string to list and return (list, is_batch)"""
266
+ is_batch = isinstance(formatted_prompts, list)
267
+ if is_batch:
268
+ return formatted_prompts, is_batch
269
+ else:
270
+ return [formatted_prompts], is_batch
271
+
272
  def initialize(
273
  self,
274
  checkpoint_dir: str,
 
310
 
311
  logger.info("loading 5Hz LM tokenizer...")
312
  start_time = time.time()
313
+ # TODO: load tokenizer too slow, not found solution yet
314
  llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True)
315
  logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
316
  self.llm_tokenizer = llm_tokenizer
 
335
  # vllm initialization failed, fallback to PyTorch
336
  if not self.llm_initialized:
337
  logger.warning("vllm initialization failed, falling back to PyTorch backend")
338
+ success, status_msg = self._load_pytorch_model(full_lm_model_path, device)
339
+ if not success:
340
+ return status_msg, False
341
+ status_msg = f"✅ 5Hz LM initialized successfully (PyTorch fallback)\nModel: {full_lm_model_path}\nBackend: PyTorch"
 
 
 
 
 
 
 
 
 
342
  # If vllm initialization succeeded, self.llm_initialized should already be True
343
  else:
344
  # Use PyTorch backend (pt)
345
+ success, status_msg = self._load_pytorch_model(full_lm_model_path, device)
346
+ if not success:
347
+ return status_msg, False
 
 
 
 
 
 
 
 
 
 
348
 
349
  return status_msg, True
350
 
351
  except Exception as e:
352
+ return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
 
353
 
354
  def _initialize_5hz_lm_vllm(self, model_path: str) -> str:
355
  """Initialize 5Hz LM model using vllm backend"""
 
395
  return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
396
  except Exception as e:
397
  self.llm_initialized = False
398
+ return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
 
399
 
400
+ def _run_vllm(
401
  self,
402
+ formatted_prompts: Union[str, List[str]],
403
  temperature: float,
404
  cfg_scale: float,
405
  negative_prompt: str,
 
408
  repetition_penalty: float,
409
  use_constrained_decoding: bool = True,
410
  constrained_decoding_debug: bool = False,
411
+ metadata_temperature: Optional[float] = None,
412
  codes_temperature: Optional[float] = None,
413
  target_duration: Optional[float] = None,
414
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
 
420
  caption: str = "",
421
  lyrics: str = "",
422
  cot_text: str = "",
423
+ seeds: Optional[List[int]] = None,
424
+ ) -> Union[str, List[str]]:
425
+ """
426
+ Unified vllm generation function supporting both single and batch modes.
427
+ Accepts either a single formatted prompt (str) or a list of formatted prompts (List[str]).
428
+ Returns a single string for single mode, or a list of strings for batch mode.
429
+ """
430
  from nanovllm import SamplingParams
431
 
432
+ # Determine if batch mode
433
+ formatted_prompt_list, is_batch = self._normalize_batch_input(formatted_prompts)
434
+ batch_size = len(formatted_prompt_list)
435
+
436
  # Determine effective temperature for sampler
437
+ # Batch mode doesn't support phase temperatures, so use simple temperature
438
+ # Single mode supports phase temperatures
439
+ use_phase_temperatures = not is_batch and (metadata_temperature is not None or codes_temperature is not None)
440
  effective_sampler_temp = 1.0 if use_phase_temperatures else temperature
441
 
442
+ # Setup constrained processor
443
+ constrained_processor = self._setup_constrained_processor(
444
+ use_constrained_decoding=use_constrained_decoding or use_phase_temperatures,
445
+ constrained_decoding_debug=constrained_decoding_debug,
446
+ target_duration=target_duration,
447
+ user_metadata=user_metadata,
448
+ stop_at_reasoning=stop_at_reasoning,
449
+ skip_genres=skip_genres,
450
+ skip_caption=skip_caption,
451
+ skip_language=skip_language,
452
+ generation_phase=generation_phase,
453
+ is_batch=is_batch,
454
+ metadata_temperature=metadata_temperature,
455
+ codes_temperature=codes_temperature,
456
+ )
 
 
 
 
 
 
 
 
457
 
458
  sampling_params = SamplingParams(
459
  max_tokens=self.max_model_len - 64,
 
468
 
469
  if cfg_scale > 1.0:
470
  # Build unconditional prompt based on generation phase
471
+ formatted_unconditional_prompt = self._build_unconditional_prompt(
472
+ caption=caption,
473
+ lyrics=lyrics,
474
+ cot_text=cot_text,
475
+ negative_prompt=negative_prompt,
476
+ generation_phase=generation_phase,
477
+ is_batch=is_batch,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  )
479
  unconditional_prompts = [formatted_unconditional_prompt] * batch_size
480
 
481
  outputs = self.llm.generate(
482
+ formatted_prompt_list,
483
  sampling_params,
484
  unconditional_prompts=unconditional_prompts,
485
  )
486
  else:
487
+ outputs = self.llm.generate(formatted_prompt_list, sampling_params)
488
+
489
+ # Extract text from outputs
490
  output_texts = []
491
  for output in outputs:
492
  if hasattr(output, "outputs") and len(output.outputs) > 0:
 
497
  output_texts.append(output["text"])
498
  else:
499
  output_texts.append(str(output))
 
 
500
 
501
+ # Return single string for single mode, list for batch mode
502
+ return output_texts[0] if not is_batch else output_texts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
+ def _run_pt_single(
505
  self,
506
  formatted_prompt: str,
507
  temperature: float,
 
510
  top_k: Optional[int],
511
  top_p: Optional[float],
512
  repetition_penalty: float,
513
+ use_constrained_decoding: bool,
514
+ constrained_decoding_debug: bool,
515
+ target_duration: Optional[float],
516
+ user_metadata: Optional[Dict[str, Optional[str]]],
517
+ stop_at_reasoning: bool,
518
+ skip_genres: bool,
519
+ skip_caption: bool,
520
+ skip_language: bool,
521
+ generation_phase: str,
522
+ caption: str,
523
+ lyrics: str,
524
+ cot_text: str,
525
  ) -> str:
526
+ """Internal helper function for single-item PyTorch generation."""
527
  inputs = self.llm_tokenizer(
528
  formatted_prompt,
529
  return_tensors="pt",
 
531
  truncation=True,
532
  )
533
 
534
+ # Setup constrained processor
535
+ constrained_processor = self._setup_constrained_processor(
536
+ use_constrained_decoding=use_constrained_decoding,
537
+ constrained_decoding_debug=constrained_decoding_debug,
538
+ target_duration=target_duration,
539
+ user_metadata=user_metadata,
540
+ stop_at_reasoning=stop_at_reasoning,
541
+ skip_genres=skip_genres,
542
+ skip_caption=skip_caption,
543
+ skip_language=skip_language,
544
+ generation_phase=generation_phase,
545
+ is_batch=False,
546
+ )
 
 
 
 
 
 
 
 
547
 
548
  with self._load_model_context():
549
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
 
552
  max_new_tokens = min(max_new_tokens, self.max_model_len - 64)
553
 
554
  # Build logits processor list (only for CFG and repetition penalty)
555
+ logits_processor = self._build_logits_processor(repetition_penalty)
 
 
 
 
556
 
557
  if cfg_scale > 1.0:
558
  # Build unconditional prompt based on generation phase
559
+ formatted_unconditional_prompt = self._build_unconditional_prompt(
560
+ caption=caption,
561
+ lyrics=lyrics,
562
+ cot_text=cot_text,
563
+ negative_prompt=negative_prompt,
564
+ generation_phase=generation_phase,
565
+ is_batch=False,
566
+ )
 
 
 
567
 
568
  # Tokenize both prompts together to ensure same length (with left padding)
569
  # Left padding is important for generation tasks
 
656
 
657
  output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
658
  return output_text
659
+
660
+ def _run_pt(
661
+ self,
662
+ formatted_prompts: Union[str, List[str]],
663
+ temperature: float,
664
+ cfg_scale: float,
665
+ negative_prompt: str,
666
+ top_k: Optional[int],
667
+ top_p: Optional[float],
668
+ repetition_penalty: float,
669
+ use_constrained_decoding: bool = True,
670
+ constrained_decoding_debug: bool = False,
671
+ target_duration: Optional[float] = None,
672
+ user_metadata: Optional[Dict[str, Optional[str]]] = None,
673
+ stop_at_reasoning: bool = False,
674
+ skip_genres: bool = True,
675
+ skip_caption: bool = False,
676
+ skip_language: bool = False,
677
+ generation_phase: str = "cot",
678
+ caption: str = "",
679
+ lyrics: str = "",
680
+ cot_text: str = "",
681
+ seeds: Optional[List[int]] = None,
682
+ ) -> Union[str, List[str]]:
683
+ """
684
+ Unified PyTorch generation function supporting both single and batch modes.
685
+ Accepts either a single formatted prompt (str) or a list of formatted prompts (List[str]).
686
+ Returns a single string for single mode, or a list of strings for batch mode.
687
+ Note: PyTorch backend processes batch items sequentially (doesn't support true batching efficiently).
688
+ """
689
+ # Determine if batch mode
690
+ formatted_prompt_list, is_batch = self._normalize_batch_input(formatted_prompts)
691
+
692
+ # For batch mode, process each item sequentially with different seeds
693
+ if is_batch:
694
+ output_texts = []
695
+ for i, formatted_prompt in enumerate(formatted_prompt_list):
696
+ # Set seed for this item if provided
697
+ if seeds and i < len(seeds):
698
+ torch.manual_seed(seeds[i])
699
+ if torch.cuda.is_available():
700
+ torch.cuda.manual_seed_all(seeds[i])
701
+
702
+ # Generate using single-item method with batch-mode defaults
703
+ output_text = self._run_pt_single(
704
+ formatted_prompt=formatted_prompt,
705
+ temperature=temperature,
706
+ cfg_scale=cfg_scale,
707
+ negative_prompt=negative_prompt,
708
+ top_k=top_k,
709
+ top_p=top_p,
710
+ repetition_penalty=repetition_penalty,
711
+ use_constrained_decoding=use_constrained_decoding,
712
+ constrained_decoding_debug=constrained_decoding_debug,
713
+ target_duration=target_duration,
714
+ user_metadata=None,
715
+ stop_at_reasoning=False,
716
+ skip_genres=True,
717
+ skip_caption=True,
718
+ skip_language=True,
719
+ generation_phase=generation_phase,
720
+ caption=caption,
721
+ lyrics=lyrics,
722
+ cot_text=cot_text,
723
+ )
724
+
725
+ output_texts.append(output_text)
726
+
727
+ return output_texts
728
+
729
+ # Single mode: process the formatted prompt
730
+ formatted_prompt = formatted_prompt_list[0]
731
+
732
+ return self._run_pt_single(
733
+ formatted_prompt=formatted_prompt,
734
+ temperature=temperature,
735
+ cfg_scale=cfg_scale,
736
+ negative_prompt=negative_prompt,
737
+ top_k=top_k,
738
+ top_p=top_p,
739
+ repetition_penalty=repetition_penalty,
740
+ use_constrained_decoding=use_constrained_decoding,
741
+ constrained_decoding_debug=constrained_decoding_debug,
742
+ target_duration=target_duration,
743
+ user_metadata=user_metadata,
744
+ stop_at_reasoning=stop_at_reasoning,
745
+ skip_genres=skip_genres,
746
+ skip_caption=skip_caption,
747
+ skip_language=skip_language,
748
+ generation_phase=generation_phase,
749
+ caption=caption,
750
+ lyrics=lyrics,
751
+ cot_text=cot_text,
752
+ )
753
+
754
  def has_all_metas(self, user_metadata: Optional[Dict[str, Optional[str]]]) -> bool:
755
  """Check if all required metadata are present."""
756
  if user_metadata is None:
 
798
  constrained_decoding_debug: bool = False,
799
  target_duration: Optional[float] = None,
800
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
801
+ use_cot_metas: bool = True,
802
  use_cot_caption: bool = True,
803
  use_cot_language: bool = True,
804
+ batch_size: Optional[int] = None,
805
+ seeds: Optional[List[int]] = None,
806
+ progress=None,
807
+ ) -> Dict[str, Any]:
808
  """Two-phase LM generation: CoT generation followed by audio codes generation.
809
 
810
  - infer_type='dit': Phase 1 only - generate CoT and return metas (no audio codes)
 
817
  If specified, constrained decoding will inject these values directly.
818
  use_cot_caption: Whether to generate caption in CoT (default True).
819
  use_cot_language: Whether to generate language in CoT (default True).
820
+ batch_size: Optional batch size for batch generation. If None or 1, returns single result.
821
+ If > 1, returns batch results (lists).
822
+ seeds: Optional list of seeds for batch generation (for reproducibility).
823
+ Only used when batch_size > 1. TODO: not used yet
824
 
825
+ Returns:
826
+ Dictionary containing:
827
+ - metadata: Dict or List[Dict] - Generated metadata
828
+ - audio_codes: str or List[str] - Generated audio codes
829
+ - success: bool - Whether generation succeeded
830
+ - error: Optional[str] - Error message if failed
831
+ - extra_outputs: Dict with time_costs and other info
832
+ """
833
+ if progress is None:
834
+ def progress(*args, **kwargs):
835
+ pass
836
+
837
  infer_type = (infer_type or "").strip().lower()
838
  if infer_type not in {"dit", "llm_dit"}:
839
+ error_msg = f"invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
840
+ return {
841
+ "metadata": [] if (batch_size and batch_size > 1) else {},
842
+ "audio_codes": [] if (batch_size and batch_size > 1) else "",
843
+ "success": False,
844
+ "error": error_msg,
845
+ "extra_outputs": {"time_costs": {}},
846
+ }
847
+
848
+ # Determine if batch mode
849
+ is_batch = batch_size and batch_size > 1
850
+ actual_batch_size = batch_size if is_batch else 1
851
+
852
+ # Initialize variables
853
  metadata = {}
854
  audio_codes = ""
855
  has_all_metas = self.has_all_metas(user_metadata)
 
 
856
  phase1_time = 0.0
857
  phase2_time = 0.0
858
 
859
+ # Handle seeds for batch mode
860
+ if is_batch:
861
+ if seeds is None:
862
+ seeds = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
863
+ elif len(seeds) < actual_batch_size:
864
+ seeds = list(seeds) + [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size - len(seeds))]
865
+ else:
866
+ seeds = seeds[:actual_batch_size]
867
+
868
  # ========== PHASE 1: CoT Generation ==========
869
+ # Skip CoT if all metadata are user-provided OR caption is already formatted
870
+ progress(0.1, f"Phase 1: Generating CoT metadata (once for all items)...")
871
+ if not has_all_metas and use_cot_metas:
872
+ if is_batch:
873
+ logger.info("Batch Phase 1: Generating CoT metadata (once for all items)...")
874
+ else:
875
+ logger.info("Phase 1: Generating CoT metadata...")
876
  phase1_start = time.time()
877
 
878
  # Build formatted prompt for CoT phase
879
  formatted_prompt = self.build_formatted_prompt(caption, lyrics, generation_phase="cot")
880
+
881
  logger.info(f"generate_with_stop_condition: formatted_prompt={formatted_prompt}")
882
  # Generate CoT (stop at </think>)
883
  cot_output_text, status = self.generate_from_formatted_prompt(
 
907
  phase1_time = time.time() - phase1_start
908
 
909
  if not cot_output_text:
910
+ return {
911
+ "metadata": [] if is_batch else {},
912
+ "audio_codes": [] if is_batch else "",
913
+ "success": False,
914
+ "error": status,
915
+ "extra_outputs": {"time_costs": {"phase1_time": phase1_time}},
916
+ }
917
 
918
  # Parse metadata from CoT output
919
  metadata, _ = self.parse_lm_output(cot_output_text)
920
+ if is_batch:
921
+ logger.info(f"Batch Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
922
+ else:
923
+ logger.info(f"Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
924
  else:
925
  # Use user-provided metadata
926
+ if is_batch:
927
+ logger.info("Batch Phase 1: Using user-provided metadata (skipping generation)")
928
+ else:
929
+ logger.info("Phase 1: Using user-provided metadata (skipping generation)")
930
  metadata = {k: v for k, v in user_metadata.items() if v is not None}
931
 
932
  # If infer_type is 'dit', stop here and return only metadata
933
  if infer_type == "dit":
934
+ if is_batch:
935
+ metadata_list = [metadata.copy() for _ in range(actual_batch_size)]
936
+ return {
937
+ "metadata": metadata_list,
938
+ "audio_codes": [""] * actual_batch_size,
939
+ "success": True,
940
+ "error": None,
941
+ "extra_outputs": {
942
+ "time_costs": {
943
+ "phase1_time": phase1_time,
944
+ "total_time": phase1_time,
945
+ }
946
+ },
947
+ }
948
+ else:
949
+ return {
950
+ "metadata": metadata,
951
+ "audio_codes": "",
952
+ "success": True,
953
+ "error": None,
954
+ "extra_outputs": {
955
+ "time_costs": {
956
+ "phase1_time": phase1_time,
957
+ "total_time": phase1_time,
958
+ }
959
+ },
960
+ }
961
 
962
  # ========== PHASE 2: Audio Codes Generation ==========
963
+ if is_batch:
964
+ logger.info(f"Batch Phase 2: Generating audio codes for {actual_batch_size} items...")
965
+ else:
966
+ logger.info("Phase 2: Generating audio codes...")
967
  phase2_start = time.time()
968
 
969
  # Format metadata as CoT using YAML (matching training format)
 
972
  # Build formatted prompt with CoT for codes generation phase
973
  formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
974
  logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
975
 
976
+ progress(0.5, f"Phase 2: Generating audio codes for {actual_batch_size} items...")
977
+ if is_batch:
978
+ # Batch mode: generate codes for all items
979
+ formatted_prompts = [formatted_prompt_with_cot] * actual_batch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
980
 
981
+ # Call backend-specific batch generation
982
+ try:
983
+ if self.llm_backend == "vllm":
984
+ codes_outputs = self._run_vllm(
985
+ formatted_prompts=formatted_prompts,
986
+ temperature=temperature,
987
+ cfg_scale=cfg_scale,
988
+ negative_prompt=negative_prompt,
989
+ top_k=top_k,
990
+ top_p=top_p,
991
+ repetition_penalty=repetition_penalty,
992
+ use_constrained_decoding=use_constrained_decoding,
993
+ constrained_decoding_debug=constrained_decoding_debug,
994
+ target_duration=target_duration,
995
+ generation_phase="codes",
996
+ caption=caption,
997
+ lyrics=lyrics,
998
+ cot_text=cot_text,
999
+ seeds=seeds,
1000
+ )
1001
+ else: # pt backend
1002
+ codes_outputs = self._run_pt(
1003
+ formatted_prompts=formatted_prompts,
1004
+ temperature=temperature,
1005
+ cfg_scale=cfg_scale,
1006
+ negative_prompt=negative_prompt,
1007
+ top_k=top_k,
1008
+ top_p=top_p,
1009
+ repetition_penalty=repetition_penalty,
1010
+ use_constrained_decoding=use_constrained_decoding,
1011
+ constrained_decoding_debug=constrained_decoding_debug,
1012
+ target_duration=target_duration,
1013
+ generation_phase="codes",
1014
+ caption=caption,
1015
+ lyrics=lyrics,
1016
+ cot_text=cot_text,
1017
+ seeds=seeds,
1018
+ )
1019
+ except Exception as e:
1020
+ error_msg = f"Error in batch codes generation: {str(e)}"
1021
+ logger.error(error_msg)
1022
+ return {
1023
+ "metadata": [],
1024
+ "audio_codes": [],
1025
+ "success": False,
1026
+ "error": error_msg,
1027
+ "extra_outputs": {
1028
+ "time_costs": {
1029
+ "phase1_time": phase1_time,
1030
+ "phase2_time": 0.0,
1031
+ "total_time": phase1_time,
1032
+ }
1033
+ },
1034
+ }
1035
+
1036
+ # Parse audio codes from each output
1037
+ audio_codes_list = []
1038
+ metadata_list = []
1039
+ for output_text in codes_outputs:
1040
+ _, audio_codes_item = self.parse_lm_output(output_text)
1041
+ audio_codes_list.append(audio_codes_item)
1042
+ metadata_list.append(metadata.copy()) # Same metadata for all
1043
+
1044
+ phase2_time = time.time() - phase2_start
1045
+
1046
+ # Log results
1047
+ codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list]
1048
+ logger.info(f"Batch Phase 2 completed in {phase2_time:.2f}s. Generated codes: {codes_counts}")
1049
+
1050
+ total_time = phase1_time + phase2_time
1051
+ return {
1052
+ "metadata": metadata_list,
1053
+ "audio_codes": audio_codes_list,
1054
+ "success": True,
1055
+ "error": None,
1056
+ "extra_outputs": {
1057
+ "time_costs": {
1058
+ "phase1_time": phase1_time,
1059
+ "phase2_time": phase2_time,
1060
+ "total_time": total_time,
1061
+ },
1062
+ "codes_counts": codes_counts,
1063
+ "total_codes": sum(codes_counts),
1064
+ },
1065
+ }
1066
+ else:
1067
+ # Single mode: generate codes for one item
1068
+ codes_output_text, status = self.generate_from_formatted_prompt(
1069
+ formatted_prompt=formatted_prompt_with_cot,
1070
+ cfg={
1071
+ "temperature": temperature,
1072
+ "cfg_scale": cfg_scale,
1073
+ "negative_prompt": negative_prompt,
1074
+ "top_k": top_k,
1075
+ "top_p": top_p,
1076
+ "repetition_penalty": repetition_penalty,
1077
+ "target_duration": target_duration,
1078
+ "user_metadata": None, # No user metadata injection in Phase 2
1079
+ "skip_caption": True, # Skip caption since CoT is already included
1080
+ "skip_language": True, # Skip language since CoT is already included
1081
+ "generation_phase": "codes",
1082
+ # Pass context for building unconditional prompt in codes phase
1083
+ "caption": caption,
1084
+ "lyrics": lyrics,
1085
+ "cot_text": cot_text,
1086
+ },
1087
  use_constrained_decoding=use_constrained_decoding,
1088
  constrained_decoding_debug=constrained_decoding_debug,
1089
+ stop_at_reasoning=False, # Generate codes until EOS
 
 
 
 
1090
  )
1091
 
1092
+ if not codes_output_text:
1093
+ total_time = phase1_time + phase2_time
1094
+ return {
1095
+ "metadata": metadata,
1096
+ "audio_codes": "",
1097
+ "success": False,
1098
+ "error": status,
1099
+ "extra_outputs": {
1100
+ "time_costs": {
1101
+ "phase1_time": phase1_time,
1102
+ "phase2_time": phase2_time,
1103
+ "total_time": total_time,
1104
+ }
1105
+ },
1106
+ }
1107
 
1108
+ phase2_time = time.time() - phase2_start
 
1109
 
1110
+ # Parse audio codes from output (metadata should be same as Phase 1)
1111
+ _, audio_codes = self.parse_lm_output(codes_output_text)
1112
+
1113
+ codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
1114
+ logger.info(f"Phase 2 completed in {phase2_time:.2f}s. Generated {codes_count} audio codes")
1115
+
1116
+ total_time = phase1_time + phase2_time
1117
+ return {
1118
+ "metadata": metadata,
1119
+ "audio_codes": audio_codes,
1120
+ "success": True,
1121
+ "error": None,
1122
+ "extra_outputs": {
1123
+ "time_costs": {
1124
+ "phase1_time": phase1_time,
1125
+ "phase2_time": phase2_time,
1126
+ "total_time": total_time,
1127
+ },
1128
+ "codes_count": codes_count,
1129
+ },
1130
+ }
1131
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1132
  def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str:
1133
  """
1134
  Build the chat-formatted prompt for 5Hz LM from caption/lyrics.
 
1150
  if is_negative_prompt:
1151
  # Unconditional prompt for CFG
1152
  # Check if user provided a meaningful negative prompt (not the default)
1153
+ has_negative_prompt = self._has_meaningful_negative_prompt(negative_prompt)
1154
 
1155
  if generation_phase == "cot":
1156
  # CoT phase unconditional prompt
 
1201
  if is_negative_prompt:
1202
  # Unconditional prompt for codes phase
1203
  # Check if user provided a meaningful negative prompt
1204
+ has_negative_prompt = self._has_meaningful_negative_prompt(negative_prompt)
1205
 
1206
  # Use empty CoT for unconditional
1207
  cot_for_prompt = "<think>\n</think>"
 
1484
 
1485
  try:
1486
  if self.llm_backend == "vllm":
1487
+ output_text = self._run_vllm(
1488
+ formatted_prompts=formatted_prompt,
1489
  temperature=temperature,
1490
  cfg_scale=cfg_scale,
1491
  negative_prompt=negative_prompt,
 
1508
  return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
1509
 
1510
  # PyTorch backend
1511
+ output_text = self._run_pt(
1512
+ formatted_prompts=formatted_prompt,
1513
  temperature=temperature,
1514
  cfg_scale=cfg_scale,
1515
  negative_prompt=negative_prompt,
 
1574
  eos_token_id = pad_token_id
1575
 
1576
  # Build logits processor for repetition penalty
1577
+ logits_processor = self._build_logits_processor(repetition_penalty)
 
 
1578
 
1579
  with torch.no_grad():
1580
  for step in range(max_new_tokens):
1581
  # Forward pass
1582
+ outputs = self._forward_pass(model, generated_ids, model_kwargs, past_key_values, use_cache)
 
 
 
 
 
 
 
 
 
 
 
 
1583
 
1584
  # Get logits for the last position
1585
  next_token_logits = outputs.logits[:, -1, :] # [batch_size, vocab_size]
 
1592
  for processor in logits_processor:
1593
  next_token_logits = processor(generated_ids, next_token_logits)
1594
 
1595
+ # Apply top-k and top-p filtering
1596
+ next_token_logits = self._apply_top_k_filter(next_token_logits, top_k)
1597
+ next_token_logits = self._apply_top_p_filter(next_token_logits, top_p)
 
 
 
 
 
 
 
 
 
 
 
1598
 
1599
  # Apply temperature and sample
1600
+ next_tokens = self._sample_tokens(next_token_logits, temperature)
 
 
 
 
 
1601
 
1602
  # Update constrained processor state
1603
+ self._update_constrained_processor_state(constrained_processor, next_tokens)
 
 
1604
 
1605
  # Check for EOS token
1606
+ should_stop = self._check_eos_token(next_tokens, eos_token_id, pad_token_id)
 
 
 
 
 
1607
 
1608
  # Append token to sequence
1609
  next_tokens_unsqueezed = next_tokens.unsqueeze(1)
 
1679
  eos_token_id = pad_token_id
1680
 
1681
  # Build logits processor for non-CFG operations (repetition penalty, top_k, top_p)
1682
+ logits_processor = self._build_logits_processor(repetition_penalty)
 
 
1683
 
1684
  with torch.no_grad():
1685
  for step in range(max_new_tokens):
1686
  # Forward pass for the entire batch (conditional + unconditional)
1687
+ outputs = self._forward_pass(model, generated_ids, model_kwargs, past_key_values, use_cache)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1688
 
1689
  # Get logits for the last position
1690
  next_token_logits = outputs.logits[:, -1, :] # [batch_size*2, vocab_size]
 
1707
  for processor in logits_processor:
1708
  cfg_logits = processor(current_input_ids, cfg_logits)
1709
 
1710
+ # Apply top-k and top-p filtering
1711
+ cfg_logits = self._apply_top_k_filter(cfg_logits, top_k)
1712
+ cfg_logits = self._apply_top_p_filter(cfg_logits, top_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
1713
 
1714
  # Apply temperature and sample
1715
+ next_tokens = self._sample_tokens(cfg_logits, temperature)
 
 
 
 
 
1716
 
1717
  # Update constrained processor state AFTER sampling
1718
+ self._update_constrained_processor_state(constrained_processor, next_tokens)
 
 
1719
 
1720
  # Check for EOS token in conditional sequences BEFORE unsqueezing
1721
  # Stop if any conditional sequence generates EOS token
1722
  # next_tokens shape: [batch_size] (only conditional tokens)
1723
+ should_stop = self._check_eos_token(next_tokens, eos_token_id, pad_token_id)
 
 
 
 
 
1724
 
1725
  # Apply the same sampled tokens to both conditional and unconditional sequences
1726
  next_tokens_unsqueezed = next_tokens.unsqueeze(1)
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py CHANGED
@@ -68,10 +68,16 @@ class ModelRunner:
68
  self.model = Qwen3ForCausalLM(hf_config)
69
  load_model(self.model, config.model)
70
  self.sampler = Sampler()
 
 
 
 
 
71
  self.warmup_model()
72
  self.allocate_kv_cache()
73
  if not self.enforce_eager:
74
  self.capture_cudagraph()
 
75
  torch.set_default_device("cpu")
76
  torch.set_default_dtype(default_dtype)
77
 
@@ -84,6 +90,39 @@ class ModelRunner:
84
  self.shm = SharedMemory(name="nanovllm")
85
  self.loop()
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def exit(self):
88
  if self.world_size > 1:
89
  self.shm.close()
@@ -203,7 +242,7 @@ class ModelRunner:
203
  if i != seq.num_blocks - 1:
204
  end = start + self.block_size
205
  else:
206
- end = start + seq.last_block_num_tokens
207
  slot_mapping.extend(list(range(start, end)))
208
  if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
209
  block_tables = self.prepare_block_tables(seqs)
@@ -216,57 +255,58 @@ class ModelRunner:
216
  return input_ids, positions
217
 
218
  def prepare_decode(self, seqs: list[Sequence]):
219
- input_ids = []
220
- positions = []
221
- slot_mapping = []
222
- context_lens = []
223
- for seq in seqs:
224
- input_ids.append(seq.last_token)
225
- positions.append(len(seq) - 1)
226
- context_lens.append(len(seq))
227
- slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
228
- input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
229
- positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
230
- slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
231
- context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
 
 
232
  block_tables = self.prepare_block_tables(seqs)
233
  set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
234
  return input_ids, positions
235
 
236
  def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
237
- """Prepare sampling parameters. For CFG batch, only return parameters for conditional sequences."""
238
  if is_cfg_batch:
239
- # For CFG batch, seqs contains [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
240
- # We only need parameters for conditional sequences (first half)
241
- num_cond = len(seqs) // 2
242
- temperatures = []
243
- cfg_scales = []
244
- top_ks = []
245
- top_ps = []
246
- repetition_penalties = []
247
- for seq in seqs[:num_cond]:
248
- temperatures.append(seq.temperature)
249
- cfg_scales.append(seq.cfg_scale)
250
- top_ks.append(seq.top_k if seq.top_k is not None else 0)
251
- top_ps.append(seq.top_p if seq.top_p is not None else 1.0)
252
- repetition_penalties.append(seq.repetition_penalty)
253
  else:
254
- temperatures = []
255
- cfg_scales = []
256
- top_ks = []
257
- top_ps = []
258
- repetition_penalties = []
259
- for seq in seqs:
260
- temperatures.append(seq.temperature)
261
- cfg_scales.append(seq.cfg_scale)
262
- top_ks.append(seq.top_k if seq.top_k is not None else 0)
263
- top_ps.append(seq.top_p if seq.top_p is not None else 1.0)
264
- repetition_penalties.append(seq.repetition_penalty)
265
- temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
266
- cfg_scales = torch.tensor(cfg_scales, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
267
- top_ks = torch.tensor(top_ks, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
268
- top_ps = torch.tensor(top_ps, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
269
- repetition_penalties = torch.tensor(repetition_penalties, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
 
 
 
 
 
 
 
 
 
 
 
270
  return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
271
 
272
  @torch.inference_mode()
@@ -293,27 +333,15 @@ class ModelRunner:
293
  [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
294
  where uncond_seqi is the paired unconditional sequence of cond_seqi."""
295
  # Check if this is a CFG batch (contains paired conditional and unconditional sequences)
296
- is_cfg_batch = False
297
- if len(seqs) > 0:
298
- # CFG batch if first sequence has cfg_scale > 1.0 and paired_seq
299
- if seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None:
300
- is_cfg_batch = True
301
- # Verify batch structure: first half conditional, second half unconditional
302
- num_cond = len(seqs) // 2
303
- for i in range(num_cond):
304
- if seqs[i].is_unconditional or seqs[i + num_cond].is_unconditional == False:
305
- is_cfg_batch = False
306
- break
307
-
308
  if is_cfg_batch:
309
  # CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
310
  num_cond = len(seqs) // 2
311
  cond_seqs = seqs[:num_cond]
312
- uncond_seqs = seqs[num_cond:]
313
 
314
  # Prepare inputs for both conditional and unconditional (they're already in the batch)
315
- input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
316
- else self.prepare_decode(seqs))
317
  sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
318
  if sample_params is not None:
319
  temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
@@ -364,7 +392,7 @@ class ModelRunner:
364
  logits_cfg[i:i+1] = seq.logits_processor(seq_input_ids, logits_cfg[i:i+1])
365
 
366
  # Prepare input_ids for sampler (for repetition penalty, though we already applied it)
367
- cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
368
 
369
  # Sample from CFG logits
370
  token_ids_cfg = self.sampler(
@@ -373,7 +401,7 @@ class ModelRunner:
373
  top_ks=top_ks if top_ks is not None else None,
374
  top_ps=top_ps if top_ps is not None else None,
375
  repetition_penalties=None, # Already applied above
376
- input_ids=cond_input_ids,
377
  ).tolist()
378
 
379
  # Update logits processor state after sampling
@@ -432,7 +460,7 @@ class ModelRunner:
432
  logits[i] = processed[0]
433
 
434
  # Prepare input_ids for sampler
435
- seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
436
 
437
  token_ids = self.sampler(
438
  logits,
@@ -440,7 +468,7 @@ class ModelRunner:
440
  top_ks=top_ks if top_ks is not None else None,
441
  top_ps=top_ps if top_ps is not None else None,
442
  repetition_penalties=None, # Already applied above
443
- input_ids=seq_input_ids,
444
  ).tolist()
445
 
446
  # Update logits processor state after sampling
 
68
  self.model = Qwen3ForCausalLM(hf_config)
69
  load_model(self.model, config.model)
70
  self.sampler = Sampler()
71
+
72
+ # Pre-allocate buffers for sampling (optimization: avoid repeated tensor creation)
73
+ # Must be called before warmup_model() since it uses these buffers
74
+ self._allocate_sample_buffers()
75
+
76
  self.warmup_model()
77
  self.allocate_kv_cache()
78
  if not self.enforce_eager:
79
  self.capture_cudagraph()
80
+
81
  torch.set_default_device("cpu")
82
  torch.set_default_dtype(default_dtype)
83
 
 
90
  self.shm = SharedMemory(name="nanovllm")
91
  self.loop()
92
 
93
+ def _allocate_sample_buffers(self):
94
+ """Pre-allocate reusable buffers for sampling to avoid repeated tensor creation."""
95
+ max_bs = self.config.max_num_seqs
96
+ max_tokens = self.config.max_num_batched_tokens
97
+ max_num_blocks = (self.config.max_model_len + self.block_size - 1) // self.block_size
98
+
99
+ # Pre-allocate pinned memory buffers on CPU for fast transfer
100
+ # Must explicitly specify device="cpu" since default device may be "cuda"
101
+ self._cpu_temperatures = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
102
+ self._cpu_cfg_scales = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
103
+ self._cpu_top_ks = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
104
+ self._cpu_top_ps = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
105
+ self._cpu_repetition_penalties = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
106
+
107
+ # Pre-allocate decode buffers on CPU with pinned memory
108
+ self._cpu_input_ids = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
109
+ self._cpu_positions = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
110
+ self._cpu_slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
111
+ self._cpu_context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
112
+
113
+ # Pre-allocate prefill buffers on CPU with pinned memory (optimization to avoid repeated tensor creation)
114
+ self._cpu_prefill_input_ids = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
115
+ self._cpu_prefill_positions = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
116
+ self._cpu_prefill_cu_seqlens = torch.zeros(max_bs + 1, dtype=torch.int32, device="cpu", pin_memory=True)
117
+ self._cpu_prefill_slot_mapping = torch.zeros(max_tokens, dtype=torch.int32, device="cpu", pin_memory=True)
118
+
119
+ # Pre-allocate block tables buffer (shared by both decode and prefill)
120
+ self._cpu_block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device="cpu", pin_memory=True)
121
+
122
+ # Pre-allocate buffer for sequence token IDs (used in logits processor and sampler)
123
+ # Max length is max_model_len since sequences can be that long
124
+ self._seq_token_ids_buffer = torch.zeros(max_bs, self.config.max_model_len, dtype=torch.int64, device="cpu", pin_memory=True)
125
+
126
  def exit(self):
127
  if self.world_size > 1:
128
  self.shm.close()
 
242
  if i != seq.num_blocks - 1:
243
  end = start + self.block_size
244
  else:
245
+ end = start + seq.last_block_num_tokens
246
  slot_mapping.extend(list(range(start, end)))
247
  if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
248
  block_tables = self.prepare_block_tables(seqs)
 
255
  return input_ids, positions
256
 
257
  def prepare_decode(self, seqs: list[Sequence]):
258
+ """Optimized decode preparation using pre-allocated buffers."""
259
+ bs = len(seqs)
260
+
261
+ # Use pre-allocated CPU buffers
262
+ for i, seq in enumerate(seqs):
263
+ self._cpu_input_ids[i] = seq.last_token
264
+ self._cpu_positions[i] = len(seq) - 1
265
+ self._cpu_context_lens[i] = len(seq)
266
+ self._cpu_slot_mapping[i] = seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1
267
+
268
+ # Transfer to GPU using sliced views
269
+ input_ids = self._cpu_input_ids[:bs].cuda(non_blocking=True)
270
+ positions = self._cpu_positions[:bs].cuda(non_blocking=True)
271
+ slot_mapping = self._cpu_slot_mapping[:bs].cuda(non_blocking=True)
272
+ context_lens = self._cpu_context_lens[:bs].cuda(non_blocking=True)
273
  block_tables = self.prepare_block_tables(seqs)
274
  set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
275
  return input_ids, positions
276
 
277
  def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
278
+ """Optimized sample preparation using pre-allocated buffers."""
279
  if is_cfg_batch:
280
+ num_seqs = len(seqs) // 2
281
+ target_seqs = seqs[:num_seqs]
 
 
 
 
 
 
 
 
 
 
 
 
282
  else:
283
+ num_seqs = len(seqs)
284
+ target_seqs = seqs
285
+
286
+ # Fill pre-allocated CPU buffers
287
+ top_ks_is_zero = True
288
+ top_ps_is_one = True
289
+ repetition_penalties_is_one = True
290
+ for i, seq in enumerate(target_seqs):
291
+ self._cpu_temperatures[i] = seq.temperature
292
+ self._cpu_cfg_scales[i] = seq.cfg_scale
293
+ self._cpu_top_ks[i] = seq.top_k if seq.top_k is not None else 0
294
+ if seq.top_k is not None and seq.top_k > 0:
295
+ top_ks_is_zero = False
296
+ self._cpu_top_ps[i] = seq.top_p if seq.top_p is not None else 1.0
297
+ if seq.top_p is not None and seq.top_p == 1.0:
298
+ top_ps_is_one = False
299
+ self._cpu_repetition_penalties[i] = seq.repetition_penalty if seq.repetition_penalty is not None else 1.0
300
+ if seq.repetition_penalty is not None and seq.repetition_penalty == 1.0:
301
+ repetition_penalties_is_one = False
302
+
303
+ # Transfer to GPU using sliced views (single batched transfer)
304
+ temperatures = self._cpu_temperatures[:num_seqs].cuda(non_blocking=True)
305
+ cfg_scales = self._cpu_cfg_scales[:num_seqs].cuda(non_blocking=True)
306
+ top_ks = self._cpu_top_ks[:num_seqs].cuda(non_blocking=True) if not top_ks_is_zero else None
307
+ top_ps = self._cpu_top_ps[:num_seqs].cuda(non_blocking=True) if not top_ps_is_one else None
308
+ repetition_penalties = self._cpu_repetition_penalties[:num_seqs].cuda(non_blocking=True) if not repetition_penalties_is_one else None
309
+
310
  return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
311
 
312
  @torch.inference_mode()
 
333
  [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
334
  where uncond_seqi is the paired unconditional sequence of cond_seqi."""
335
  # Check if this is a CFG batch (contains paired conditional and unconditional sequences)
336
+ is_cfg_batch = seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None
 
 
 
 
 
 
 
 
 
 
 
337
  if is_cfg_batch:
338
  # CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
339
  num_cond = len(seqs) // 2
340
  cond_seqs = seqs[:num_cond]
341
+ # uncond_seqs = seqs[num_cond:]
342
 
343
  # Prepare inputs for both conditional and unconditional (they're already in the batch)
344
+ input_ids, positions = (self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs))
 
345
  sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
346
  if sample_params is not None:
347
  temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
 
392
  logits_cfg[i:i+1] = seq.logits_processor(seq_input_ids, logits_cfg[i:i+1])
393
 
394
  # Prepare input_ids for sampler (for repetition penalty, though we already applied it)
395
+ # cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
396
 
397
  # Sample from CFG logits
398
  token_ids_cfg = self.sampler(
 
401
  top_ks=top_ks if top_ks is not None else None,
402
  top_ps=top_ps if top_ps is not None else None,
403
  repetition_penalties=None, # Already applied above
404
+ # input_ids=cond_input_ids,
405
  ).tolist()
406
 
407
  # Update logits processor state after sampling
 
460
  logits[i] = processed[0]
461
 
462
  # Prepare input_ids for sampler
463
+ # seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
464
 
465
  token_ids = self.sampler(
466
  logits,
 
468
  top_ks=top_ks if top_ks is not None else None,
469
  top_ps=top_ps if top_ps is not None else None,
470
  repetition_penalties=None, # Already applied above
471
+ # input_ids=seq_input_ids,
472
  ).tolist()
473
 
474
  # Update logits processor state after sampling
acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py CHANGED
@@ -3,6 +3,83 @@ from torch import nn
3
  from typing import Optional
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class Sampler(nn.Module):
7
 
8
  def __init__(self):
@@ -19,56 +96,19 @@ class Sampler(nn.Module):
19
  input_ids: Optional[torch.Tensor] = None,
20
  ):
21
  """
22
- Sample tokens from logits with optional top-k, top-p, and repetition penalty.
23
 
24
- Args:
25
- logits: [batch_size, vocab_size] logits tensor
26
- temperatures: [batch_size] temperature values
27
- top_ks: Optional [batch_size] top-k values (None or 0 means no top-k filtering)
28
- top_ps: Optional [batch_size] top-p values (None or 1.0 means no top-p filtering)
29
- repetition_penalties: Optional [batch_size] repetition penalty values (1.0 means no penalty)
30
- input_ids: Optional [batch_size, seq_len] input token ids for repetition penalty
31
  """
32
- batch_size, vocab_size = logits.shape
33
-
34
- # Note: Repetition penalty is applied in ModelRunner before calling sampler
35
- # This allows us to use the full sequence context
36
-
37
  # Apply temperature
38
  logits = logits.float().div_(temperatures.unsqueeze(dim=1))
39
-
40
- # Apply top-k filtering if specified
41
- if top_ks is not None:
42
- for i in range(batch_size):
43
- top_k = top_ks[i].item()
44
- if top_k > 0 and top_k < vocab_size:
45
- # Get top-k logits, set others to -inf
46
- top_k_logits, top_k_indices = torch.topk(logits[i], int(top_k), dim=-1)
47
- filtered_logits = torch.full_like(logits[i], float('-inf'))
48
- filtered_logits[top_k_indices] = top_k_logits
49
- logits[i] = filtered_logits
50
-
51
- # Apply top-p (nucleus) filtering if specified
52
- if top_ps is not None:
53
- probs = torch.softmax(logits, dim=-1)
54
- for i in range(batch_size):
55
- top_p = top_ps[i].item()
56
- if 0.0 < top_p < 1.0:
57
- # Sort probabilities in descending order
58
- sorted_probs, sorted_indices = torch.sort(probs[i], descending=True)
59
- # Calculate cumulative probabilities
60
- cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
61
- # Find the cutoff point
62
- cutoff_idx = (cumsum_probs <= top_p).sum().item()
63
- if cutoff_idx < len(sorted_indices):
64
- cutoff_idx += 1 # Include one more token to ensure we have at least one
65
- # Create mask for tokens to keep
66
- mask = torch.zeros_like(probs[i])
67
- mask[sorted_indices[:cutoff_idx]] = 1.0
68
- # Apply mask: set filtered tokens to -inf
69
- logits[i] = torch.where(mask > 0, logits[i], torch.tensor(float('-inf'), device=logits.device))
70
-
71
- # Sample using Gumbel-max trick (equivalent to sampling from softmax)
72
  probs = torch.softmax(logits, dim=-1)
73
  sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
74
- return sample_tokens
 
3
  from typing import Optional
4
 
5
 
6
+ def apply_top_k_top_p(
7
+ logits: torch.Tensor,
8
+ k: Optional[torch.Tensor],
9
+ p: Optional[torch.Tensor],
10
+ ) -> torch.Tensor:
11
+ """Apply top-k and top-p masks to the logits (vLLM style).
12
+
13
+ The logits tensor is updated in-place.
14
+ """
15
+ if p is None:
16
+ if k is None:
17
+ return logits
18
+ # Avoid sorting vocab for top-k only case
19
+ return apply_top_k_only(logits, k)
20
+
21
+ # Need to sort for top-p
22
+ logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
23
+
24
+ if k is not None:
25
+ # Apply top-k first
26
+ vocab_size = logits_sort.size(1)
27
+ # Clamp k to valid range
28
+ k_clamped = k.clamp(1, vocab_size).long()
29
+ top_k_mask_idx = vocab_size - k_clamped # shape: [B]
30
+ # Get the threshold value for each batch
31
+ top_k_thresh = logits_sort.gather(1, top_k_mask_idx.unsqueeze(1))
32
+ top_k_mask = logits_sort < top_k_thresh
33
+ logits_sort.masked_fill_(top_k_mask, float('-inf'))
34
+
35
+ # Apply top-p
36
+ probs_sort = logits_sort.softmax(dim=-1)
37
+ probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) # reuse buffer
38
+ top_p_mask = probs_sum <= (1.0 - p.unsqueeze(1))
39
+ # Ensure at least one token is kept
40
+ top_p_mask[:, -1] = False
41
+ logits_sort.masked_fill_(top_p_mask, float('-inf'))
42
+
43
+ # Re-sort back to original positions
44
+ logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
45
+ return logits
46
+
47
+
48
+ def apply_top_k_only(
49
+ logits: torch.Tensor,
50
+ k: torch.Tensor,
51
+ ) -> torch.Tensor:
52
+ """Apply top-k mask without sorting the entire vocab (vLLM style).
53
+
54
+ This is much faster than sorting for top-k only cases.
55
+ The logits tensor is updated in-place.
56
+ """
57
+ vocab_size = logits.shape[1]
58
+ # Handle cases where k >= vocab_size (no filtering needed)
59
+ no_top_k_mask = (k <= 0) | (k >= vocab_size)
60
+ # Set invalid k to 1 so we can still gather
61
+ k_safe = k.masked_fill(no_top_k_mask, 1).long()
62
+ # NOTE: This int() causes CPU-GPU sync, but torch.topk requires Python int
63
+ max_top_k = int(k_safe.max().clamp(max=vocab_size))
64
+
65
+ # Get top-k values for all batches
66
+ # topk.values has shape [batch_size, max_top_k]
67
+ topk_values = logits.topk(max_top_k, dim=1).values
68
+
69
+ # Convert k to 0-based index: we want the k-th largest value (index k-1)
70
+ # Clamp to valid range for gather
71
+ k_index = (k_safe - 1).clamp(0, max_top_k - 1).unsqueeze(1) # shape: [B, 1]
72
+ # Gather the threshold value (the k-th largest)
73
+ top_k_thresh = topk_values.gather(1, k_index)
74
+
75
+ # For rows with no top-k filtering, set threshold to -inf so nothing gets masked
76
+ top_k_thresh.masked_fill_(no_top_k_mask.unsqueeze(1), float('-inf'))
77
+
78
+ # Mask all values below the threshold
79
+ logits.masked_fill_(logits < top_k_thresh, float('-inf'))
80
+ return logits
81
+
82
+
83
  class Sampler(nn.Module):
84
 
85
  def __init__(self):
 
96
  input_ids: Optional[torch.Tensor] = None,
97
  ):
98
  """
99
+ Sample tokens from logits with optional top-k and top-p filtering.
100
 
101
+ Condition checking is done OUTSIDE the compiled function to avoid
102
+ graph breaks from .any() calls.
 
 
 
 
 
103
  """
 
 
 
 
 
104
  # Apply temperature
105
  logits = logits.float().div_(temperatures.unsqueeze(dim=1))
106
+
107
+ logits = apply_top_k_top_p(
108
+ logits,
109
+ top_ks,
110
+ top_ps,
111
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  probs = torch.softmax(logits, dim=-1)
113
  sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
114
+ return sample_tokens
acestep/third_parts/nano-vllm/pyproject.toml CHANGED
@@ -15,8 +15,6 @@ dependencies = [
15
  "triton-windows>=3.0.0; sys_platform == 'win32'",
16
  "triton>=3.0.0; sys_platform != 'win32'",
17
  "transformers>=4.51.0",
18
- "flash-attn @ https://github.com/sdbds/flash-attention-for-windows/releases/download/2.8.3/flash_attn-2.8.3+cu128torch2.8.0cxx11abiFALSEfullbackward-cp311-cp311-win_amd64.whl; sys_platform == 'win32'",
19
- "flash-attn; sys_platform != 'win32'",
20
  "xxhash",
21
  ]
22
 
 
15
  "triton-windows>=3.0.0; sys_platform == 'win32'",
16
  "triton>=3.0.0; sys_platform != 'win32'",
17
  "transformers>=4.51.0",
 
 
18
  "xxhash",
19
  ]
20
 
profile_inference.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Enhanced profiling script for ACE-Step inference with deep LLM analysis
4
+
5
+ This script helps diagnose why LLM generation is slow by tracking:
6
+ 1. Total tokens generated vs expected throughput (200 tokens/sec baseline)
7
+ 2. Per-iteration timing to detect compilation overhead or slow operations
8
+ 3. Constrained decoding overhead
9
+ 4. CFG overhead (2x forward passes)
10
+ 5. Model forward time vs sampling/processing time
11
+
12
+ Usage:
13
+ python profile_inference.py # Standard profiling with warmup
14
+ python profile_inference.py --no-warmup # Profile first run (includes compilation)
15
+ python profile_inference.py --llm-debug # Deep LLM performance debugging
16
+ python profile_inference.py --detailed # Add cProfile function-level analysis
17
+
18
+ Inference mode options:
19
+ python profile_inference.py --thinking # Enable CoT for code generation
20
+ python profile_inference.py --use-constrained-decoding # Use FSM constrained decoding
21
+ python profile_inference.py --use-cot-metas # Enable LM to generate metadata via CoT
22
+ """
23
+
24
+ import time
25
+ import argparse
26
+ import sys
27
+ import os
28
+ from contextlib import contextmanager
29
+ from collections import defaultdict
30
+ import json
31
+ from typing import Tuple, Dict, Any, List
32
+ from functools import wraps
33
+
34
+ # Add project root to path
35
+ project_root = os.path.abspath(os.path.dirname(__file__))
36
+ if project_root not in sys.path:
37
+ sys.path.insert(0, project_root)
38
+
39
+ import torch
40
+ from acestep.inference import generate_music, GenerationParams, GenerationConfig
41
+ from acestep.handler import AceStepHandler
42
+ from acestep.llm_inference import LLMHandler
43
+
44
+
45
+ class PreciseTimer:
46
+ """High-precision timer with CUDA synchronization for accurate GPU timing"""
47
+
48
+ def __init__(self, device="cuda"):
49
+ self.device = device
50
+ self.timings = defaultdict(list)
51
+ self.enabled = True
52
+
53
+ def sync(self):
54
+ """Synchronize CUDA operations for accurate timing"""
55
+ if self.enabled and self.device.startswith("cuda") and torch.cuda.is_available():
56
+ torch.cuda.synchronize()
57
+
58
+ @contextmanager
59
+ def time(self, name: str):
60
+ """Time a code section with CUDA synchronization"""
61
+ if not self.enabled:
62
+ yield
63
+ return
64
+
65
+ self.sync()
66
+ start = time.perf_counter()
67
+ try:
68
+ yield
69
+ finally:
70
+ self.sync()
71
+ elapsed = time.perf_counter() - start
72
+ self.timings[name].append(elapsed)
73
+
74
+ def get_total(self, name: str) -> float:
75
+ """Get total accumulated time for a section"""
76
+ return sum(self.timings.get(name, []))
77
+
78
+ def get_mean(self, name: str) -> float:
79
+ """Get mean time per call for a section"""
80
+ times = self.timings.get(name, [])
81
+ return sum(times) / len(times) if times else 0.0
82
+
83
+ def get_count(self, name: str) -> int:
84
+ """Get number of calls for a section"""
85
+ return len(self.timings.get(name, []))
86
+
87
+ def get_all(self, name: str) -> List[float]:
88
+ """Get all timing samples for a section"""
89
+ return self.timings.get(name, [])
90
+
91
+
92
+ class LLMDebugger:
93
+ """Track detailed LLM performance metrics to diagnose slow generation"""
94
+
95
+ def __init__(self):
96
+ self.reset()
97
+
98
+ def reset(self):
99
+ """Reset all metrics"""
100
+ self.total_tokens = 0
101
+ self.generation_start = None
102
+ self.generation_end = None
103
+ self.output_text = ""
104
+ self.prompt_length = 0
105
+
106
+ def start(self, prompt_length: int = 0):
107
+ """Mark generation start"""
108
+ self.generation_start = time.perf_counter()
109
+ self.prompt_length = prompt_length
110
+
111
+ def end(self, output_text: str = ""):
112
+ """Mark generation end and store output"""
113
+ self.generation_end = time.perf_counter()
114
+ self.output_text = output_text
115
+
116
+ def set_token_count(self, count: int):
117
+ """Set total token count"""
118
+ self.total_tokens = count
119
+
120
+ def get_throughput(self) -> float:
121
+ """Calculate actual tokens per second"""
122
+ if self.generation_start and self.generation_end and self.total_tokens > 0:
123
+ total_time = self.generation_end - self.generation_start
124
+ if total_time > 0:
125
+ return self.total_tokens / total_time
126
+ return 0.0
127
+
128
+ def print_analysis(self):
129
+ """Print detailed LLM performance analysis"""
130
+ if not self.generation_start or not self.generation_end:
131
+ return
132
+
133
+ print("\n" + "=" * 100)
134
+ print("🔍 LLM PERFORMANCE DEEP DIVE")
135
+ print("=" * 100)
136
+
137
+ total_time = self.generation_end - self.generation_start
138
+ throughput = self.get_throughput()
139
+
140
+ # Basic metrics table
141
+ print(f"\n{'Metric':<40} {'Value':<20} {'Notes'}")
142
+ print("-" * 100)
143
+ print(f"{'Total Tokens Generated:':<40} {self.total_tokens:<20} (new tokens only)")
144
+ print(f"{'Prompt Length (estimate):':<40} {self.prompt_length:<20} (input tokens)")
145
+ print(f"{'Total Generation Time:':<40} {total_time:<20.3f} seconds")
146
+ print(f"{'Measured Throughput:':<40} {throughput:<20.1f} tokens/sec")
147
+ print(f"{'Expected Throughput:':<40} {'200':<20} tokens/sec (baseline)")
148
+
149
+ # Calculate performance gap
150
+ if throughput > 0:
151
+ slowdown = 200.0 / throughput
152
+ efficiency = (throughput / 200.0) * 100
153
+ print(f"{'Performance vs Baseline:':<40} {efficiency:<20.1f}% of expected")
154
+ print(f"{'Slowdown Factor:':<40} {slowdown:<20.2f}x slower")
155
+
156
+ # Analyze generated output
157
+ if self.output_text:
158
+ print(f"\n{'Output Analysis:':<40}")
159
+ print(f"{' Output length:':<40} {len(self.output_text):<20} characters")
160
+
161
+ # Count audio codes
162
+ import re
163
+ code_pattern = r'<\|audio_code_\d+\|>'
164
+ codes = re.findall(code_pattern, self.output_text)
165
+ if codes:
166
+ print(f"{' Audio codes generated:':<40} {len(codes):<20} codes")
167
+ print(f"{' Expected audio duration:':<40} {f'~{len(codes)/5:.1f}s':<20} (5 codes per second)")
168
+ if total_time > 0:
169
+ print(f"{' Time per audio code:':<40} {f'{total_time/len(codes)*1000:.1f}ms':<20}")
170
+
171
+ # Check for CoT section
172
+ if '<think>' in self.output_text and '</think>' in self.output_text:
173
+ cot_start = self.output_text.find('<think>')
174
+ cot_end = self.output_text.find('</think>') + 8
175
+ cot_section = self.output_text[cot_start:cot_end]
176
+ cot_token_est = len(cot_section) // 4
177
+ print(f"{' CoT section tokens (estimate):':<40} {f'~{cot_token_est}':<20}")
178
+
179
+ # Diagnostic guidance
180
+ print("\n" + "=" * 100)
181
+ print("🔧 DIAGNOSTIC GUIDANCE")
182
+ print("=" * 100)
183
+
184
+ if throughput < 50:
185
+ print("\n⚠️ CRITICAL: Throughput is extremely low (<50 tokens/sec)")
186
+ print("\nThis is ~4x slower than expected. Likely causes:")
187
+ print(" 1. ❗ Constrained decoding FSM overhead")
188
+ print(" → Each token triggers FSM state machine validation")
189
+ print(" → Try: set use_constrained_decoding=False in config")
190
+ print(" 2. ❗ CFG with double forward passes")
191
+ print(" → cfg_scale > 1.0 means running model twice per token")
192
+ print(" → Check: params.lm_cfg_scale value")
193
+ print(" 3. ❗ Running in eager mode without compilation")
194
+ print(" → PyTorch should compile kernels after warmup")
195
+ print(" → Check: torch._dynamo.config settings")
196
+
197
+ elif throughput < 100:
198
+ print("\n⚠️ WARNING: Throughput is low (50-100 tokens/sec)")
199
+ print("\nLikely causes:")
200
+ print(" 1. Constrained decoding overhead (~30-50% slowdown expected)")
201
+ print(" 2. CFG enabled (2x compute per token if cfg_scale > 1.0)")
202
+ print(" 3. Small model or inefficient GPU utilization")
203
+
204
+ elif throughput < 150:
205
+ print("\n⚠️ Throughput is below baseline but acceptable (100-150 tokens/sec)")
206
+ print("\nMinor overhead from:")
207
+ print(" - Constrained decoding: ~20-30% overhead")
208
+ print(" - Profiling instrumentation: ~5-10% overhead")
209
+
210
+ else:
211
+ print(f"\n✓ Throughput is good ({throughput:.1f} tokens/sec)")
212
+ print(" Performance is within acceptable range")
213
+
214
+
215
+ # Global instances
216
+ timer = None
217
+ llm_debugger = None
218
+
219
+
220
+ def wrap_method_with_timing(obj, method_name: str, timing_key: str):
221
+ """Wrap a method with timing instrumentation"""
222
+ original_method = getattr(obj, method_name)
223
+
224
+ @wraps(original_method)
225
+ def timed_wrapper(*args, **kwargs):
226
+ with timer.time(timing_key):
227
+ return original_method(*args, **kwargs)
228
+
229
+ setattr(obj, method_name, timed_wrapper)
230
+ return original_method
231
+
232
+
233
+ def wrap_llm_with_debug_tracking(llm_handler):
234
+ """Wrap LLM generation with detailed performance tracking"""
235
+ original_method = llm_handler.generate_with_stop_condition
236
+
237
+ @wraps(original_method)
238
+ def debug_wrapper(*args, **kwargs):
239
+ # Estimate prompt length
240
+ caption = kwargs.get('caption', args[0] if len(args) > 0 else "")
241
+ lyrics = kwargs.get('lyrics', args[1] if len(args) > 1 else "")
242
+ prompt_estimate = len(caption) + len(lyrics)
243
+ prompt_tokens_estimate = prompt_estimate // 4
244
+
245
+ # Start tracking
246
+ llm_debugger.reset()
247
+ llm_debugger.start(prompt_length=prompt_tokens_estimate)
248
+
249
+ # Call original with timing
250
+ with timer.time('llm_inference'):
251
+ result = original_method(*args, **kwargs)
252
+
253
+ # Extract and analyze output
254
+ output_text = ""
255
+ if isinstance(result, tuple) and len(result) >= 2:
256
+ if isinstance(result[1], list):
257
+ # Batch mode
258
+ output_text = "".join(result[1])
259
+ else:
260
+ # Single mode
261
+ cot_output = ""
262
+ if isinstance(result[0], dict):
263
+ for v in result[0].values():
264
+ if isinstance(v, str):
265
+ cot_output += v
266
+ output_text = cot_output + str(result[1])
267
+
268
+ # Count tokens
269
+ import re
270
+ code_pattern = r'<\|audio_code_\d+\|>'
271
+ codes = re.findall(code_pattern, output_text)
272
+ remaining_text = re.sub(code_pattern, '', output_text)
273
+ cot_tokens_estimate = len(remaining_text) // 4
274
+ total_tokens = len(codes) + cot_tokens_estimate
275
+
276
+ llm_debugger.set_token_count(total_tokens)
277
+ llm_debugger.end(output_text)
278
+
279
+ return result
280
+
281
+ llm_handler.generate_with_stop_condition = debug_wrapper
282
+ return original_method
283
+
284
+
285
+ def instrument_handlers(dit_handler, llm_handler, enable_llm_debug=False):
286
+ """Add timing instrumentation to handler methods"""
287
+ originals = {}
288
+
289
+ # Instrument LLM
290
+ if llm_handler and llm_handler.llm_initialized:
291
+ if enable_llm_debug:
292
+ originals['llm_generate'] = wrap_llm_with_debug_tracking(llm_handler)
293
+ else:
294
+ originals['llm_generate'] = wrap_method_with_timing(
295
+ llm_handler, 'generate_with_stop_condition', 'llm_inference'
296
+ )
297
+
298
+ # Instrument DiT handler
299
+ originals['dit_prepare'] = wrap_method_with_timing(
300
+ dit_handler, 'prepare_batch_data', 'prepare_batch_data'
301
+ )
302
+ originals['dit_generate'] = wrap_method_with_timing(
303
+ dit_handler, 'service_generate', 'dit_inference'
304
+ )
305
+ originals['dit_decode'] = wrap_method_with_timing(
306
+ dit_handler, 'tiled_decode', 'vae_decode'
307
+ )
308
+
309
+ return originals
310
+
311
+
312
+ def restore_handlers(dit_handler, llm_handler, originals):
313
+ """Restore original handler methods after profiling"""
314
+ if llm_handler and 'llm_generate' in originals:
315
+ llm_handler.generate_with_stop_condition = originals['llm_generate']
316
+
317
+ dit_handler.prepare_batch_data = originals['dit_prepare']
318
+ dit_handler.service_generate = originals['dit_generate']
319
+ dit_handler.tiled_decode = originals['dit_decode']
320
+
321
+
322
+ def print_profiling_results(total_time: float, show_llm_debug: bool = False):
323
+ """Print comprehensive profiling results with performance insights"""
324
+ print("\n" + "=" * 100)
325
+ print("🎯 PROFILING RESULTS")
326
+ print("=" * 100)
327
+
328
+ # Define timing categories
329
+ model_sections = {
330
+ 'llm_inference': 'LLM Inference (5Hz Language Model)',
331
+ 'dit_inference': 'DiT Inference (Diffusion Transformer)',
332
+ 'vae_decode': 'VAE Decode (Audio Decoder)',
333
+ }
334
+
335
+ non_model_sections = {
336
+ 'prepare_batch_data': 'Prepare Batch Data (embedding, formatting)',
337
+ }
338
+
339
+ # Calculate totals
340
+ model_time = sum(timer.get_total(k) for k in model_sections.keys())
341
+ non_model_time = sum(timer.get_total(k) for k in non_model_sections.keys())
342
+ other_time = total_time - model_time - non_model_time
343
+
344
+ # Print summary table
345
+ print(f"\n{'CATEGORY':<50} {'TIME (s)':<12} {'%':<8} {'CALLS':<8}")
346
+ print("-" * 100)
347
+
348
+ # Model time breakdown
349
+ print(f"\n{'🤖 MODEL TIME (Total)':<50} {model_time:<12.3f} {100*model_time/total_time:>6.1f}% {'':<8}")
350
+ for key, desc in model_sections.items():
351
+ t = timer.get_total(key)
352
+ c = timer.get_count(key)
353
+ if c > 0:
354
+ mean = timer.get_mean(key)
355
+ pct = 100 * t / total_time
356
+ print(f" {'├─ ' + desc:<48} {t:<12.3f} {pct:>6.1f}% {c:<8} (avg: {mean:.3f}s)")
357
+
358
+ # Non-model time breakdown
359
+ print(f"\n{'⚙️ NON-MODEL TIME (Total)':<50} {non_model_time:<12.3f} {100*non_model_time/total_time:>6.1f}% {'':<8}")
360
+ for key, desc in non_model_sections.items():
361
+ t = timer.get_total(key)
362
+ c = timer.get_count(key)
363
+ if c > 0:
364
+ mean = timer.get_mean(key)
365
+ pct = 100 * t / total_time
366
+ print(f" {'├─ ' + desc:<48} {t:<12.3f} {pct:>6.1f}% {c:<8} (avg: {mean:.3f}s)")
367
+
368
+ # Other time
369
+ if other_time > 0.01:
370
+ pct = 100 * other_time / total_time
371
+ print(f"\n{'📦 OTHER TIME (I/O, overhead, audio save)':<50} {other_time:<12.3f} {pct:>6.1f}% {'':<8}")
372
+
373
+ print(f"\n{'📊 TOTAL TIME':<50} {total_time:<12.3f} {'100.0%':>6} {'':<8}")
374
+
375
+ # Show LLM detailed analysis if enabled
376
+ if show_llm_debug:
377
+ llm_debugger.print_analysis()
378
+
379
+ # Performance insights
380
+ print("\n" + "=" * 100)
381
+ print("💡 PERFORMANCE INSIGHTS")
382
+ print("=" * 100)
383
+
384
+ llm_t = timer.get_total('llm_inference')
385
+ dit_t = timer.get_total('dit_inference')
386
+ vae_t = timer.get_total('vae_decode')
387
+ prep_t = timer.get_total('prepare_batch_data')
388
+
389
+ # Model time insights
390
+ if model_time > 0:
391
+ print(f"\n✓ Model operations: {model_time:.3f}s ({100*model_time/total_time:.1f}% of total)")
392
+
393
+ if llm_t > 0:
394
+ print(f" - LLM: {llm_t:.3f}s ({100*llm_t/model_time:.1f}% of model time)")
395
+ if dit_t > 0:
396
+ print(f" - DiT: {dit_t:.3f}s ({100*dit_t/model_time:.1f}% of model time)")
397
+ if vae_t > 0:
398
+ print(f" - VAE: {vae_t:.3f}s ({100*vae_t/model_time:.1f}% of model time)")
399
+
400
+ # LLM bottleneck analysis
401
+ if llm_t > dit_t and llm_t > 5.0:
402
+ print(f"\n⚠️ LLM IS THE BOTTLENECK: {llm_t:.3f}s ({100*llm_t/total_time:.1f}% of total)")
403
+ print(f"\n Possible causes:")
404
+ print(f" 1. Generating too many tokens → use --llm-debug to verify")
405
+ print(f" 2. Constrained decoding overhead → FSM validation per token")
406
+ print(f" 3. CFG overhead → cfg_scale > 1.0 = 2x forward passes")
407
+ print(f" 4. First-token latency → warmup should help")
408
+ print(f" 5. KV cache inefficiency → should be ~5-10ms/token")
409
+
410
+ # Non-model insights
411
+ if non_model_time / total_time > 0.1:
412
+ print(f"\n⚠️ Non-model operations: {non_model_time:.3f}s ({100*non_model_time/total_time:.1f}%)")
413
+ if prep_t > 0.1:
414
+ print(f" - Batch preparation: {prep_t:.3f}s")
415
+
416
+ # I/O overhead
417
+ if other_time / total_time > 0.2:
418
+ print(f"\n⚠️ Overhead/I/O: {other_time:.3f}s ({100*other_time/total_time:.1f}%)")
419
+
420
+ # Recommendations
421
+ print("\n" + "=" * 100)
422
+ print("🚀 OPTIMIZATION RECOMMENDATIONS")
423
+ print("=" * 100)
424
+
425
+ if llm_t > dit_t * 2:
426
+ print("\n🎯 Priority: Optimize LLM")
427
+ print(" 1. Run: python profile_inference.py --llm-debug")
428
+ print(" → Shows exact token count and throughput")
429
+ print(" 2. Check constrained decoding overhead")
430
+ print(" 3. Check CFG scaling (lm_cfg_scale parameter)")
431
+ print(" 4. Profile nanovllm engine step() timing")
432
+ print(" 5. Compare vllm vs transformers backends")
433
+
434
+
435
+ def run_profiled_generation(dit_handler, llm_handler, params, config,
436
+ enable_cprofile=False, enable_llm_debug=False):
437
+ """Execute generation with full profiling instrumentation"""
438
+ # Instrument handlers
439
+ originals = instrument_handlers(dit_handler, llm_handler, enable_llm_debug)
440
+
441
+ try:
442
+ print("\n[Profiling] Starting generation...")
443
+ timer.sync()
444
+ total_start = time.perf_counter()
445
+
446
+ # Optional cProfile
447
+ prof = None
448
+ if enable_cprofile:
449
+ import cProfile
450
+ prof = cProfile.Profile()
451
+ prof.enable()
452
+
453
+ # Run generation
454
+ result = generate_music(dit_handler, llm_handler, params, config, save_dir="./")
455
+
456
+ # Stop timing
457
+ timer.sync()
458
+ total_time = time.perf_counter() - total_start
459
+
460
+ # Save cProfile if enabled
461
+ if enable_cprofile and prof:
462
+ prof.disable()
463
+
464
+ import pstats
465
+ import io
466
+
467
+ output_file = "profile_cprofile_detailed.txt"
468
+ with open(output_file, 'w') as f:
469
+ ps = pstats.Stats(prof, stream=f)
470
+ ps.sort_stats('cumulative')
471
+ ps.print_stats(100)
472
+
473
+ # Print top functions
474
+ print("\n" + "=" * 100)
475
+ print("📊 TOP 20 FUNCTIONS BY CUMULATIVE TIME (cProfile)")
476
+ print("=" * 100)
477
+ s = io.StringIO()
478
+ ps = pstats.Stats(prof, stream=s)
479
+ ps.sort_stats('cumulative')
480
+ ps.print_stats(20)
481
+ print(s.getvalue())
482
+
483
+ print(f"\nFull report: {output_file}")
484
+
485
+ # Print results
486
+ print_profiling_results(total_time, show_llm_debug=enable_llm_debug)
487
+
488
+ return result, total_time
489
+
490
+ finally:
491
+ restore_handlers(dit_handler, llm_handler, originals)
492
+
493
+
494
+ def load_example_config(example_file: str) -> Tuple[GenerationParams, GenerationConfig]:
495
+ """Load configuration from example JSON file"""
496
+ try:
497
+ with open(example_file, 'r', encoding='utf-8') as f:
498
+ data = json.load(f)
499
+
500
+ params = GenerationParams(
501
+ caption=data.get('caption', ''),
502
+ lyrics=data.get('lyrics', ''),
503
+ bpm=data.get('bpm'),
504
+ keyscale=data.get('keyscale', ''),
505
+ timesignature=data.get('timesignature', ''),
506
+ vocal_language=data.get('language', 'unknown'),
507
+ duration=data.get('duration'),
508
+ thinking=data.get('think', False),
509
+ inference_steps=data.get('inference_steps', 8),
510
+ seed=data.get('seed', 42),
511
+ )
512
+
513
+ config = GenerationConfig(batch_size=data.get('batch_size', 1), seeds=[42])
514
+
515
+ return params, config
516
+
517
+ except Exception as e:
518
+ print(f" ❌ Failed to load: {e}")
519
+ return None, None
520
+
521
+
522
+ def main():
523
+ global timer, llm_debugger
524
+
525
+ parser = argparse.ArgumentParser(
526
+ description="Profile ACE-Step inference with LLM debugging"
527
+ )
528
+ parser.add_argument("--checkpoint-dir", type=str, default="./checkpoints")
529
+ parser.add_argument("--config-path", type=str, default="acestep-v15-turbo-rl")
530
+ parser.add_argument("--device", type=str, default="cuda")
531
+ parser.add_argument("--lm-model", type=str, default="acestep-5Hz-lm-0.6B-v3")
532
+ parser.add_argument("--lm-backend", type=str, default="vllm")
533
+ parser.add_argument("--no-warmup", action="store_true")
534
+ parser.add_argument("--detailed", action="store_true")
535
+ parser.add_argument("--llm-debug", action="store_true",
536
+ help="Enable deep LLM debugging (token count, throughput)")
537
+ parser.add_argument("--example", type=str, default="example_05.json")
538
+
539
+ # Inference mode parameters
540
+ parser.add_argument("--thinking", action="store_true",
541
+ help="Enable CoT reasoning for LM to generate audio codes")
542
+ parser.add_argument("--use-constrained-decoding", action="store_true",
543
+ help="Use FSM-based constrained decoding for meta generation")
544
+ parser.add_argument("--use-cot-metas", action="store_true",
545
+ help="Enable LLM to generate music metadata via CoT reasoning")
546
+
547
+ args = parser.parse_args()
548
+
549
+ # Initialize
550
+ timer = PreciseTimer(device=args.device)
551
+ llm_debugger = LLMDebugger()
552
+
553
+ print("=" * 100)
554
+ print("🎵 ACE-Step Inference Profiler (LLM Performance Analysis)")
555
+ print("=" * 100)
556
+ print(f"\nConfiguration:")
557
+ print(f" Device: {args.device}")
558
+ print(f" LLM Backend: {args.lm_backend}")
559
+ print(f" LLM Debug: {'Enabled' if args.llm_debug else 'Disabled'}")
560
+ print(f" Warmup: {'Disabled' if args.no_warmup else 'Enabled'}")
561
+ print(f"\nInference Mode:")
562
+ print(f" Thinking (CoT): {'Enabled' if args.thinking else 'Disabled'}")
563
+ print(f" Constrained Decoding: {'Enabled' if args.use_constrained_decoding else 'Disabled'}")
564
+ print(f" Use CoT for Metas: {'Enabled' if args.use_cot_metas else 'Disabled'}")
565
+
566
+ # Initialize models
567
+ print(f"\nInitializing models...")
568
+
569
+ dit_handler = AceStepHandler()
570
+ llm_handler = LLMHandler()
571
+
572
+ print(" 🎹 Initializing DiT...")
573
+ status_dit, success_dit = dit_handler.initialize_service(
574
+ project_root=project_root,
575
+ config_path=args.config_path,
576
+ device=args.device,
577
+ use_flash_attention=True,
578
+ )
579
+ if not success_dit:
580
+ print(f" ❌ Failed: {status_dit}")
581
+ sys.exit(1)
582
+ print(f" ✓ DiT ready")
583
+
584
+ print(" 🧠 Initializing LLM...")
585
+ if args.thinking or args.use_cot_metas:
586
+ status_llm, success_llm = llm_handler.initialize(
587
+ checkpoint_dir=args.checkpoint_dir,
588
+ lm_model_path=args.lm_model,
589
+ backend=args.lm_backend,
590
+ device=args.device,
591
+ )
592
+ if success_llm:
593
+ print(f" ✓ LLM ready ({args.lm_backend})")
594
+ else:
595
+ print(f" ⚠ Failed: {status_llm}")
596
+ else:
597
+ print(f" ✓ LLM not initialized (thinking or use_cot_metas is disabled)")
598
+
599
+ # Load example
600
+ example_file = os.path.join(project_root, "examples", "text2music", args.example)
601
+ if not os.path.exists(example_file):
602
+ print(f"\n❌ Not found: {example_file}")
603
+ sys.exit(1)
604
+
605
+ print(f"\n📄 Loading: {args.example}")
606
+ params, config = load_example_config(example_file)
607
+
608
+ if not params or not config:
609
+ print("❌ Failed to load config")
610
+ sys.exit(1)
611
+
612
+ print(f" Caption: {params.caption[:60]}...")
613
+ print(f" Batch: {config.batch_size}, Steps: {params.inference_steps}, LLM: {params.thinking}")
614
+
615
+ # Warmup
616
+ if not args.no_warmup:
617
+ print("\n" + "=" * 100)
618
+ print("🔥 WARMUP RUN")
619
+ print("=" * 100)
620
+
621
+ warmup_params = GenerationParams(
622
+ caption=params.caption,
623
+ lyrics=params.lyrics,
624
+ bpm=params.bpm,
625
+ keyscale=params.keyscale,
626
+ timesignature=params.timesignature,
627
+ vocal_language=params.vocal_language,
628
+ duration=params.duration,
629
+ thinking=args.thinking,
630
+ use_cot_metas=args.use_cot_metas,
631
+ inference_steps=params.inference_steps,
632
+ seed=params.seed,
633
+ )
634
+ warmup_config = GenerationConfig(batch_size=1, seeds=[42])
635
+ warmup_config.use_constrained_decoding = args.use_constrained_decoding
636
+
637
+ warmup_start = time.perf_counter()
638
+ warmup_result = generate_music(dit_handler, llm_handler, warmup_params, warmup_config, save_dir="./")
639
+ warmup_time = time.perf_counter() - warmup_start
640
+
641
+ print(f"\n✓ Warmup: {warmup_time:.2f}s")
642
+ if not warmup_result.success:
643
+ print(f"⚠️ Warning: {warmup_result.error}")
644
+
645
+ # Reset
646
+ timer = PreciseTimer(device=args.device)
647
+ llm_debugger = LLMDebugger()
648
+
649
+ # Profiling run
650
+ print("\n" + "=" * 100)
651
+ print("⏱️ PROFILING RUN")
652
+ print("=" * 100)
653
+
654
+ # Apply inference mode settings
655
+ config.use_constrained_decoding = args.use_constrained_decoding
656
+ # Override thinking and use_cot_metas parameters if specified via CLI
657
+ if args.thinking:
658
+ params.thinking = True
659
+ if args.use_cot_metas:
660
+ params.use_cot_metas = True
661
+
662
+ result, total_time = run_profiled_generation(
663
+ dit_handler, llm_handler, params, config,
664
+ enable_cprofile=args.detailed,
665
+ enable_llm_debug=args.llm_debug
666
+ )
667
+
668
+ if not result.success:
669
+ print(f"\n❌ Failed: {result.error}")
670
+ sys.exit(1)
671
+
672
+ print(f"\n✅ Success! Generated {len(result.audios)} audio file(s)")
673
+
674
+ # Final tips
675
+ if args.detailed:
676
+ print("\n💡 Check profile_cprofile_detailed.txt for function-level analysis")
677
+ elif not args.llm_debug:
678
+ print("\n💡 Run with --llm-debug to see LLM token count and throughput analysis")
679
+
680
+
681
+ if __name__ == "__main__":
682
+ main()