ChuxiJ commited on
Commit
e1ce8b3
·
1 Parent(s): bbb4f62

feat: api-server add sample

Browse files
Files changed (1) hide show
  1. acestep/api_server.py +129 -25
acestep/api_server.py CHANGED
@@ -58,6 +58,8 @@ class GenerateMusicRequest(BaseModel):
58
  # - thinking=False: do not use LM to generate codes (dit behavior)
59
  # Regardless of thinking, if some metas are missing, server may use LM to fill them.
60
  thinking: bool = False
 
 
61
 
62
  bpm: Optional[int] = None
63
  # Accept common client keys via manual parsing (see _build_req_from_mapping).
@@ -559,6 +561,36 @@ def create_app() -> FastAPI:
559
  out[k] = "N/A"
560
  return out
561
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
  # Optional: generate 5Hz LM codes server-side
563
  audio_code_string = req.audio_code_string
564
  bpm_val = req.bpm
@@ -580,6 +612,57 @@ def create_app() -> FastAPI:
580
 
581
  lm_meta: Optional[Dict[str, Any]] = None
582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  # Determine effective batch size (used for per-sample LM code diversity)
584
  effective_batch_size = req.batch_size
585
  if effective_batch_size is None:
@@ -641,31 +724,7 @@ def create_app() -> FastAPI:
641
  )
642
 
643
  if need_lm_metas or need_lm_codes:
644
- # Lazy init 5Hz LM once
645
- with app.state._llm_init_lock:
646
- if getattr(app.state, "_llm_initialized", False) is False and getattr(app.state, "_llm_init_error", None) is None:
647
- project_root = _get_project_root()
648
- checkpoint_dir = os.path.join(project_root, "checkpoints")
649
- lm_model_path = (req.lm_model_path or os.getenv("ACESTEP_LM_MODEL_PATH") or "acestep-5Hz-lm-0.6B-v3").strip()
650
- backend = (req.lm_backend or os.getenv("ACESTEP_LM_BACKEND") or "vllm").strip().lower()
651
- if backend not in {"vllm", "pt"}:
652
- backend = "vllm"
653
-
654
- lm_device = os.getenv("ACESTEP_LM_DEVICE", os.getenv("ACESTEP_DEVICE", "auto"))
655
- lm_offload = _env_bool("ACESTEP_LM_OFFLOAD_TO_CPU", False)
656
-
657
- status, ok = llm.initialize(
658
- checkpoint_dir=checkpoint_dir,
659
- lm_model_path=lm_model_path,
660
- backend=backend,
661
- device=lm_device,
662
- offload_to_cpu=lm_offload,
663
- dtype=h.dtype,
664
- )
665
- if not ok:
666
- app.state._llm_init_error = status
667
- else:
668
- app.state._llm_initialized = True
669
 
670
  if getattr(app.state, "_llm_init_error", None):
671
  # If codes generation is required, fail hard.
@@ -972,6 +1031,7 @@ def create_app() -> FastAPI:
972
  caption=str(get("caption", "") or ""),
973
  lyrics=str(get("lyrics", "") or ""),
974
  thinking=_to_bool(get("thinking"), False),
 
975
  bpm=normalized_bpm,
976
  key_scale=normalized_keyscale,
977
  time_signature=normalized_timesig,
@@ -1113,6 +1173,50 @@ def create_app() -> FastAPI:
1113
  await q.put((rec.job_id, req))
1114
  return CreateJobResponse(job_id=rec.job_id, status="queued", queue_position=position)
1115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1116
  @app.get("/v1/jobs/{job_id}", response_model=JobResponse)
1117
  async def get_job(job_id: str) -> JobResponse:
1118
  rec = store.get(job_id)
 
58
  # - thinking=False: do not use LM to generate codes (dit behavior)
59
  # Regardless of thinking, if some metas are missing, server may use LM to fill them.
60
  thinking: bool = False
61
+ # Sample-mode requests auto-generate caption/lyrics/metas via LM (no user prompt).
62
+ sample_mode: bool = False
63
 
64
  bpm: Optional[int] = None
65
  # Accept common client keys via manual parsing (see _build_req_from_mapping).
 
561
  out[k] = "N/A"
562
  return out
563
 
564
+ def _ensure_llm_ready() -> None:
565
+ with app.state._llm_init_lock:
566
+ initialized = getattr(app.state, "_llm_initialized", False)
567
+ had_error = getattr(app.state, "_llm_init_error", None)
568
+ if initialized or had_error is not None:
569
+ return
570
+
571
+ project_root = _get_project_root()
572
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
573
+ lm_model_path = (req.lm_model_path or os.getenv("ACESTEP_LM_MODEL_PATH") or "acestep-5Hz-lm-0.6B-v3").strip()
574
+ backend = (req.lm_backend or os.getenv("ACESTEP_LM_BACKEND") or "vllm").strip().lower()
575
+ if backend not in {"vllm", "pt"}:
576
+ backend = "vllm"
577
+
578
+ lm_device = os.getenv("ACESTEP_LM_DEVICE", os.getenv("ACESTEP_DEVICE", "auto"))
579
+ lm_offload = _env_bool("ACESTEP_LM_OFFLOAD_TO_CPU", False)
580
+
581
+ status, ok = llm.initialize(
582
+ checkpoint_dir=checkpoint_dir,
583
+ lm_model_path=lm_model_path,
584
+ backend=backend,
585
+ device=lm_device,
586
+ offload_to_cpu=lm_offload,
587
+ dtype=h.dtype,
588
+ )
589
+ if not ok:
590
+ app.state._llm_init_error = status
591
+ else:
592
+ app.state._llm_initialized = True
593
+
594
  # Optional: generate 5Hz LM codes server-side
595
  audio_code_string = req.audio_code_string
