Spaces:
Running
on
A100
Running
on
A100
feat: api-server add sample
Browse files- acestep/api_server.py +129 -25
acestep/api_server.py
CHANGED
|
@@ -58,6 +58,8 @@ class GenerateMusicRequest(BaseModel):
|
|
| 58 |
# - thinking=False: do not use LM to generate codes (dit behavior)
|
| 59 |
# Regardless of thinking, if some metas are missing, server may use LM to fill them.
|
| 60 |
thinking: bool = False
|
|
|
|
|
|
|
| 61 |
|
| 62 |
bpm: Optional[int] = None
|
| 63 |
# Accept common client keys via manual parsing (see _build_req_from_mapping).
|
|
@@ -559,6 +561,36 @@ def create_app() -> FastAPI:
|
|
| 559 |
out[k] = "N/A"
|
| 560 |
return out
|
| 561 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
# Optional: generate 5Hz LM codes server-side
|
| 563 |
audio_code_string = req.audio_code_string
|
| 564 |
bpm_val = req.bpm
|
|
@@ -580,6 +612,57 @@ def create_app() -> FastAPI:
|
|
| 580 |
|
| 581 |
lm_meta: Optional[Dict[str, Any]] = None
|
| 582 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
# Determine effective batch size (used for per-sample LM code diversity)
|
| 584 |
effective_batch_size = req.batch_size
|
| 585 |
if effective_batch_size is None:
|
|
@@ -641,31 +724,7 @@ def create_app() -> FastAPI:
|
|
| 641 |
)
|
| 642 |
|
| 643 |
if need_lm_metas or need_lm_codes:
|
| 644 |
-
|
| 645 |
-
with app.state._llm_init_lock:
|
| 646 |
-
if getattr(app.state, "_llm_initialized", False) is False and getattr(app.state, "_llm_init_error", None) is None:
|
| 647 |
-
project_root = _get_project_root()
|
| 648 |
-
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 649 |
-
lm_model_path = (req.lm_model_path or os.getenv("ACESTEP_LM_MODEL_PATH") or "acestep-5Hz-lm-0.6B-v3").strip()
|
| 650 |
-
backend = (req.lm_backend or os.getenv("ACESTEP_LM_BACKEND") or "vllm").strip().lower()
|
| 651 |
-
if backend not in {"vllm", "pt"}:
|
| 652 |
-
backend = "vllm"
|
| 653 |
-
|
| 654 |
-
lm_device = os.getenv("ACESTEP_LM_DEVICE", os.getenv("ACESTEP_DEVICE", "auto"))
|
| 655 |
-
lm_offload = _env_bool("ACESTEP_LM_OFFLOAD_TO_CPU", False)
|
| 656 |
-
|
| 657 |
-
status, ok = llm.initialize(
|
| 658 |
-
checkpoint_dir=checkpoint_dir,
|
| 659 |
-
lm_model_path=lm_model_path,
|
| 660 |
-
backend=backend,
|
| 661 |
-
device=lm_device,
|
| 662 |
-
offload_to_cpu=lm_offload,
|
| 663 |
-
dtype=h.dtype,
|
| 664 |
-
)
|
| 665 |
-
if not ok:
|
| 666 |
-
app.state._llm_init_error = status
|
| 667 |
-
else:
|
| 668 |
-
app.state._llm_initialized = True
|
| 669 |
|
| 670 |
if getattr(app.state, "_llm_init_error", None):
|
| 671 |
# If codes generation is required, fail hard.
|
|
@@ -972,6 +1031,7 @@ def create_app() -> FastAPI:
|
|
| 972 |
caption=str(get("caption", "") or ""),
|
| 973 |
lyrics=str(get("lyrics", "") or ""),
|
| 974 |
thinking=_to_bool(get("thinking"), False),
|
|
|
|
| 975 |
bpm=normalized_bpm,
|
| 976 |
key_scale=normalized_keyscale,
|
| 977 |
time_signature=normalized_timesig,
|
|
@@ -1113,6 +1173,50 @@ def create_app() -> FastAPI:
|
|
| 1113 |
await q.put((rec.job_id, req))
|
| 1114 |
return CreateJobResponse(job_id=rec.job_id, status="queued", queue_position=position)
|
| 1115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1116 |
@app.get("/v1/jobs/{job_id}", response_model=JobResponse)
|
| 1117 |
async def get_job(job_id: str) -> JobResponse:
|
| 1118 |
rec = store.get(job_id)
|
|
|
|
| 58 |
# - thinking=False: do not use LM to generate codes (dit behavior)
|
| 59 |
# Regardless of thinking, if some metas are missing, server may use LM to fill them.
|
| 60 |
thinking: bool = False
|
| 61 |
+
# Sample-mode requests auto-generate caption/lyrics/metas via LM (no user prompt).
|
| 62 |
+
sample_mode: bool = False
|
| 63 |
|
| 64 |
bpm: Optional[int] = None
|
| 65 |
# Accept common client keys via manual parsing (see _build_req_from_mapping).
|
|
|
|
| 561 |
out[k] = "N/A"
|
| 562 |
return out
|
| 563 |
|
| 564 |
+
def _ensure_llm_ready() -> None:
|
| 565 |
+
with app.state._llm_init_lock:
|
| 566 |
+
initialized = getattr(app.state, "_llm_initialized", False)
|
| 567 |
+
had_error = getattr(app.state, "_llm_init_error", None)
|
| 568 |
+
if initialized or had_error is not None:
|
| 569 |
+
return
|
| 570 |
+
|
| 571 |
+
project_root = _get_project_root()
|
| 572 |
+
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 573 |
+
lm_model_path = (req.lm_model_path or os.getenv("ACESTEP_LM_MODEL_PATH") or "acestep-5Hz-lm-0.6B-v3").strip()
|
| 574 |
+
backend = (req.lm_backend or os.getenv("ACESTEP_LM_BACKEND") or "vllm").strip().lower()
|
| 575 |
+
if backend not in {"vllm", "pt"}:
|
| 576 |
+
backend = "vllm"
|
| 577 |
+
|
| 578 |
+
lm_device = os.getenv("ACESTEP_LM_DEVICE", os.getenv("ACESTEP_DEVICE", "auto"))
|
| 579 |
+
lm_offload = _env_bool("ACESTEP_LM_OFFLOAD_TO_CPU", False)
|
| 580 |
+
|
| 581 |
+
status, ok = llm.initialize(
|
| 582 |
+
checkpoint_dir=checkpoint_dir,
|
| 583 |
+
lm_model_path=lm_model_path,
|
| 584 |
+
backend=backend,
|
| 585 |
+
device=lm_device,
|
| 586 |
+
offload_to_cpu=lm_offload,
|
| 587 |
+
dtype=h.dtype,
|
| 588 |
+
)
|
| 589 |
+
if not ok:
|
| 590 |
+
app.state._llm_init_error = status
|
| 591 |
+
else:
|
| 592 |
+
app.state._llm_initialized = True
|
| 593 |
+
|
| 594 |
# Optional: generate 5Hz LM codes server-side
|
| 595 |
audio_code_string = req.audio_code_string
|
| 596 |
bpm_val = req.bpm
|
|
|
|
| 612 |
|
| 613 |
lm_meta: Optional[Dict[str, Any]] = None
|
| 614 |
|
| 615 |
+
sample_mode = bool(getattr(req, "sample_mode", False))
|
| 616 |
+
if sample_mode:
|
| 617 |
+
_ensure_llm_ready()
|
| 618 |
+
if getattr(app.state, "_llm_init_error", None):
|
| 619 |
+
raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
|
| 620 |
+
|
| 621 |
+
sample_metadata, sample_status = llm.understand_audio_from_codes(
|
| 622 |
+
audio_codes="NO USER INPUT",
|
| 623 |
+
temperature=float(getattr(req, "lm_temperature", _LM_DEFAULT_TEMPERATURE)),
|
| 624 |
+
cfg_scale=max(1.0, float(getattr(req, "lm_cfg_scale", _LM_DEFAULT_CFG_SCALE))),
|
| 625 |
+
negative_prompt=str(getattr(req, "lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
|
| 626 |
+
top_k=_normalize_optional_int(getattr(req, "lm_top_k", None)),
|
| 627 |
+
top_p=_normalize_optional_float(getattr(req, "lm_top_p", None)),
|
| 628 |
+
repetition_penalty=float(getattr(req, "lm_repetition_penalty", 1.0)),
|
| 629 |
+
use_constrained_decoding=bool(getattr(req, "constrained_decoding", True)),
|
| 630 |
+
constrained_decoding_debug=bool(getattr(req, "constrained_decoding_debug", False)),
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
if not sample_metadata or str(sample_status).startswith("❌"):
|
| 634 |
+
raise RuntimeError(f"Sample generation failed: {sample_status}")
|
| 635 |
+
|
| 636 |
+
req.caption = str(sample_metadata.get("caption", "") or "")
|
| 637 |
+
req.lyrics = str(sample_metadata.get("lyrics", "") or "")
|
| 638 |
+
req.bpm = _to_int(sample_metadata.get("bpm"), req.bpm)
|
| 639 |
+
|
| 640 |
+
sample_keyscale = sample_metadata.get("keyscale", sample_metadata.get("key_scale", ""))
|
| 641 |
+
if sample_keyscale:
|
| 642 |
+
req.key_scale = str(sample_keyscale)
|
| 643 |
+
|
| 644 |
+
sample_timesig = sample_metadata.get("timesignature", sample_metadata.get("time_signature", ""))
|
| 645 |
+
if sample_timesig:
|
| 646 |
+
req.time_signature = str(sample_timesig)
|
| 647 |
+
|
| 648 |
+
sample_duration = _to_float(sample_metadata.get("duration"), None)
|
| 649 |
+
if sample_duration is not None and sample_duration > 0:
|
| 650 |
+
req.audio_duration = sample_duration
|
| 651 |
+
|
| 652 |
+
lm_meta = sample_metadata
|
| 653 |
+
|
| 654 |
+
print(
|
| 655 |
+
"[api_server] sample mode metadata:",
|
| 656 |
+
{
|
| 657 |
+
"caption_len": len(req.caption),
|
| 658 |
+
"lyrics_len": len(req.lyrics),
|
| 659 |
+
"bpm": req.bpm,
|
| 660 |
+
"audio_duration": req.audio_duration,
|
| 661 |
+
"key_scale": req.key_scale,
|
| 662 |
+
"time_signature": req.time_signature,
|
| 663 |
+
},
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
# Determine effective batch size (used for per-sample LM code diversity)
|
| 667 |
effective_batch_size = req.batch_size
|
| 668 |
if effective_batch_size is None:
|
|
|
|
| 724 |
)
|
| 725 |
|
| 726 |
if need_lm_metas or need_lm_codes:
|
| 727 |
+
_ensure_llm_ready()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 728 |
|
| 729 |
if getattr(app.state, "_llm_init_error", None):
|
| 730 |
# If codes generation is required, fail hard.
|
|
|
|
| 1031 |
caption=str(get("caption", "") or ""),
|
| 1032 |
lyrics=str(get("lyrics", "") or ""),
|
| 1033 |
thinking=_to_bool(get("thinking"), False),
|
| 1034 |
+
sample_mode=_to_bool(_get_any("sample_mode", "sampleMode"), False),
|
| 1035 |
bpm=normalized_bpm,
|
| 1036 |
key_scale=normalized_keyscale,
|
| 1037 |
time_signature=normalized_timesig,
|
|
|
|
| 1173 |
await q.put((rec.job_id, req))
|
| 1174 |
return CreateJobResponse(job_id=rec.job_id, status="queued", queue_position=position)
|
| 1175 |
|
| 1176 |
+
@app.post("/v1/music/random", response_model=CreateJobResponse)
|
| 1177 |
+
async def create_random_sample_job(request: Request) -> CreateJobResponse:
|
| 1178 |
+
"""Create a sample-mode job that auto-generates caption/lyrics via LM."""
|
| 1179 |
+
|
| 1180 |
+
thinking_value: Any = None
|
| 1181 |
+
content_type = (request.headers.get("content-type") or "").lower()
|
| 1182 |
+
body_dict: Dict[str, Any] = {}
|
| 1183 |
+
|
| 1184 |
+
if "json" in content_type:
|
| 1185 |
+
try:
|
| 1186 |
+
payload = await request.json()
|
| 1187 |
+
if isinstance(payload, dict):
|
| 1188 |
+
body_dict = payload
|
| 1189 |
+
except Exception:
|
| 1190 |
+
body_dict = {}
|
| 1191 |
+
|
| 1192 |
+
if not body_dict and request.query_params:
|
| 1193 |
+
body_dict = dict(request.query_params)
|
| 1194 |
+
|
| 1195 |
+
thinking_value = body_dict.get("thinking")
|
| 1196 |
+
if thinking_value is None:
|
| 1197 |
+
thinking_value = body_dict.get("Thinking")
|
| 1198 |
+
|
| 1199 |
+
thinking_flag = _to_bool(thinking_value, True)
|
| 1200 |
+
|
| 1201 |
+
req = GenerateMusicRequest(
|
| 1202 |
+
caption="",
|
| 1203 |
+
lyrics="",
|
| 1204 |
+
thinking=thinking_flag,
|
| 1205 |
+
sample_mode=True,
|
| 1206 |
+
)
|
| 1207 |
+
|
| 1208 |
+
rec = store.create()
|
| 1209 |
+
q: asyncio.Queue = app.state.job_queue
|
| 1210 |
+
if q.full():
|
| 1211 |
+
raise HTTPException(status_code=429, detail="Server busy: queue is full")
|
| 1212 |
+
|
| 1213 |
+
async with app.state.pending_lock:
|
| 1214 |
+
app.state.pending_ids.append(rec.job_id)
|
| 1215 |
+
position = len(app.state.pending_ids)
|
| 1216 |
+
|
| 1217 |
+
await q.put((rec.job_id, req))
|
| 1218 |
+
return CreateJobResponse(job_id=rec.job_id, status="queued", queue_position=position)
|
| 1219 |
+
|
| 1220 |
@app.get("/v1/jobs/{job_id}", response_model=JobResponse)
|
| 1221 |
async def get_job(job_id: str) -> JobResponse:
|
| 1222 |
rec = store.get(job_id)
|