ChuxiJ commited on
Commit
ba6c5ba
·
1 Parent(s): 7642e62

fix api server bugs

Browse files
Files changed (1) hide show
  1. acestep/api_server.py +200 -464
acestep/api_server.py CHANGED
@@ -44,6 +44,12 @@ from acestep.constants import (
44
  DEFAULT_DIT_INSTRUCTION,
45
  DEFAULT_LM_INSTRUCTION,
46
  )
 
 
 
 
 
 
47
 
48
 
49
  JobStatus = Literal["queued", "running", "succeeded", "failed"]
@@ -387,6 +393,10 @@ def create_app() -> FastAPI:
387
  app.state.executor = executor
388
  app.state.job_store = store
389
  app.state._python_executable = sys.executable
 
 
 
 
390
 
391
  async def _ensure_initialized() -> None:
392
  h: AceStepHandler = app.state.handler
@@ -443,131 +453,10 @@ def create_app() -> FastAPI:
443
  job_store.mark_running(job_id)
444
 
445
  def _blocking_generate() -> Dict[str, Any]:
446
- def _normalize_optional_int(v: Any) -> Optional[int]:
447
- if v is None:
448
- return None
449
- try:
450
- iv = int(v)
451
- except Exception:
452
- return None
453
- return None if iv == 0 else iv
454
-
455
- def _normalize_optional_float(v: Any) -> Optional[float]:
456
- if v is None:
457
- return None
458
- try:
459
- fv = float(v)
460
- except Exception:
461
- return None
462
- # gradio treats 1.0 as disabled for top_p
463
- return None if fv >= 1.0 else fv
464
-
465
- def _maybe_fill_from_metadata(current: GenerateMusicRequest, meta: Dict[str, Any]) -> tuple[Optional[int], str, str, Optional[float]]:
466
- def _parse_first_float(v: Any) -> Optional[float]:
467
- if v is None:
468
- return None
469
- if isinstance(v, (int, float)):
470
- return float(v)
471
- s = str(v).strip()
472
- if not s or s.upper() == "N/A":
473
- return None
474
- try:
475
- return float(s)
476
- except Exception:
477
- pass
478
- m = re.search(r"[-+]?\d*\.?\d+", s)
479
- if not m:
480
- return None
481
- try:
482
- return float(m.group(0))
483
- except Exception:
484
- return None
485
-
486
- def _parse_first_int(v: Any) -> Optional[int]:
487
- fv = _parse_first_float(v)
488
- if fv is None:
489
- return None
490
- try:
491
- return int(round(fv))
492
- except Exception:
493
- return None
494
-
495
- # Fill only when user did not provide values
496
- bpm_val = current.bpm
497
- if bpm_val is None:
498
- m = meta.get("bpm")
499
- parsed = _parse_first_int(m)
500
- if parsed is not None and parsed > 0:
501
- bpm_val = parsed
502
-
503
- key_scale_val = current.key_scale
504
- if not key_scale_val:
505
- m = meta.get("keyscale", meta.get("key_scale", ""))
506
- if m not in (None, "", "N/A"):
507
- key_scale_val = str(m)
508
-
509
- time_sig_val = current.time_signature
510
- if not time_sig_val:
511
- m = meta.get("timesignature", meta.get("time_signature", ""))
512
- if m not in (None, "", "N/A"):
513
- time_sig_val = str(m)
514
-
515
- dur_val = current.audio_duration
516
- if dur_val is None:
517
- m = meta.get("duration", meta.get("audio_duration"))
518
- parsed = _parse_first_float(m)
519
- if parsed is not None:
520
- dur_val = float(parsed)
521
- if dur_val <= 0:
522
- dur_val = None
523
-
524
- # Avoid truncating lyrical songs when LM predicts a very short duration.
525
- # (Users can still force a short duration by explicitly setting `audio_duration`.)
526
- if dur_val is not None and (current.lyrics or "").strip():
527
- min_dur = float(os.getenv("ACESTEP_LM_MIN_DURATION_SECONDS", "30"))
528
- if dur_val < min_dur:
529
- dur_val = None
530
-
531
- return bpm_val, key_scale_val, time_sig_val, dur_val
532
-
533
- def _estimate_duration_from_lyrics(lyrics: str) -> Optional[float]:
534
- lyrics = (lyrics or "").strip()
535
- if not lyrics:
536
- return None
537
-
538
- # Best-effort heuristic: singing rate ~ 2.2 words/sec for English-like lyrics.
539
- # For languages without spaces, fall back to non-space char count.
540
- words = re.findall(r"[A-Za-z0-9']+", lyrics)
541
- if len(words) >= 8:
542
- words_per_sec = float(os.getenv("ACESTEP_LYRICS_WORDS_PER_SEC", "2.2"))
543
- est = len(words) / max(0.5, words_per_sec)
544
- else:
545
- non_space = len(re.sub(r"\s+", "", lyrics))
546
- chars_per_sec = float(os.getenv("ACESTEP_LYRICS_CHARS_PER_SEC", "12"))
547
- est = non_space / max(4.0, chars_per_sec)
548
-
549
- min_dur = float(os.getenv("ACESTEP_LYRICS_MIN_DURATION_SECONDS", "45"))
550
- max_dur = float(os.getenv("ACESTEP_LYRICS_MAX_DURATION_SECONDS", "180"))
551
- return float(min(max(est, min_dur), max_dur))
552
-
553
- def _normalize_metas(meta: Dict[str, Any]) -> Dict[str, Any]:
554
- """Ensure a stable `metas` dict (keys always present)."""
555
- meta = meta or {}
556
- out: Dict[str, Any] = dict(meta)
557
-
558
- # Normalize key aliases
559
- if "keyscale" not in out and "key_scale" in out:
560
- out["keyscale"] = out.get("key_scale")
561
- if "timesignature" not in out and "time_signature" in out:
562
- out["timesignature"] = out.get("time_signature")
563
-
564
- # Ensure required keys exist
565
- for k in ["bpm", "duration", "genres", "keyscale", "timesignature"]:
566
- if out.get(k) in (None, ""):
567
- out[k] = "N/A"
568
- return out
569
-
570
  def _ensure_llm_ready() -> None:
 
