Spaces:
Running
on
A100
Running
on
A100
feat: api support lm-dit
Browse files- acestep/api_server.py +372 -10
acestep/api_server.py
CHANGED
|
@@ -14,6 +14,7 @@ from __future__ import annotations
|
|
| 14 |
import asyncio
|
| 15 |
import json
|
| 16 |
import os
|
|
|
|
| 17 |
import sys
|
| 18 |
import time
|
| 19 |
import traceback
|
|
@@ -33,6 +34,7 @@ from pydantic import BaseModel, Field
|
|
| 33 |
from starlette.datastructures import UploadFile as StarletteUploadFile
|
| 34 |
|
| 35 |
from .handler import AceStepHandler
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
JobStatus = Literal["queued", "running", "succeeded", "failed"]
|
|
@@ -42,6 +44,9 @@ class GenerateMusicRequest(BaseModel):
|
|
| 42 |
caption: str = Field(default="", description="Text caption describing the music")
|
| 43 |
lyrics: str = Field(default="", description="Lyric text")
|
| 44 |
|
|
|
|
|
|
|
|
|
|
| 45 |
bpm: Optional[int] = None
|
| 46 |
key_scale: str = ""
|
| 47 |
time_signature: str = ""
|
|
@@ -72,6 +77,36 @@ class GenerateMusicRequest(BaseModel):
|
|
| 72 |
audio_format: str = "mp3"
|
| 73 |
use_tiled_decode: bool = True
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
class CreateJobResponse(BaseModel):
|
| 77 |
job_id: str
|
|
@@ -88,6 +123,15 @@ class JobResult(BaseModel):
|
|
| 88 |
status_message: str = ""
|
| 89 |
seed_value: str = ""
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
class JobResponse(BaseModel):
|
| 93 |
job_id: str
|
|
@@ -240,12 +284,46 @@ def create_app() -> FastAPI:
|
|
| 240 |
for proxy_var in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
|
| 241 |
os.environ.pop(proxy_var, None)
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
handler = AceStepHandler()
|
|
|
|
| 244 |
init_lock = asyncio.Lock()
|
| 245 |
app.state._initialized = False
|
| 246 |
app.state._init_error = None
|
| 247 |
app.state._init_lock = init_lock
|
| 248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
max_workers = int(os.getenv("ACESTEP_API_WORKERS", "1"))
|
| 250 |
executor = ThreadPoolExecutor(max_workers=max_workers)
|
| 251 |
|
|
@@ -316,33 +394,305 @@ def create_app() -> FastAPI:
|
|
| 316 |
async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None:
|
| 317 |
job_store: _JobStore = app.state.job_store
|
| 318 |
h: AceStepHandler = app.state.handler
|
|
|
|
| 319 |
executor: ThreadPoolExecutor = app.state.executor
|
| 320 |
|
| 321 |
await _ensure_initialized()
|
| 322 |
job_store.mark_running(job_id)
|
| 323 |
|
| 324 |
def _blocking_generate() -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
first, second, paths, gen_info, status_msg, seed_value, *_ = h.generate_music(
|
| 326 |
captions=req.caption,
|
| 327 |
lyrics=req.lyrics,
|
| 328 |
-
bpm=
|
| 329 |
-
key_scale=
|
| 330 |
-
time_signature=
|
| 331 |
vocal_language=req.vocal_language,
|
| 332 |
inference_steps=req.inference_steps,
|
| 333 |
guidance_scale=req.guidance_scale,
|
| 334 |
use_random_seed=req.use_random_seed,
|
| 335 |
-
seed=req.seed,
|
| 336 |
reference_audio=req.reference_audio_path,
|
| 337 |
-
audio_duration=
|
| 338 |
batch_size=req.batch_size,
|
| 339 |
src_audio=req.src_audio_path,
|
| 340 |
-
audio_code_string=
|
| 341 |
repainting_start=req.repainting_start,
|
| 342 |
repainting_end=req.repainting_end,
|
| 343 |
-
instruction=
|
| 344 |
-
audio_cover_strength=
|
| 345 |
-
task_type=
|
| 346 |
use_adg=req.use_adg,
|
| 347 |
cfg_interval_start=req.cfg_interval_start,
|
| 348 |
cfg_interval_end=req.cfg_interval_end,
|
|
@@ -357,6 +707,7 @@ def create_app() -> FastAPI:
|
|
| 357 |
"generation_info": gen_info,
|
| 358 |
"status_message": status_msg,
|
| 359 |
"seed_value": seed_value,
|
|
|
|
| 360 |
}
|
| 361 |
|
| 362 |
t0 = time.time()
|
|
@@ -428,6 +779,7 @@ def create_app() -> FastAPI:
|
|
| 428 |
return GenerateMusicRequest(
|
| 429 |
caption=str(get("caption", "") or ""),
|
| 430 |
lyrics=str(get("lyrics", "") or ""),
|
|
|
|
| 431 |
bpm=_to_int(get("bpm"), None),
|
| 432 |
key_scale=str(get("key_scale", "") or ""),
|
| 433 |
time_signature=str(get("time_signature", "") or ""),
|
|
@@ -443,7 +795,7 @@ def create_app() -> FastAPI:
|
|
| 443 |
audio_code_string=str(get("audio_code_string", "") or ""),
|
| 444 |
repainting_start=_to_float(get("repainting_start"), 0.0) or 0.0,
|
| 445 |
repainting_end=_to_float(get("repainting_end"), None),
|
| 446 |
-
instruction=str(get("instruction",
|
| 447 |
audio_cover_strength=_to_float(get("audio_cover_strength"), 1.0) or 1.0,
|
| 448 |
task_type=str(get("task_type", "text2music") or "text2music"),
|
| 449 |
use_adg=_to_bool(get("use_adg"), False),
|
|
@@ -451,6 +803,16 @@ def create_app() -> FastAPI:
|
|
| 451 |
cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
|
| 452 |
audio_format=str(get("audio_format", "mp3") or "mp3"),
|
| 453 |
use_tiled_decode=_to_bool(get("use_tiled_decode"), True),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
)
|
| 455 |
|
| 456 |
def _first_value(v: Any) -> Any:
|
|
|
|
| 14 |
import asyncio
|
| 15 |
import json
|
| 16 |
import os
|
| 17 |
+
import re
|
| 18 |
import sys
|
| 19 |
import time
|
| 20 |
import traceback
|
|
|
|
| 34 |
from starlette.datastructures import UploadFile as StarletteUploadFile
|
| 35 |
|
| 36 |
from .handler import AceStepHandler
|
| 37 |
+
from .llm_inference import LLMHandler
|
| 38 |
|
| 39 |
|
| 40 |
JobStatus = Literal["queued", "running", "succeeded", "failed"]
|
|
|
|
| 44 |
caption: str = Field(default="", description="Text caption describing the music")
|
| 45 |
lyrics: str = Field(default="", description="Lyric text")
|
| 46 |
|
| 47 |
+
# Match feishu bot semantics: `dit` (metas only) vs `llm_dit` (metas + audio codes)
|
| 48 |
+
infer_type: Optional[Literal["dit", "llm_dit"]] = None
|
| 49 |
+
|
| 50 |
bpm: Optional[int] = None
|
| 51 |
key_scale: str = ""
|
| 52 |
time_signature: str = ""
|
|
|
|
| 77 |
audio_format: str = "mp3"
|
| 78 |
use_tiled_decode: bool = True
|
| 79 |
|
| 80 |
+
# 5Hz LM generation (server-side, like gradio's generate_lm_hints_wrapper)
|
| 81 |
+
use_5hz_lm: bool = False
|
| 82 |
+
lm_model_path: Optional[str] = None # e.g. "acestep-5Hz-lm-0.6B"
|
| 83 |
+
lm_backend: Literal["vllm", "pt"] = "vllm"
|
| 84 |
+
|
| 85 |
+
# Align defaults with `acestep/gradio_ui.py` and `feishu_bot/config.py`
|
| 86 |
+
# to improve lyric adherence in lm-dit mode.
|
| 87 |
+
lm_temperature: float = 0.85
|
| 88 |
+
lm_cfg_scale: float = 2.0
|
| 89 |
+
lm_top_k: Optional[int] = None
|
| 90 |
+
lm_top_p: Optional[float] = 0.9
|
| 91 |
+
lm_repetition_penalty: float = 1.0
|
| 92 |
+
lm_negative_prompt: str = "NO USER INPUT"
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
_LM_DEFAULT_TEMPERATURE = 0.85
|
| 96 |
+
_LM_DEFAULT_CFG_SCALE = 2.0
|
| 97 |
+
_LM_DEFAULT_TOP_P = 0.9
|
| 98 |
+
_DEFAULT_DIT_INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
|
| 99 |
+
_DEFAULT_LM_INSTRUCTION = "Generate audio semantic tokens based on the given conditions:"
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _normalize_infer_type(v: Any) -> Optional[str]:
|
| 103 |
+
s = str(v or "").strip().lower()
|
| 104 |
+
if not s:
|
| 105 |
+
return None
|
| 106 |
+
if s in {"dit", "llm_dit"}:
|
| 107 |
+
return s
|
| 108 |
+
return None
|
| 109 |
+
|
| 110 |
|
| 111 |
class CreateJobResponse(BaseModel):
|
| 112 |
job_id: str
|
|
|
|
| 123 |
status_message: str = ""
|
| 124 |
seed_value: str = ""
|
| 125 |
|
| 126 |
+
# 5Hz LM metadata (present when `use_5hz_lm=true` and server generates codes)
|
| 127 |
+
# Keep a raw-ish dict for clients that expect a `metas` object.
|
| 128 |
+
metas: Dict[str, Any] = Field(default_factory=dict)
|
| 129 |
+
bpm: Optional[int] = None
|
| 130 |
+
duration: Optional[float] = None
|
| 131 |
+
genres: Optional[str] = None
|
| 132 |
+
keyscale: Optional[str] = None
|
| 133 |
+
timesignature: Optional[str] = None
|
| 134 |
+
|
| 135 |
|
| 136 |
class JobResponse(BaseModel):
|
| 137 |
job_id: str
|
|
|
|
| 284 |
for proxy_var in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
|
| 285 |
os.environ.pop(proxy_var, None)
|
| 286 |
|
| 287 |
+
# Ensure compilation/temp caches do not fill up small default /tmp.
|
| 288 |
+
# Triton/Inductor (and the system compiler) can create large temporary files.
|
| 289 |
+
project_root = _get_project_root()
|
| 290 |
+
cache_root = os.path.join(project_root, ".cache", "acestep")
|
| 291 |
+
tmp_root = (os.getenv("ACESTEP_TMPDIR") or os.path.join(cache_root, "tmp")).strip()
|
| 292 |
+
triton_cache_root = (os.getenv("TRITON_CACHE_DIR") or os.path.join(cache_root, "triton")).strip()
|
| 293 |
+
inductor_cache_root = (os.getenv("TORCHINDUCTOR_CACHE_DIR") or os.path.join(cache_root, "torchinductor")).strip()
|
| 294 |
+
|
| 295 |
+
for p in [cache_root, tmp_root, triton_cache_root, inductor_cache_root]:
|
| 296 |
+
try:
|
| 297 |
+
os.makedirs(p, exist_ok=True)
|
| 298 |
+
except Exception:
|
| 299 |
+
# Best-effort: do not block startup if directory creation fails.
|
| 300 |
+
pass
|
| 301 |
+
|
| 302 |
+
# Respect explicit user overrides; if ACESTEP_TMPDIR is set, it should win.
|
| 303 |
+
if os.getenv("ACESTEP_TMPDIR"):
|
| 304 |
+
os.environ["TMPDIR"] = tmp_root
|
| 305 |
+
os.environ["TEMP"] = tmp_root
|
| 306 |
+
os.environ["TMP"] = tmp_root
|
| 307 |
+
else:
|
| 308 |
+
os.environ.setdefault("TMPDIR", tmp_root)
|
| 309 |
+
os.environ.setdefault("TEMP", tmp_root)
|
| 310 |
+
os.environ.setdefault("TMP", tmp_root)
|
| 311 |
+
|
| 312 |
+
os.environ.setdefault("TRITON_CACHE_DIR", triton_cache_root)
|
| 313 |
+
os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", inductor_cache_root)
|
| 314 |
+
|
| 315 |
handler = AceStepHandler()
|
| 316 |
+
llm_handler = LLMHandler()
|
| 317 |
init_lock = asyncio.Lock()
|
| 318 |
app.state._initialized = False
|
| 319 |
app.state._init_error = None
|
| 320 |
app.state._init_lock = init_lock
|
| 321 |
|
| 322 |
+
app.state.llm_handler = llm_handler
|
| 323 |
+
app.state._llm_initialized = False
|
| 324 |
+
app.state._llm_init_error = None
|
| 325 |
+
app.state._llm_init_lock = Lock()
|
| 326 |
+
|
| 327 |
max_workers = int(os.getenv("ACESTEP_API_WORKERS", "1"))
|
| 328 |
executor = ThreadPoolExecutor(max_workers=max_workers)
|
| 329 |
|
|
|
|
| 394 |
async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None:
|
| 395 |
job_store: _JobStore = app.state.job_store
|
| 396 |
h: AceStepHandler = app.state.handler
|
| 397 |
+
llm: LLMHandler = app.state.llm_handler
|
| 398 |
executor: ThreadPoolExecutor = app.state.executor
|
| 399 |
|
| 400 |
await _ensure_initialized()
|
| 401 |
job_store.mark_running(job_id)
|
| 402 |
|
| 403 |
def _blocking_generate() -> Dict[str, Any]:
|
| 404 |
+
def _normalize_optional_int(v: Any) -> Optional[int]:
|
| 405 |
+
if v is None:
|
| 406 |
+
return None
|
| 407 |
+
try:
|
| 408 |
+
iv = int(v)
|
| 409 |
+
except Exception:
|
| 410 |
+
return None
|
| 411 |
+
return None if iv == 0 else iv
|
| 412 |
+
|
| 413 |
+
def _normalize_optional_float(v: Any) -> Optional[float]:
|
| 414 |
+
if v is None:
|
| 415 |
+
return None
|
| 416 |
+
try:
|
| 417 |
+
fv = float(v)
|
| 418 |
+
except Exception:
|
| 419 |
+
return None
|
| 420 |
+
# gradio treats 1.0 as disabled for top_p
|
| 421 |
+
return None if fv >= 1.0 else fv
|
| 422 |
+
|
| 423 |
+
def _maybe_fill_from_metadata(current: GenerateMusicRequest, meta: Dict[str, Any]) -> tuple[Optional[int], str, str, Optional[float]]:
|
| 424 |
+
# Fill only when user did not provide values
|
| 425 |
+
bpm_val = current.bpm
|
| 426 |
+
if bpm_val is None:
|
| 427 |
+
try:
|
| 428 |
+
m = meta.get("bpm")
|
| 429 |
+
if m not in (None, "", "N/A"):
|
| 430 |
+
bpm_val = int(float(m))
|
| 431 |
+
except Exception:
|
| 432 |
+
bpm_val = current.bpm
|
| 433 |
+
|
| 434 |
+
key_scale_val = current.key_scale
|
| 435 |
+
if not key_scale_val:
|
| 436 |
+
m = meta.get("keyscale", meta.get("key_scale", ""))
|
| 437 |
+
if m not in (None, "", "N/A"):
|
| 438 |
+
key_scale_val = str(m)
|
| 439 |
+
|
| 440 |
+
time_sig_val = current.time_signature
|
| 441 |
+
if not time_sig_val:
|
| 442 |
+
m = meta.get("timesignature", meta.get("time_signature", ""))
|
| 443 |
+
if m not in (None, "", "N/A"):
|
| 444 |
+
time_sig_val = str(m)
|
| 445 |
+
|
| 446 |
+
dur_val = current.audio_duration
|
| 447 |
+
if dur_val is None:
|
| 448 |
+
m = meta.get("duration", meta.get("audio_duration"))
|
| 449 |
+
try:
|
| 450 |
+
if m not in (None, "", "N/A"):
|
| 451 |
+
dur_val = float(m)
|
| 452 |
+
if dur_val <= 0:
|
| 453 |
+
dur_val = None
|
| 454 |
+
# Avoid truncating lyrical songs when LM predicts a very short duration.
|
| 455 |
+
# (Users can still force a short duration by explicitly setting `audio_duration`.)
|
| 456 |
+
if dur_val is not None and (current.lyrics or "").strip():
|
| 457 |
+
min_dur = float(os.getenv("ACESTEP_LM_MIN_DURATION_SECONDS", "30"))
|
| 458 |
+
if dur_val < min_dur:
|
| 459 |
+
dur_val = None
|
| 460 |
+
except Exception:
|
| 461 |
+
dur_val = current.audio_duration
|
| 462 |
+
|
| 463 |
+
return bpm_val, key_scale_val, time_sig_val, dur_val
|
| 464 |
+
|
| 465 |
+
def _estimate_duration_from_lyrics(lyrics: str) -> Optional[float]:
|
| 466 |
+
lyrics = (lyrics or "").strip()
|
| 467 |
+
if not lyrics:
|
| 468 |
+
return None
|
| 469 |
+
|
| 470 |
+
# Best-effort heuristic: singing rate ~ 2.2 words/sec for English-like lyrics.
|
| 471 |
+
# For languages without spaces, fall back to non-space char count.
|
| 472 |
+
words = re.findall(r"[A-Za-z0-9']+", lyrics)
|
| 473 |
+
if len(words) >= 8:
|
| 474 |
+
words_per_sec = float(os.getenv("ACESTEP_LYRICS_WORDS_PER_SEC", "2.2"))
|
| 475 |
+
est = len(words) / max(0.5, words_per_sec)
|
| 476 |
+
else:
|
| 477 |
+
non_space = len(re.sub(r"\s+", "", lyrics))
|
| 478 |
+
chars_per_sec = float(os.getenv("ACESTEP_LYRICS_CHARS_PER_SEC", "12"))
|
| 479 |
+
est = non_space / max(4.0, chars_per_sec)
|
| 480 |
+
|
| 481 |
+
min_dur = float(os.getenv("ACESTEP_LYRICS_MIN_DURATION_SECONDS", "45"))
|
| 482 |
+
max_dur = float(os.getenv("ACESTEP_LYRICS_MAX_DURATION_SECONDS", "180"))
|
| 483 |
+
return float(min(max(est, min_dur), max_dur))
|
| 484 |
+
|
| 485 |
+
def _extract_lm_fields(meta: Dict[str, Any]) -> Dict[str, Any]:
|
| 486 |
+
def _none_if_na(v: Any) -> Any:
|
| 487 |
+
if v is None:
|
| 488 |
+
return None
|
| 489 |
+
if isinstance(v, str) and v.strip() in {"", "N/A"}:
|
| 490 |
+
return None
|
| 491 |
+
return v
|
| 492 |
+
|
| 493 |
+
out: Dict[str, Any] = {}
|
| 494 |
+
|
| 495 |
+
bpm_raw = _none_if_na(meta.get("bpm"))
|
| 496 |
+
try:
|
| 497 |
+
out["bpm"] = int(float(bpm_raw)) if bpm_raw is not None else None
|
| 498 |
+
except Exception:
|
| 499 |
+
out["bpm"] = None
|
| 500 |
+
|
| 501 |
+
dur_raw = _none_if_na(meta.get("duration"))
|
| 502 |
+
try:
|
| 503 |
+
out["duration"] = float(dur_raw) if dur_raw is not None else None
|
| 504 |
+
except Exception:
|
| 505 |
+
out["duration"] = None
|
| 506 |
+
|
| 507 |
+
genres_raw = _none_if_na(meta.get("genres"))
|
| 508 |
+
out["genres"] = str(genres_raw) if genres_raw is not None else None
|
| 509 |
+
|
| 510 |
+
keyscale_raw = _none_if_na(meta.get("keyscale", meta.get("key_scale")))
|
| 511 |
+
out["keyscale"] = str(keyscale_raw) if keyscale_raw is not None else None
|
| 512 |
+
|
| 513 |
+
ts_raw = _none_if_na(meta.get("timesignature", meta.get("time_signature")))
|
| 514 |
+
out["timesignature"] = str(ts_raw) if ts_raw is not None else None
|
| 515 |
+
|
| 516 |
+
return out
|
| 517 |
+
|
| 518 |
+
def _normalize_metas(meta: Dict[str, Any]) -> Dict[str, Any]:
|
| 519 |
+
"""Ensure a stable `metas` dict (keys always present)."""
|
| 520 |
+
meta = meta or {}
|
| 521 |
+
out: Dict[str, Any] = dict(meta)
|
| 522 |
+
|
| 523 |
+
# Normalize key aliases
|
| 524 |
+
if "keyscale" not in out and "key_scale" in out:
|
| 525 |
+
out["keyscale"] = out.get("key_scale")
|
| 526 |
+
if "timesignature" not in out and "time_signature" in out:
|
| 527 |
+
out["timesignature"] = out.get("time_signature")
|
| 528 |
+
|
| 529 |
+
# Ensure required keys exist
|
| 530 |
+
for k in ["bpm", "duration", "genres", "keyscale", "timesignature"]:
|
| 531 |
+
if out.get(k) in (None, ""):
|
| 532 |
+
out[k] = "N/A"
|
| 533 |
+
return out
|
| 534 |
+
|
| 535 |
+
# Optional: generate 5Hz LM codes server-side
|
| 536 |
+
audio_code_string = req.audio_code_string
|
| 537 |
+
bpm_val = req.bpm
|
| 538 |
+
key_scale_val = req.key_scale
|
| 539 |
+
time_sig_val = req.time_signature
|
| 540 |
+
audio_duration_val = req.audio_duration
|
| 541 |
+
|
| 542 |
+
# Infer type semantics: `dit` => metas only, `llm_dit` => metas + audio codes.
|
| 543 |
+
# Default to llm_dit only when we actually have (or will generate) codes.
|
| 544 |
+
explicit_infer = (req.infer_type or "").strip().lower() in {"dit", "llm_dit"}
|
| 545 |
+
infer_type = (req.infer_type or "").strip().lower()
|
| 546 |
+
if infer_type not in {"dit", "llm_dit"}:
|
| 547 |
+
has_codes = bool(audio_code_string and str(audio_code_string).strip())
|
| 548 |
+
infer_type = "llm_dit" if (req.use_5hz_lm or has_codes) else "dit"
|
| 549 |
+
|
| 550 |
+
# If LM-generated code hints are used, a too-strong cover strength can suppress lyric/vocal conditioning.
|
| 551 |
+
# We keep backward compatibility: only auto-adjust when user didn't override (still at default 1.0).
|
| 552 |
+
audio_cover_strength_val = float(req.audio_cover_strength)
|
| 553 |
+
|
| 554 |
+
lm_fields: Dict[str, Any] = {}
|
| 555 |
+
|
| 556 |
+
# Determine effective batch size (used for per-sample LM code diversity)
|
| 557 |
+
effective_batch_size = req.batch_size
|
| 558 |
+
if effective_batch_size is None:
|
| 559 |
+
try:
|
| 560 |
+
effective_batch_size = int(getattr(h, "batch_size", 1))
|
| 561 |
+
except Exception:
|
| 562 |
+
effective_batch_size = 1
|
| 563 |
+
effective_batch_size = max(1, int(effective_batch_size))
|
| 564 |
+
|
| 565 |
+
if req.use_5hz_lm and not (audio_code_string and str(audio_code_string).strip()):
|
| 566 |
+
# Lazy init 5Hz LM once
|
| 567 |
+
with app.state._llm_init_lock:
|
| 568 |
+
if getattr(app.state, "_llm_initialized", False) is False and getattr(app.state, "_llm_init_error", None) is None:
|
| 569 |
+
project_root = _get_project_root()
|
| 570 |
+
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 571 |
+
lm_model_path = (req.lm_model_path or os.getenv("ACESTEP_LM_MODEL_PATH") or "acestep-5Hz-lm-0.6B").strip()
|
| 572 |
+
backend = (req.lm_backend or os.getenv("ACESTEP_LM_BACKEND") or "vllm").strip().lower()
|
| 573 |
+
if backend not in {"vllm", "pt"}:
|
| 574 |
+
backend = "vllm"
|
| 575 |
+
|
| 576 |
+
lm_device = os.getenv("ACESTEP_LM_DEVICE", os.getenv("ACESTEP_DEVICE", "auto"))
|
| 577 |
+
lm_offload = _env_bool("ACESTEP_LM_OFFLOAD_TO_CPU", False)
|
| 578 |
+
|
| 579 |
+
status, ok = llm.initialize(
|
| 580 |
+
checkpoint_dir=checkpoint_dir,
|
| 581 |
+
lm_model_path=lm_model_path,
|
| 582 |
+
backend=backend,
|
| 583 |
+
device=lm_device,
|
| 584 |
+
offload_to_cpu=lm_offload,
|
| 585 |
+
dtype=h.dtype,
|
| 586 |
+
)
|
| 587 |
+
if not ok:
|
| 588 |
+
app.state._llm_init_error = status
|
| 589 |
+
else:
|
| 590 |
+
app.state._llm_initialized = True
|
| 591 |
+
|
| 592 |
+
if getattr(app.state, "_llm_init_error", None):
|
| 593 |
+
raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
|
| 594 |
+
|
| 595 |
+
def _lm_call() -> tuple[Dict[str, Any], str, str]:
|
| 596 |
+
return llm.generate_with_stop_condition(
|
| 597 |
+
caption=req.caption,
|
| 598 |
+
lyrics=req.lyrics,
|
| 599 |
+
infer_type=infer_type,
|
| 600 |
+
temperature=float(req.lm_temperature),
|
| 601 |
+
cfg_scale=max(1.0, float(req.lm_cfg_scale)),
|
| 602 |
+
negative_prompt=str(req.lm_negative_prompt or "NO USER INPUT"),
|
| 603 |
+
top_k=_normalize_optional_int(req.lm_top_k),
|
| 604 |
+
top_p=_normalize_optional_float(req.lm_top_p),
|
| 605 |
+
repetition_penalty=float(req.lm_repetition_penalty),
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
meta, codes, status = _lm_call()
|
| 609 |
+
|
| 610 |
+
if infer_type == "llm_dit":
|
| 611 |
+
if not codes:
|
| 612 |
+
raise RuntimeError(f"5Hz LM generation failed: {status}")
|
| 613 |
+
|
| 614 |
+
# LM once per job; rely on DiT seeds for batch diversity.
|
| 615 |
+
# For convenience, replicate the same codes across the batch.
|
| 616 |
+
if effective_batch_size > 1:
|
| 617 |
+
# use the same codes for all in the batch
|
| 618 |
+
audio_code_string = [codes] * effective_batch_size
|
| 619 |
+
|
| 620 |
+
# If needed in future: call LM multiple times for more diverse codes.
|
| 621 |
+
# codes_list: list[str] = [codes]
|
| 622 |
+
# for _ in range(effective_batch_size - 1):
|
| 623 |
+
# _m2, _c2, _s2 = _lm_call()
|
| 624 |
+
# if not _c2:
|
| 625 |
+
# raise RuntimeError(f"5Hz LM generation failed: {_s2}")
|
| 626 |
+
# codes_list.append(_c2)
|
| 627 |
+
# audio_code_string = codes_list
|
| 628 |
+
else:
|
| 629 |
+
audio_code_string = codes
|
| 630 |
+
|
| 631 |
+
lm_fields = {
|
| 632 |
+
"metas": _normalize_metas(meta),
|
| 633 |
+
**_extract_lm_fields(meta),
|
| 634 |
+
}
|
| 635 |
+
bpm_val, key_scale_val, time_sig_val, audio_duration_val = _maybe_fill_from_metadata(req, meta)
|
| 636 |
+
|
| 637 |
+
# If user provided long lyrics but LM didn't provide a usable duration, estimate a longer duration.
|
| 638 |
+
if infer_type == "llm_dit" and audio_duration_val is None and (req.audio_duration is None):
|
| 639 |
+
est = _estimate_duration_from_lyrics(req.lyrics)
|
| 640 |
+
if est is not None:
|
| 641 |
+
audio_duration_val = est
|
| 642 |
+
|
| 643 |
+
# Optional: auto-tune LM cover strength (opt-in) to avoid suppressing lyric/vocal conditioning.
|
| 644 |
+
if infer_type == "llm_dit" and audio_cover_strength_val >= 0.999 and (req.lyrics or "").strip():
|
| 645 |
+
tuned = os.getenv("ACESTEP_LM_COVER_STRENGTH")
|
| 646 |
+
if tuned is not None and tuned.strip() != "":
|
| 647 |
+
audio_cover_strength_val = float(tuned)
|
| 648 |
+
|
| 649 |
+
# Align behavior with feishu bot:
|
| 650 |
+
# - dit: metas only (ignore audio codes), keep text2music.
|
| 651 |
+
# - llm_dit: metas + audio codes, run in cover mode with LM instruction.
|
| 652 |
+
instruction_val = req.instruction
|
| 653 |
+
task_type_val = (req.task_type or "").strip() or "text2music"
|
| 654 |
+
|
| 655 |
+
if infer_type == "dit":
|
| 656 |
+
audio_code_string = ""
|
| 657 |
+
if task_type_val == "cover":
|
| 658 |
+
task_type_val = "text2music"
|
| 659 |
+
if (instruction_val or "").strip() in {"", _DEFAULT_LM_INSTRUCTION}:
|
| 660 |
+
instruction_val = _DEFAULT_DIT_INSTRUCTION
|
| 661 |
+
|
| 662 |
+
if infer_type == "llm_dit":
|
| 663 |
+
task_type_val = "cover"
|
| 664 |
+
if (instruction_val or "").strip() in {"", _DEFAULT_DIT_INSTRUCTION}:
|
| 665 |
+
instruction_val = _DEFAULT_LM_INSTRUCTION
|
| 666 |
+
|
| 667 |
+
if not (audio_code_string and str(audio_code_string).strip()):
|
| 668 |
+
if explicit_infer or req.use_5hz_lm:
|
| 669 |
+
raise RuntimeError("llm_dit requires non-empty audio codes: provide 'audio_code_string' or set 'use_5hz_lm=true'.")
|
| 670 |
+
# If not explicitly requested, fall back to dit semantics.
|
| 671 |
+
infer_type = "dit"
|
| 672 |
+
task_type_val = "text2music"
|
| 673 |
+
instruction_val = _DEFAULT_DIT_INSTRUCTION
|
| 674 |
+
|
| 675 |
first, second, paths, gen_info, status_msg, seed_value, *_ = h.generate_music(
|
| 676 |
captions=req.caption,
|
| 677 |
lyrics=req.lyrics,
|
| 678 |
+
bpm=bpm_val,
|
| 679 |
+
key_scale=key_scale_val,
|
| 680 |
+
time_signature=time_sig_val,
|
| 681 |
vocal_language=req.vocal_language,
|
| 682 |
inference_steps=req.inference_steps,
|
| 683 |
guidance_scale=req.guidance_scale,
|
| 684 |
use_random_seed=req.use_random_seed,
|
| 685 |
+
seed=("-1" if (req.use_random_seed and int(req.seed) < 0) else str(req.seed)),
|
| 686 |
reference_audio=req.reference_audio_path,
|
| 687 |
+
audio_duration=audio_duration_val,
|
| 688 |
batch_size=req.batch_size,
|
| 689 |
src_audio=req.src_audio_path,
|
| 690 |
+
audio_code_string=audio_code_string,
|
| 691 |
repainting_start=req.repainting_start,
|
| 692 |
repainting_end=req.repainting_end,
|
| 693 |
+
instruction=instruction_val,
|
| 694 |
+
audio_cover_strength=audio_cover_strength_val,
|
| 695 |
+
task_type=task_type_val,
|
| 696 |
use_adg=req.use_adg,
|
| 697 |
cfg_interval_start=req.cfg_interval_start,
|
| 698 |
cfg_interval_end=req.cfg_interval_end,
|
|
|
|
| 707 |
"generation_info": gen_info,
|
| 708 |
"status_message": status_msg,
|
| 709 |
"seed_value": seed_value,
|
| 710 |
+
**lm_fields,
|
| 711 |
}
|
| 712 |
|
| 713 |
t0 = time.time()
|
|
|
|
| 779 |
return GenerateMusicRequest(
|
| 780 |
caption=str(get("caption", "") or ""),
|
| 781 |
lyrics=str(get("lyrics", "") or ""),
|
| 782 |
+
infer_type=_normalize_infer_type(get("infer_type")),
|
| 783 |
bpm=_to_int(get("bpm"), None),
|
| 784 |
key_scale=str(get("key_scale", "") or ""),
|
| 785 |
time_signature=str(get("time_signature", "") or ""),
|
|
|
|
| 795 |
audio_code_string=str(get("audio_code_string", "") or ""),
|
| 796 |
repainting_start=_to_float(get("repainting_start"), 0.0) or 0.0,
|
| 797 |
repainting_end=_to_float(get("repainting_end"), None),
|
| 798 |
+
instruction=str(get("instruction", _DEFAULT_DIT_INSTRUCTION) or ""),
|
| 799 |
audio_cover_strength=_to_float(get("audio_cover_strength"), 1.0) or 1.0,
|
| 800 |
task_type=str(get("task_type", "text2music") or "text2music"),
|
| 801 |
use_adg=_to_bool(get("use_adg"), False),
|
|
|
|
| 803 |
cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
|
| 804 |
audio_format=str(get("audio_format", "mp3") or "mp3"),
|
| 805 |
use_tiled_decode=_to_bool(get("use_tiled_decode"), True),
|
| 806 |
+
|
| 807 |
+
use_5hz_lm=_to_bool(get("use_5hz_lm"), False),
|
| 808 |
+
lm_model_path=str(get("lm_model_path") or "").strip() or None,
|
| 809 |
+
lm_backend=str(get("lm_backend", "vllm") or "vllm"),
|
| 810 |
+
lm_temperature=_to_float(get("lm_temperature"), _LM_DEFAULT_TEMPERATURE) or _LM_DEFAULT_TEMPERATURE,
|
| 811 |
+
lm_cfg_scale=_to_float(get("lm_cfg_scale"), _LM_DEFAULT_CFG_SCALE) or _LM_DEFAULT_CFG_SCALE,
|
| 812 |
+
lm_top_k=_to_int(get("lm_top_k"), None),
|
| 813 |
+
lm_top_p=_to_float(get("lm_top_p"), _LM_DEFAULT_TOP_P),
|
| 814 |
+
lm_repetition_penalty=_to_float(get("lm_repetition_penalty"), 1.0) or 1.0,
|
| 815 |
+
lm_negative_prompt=str(get("lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
|
| 816 |
)
|
| 817 |
|
| 818 |
def _first_value(v: Any) -> Any:
|