596
  bpm_val = req.bpm
 
612
 
613
  lm_meta: Optional[Dict[str, Any]] = None
614
 
615
+ sample_mode = bool(getattr(req, "sample_mode", False))
616
+ if sample_mode:
617
+ _ensure_llm_ready()
618
+ if getattr(app.state, "_llm_init_error", None):
619
+ raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
620
+
621
+ sample_metadata, sample_status = llm.understand_audio_from_codes(
622
+ audio_codes="NO USER INPUT",
623
+ temperature=float(getattr(req, "lm_temperature", _LM_DEFAULT_TEMPERATURE)),
624
+ cfg_scale=max(1.0, float(getattr(req, "lm_cfg_scale", _LM_DEFAULT_CFG_SCALE))),
625
+ negative_prompt=str(getattr(req, "lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
626
+ top_k=_normalize_optional_int(getattr(req, "lm_top_k", None)),
627
+ top_p=_normalize_optional_float(getattr(req, "lm_top_p", None)),
628
+ repetition_penalty=float(getattr(req, "lm_repetition_penalty", 1.0)),
629
+ use_constrained_decoding=bool(getattr(req, "constrained_decoding", True)),
630
+ constrained_decoding_debug=bool(getattr(req, "constrained_decoding_debug", False)),
631
+ )
632
+
633
+ if not sample_metadata or str(sample_status).startswith("❌"):
634
+ raise RuntimeError(f"Sample generation failed: {sample_status}")
635
+
636
+ req.caption = str(sample_metadata.get("caption", "") or "")
637
+ req.lyrics = str(sample_metadata.get("lyrics", "") or "")
638
+ req.bpm = _to_int(sample_metadata.get("bpm"), req.bpm)
639
+
640
+ sample_keyscale = sample_metadata.get("keyscale", sample_metadata.get("key_scale", ""))
641
+ if sample_keyscale:
642
+ req.key_scale = str(sample_keyscale)
643
+
644
+ sample_timesig = sample_metadata.get("timesignature", sample_metadata.get("time_signature", ""))
645
+ if sample_timesig:
646
+ req.time_signature = str(sample_timesig)
647
+
648
+ sample_duration = _to_float(sample_metadata.get("duration"), None)
649
+ if sample_duration is not None and sample_duration > 0:
650
+ req.audio_duration = sample_duration
651
+
652
+ lm_meta = sample_metadata
653
+
654
+ print(
655
+ "[api_server] sample mode metadata:",
656
+ {
657
+ "caption_len": len(req.caption),
658
+ "lyrics_len": len(req.lyrics),
659
+ "bpm": req.bpm,
660
+ "audio_duration": req.audio_duration,
661
+ "key_scale": req.key_scale,
662
+ "time_signature": req.time_signature,
663
+ },
664
+ )
665
+
666
  # Determine effective batch size (used for per-sample LM code diversity)
667
  effective_batch_size = req.batch_size
668
  if effective_batch_size is None:
 
724
  )
725
 
726
  if need_lm_metas or need_lm_codes:
727
+ _ensure_llm_ready()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
 
729
  if getattr(app.state, "_llm_init_error", None):
730
  # If codes generation is required, fail hard.
 
1031
  caption=str(get("caption", "") or ""),
1032
  lyrics=str(get("lyrics", "") or ""),
1033
  thinking=_to_bool(get("thinking"), False),
1034
+ sample_mode=_to_bool(_get_any("sample_mode", "sampleMode"), False),
1035
  bpm=normalized_bpm,
1036
  key_scale=normalized_keyscale,
1037
  time_signature=normalized_timesig,
 
1173
  await q.put((rec.job_id, req))
1174
  return CreateJobResponse(job_id=rec.job_id, status="queued", queue_position=position)
1175
 
1176
+ @app.post("/v1/music/random", response_model=CreateJobResponse)
1177
+ async def create_random_sample_job(request: Request) -> CreateJobResponse:
1178
+ """Create a sample-mode job that auto-generates caption/lyrics via LM."""
1179
+
1180
+ thinking_value: Any = None
1181
+ content_type = (request.headers.get("content-type") or "").lower()
1182
+ body_dict: Dict[str, Any] = {}
1183
+
1184
+ if "json" in content_type:
1185
+ try:
1186
+ payload = await request.json()
1187
+ if isinstance(payload, dict):
1188
+ body_dict = payload
1189
+ except Exception:
1190
+ body_dict = {}
1191
+
1192
+ if not body_dict and request.query_params:
1193
+ body_dict = dict(request.query_params)
1194
+
1195
+ thinking_value = body_dict.get("thinking")
1196
+ if thinking_value is None:
1197
+ thinking_value = body_dict.get("Thinking")
1198
+
1199
+ thinking_flag = _to_bool(thinking_value, True)
1200
+
1201
+ req = GenerateMusicRequest(
1202
+ caption="",
1203
+ lyrics="",
1204
+ thinking=thinking_flag,
1205
+ sample_mode=True,
1206
+ )
1207
+
1208
+ rec = store.create()
1209
+ q: asyncio.Queue = app.state.job_queue
1210
+ if q.full():
1211
+ raise HTTPException(status_code=429, detail="Server busy: queue is full")
1212
+
1213
+ async with app.state.pending_lock:
1214
+ app.state.pending_ids.append(rec.job_id)
1215
+ position = len(app.state.pending_ids)
1216
+
1217
+ await q.put((rec.job_id, req))
1218
+ return CreateJobResponse(job_id=rec.job_id, status="queued", queue_position=position)
1219
+
1220
  @app.get("/v1/jobs/{job_id}", response_model=JobResponse)
1221
  async def get_job(job_id: str) -> JobResponse:
1222
  rec = store.get(job_id)