571
  with app.state._llm_init_lock:
572
  initialized = getattr(app.state, "_llm_initialized", False)
573
  had_error = getattr(app.state, "_llm_init_error", None)
@@ -597,269 +486,207 @@ def create_app() -> FastAPI:
597
  else:
598
  app.state._llm_initialized = True
599
 
600
- # Optional: generate 5Hz LM codes server-side
601
- audio_code_string = req.audio_code_string
602
- bpm_val = req.bpm
603
- key_scale_val = req.key_scale
604
- time_sig_val = req.time_signature
605
- audio_duration_val = req.audio_duration
606
-
607
- thinking = bool(getattr(req, "thinking", False))
608
 
609
- print(
610
- "[api_server] parsed req: "
611
- f"thinking={thinking}, caption_len={len((req.caption or '').strip())}, lyrics_len={len((req.lyrics or '').strip())}, "
612
- f"bpm={req.bpm}, audio_duration={req.audio_duration}, key_scale={req.key_scale!r}, time_signature={req.time_signature!r}"
613
- )
614
 
615
- # If LM-generated code hints are used, a too-strong cover strength can suppress lyric/vocal conditioning.
616
- # We keep backward compatibility: only auto-adjust when user didn't override (still at default 1.0).
617
- audio_cover_strength_val = float(req.audio_cover_strength)
 
 
618
 
619
- lm_meta: Optional[Dict[str, Any]] = None
 
 
620
 
621
- sample_mode = bool(getattr(req, "sample_mode", False))
622
- if sample_mode:
 
 
 
 
 
 
 
 
623
  _ensure_llm_ready()
624
  if getattr(app.state, "_llm_init_error", None):
625
  raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
626
 
 
 
 
 
 
 
 
 
 
 
627
  sample_metadata, sample_status = llm.understand_audio_from_codes(
628
  audio_codes="NO USER INPUT",
629
- temperature=float(getattr(req, "lm_temperature", _LM_DEFAULT_TEMPERATURE)),
630
- cfg_scale=max(1.0, float(getattr(req, "lm_cfg_scale", _LM_DEFAULT_CFG_SCALE))),
631
- negative_prompt=str(getattr(req, "lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
632
- top_k=_normalize_optional_int(getattr(req, "lm_top_k", None)),
633
- top_p=_normalize_optional_float(getattr(req, "lm_top_p", None)),
634
- repetition_penalty=float(getattr(req, "lm_repetition_penalty", 1.0)),
635
- use_constrained_decoding=bool(getattr(req, "constrained_decoding", True)),
636
- constrained_decoding_debug=bool(getattr(req, "constrained_decoding_debug", False)),
637
  )
638
 
639
  if not sample_metadata or str(sample_status).startswith("❌"):
640
  raise RuntimeError(f"Sample generation failed: {sample_status}")
641
 
642
- req.caption = str(sample_metadata.get("caption", "") or "")
643
- req.lyrics = str(sample_metadata.get("lyrics", "") or "")
644
- req.bpm = _to_int(sample_metadata.get("bpm"), req.bpm)
645
-
646
- sample_keyscale = sample_metadata.get("keyscale", sample_metadata.get("key_scale", ""))
647
- if sample_keyscale:
648
- req.key_scale = str(sample_keyscale)
649
-
650
- sample_timesig = sample_metadata.get("timesignature", sample_metadata.get("time_signature", ""))
651
- if sample_timesig:
652
- req.time_signature = str(sample_timesig)
653
-
654
- sample_duration = _to_float(sample_metadata.get("duration"), None)
655
- if sample_duration is not None and sample_duration > 0:
656
- req.audio_duration = sample_duration
657
-
658
- lm_meta = sample_metadata
659
-
660
- fallback_values: Dict[str, Any] = {}
661
- default_bpm = _to_int(os.getenv("ACESTEP_SAMPLE_DEFAULT_BPM", "120"), 120) or 120
662
- default_duration = _to_float(os.getenv("ACESTEP_SAMPLE_DEFAULT_DURATION_SECONDS", "120"), 120.0) or 120.0
663
- default_key = os.getenv("ACESTEP_SAMPLE_DEFAULT_KEY", "C Major") or "C Major"
664
- default_timesig = os.getenv("ACESTEP_SAMPLE_DEFAULT_TIMESIGNATURE", "4/4") or "4/4"
665
-
666
- if req.bpm is None or req.bpm <= 0:
667
- req.bpm = default_bpm
668
- fallback_values["bpm"] = default_bpm
669
-
670
- if req.audio_duration is None or req.audio_duration <= 0:
671
- req.audio_duration = default_duration
672
- fallback_values["audio_duration"] = default_duration
673
-
674
- if not (req.key_scale or "").strip():
675
- req.key_scale = default_key
676
- fallback_values["key_scale"] = default_key
677
-
678
- if not (req.time_signature or "").strip():
679
- req.time_signature = default_timesig
680
- fallback_values["time_signature"] = default_timesig
681
-
682
- if fallback_values:
683
- print("[api_server] sample mode fallback values:", fallback_values)
684
-
685
- print(
686
- "[api_server] sample mode metadata:",
687
- {
688
- "caption_len": len(req.caption),
689
- "lyrics_len": len(req.lyrics),
690
- "bpm": req.bpm,
691
- "audio_duration": req.audio_duration,
692
- "key_scale": req.key_scale,
693
- "time_signature": req.time_signature,
694
- },
695
- )
696
 
697
- # Determine effective batch size (used for per-sample LM code diversity)
698
- effective_batch_size = req.batch_size
699
- if effective_batch_size is None:
700
- try:
701
- effective_batch_size = int(getattr(h, "batch_size", 1))
702
- except Exception:
703
- effective_batch_size = 1
704
- effective_batch_size = max(1, int(effective_batch_size))
705
-
706
- has_codes = bool(audio_code_string and str(audio_code_string).strip())
707
- need_lm_codes = bool(thinking) and (not has_codes)
708
-
709
- use_constrained_decoding = bool(getattr(req, "constrained_decoding", True))
710
- constrained_decoding_debug = bool(getattr(req, "constrained_decoding_debug", False))
711
- use_cot_caption = bool(getattr(req, "use_cot_caption", True))
712
- use_cot_language = bool(getattr(req, "use_cot_language", True))
713
- is_format_caption = bool(getattr(req, "is_format_caption", False))
714
-
715
- # pass them into constrained decoding so LM injects them directly
716
- # (i.e. does not re-infer / override those fields).
717
- user_metadata: Dict[str, Optional[str]] = {}
718
-
719
- def _set_user_meta(field: str, value: Optional[Any]) -> None:
720
- if value is None:
721
- return
722
- s = str(value).strip()
723
- if not s or s.upper() == "N/A":
724
- return
725
- user_metadata[field] = s
726
-
727
- _set_user_meta("bpm", int(bpm_val) if bpm_val is not None else None)
728
- _set_user_meta("duration", float(audio_duration_val) if audio_duration_val is not None else None)
729
- _set_user_meta("keyscale", key_scale_val if (key_scale_val or "").strip() else None)
730
- _set_user_meta("timesignature", time_sig_val if (time_sig_val or "").strip() else None)
731
-
732
- def _has_meta(field: str) -> bool:
733
- v = user_metadata.get(field)
734
- return bool((v or "").strip())
735
-
736
- need_lm_metas = not (
737
- _has_meta("bpm")
738
- and _has_meta("duration")
739
- and _has_meta("keyscale")
740
- and _has_meta("timesignature")
741
  )
742
 
743
- lm_target_duration: Optional[float] = None
744
- if need_lm_codes:
745
- # If user specified a duration, constrain codes generation length accordingly.
746
- if audio_duration_val is not None and float(audio_duration_val) > 0:
747
- lm_target_duration = float(audio_duration_val)
748
-
749
- print(
750
- "[api_server] LM调用参数: "
751
- f"user_metadata_keys={sorted(user_metadata.keys())}, target_duration={lm_target_duration}, "
752
- f"need_lm_codes={need_lm_codes}, need_lm_metas={need_lm_metas}, "
753
- f"use_constrained_decoding={use_constrained_decoding}, use_cot_caption={use_cot_caption}, "
754
- f"use_cot_language={use_cot_language}, is_format_caption={is_format_caption}"
 
 
 
 
 
 
755
  )
 
 
 
756
 
757
- if need_lm_metas or need_lm_codes:
758
- _ensure_llm_ready()
759
 
760
- if getattr(app.state, "_llm_init_error", None):
761
- # If codes generation is required, fail hard.
762
- if need_lm_codes:
763
- raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
764
- # Otherwise, skip LM best-effort (fallback to default/meta-less behavior)
765
- else:
766
- lm_infer = "llm_dit" if need_lm_codes else "dit"
767
-
768
- def _lm_call() -> tuple[Dict[str, Any], str, str]:
769
- return llm.generate_with_stop_condition(
770
- caption=req.caption,
771
- lyrics=req.lyrics,
772
- infer_type=lm_infer,
773
- temperature=float(req.lm_temperature),
774
- cfg_scale=max(1.0, float(req.lm_cfg_scale)),
775
- negative_prompt=str(req.lm_negative_prompt or "NO USER INPUT"),
776
- top_k=_normalize_optional_int(req.lm_top_k),
777
- top_p=_normalize_optional_float(req.lm_top_p),
778
- repetition_penalty=float(req.lm_repetition_penalty),
779
- target_duration=lm_target_duration,
780
- user_metadata=(user_metadata or None),
781
- use_constrained_decoding=use_constrained_decoding,
782
- constrained_decoding_debug=constrained_decoding_debug,
783
- use_cot_caption=use_cot_caption,
784
- use_cot_language=use_cot_language,
785
- is_format_caption=is_format_caption,
786
- )
787
-
788
- meta, codes, status = _lm_call()
789
- lm_meta = meta
790
-
791
- if need_lm_codes:
792
- if not codes:
793
- raise RuntimeError(f"5Hz LM generation failed: {status}")
794
-
795
- # LM once per job; rely on DiT seeds for batch diversity.
796
- # For convenience, replicate the same codes across the batch.
797
- if effective_batch_size > 1:
798
- audio_code_string = [codes] * effective_batch_size
799
- else:
800
- audio_code_string = codes
801
-
802
- # Fill only missing fields (user-provided values win)
803
- bpm_val, key_scale_val, time_sig_val, audio_duration_val = _maybe_fill_from_metadata(req, meta)
804
-
805
- # If user provided lyrics but LM didn't provide a usable duration, estimate a longer duration.
806
- if audio_duration_val is None and (req.audio_duration is None):
807
- est = _estimate_duration_from_lyrics(req.lyrics)
808
- if est is not None:
809
- audio_duration_val = est
810
-
811
- # Optional: auto-tune LM cover strength (opt-in) to avoid suppressing lyric/vocal conditioning.
812
- if thinking and audio_cover_strength_val >= 0.999 and (req.lyrics or "").strip():
813
- tuned = os.getenv("ACESTEP_LM_COVER_STRENGTH")
814
- if tuned is not None and tuned.strip() != "":
815
- audio_cover_strength_val = float(tuned)
816
-
817
- # Align behavior:
818
- # - thinking=False: metas only (ignore audio codes), keep text2music.
819
- # - thinking=True: metas + audio codes, run in cover mode with LM instruction.
820
- instruction_val = req.instruction
821
- task_type_val = (req.task_type or "").strip() or "text2music"
822
-
823
- if not thinking:
824
- audio_code_string = ""
825
- if task_type_val == "cover":
826
- task_type_val = "text2music"
827
- if (instruction_val or "").strip() in {"", _DEFAULT_LM_INSTRUCTION}:
828
- instruction_val = _DEFAULT_DIT_INSTRUCTION
829
-
830
- if thinking:
831
- task_type_val = "cover"
832
- if (instruction_val or "").strip() in {"", _DEFAULT_DIT_INSTRUCTION}:
833
- instruction_val = _DEFAULT_LM_INSTRUCTION
834
-
835
- if not (audio_code_string and str(audio_code_string).strip()):
836
- # thinking=True requires codes generation.
837
- raise RuntimeError("thinking=true requires non-empty audio codes (LM generation failed).")
838
-
839
- # Response metas MUST reflect the actual values used by DiT.
840
- metas_out = _normalize_metas(lm_meta or {})
841
- if bpm_val is not None and int(bpm_val) > 0:
842
- metas_out["bpm"] = int(bpm_val)
843
- if audio_duration_val is not None and float(audio_duration_val) > 0:
844
- metas_out["duration"] = float(audio_duration_val)
845
- if (key_scale_val or "").strip():
846
- metas_out["keyscale"] = str(key_scale_val)
847
- if (time_sig_val or "").strip():
848
- metas_out["timesignature"] = str(time_sig_val)
849
-
850
- def _ensure_text_meta(field: str, fallback: Optional[str]) -> None:
851
- existing = metas_out.get(field)
852
- if isinstance(existing, str):
853
- stripped = existing.strip()
854
- if stripped and stripped.upper() != "N/A":
855
- return
856
- if fallback is None:
857
- return
858
- if fallback.strip():
859
- metas_out[field] = fallback
860
 
861
- _ensure_text_meta("caption", req.caption)
862
- _ensure_text_meta("lyrics", req.lyrics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863
 
864
  def _none_if_na_str(v: Any) -> Optional[str]:
865
  if v is None:
@@ -868,54 +695,17 @@ def create_app() -> FastAPI:
868
  if s in {"", "N/A"}:
869
  return None
870
  return s
871
- result = h.generate_music(
872
- captions=req.caption,
873
- lyrics=req.lyrics,
874
- bpm=bpm_val,
875
- key_scale=key_scale_val,
876
- time_signature=time_sig_val,
877
- vocal_language=req.vocal_language,
878
- inference_steps=req.inference_steps,
879
- guidance_scale=req.guidance_scale,
880
- use_random_seed=req.use_random_seed,
881
- seed=("-1" if (req.use_random_seed and int(req.seed) < 0) else str(req.seed)),
882
- reference_audio=req.reference_audio_path,
883
- audio_duration=audio_duration_val,
884
- batch_size=req.batch_size,
885
- src_audio=req.src_audio_path,
886
- audio_code_string=audio_code_string,
887
- repainting_start=req.repainting_start,
888
- repainting_end=req.repainting_end,
889
- instruction=instruction_val,
890
- audio_cover_strength=audio_cover_strength_val,
891
- task_type=task_type_val,
892
- use_adg=req.use_adg,
893
- cfg_interval_start=req.cfg_interval_start,
894
- cfg_interval_end=req.cfg_interval_end,
895
- audio_format=req.audio_format,
896
- use_tiled_decode=req.use_tiled_decode,
897
- progress=None,
898
- )
899
-
900
- # Extract values from new dict structure
901
- audios = result.get("audios", [])
902
- audio_paths = [audio.get("path") for audio in audios]
903
- first = audio_paths[0] if len(audio_paths) > 0 else None
904
- second = audio_paths[1] if len(audio_paths) > 1 else None
905
- gen_info = result.get("generation_info", "")
906
- status_msg = result.get("status_message", "")
907
- seed_value = result.get("extra_outputs", {}).get("seed_value", "")
908
-
909
  return {
910
- "first_audio_path": _path_to_audio_url(first) if first else None,
911
- "second_audio_path": _path_to_audio_url(second) if second else None,
912
- "audio_paths": [_path_to_audio_url(p) for p in (audio_paths or [])],
913
- "generation_info": gen_info,
914
- "status_message": status_msg,
915
  "seed_value": seed_value,
916
  "metas": metas_out,
917
- "bpm": int(bpm_val) if bpm_val is not None else None,
918
- "duration": float(audio_duration_val) if audio_duration_val is not None else None,
919
  "genres": _none_if_na_str(metas_out.get("genres")),
920
  "keyscale": _none_if_na_str(metas_out.get("keyscale")),
921
  "timesignature": _none_if_na_str(metas_out.get("timesignature")),
@@ -1020,53 +810,6 @@ def create_app() -> FastAPI:
1020
 
1021
  return default
1022
 
1023
- # Debug: print what keys we actually received (helps explain empty parsed values)
1024
- try:
1025
- top_keys = list(getattr(mapping, "keys", lambda: [])())
1026
- except Exception:
1027
- top_keys = []
1028
- try:
1029
- nested_probe = (
1030
- get("metas", None)
1031
- or get("meta", None)
1032
- or get("metadata", None)
1033
- or get("user_metadata", None)
1034
- or get("userMetadata", None)
1035
- )
1036
- if isinstance(nested_probe, str):
1037
- sp = nested_probe.strip()
1038
- if sp.startswith("{") and sp.endswith("}"):
1039
- try:
1040
- nested_probe = json.loads(sp)
1041
- except Exception:
1042
- nested_probe = None
1043
- nested_keys = list(nested_probe.keys()) if isinstance(nested_probe, dict) else []
1044
- except Exception:
1045
- nested_keys = []
1046
- print(f"[api_server] request keys: top={sorted(top_keys)}, nested={sorted(nested_keys)}")
1047
-
1048
- # Debug: print raw values/types for common meta fields (top-level + common aliases)
1049
- try:
1050
- probe_keys = [
1051
- "thinking",
1052
- "bpm",
1053
- "audio_duration",
1054
- "duration",
1055
- "audioDuration",
1056
- "key_scale",
1057
- "keyscale",
1058
- "keyScale",
1059
- "time_signature",
1060
- "timesignature",
1061
- "timeSignature",
1062
- ]
1063
- raw = {k: get(k, None) for k in probe_keys}
1064
- raw_types = {k: (type(v).__name__ if v is not None else None) for k, v in raw.items()}
1065
- print(f"[api_server] request raw: {raw}")
1066
- print(f"[api_server] request raw types: {raw_types}")
1067
- except Exception:
1068
- pass
1069
-
1070
  normalized_audio_duration = _to_float(_get_any("audio_duration", "duration", "audioDuration"), None)
1071
  normalized_bpm = _to_int(_get_any("bpm"), None)
1072
  normalized_keyscale = str(_get_any("key_scale", "keyscale", "keyScale", default="") or "")
@@ -1076,12 +819,6 @@ def create_app() -> FastAPI:
1076
  if normalized_audio_duration is None:
1077
  normalized_audio_duration = _to_float(_get_any("target_duration", "targetDuration"), None)
1078
 
1079
- print(
1080
- "[api_server] normalized: "
1081
- f"thinking={_to_bool(get('thinking'), False)}, bpm={normalized_bpm}, "
1082
- f"audio_duration={normalized_audio_duration}, key_scale={normalized_keyscale!r}, time_signature={normalized_timesig!r}"
1083
- )
1084
-
1085
  return GenerateMusicRequest(
1086
  caption=str(get("caption", "") or ""),
1087
  lyrics=str(get("lyrics", "") or ""),
@@ -1120,7 +857,6 @@ def create_app() -> FastAPI:
1120
  lm_negative_prompt=str(get("lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
1121
  constrained_decoding=_to_bool(_get_any("constrained_decoding", "constrainedDecoding", "constrained"), True),
1122
  constrained_decoding_debug=_to_bool(_get_any("constrained_decoding_debug", "constrainedDecodingDebug"), False),
1123
- # Accept common aliases, including hyphenated keys from some clients.
1124
  use_cot_caption=_to_bool(_get_any("use_cot_caption", "cot_caption", "cot-caption"), True),
1125
  use_cot_language=_to_bool(_get_any("use_cot_language", "cot_language", "cot-language"), True),
1126
  is_format_caption=_to_bool(_get_any("is_format_caption", "isFormatCaption"), False),
 
44
  DEFAULT_DIT_INSTRUCTION,
45
  DEFAULT_LM_INSTRUCTION,
46
  )
47
+ from acestep.inference import (
48
+ GenerationParams,
49
+ GenerationConfig,
50
+ generate_music,
51
+ )
52
+ from acestep.gradio_ui.events.results_handlers import _build_generation_info
53
 
54
 
55
  JobStatus = Literal["queued", "running", "succeeded", "failed"]
 
393
  app.state.executor = executor
394
  app.state.job_store = store
395
  app.state._python_executable = sys.executable
396
+
397
+ # Temporary directory for saving generated audio files
398
+ app.state.temp_audio_dir = os.path.join(tmp_root, "api_audio")
399
+ os.makedirs(app.state.temp_audio_dir, exist_ok=True)
400
 
401
  async def _ensure_initialized() -> None:
402
  h: AceStepHandler = app.state.handler
 
453
  job_store.mark_running(job_id)
454
 
455
  def _blocking_generate() -> Dict[str, Any]:
456
+ """Generate music using unified inference logic from acestep.inference"""
457
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  def _ensure_llm_ready() -> None:
459
+ """Ensure LLM handler is initialized when needed"""
460
  with app.state._llm_init_lock:
461
  initialized = getattr(app.state, "_llm_initialized", False)
462
  had_error = getattr(app.state, "_llm_init_error", None)
 
486
  else:
487
  app.state._llm_initialized = True
488
 
489
+ def _normalize_metas(meta: Dict[str, Any]) -> Dict[str, Any]:
490
+ """Ensure a stable `metas` dict (keys always present)."""
491
+ meta = meta or {}
492
+ out: Dict[str, Any] = dict(meta)
 
 
 
 
493
 
494
+ # Normalize key aliases
495
+ if "keyscale" not in out and "key_scale" in out:
496
+ out["keyscale"] = out.get("key_scale")
497
+ if "timesignature" not in out and "time_signature" in out:
498
+ out["timesignature"] = out.get("time_signature")
499
 
500
+ # Ensure required keys exist
501
+ for k in ["bpm", "duration", "genres", "keyscale", "timesignature"]:
502
+ if out.get(k) in (None, ""):
503
+ out[k] = "N/A"
504
+ return out
505
 
506
+ # Normalize LM sampling parameters
507
+ lm_top_k = req.lm_top_k if req.lm_top_k and req.lm_top_k > 0 else 0
508
+ lm_top_p = req.lm_top_p if req.lm_top_p and req.lm_top_p < 1.0 else 0.9
509
 
510
+ # Determine if LLM is needed
511
+ thinking = bool(req.thinking)
512
+ sample_mode = bool(req.sample_mode)
513
+ need_llm = thinking or sample_mode
514
+
515
+ print(f"[api_server] Request params: req.thinking={req.thinking}, req.sample_mode={req.sample_mode}")
516
+ print(f"[api_server] Determined: thinking={thinking}, sample_mode={sample_mode}, need_llm={need_llm}")
517
+
518
+ # Ensure LLM is ready if needed
519
+ if need_llm:
520
  _ensure_llm_ready()
521
  if getattr(app.state, "_llm_init_error", None):
522
  raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
523
 
524
+ # Handle sample mode: generate random caption/lyrics first
525
+ caption = req.caption
526
+ lyrics = req.lyrics
527
+ bpm = req.bpm
528
+ key_scale = req.key_scale
529
+ time_signature = req.time_signature
530
+ audio_duration = req.audio_duration
531
+
532
+ if sample_mode:
533
+ print("[api_server] Sample mode: generating random caption/lyrics via LM")
534
  sample_metadata, sample_status = llm.understand_audio_from_codes(
535
  audio_codes="NO USER INPUT",
536
+ temperature=req.lm_temperature,
537
+ cfg_scale=max(1.0, req.lm_cfg_scale),
538
+ negative_prompt=req.lm_negative_prompt,
539
+ top_k=lm_top_k if lm_top_k > 0 else None,
540
+ top_p=lm_top_p if lm_top_p < 1.0 else None,
541
+ repetition_penalty=req.lm_repetition_penalty,
542
+ use_constrained_decoding=req.constrained_decoding,
543
+ constrained_decoding_debug=req.constrained_decoding_debug,
544
  )
545
 
546
  if not sample_metadata or str(sample_status).startswith("❌"):
547
  raise RuntimeError(f"Sample generation failed: {sample_status}")
548
 
549
+ # Use generated values with fallback defaults
550
+ caption = sample_metadata.get("caption", "")
551
+ lyrics = sample_metadata.get("lyrics", "")
552
+ bpm = _to_int(sample_metadata.get("bpm"), None) or _to_int(os.getenv("ACESTEP_SAMPLE_DEFAULT_BPM", "120"), 120)
553
+ key_scale = sample_metadata.get("keyscale", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_KEY", "C Major")
554
+ time_signature = sample_metadata.get("timesignature", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_TIMESIGNATURE", "4/4")
555
+ audio_duration = _to_float(sample_metadata.get("duration"), None) or _to_float(os.getenv("ACESTEP_SAMPLE_DEFAULT_DURATION_SECONDS", "120"), 120.0)
556
+
557
+ print(f"[api_server] Sample generated: caption_len={len(caption)}, lyrics_len={len(lyrics)}, bpm={bpm}, duration={audio_duration}")
558
+
559
+ print(f"[api_server] Before GenerationParams: thinking={thinking}, sample_mode={sample_mode}")
560
+ print(f"[api_server] Caption/Lyrics to use: caption_len={len(caption)}, lyrics_len={len(lyrics)}")
561
+
562
+ # Build GenerationParams using unified interface
563
+ # Note: thinking controls LM code generation, sample_mode only affects CoT metas
564
+ params = GenerationParams(
565
+ task_type=req.task_type,
566
+ instruction=req.instruction,
567
+ reference_audio=req.reference_audio_path,
568
+ src_audio=req.src_audio_path,
569
+ audio_codes=req.audio_code_string,
570
+ caption=caption,
571
+ lyrics=lyrics,
572
+ instrumental=False,
573
+ vocal_language=req.vocal_language,
574
+ bpm=bpm,
575
+ keyscale=key_scale,
576
+ timesignature=time_signature,
577
+ duration=audio_duration if audio_duration else -1.0,
578
+ inference_steps=req.inference_steps,
579
+ seed=req.seed,
580
+ guidance_scale=req.guidance_scale,
581
+ use_adg=req.use_adg,
582
+ cfg_interval_start=req.cfg_interval_start,
583
+ cfg_interval_end=req.cfg_interval_end,
584
+ repainting_start=req.repainting_start,
585
+ repainting_end=req.repainting_end if req.repainting_end else -1,
586
+ audio_cover_strength=req.audio_cover_strength,
587
+ # LM parameters
588
+ thinking=thinking, # Use LM for code generation when thinking=True
589
+ lm_temperature=req.lm_temperature,
590
+ lm_cfg_scale=req.lm_cfg_scale,
591
+ lm_top_k=lm_top_k,
592
+ lm_top_p=lm_top_p,
593
+ lm_negative_prompt=req.lm_negative_prompt,
594
+ use_cot_metas=not sample_mode, # Sample mode already generated metas, don't regenerate
595
+ use_cot_caption=req.use_cot_caption,
596
+ use_cot_language=req.use_cot_language,
597
+ use_constrained_decoding=req.constrained_decoding,
598
+ )
 
 
 
 
599
 
600
+ # Build GenerationConfig - default to 2 audios like gradio_ui
601
+ batch_size = req.batch_size if req.batch_size is not None else 2
602
+ config = GenerationConfig(
603
+ batch_size=batch_size,
604
+ use_random_seed=req.use_random_seed,
605
+ seeds=None, # Let unified logic handle seed generation
606
+ audio_format=req.audio_format,
607
+ constrained_decoding_debug=req.constrained_decoding_debug,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608
  )
609
 
610
+ # Check LLM initialization status
611
+ llm_is_initialized = getattr(app.state, "_llm_initialized", False)
612
+ llm_to_pass = llm if llm_is_initialized else None
613
+
614
+ print(f"[api_server] Generating music with unified interface:")
615
+ print(f" - thinking={params.thinking}")
616
+ print(f" - batch_size={batch_size}")
617
+ print(f" - llm_initialized={llm_is_initialized}")
618
+ print(f" - llm_handler={'Available' if llm_to_pass else 'None'}")
619
+
620
+ # Generate music using unified interface
621
+ result = generate_music(
622
+ dit_handler=h,
623
+ llm_handler=llm_to_pass,
624
+ params=params,
625
+ config=config,
626
+ save_dir=app.state.temp_audio_dir,
627
+ progress=None,
628
  )
629
+
630
+ print(f"[api_server] Generation completed. Success={result.success}, Audios={len(result.audios)}")
631
+ print(f"[api_server] Time costs keys: {list(result.extra_outputs.get('time_costs', {}).keys())}")
632
 
633
+ if not result.success:
634
+ raise RuntimeError(f"Music generation failed: {result.error or result.status_message}")
635
 
636
+ # Extract results
637
+ audio_paths = [audio["path"] for audio in result.audios if audio.get("path")]
638
+ first_audio = audio_paths[0] if len(audio_paths) > 0 else None
639
+ second_audio = audio_paths[1] if len(audio_paths) > 1 else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640
 
641
+ # Get metadata from LM or CoT results
642
+ lm_metadata = result.extra_outputs.get("lm_metadata", {})
643
+ metas_out = _normalize_metas(lm_metadata)
644
+
645
+ # Update metas with actual values used
646
+ if params.cot_bpm:
647
+ metas_out["bpm"] = params.cot_bpm
648
+ elif bpm:
649
+ metas_out["bpm"] = bpm
650
+
651
+ if params.cot_duration:
652
+ metas_out["duration"] = params.cot_duration
653
+ elif audio_duration:
654
+ metas_out["duration"] = audio_duration
655
+
656
+ if params.cot_keyscale:
657
+ metas_out["keyscale"] = params.cot_keyscale
658
+ elif key_scale:
659
+ metas_out["keyscale"] = key_scale
660
+
661
+ if params.cot_timesignature:
662
+ metas_out["timesignature"] = params.cot_timesignature
663
+ elif time_signature:
664
+ metas_out["timesignature"] = time_signature
665
+
666
+ # Ensure caption and lyrics are in metas
667
+ if caption:
668
+ metas_out["caption"] = caption
669
+ if lyrics:
670
+ metas_out["lyrics"] = lyrics
671
+
672
+ # Extract seed values for response (comma-separated for multiple audios)
673
+ seed_values = []
674
+ for audio in result.audios:
675
+ audio_params = audio.get("params", {})
676
+ seed = audio_params.get("seed")
677
+ if seed is not None:
678
+ seed_values.append(str(seed))
679
+ seed_value = ",".join(seed_values) if seed_values else ""
680
+
681
+ # Build generation_info using the helper function (like gradio_ui)
682
+ time_costs = result.extra_outputs.get("time_costs", {})
683
+ generation_info = _build_generation_info(
684
+ lm_metadata=lm_metadata,
685
+ time_costs=time_costs,
686
+ seed_value=seed_value,
687
+ inference_steps=req.inference_steps,
688
+ num_audios=len(result.audios),
689
+ )
690
 
691
  def _none_if_na_str(v: Any) -> Optional[str]:
692
  if v is None:
 
695
  if s in {"", "N/A"}:
696
  return None
697
  return s
698
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
699
  return {
700
+ "first_audio_path": _path_to_audio_url(first_audio) if first_audio else None,
701
+ "second_audio_path": _path_to_audio_url(second_audio) if second_audio else None,
702
+ "audio_paths": [_path_to_audio_url(p) for p in audio_paths],
703
+ "generation_info": generation_info,
704
+ "status_message": result.status_message,
705
  "seed_value": seed_value,
706
  "metas": metas_out,
707
+ "bpm": metas_out.get("bpm") if isinstance(metas_out.get("bpm"), int) else None,
708
+ "duration": metas_out.get("duration") if isinstance(metas_out.get("duration"), (int, float)) else None,
709
  "genres": _none_if_na_str(metas_out.get("genres")),
710
  "keyscale": _none_if_na_str(metas_out.get("keyscale")),
711
  "timesignature": _none_if_na_str(metas_out.get("timesignature")),
 
810
 
811
  return default
812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813
  normalized_audio_duration = _to_float(_get_any("audio_duration", "duration", "audioDuration"), None)
814
  normalized_bpm = _to_int(_get_any("bpm"), None)
815
  normalized_keyscale = str(_get_any("key_scale", "keyscale", "keyScale", default="") or "")
 
819
  if normalized_audio_duration is None:
820
  normalized_audio_duration = _to_float(_get_any("target_duration", "targetDuration"), None)
821
 
 
 
 
 
 
 
822
  return GenerateMusicRequest(
823
  caption=str(get("caption", "") or ""),
824
  lyrics=str(get("lyrics", "") or ""),
 
857
  lm_negative_prompt=str(get("lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
858
  constrained_decoding=_to_bool(_get_any("constrained_decoding", "constrainedDecoding", "constrained"), True),
859
  constrained_decoding_debug=_to_bool(_get_any("constrained_decoding_debug", "constrainedDecodingDebug"), False),
 
860
  use_cot_caption=_to_bool(_get_any("use_cot_caption", "cot_caption", "cot-caption"), True),
861
  use_cot_language=_to_bool(_get_any("use_cot_language", "cot_language", "cot-language"), True),
862
  is_format_caption=_to_bool(_get_any("is_format_caption", "isFormatCaption"), False),