ChuxiJ commited on
Commit
ae5026d
·
1 Parent(s): 3092911

feat: api support lm-dit

Browse files
Files changed (1) hide show
  1. acestep/api_server.py +372 -10
acestep/api_server.py CHANGED
@@ -14,6 +14,7 @@ from __future__ import annotations
14
  import asyncio
15
  import json
16
  import os
 
17
  import sys
18
  import time
19
  import traceback
@@ -33,6 +34,7 @@ from pydantic import BaseModel, Field
33
  from starlette.datastructures import UploadFile as StarletteUploadFile
34
 
35
  from .handler import AceStepHandler
 
36
 
37
 
38
  JobStatus = Literal["queued", "running", "succeeded", "failed"]
@@ -42,6 +44,9 @@ class GenerateMusicRequest(BaseModel):
42
  caption: str = Field(default="", description="Text caption describing the music")
43
  lyrics: str = Field(default="", description="Lyric text")
44
 
 
 
 
45
  bpm: Optional[int] = None
46
  key_scale: str = ""
47
  time_signature: str = ""
@@ -72,6 +77,36 @@ class GenerateMusicRequest(BaseModel):
72
  audio_format: str = "mp3"
73
  use_tiled_decode: bool = True
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  class CreateJobResponse(BaseModel):
77
  job_id: str
@@ -88,6 +123,15 @@ class JobResult(BaseModel):
88
  status_message: str = ""
89
  seed_value: str = ""
90
 
 
 
 
 
 
 
 
 
 
91
 
92
  class JobResponse(BaseModel):
93
  job_id: str
@@ -240,12 +284,46 @@ def create_app() -> FastAPI:
240
  for proxy_var in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
241
  os.environ.pop(proxy_var, None)
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  handler = AceStepHandler()
 
244
  init_lock = asyncio.Lock()
245
  app.state._initialized = False
246
  app.state._init_error = None
247
  app.state._init_lock = init_lock
248
 
 
 
 
 
 
249
  max_workers = int(os.getenv("ACESTEP_API_WORKERS", "1"))
250
  executor = ThreadPoolExecutor(max_workers=max_workers)
251
 
@@ -316,33 +394,305 @@ def create_app() -> FastAPI:
316
  async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None:
317
  job_store: _JobStore = app.state.job_store
318
  h: AceStepHandler = app.state.handler
 
319
  executor: ThreadPoolExecutor = app.state.executor
320
 
321
  await _ensure_initialized()
322
  job_store.mark_running(job_id)
323
 
324
  def _blocking_generate() -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  first, second, paths, gen_info, status_msg, seed_value, *_ = h.generate_music(
326
  captions=req.caption,
327
  lyrics=req.lyrics,
328
- bpm=req.bpm,
329
- key_scale=req.key_scale,
330
- time_signature=req.time_signature,
331
  vocal_language=req.vocal_language,
332
  inference_steps=req.inference_steps,
333
  guidance_scale=req.guidance_scale,
334
  use_random_seed=req.use_random_seed,
335
- seed=req.seed,
336
  reference_audio=req.reference_audio_path,
337
- audio_duration=req.audio_duration,
338
  batch_size=req.batch_size,
339
  src_audio=req.src_audio_path,
340
- audio_code_string=req.audio_code_string,
341
  repainting_start=req.repainting_start,
342
  repainting_end=req.repainting_end,
343
- instruction=req.instruction,
344
- audio_cover_strength=req.audio_cover_strength,
345
- task_type=req.task_type,
346
  use_adg=req.use_adg,
347
  cfg_interval_start=req.cfg_interval_start,
348
  cfg_interval_end=req.cfg_interval_end,
@@ -357,6 +707,7 @@ def create_app() -> FastAPI:
357
  "generation_info": gen_info,
358
  "status_message": status_msg,
359
  "seed_value": seed_value,
 
360
  }
361
 
362
  t0 = time.time()
@@ -428,6 +779,7 @@ def create_app() -> FastAPI:
428
  return GenerateMusicRequest(
429
  caption=str(get("caption", "") or ""),
430
  lyrics=str(get("lyrics", "") or ""),
 
431
  bpm=_to_int(get("bpm"), None),
432
  key_scale=str(get("key_scale", "") or ""),
433
  time_signature=str(get("time_signature", "") or ""),
@@ -443,7 +795,7 @@ def create_app() -> FastAPI:
443
  audio_code_string=str(get("audio_code_string", "") or ""),
444
  repainting_start=_to_float(get("repainting_start"), 0.0) or 0.0,
445
  repainting_end=_to_float(get("repainting_end"), None),
446
- instruction=str(get("instruction", "Fill the audio semantic mask based on the given conditions:") or ""),
447
  audio_cover_strength=_to_float(get("audio_cover_strength"), 1.0) or 1.0,
448
  task_type=str(get("task_type", "text2music") or "text2music"),
449
  use_adg=_to_bool(get("use_adg"), False),
@@ -451,6 +803,16 @@ def create_app() -> FastAPI:
451
  cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
452
  audio_format=str(get("audio_format", "mp3") or "mp3"),
453
  use_tiled_decode=_to_bool(get("use_tiled_decode"), True),
 
 
 
 
 
 
 
 
 
 
454
  )
455
 
456
  def _first_value(v: Any) -> Any:
 
14
  import asyncio
15
  import json
16
  import os
17
+ import re
18
  import sys
19
  import time
20
  import traceback
 
34
  from starlette.datastructures import UploadFile as StarletteUploadFile
35
 
36
  from .handler import AceStepHandler
37
+ from .llm_inference import LLMHandler
38
 
39
 
40
  JobStatus = Literal["queued", "running", "succeeded", "failed"]
 
44
  caption: str = Field(default="", description="Text caption describing the music")
45
  lyrics: str = Field(default="", description="Lyric text")
46
 
47
+ # Match feishu bot semantics: `dit` (metas only) vs `llm_dit` (metas + audio codes)
48
+ infer_type: Optional[Literal["dit", "llm_dit"]] = None
49
+
50
  bpm: Optional[int] = None
51
  key_scale: str = ""
52
  time_signature: str = ""
 
77
  audio_format: str = "mp3"
78
  use_tiled_decode: bool = True
79
 
80
+ # 5Hz LM generation (server-side, like gradio's generate_lm_hints_wrapper)
81
+ use_5hz_lm: bool = False
82
+ lm_model_path: Optional[str] = None # e.g. "acestep-5Hz-lm-0.6B"
83
+ lm_backend: Literal["vllm", "pt"] = "vllm"
84
+
85
+ # Align defaults with `acestep/gradio_ui.py` and `feishu_bot/config.py`
86
+ # to improve lyric adherence in lm-dit mode.
87
+ lm_temperature: float = 0.85
88
+ lm_cfg_scale: float = 2.0
89
+ lm_top_k: Optional[int] = None
90
+ lm_top_p: Optional[float] = 0.9
91
+ lm_repetition_penalty: float = 1.0
92
+ lm_negative_prompt: str = "NO USER INPUT"
93
+
94
+
95
+ _LM_DEFAULT_TEMPERATURE = 0.85
96
+ _LM_DEFAULT_CFG_SCALE = 2.0
97
+ _LM_DEFAULT_TOP_P = 0.9
98
+ _DEFAULT_DIT_INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
99
+ _DEFAULT_LM_INSTRUCTION = "Generate audio semantic tokens based on the given conditions:"
100
+
101
+
102
+ def _normalize_infer_type(v: Any) -> Optional[str]:
103
+ s = str(v or "").strip().lower()
104
+ if not s:
105
+ return None
106
+ if s in {"dit", "llm_dit"}:
107
+ return s
108
+ return None
109
+
110
 
111
  class CreateJobResponse(BaseModel):
112
  job_id: str
 
123
  status_message: str = ""
124
  seed_value: str = ""
125
 
126
+ # 5Hz LM metadata (present when `use_5hz_lm=true` and server generates codes)
127
+ # Keep a raw-ish dict for clients that expect a `metas` object.
128
+ metas: Dict[str, Any] = Field(default_factory=dict)
129
+ bpm: Optional[int] = None
130
+ duration: Optional[float] = None
131
+ genres: Optional[str] = None
132
+ keyscale: Optional[str] = None
133
+ timesignature: Optional[str] = None
134
+
135
 
136
  class JobResponse(BaseModel):
137
  job_id: str
 
284
  for proxy_var in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
285
  os.environ.pop(proxy_var, None)
286
 
287
+ # Ensure compilation/temp caches do not fill up small default /tmp.
288
+ # Triton/Inductor (and the system compiler) can create large temporary files.
289
+ project_root = _get_project_root()
290
+ cache_root = os.path.join(project_root, ".cache", "acestep")
291
+ tmp_root = (os.getenv("ACESTEP_TMPDIR") or os.path.join(cache_root, "tmp")).strip()
292
+ triton_cache_root = (os.getenv("TRITON_CACHE_DIR") or os.path.join(cache_root, "triton")).strip()
293
+ inductor_cache_root = (os.getenv("TORCHINDUCTOR_CACHE_DIR") or os.path.join(cache_root, "torchinductor")).strip()
294
+
295
+ for p in [cache_root, tmp_root, triton_cache_root, inductor_cache_root]:
296
+ try:
297
+ os.makedirs(p, exist_ok=True)
298
+ except Exception:
299
+ # Best-effort: do not block startup if directory creation fails.
300
+ pass
301
+
302
+ # Respect explicit user overrides; if ACESTEP_TMPDIR is set, it should win.
303
+ if os.getenv("ACESTEP_TMPDIR"):
304
+ os.environ["TMPDIR"] = tmp_root
305
+ os.environ["TEMP"] = tmp_root
306
+ os.environ["TMP"] = tmp_root
307
+ else:
308
+ os.environ.setdefault("TMPDIR", tmp_root)
309
+ os.environ.setdefault("TEMP", tmp_root)
310
+ os.environ.setdefault("TMP", tmp_root)
311
+
312
+ os.environ.setdefault("TRITON_CACHE_DIR", triton_cache_root)
313
+ os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", inductor_cache_root)
314
+
315
  handler = AceStepHandler()
316
+ llm_handler = LLMHandler()
317
  init_lock = asyncio.Lock()
318
  app.state._initialized = False
319
  app.state._init_error = None
320
  app.state._init_lock = init_lock
321
 
322
+ app.state.llm_handler = llm_handler
323
+ app.state._llm_initialized = False
324
+ app.state._llm_init_error = None
325
+ app.state._llm_init_lock = Lock()
326
+
327
  max_workers = int(os.getenv("ACESTEP_API_WORKERS", "1"))
328
  executor = ThreadPoolExecutor(max_workers=max_workers)
329
 
 
394
  async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None:
395
  job_store: _JobStore = app.state.job_store
396
  h: AceStepHandler = app.state.handler
397
+ llm: LLMHandler = app.state.llm_handler
398
  executor: ThreadPoolExecutor = app.state.executor
399
 
400
  await _ensure_initialized()
401
  job_store.mark_running(job_id)
402
 
403
  def _blocking_generate() -> Dict[str, Any]:
404
+ def _normalize_optional_int(v: Any) -> Optional[int]:
405
+ if v is None:
406
+ return None
407
+ try:
408
+ iv = int(v)
409
+ except Exception:
410
+ return None
411
+ return None if iv == 0 else iv
412
+
413
+ def _normalize_optional_float(v: Any) -> Optional[float]:
414
+ if v is None:
415
+ return None
416
+ try:
417
+ fv = float(v)
418
+ except Exception:
419
+ return None
420
+ # gradio treats 1.0 as disabled for top_p
421
+ return None if fv >= 1.0 else fv
422
+
423
+ def _maybe_fill_from_metadata(current: GenerateMusicRequest, meta: Dict[str, Any]) -> tuple[Optional[int], str, str, Optional[float]]:
424
+ # Fill only when user did not provide values
425
+ bpm_val = current.bpm
426
+ if bpm_val is None:
427
+ try:
428
+ m = meta.get("bpm")
429
+ if m not in (None, "", "N/A"):
430
+ bpm_val = int(float(m))
431
+ except Exception:
432
+ bpm_val = current.bpm
433
+
434
+ key_scale_val = current.key_scale
435
+ if not key_scale_val:
436
+ m = meta.get("keyscale", meta.get("key_scale", ""))
437
+ if m not in (None, "", "N/A"):
438
+ key_scale_val = str(m)
439
+
440
+ time_sig_val = current.time_signature
441
+ if not time_sig_val:
442
+ m = meta.get("timesignature", meta.get("time_signature", ""))
443
+ if m not in (None, "", "N/A"):
444
+ time_sig_val = str(m)
445
+
446
+ dur_val = current.audio_duration
447
+ if dur_val is None:
448
+ m = meta.get("duration", meta.get("audio_duration"))
449
+ try:
450
+ if m not in (None, "", "N/A"):
451
+ dur_val = float(m)
452
+ if dur_val <= 0:
453
+ dur_val = None
454
+ # Avoid truncating lyrical songs when LM predicts a very short duration.
455
+ # (Users can still force a short duration by explicitly setting `audio_duration`.)
456
+ if dur_val is not None and (current.lyrics or "").strip():
457
+ min_dur = float(os.getenv("ACESTEP_LM_MIN_DURATION_SECONDS", "30"))
458
+ if dur_val < min_dur:
459
+ dur_val = None
460
+ except Exception:
461
+ dur_val = current.audio_duration
462
+
463
+ return bpm_val, key_scale_val, time_sig_val, dur_val
464
+
465
+ def _estimate_duration_from_lyrics(lyrics: str) -> Optional[float]:
466
+ lyrics = (lyrics or "").strip()
467
+ if not lyrics:
468
+ return None
469
+
470
+ # Best-effort heuristic: singing rate ~ 2.2 words/sec for English-like lyrics.
471
+ # For languages without spaces, fall back to non-space char count.
472
+ words = re.findall(r"[A-Za-z0-9']+", lyrics)
473
+ if len(words) >= 8:
474
+ words_per_sec = float(os.getenv("ACESTEP_LYRICS_WORDS_PER_SEC", "2.2"))
475
+ est = len(words) / max(0.5, words_per_sec)
476
+ else:
477
+ non_space = len(re.sub(r"\s+", "", lyrics))
478
+ chars_per_sec = float(os.getenv("ACESTEP_LYRICS_CHARS_PER_SEC", "12"))
479
+ est = non_space / max(4.0, chars_per_sec)
480
+
481
+ min_dur = float(os.getenv("ACESTEP_LYRICS_MIN_DURATION_SECONDS", "45"))
482
+ max_dur = float(os.getenv("ACESTEP_LYRICS_MAX_DURATION_SECONDS", "180"))
483
+ return float(min(max(est, min_dur), max_dur))
484
+
485
+ def _extract_lm_fields(meta: Dict[str, Any]) -> Dict[str, Any]:
486
+ def _none_if_na(v: Any) -> Any:
487
+ if v is None:
488
+ return None
489
+ if isinstance(v, str) and v.strip() in {"", "N/A"}:
490
+ return None
491
+ return v
492
+
493
+ out: Dict[str, Any] = {}
494
+
495
+ bpm_raw = _none_if_na(meta.get("bpm"))
496
+ try:
497
+ out["bpm"] = int(float(bpm_raw)) if bpm_raw is not None else None
498
+ except Exception:
499
+ out["bpm"] = None
500
+
501
+ dur_raw = _none_if_na(meta.get("duration"))
502
+ try:
503
+ out["duration"] = float(dur_raw) if dur_raw is not None else None
504
+ except Exception:
505
+ out["duration"] = None
506
+
507
+ genres_raw = _none_if_na(meta.get("genres"))
508
+ out["genres"] = str(genres_raw) if genres_raw is not None else None
509
+
510
+ keyscale_raw = _none_if_na(meta.get("keyscale", meta.get("key_scale")))
511
+ out["keyscale"] = str(keyscale_raw) if keyscale_raw is not None else None
512
+
513
+ ts_raw = _none_if_na(meta.get("timesignature", meta.get("time_signature")))
514
+ out["timesignature"] = str(ts_raw) if ts_raw is not None else None
515
+
516
+ return out
517
+
518
+ def _normalize_metas(meta: Dict[str, Any]) -> Dict[str, Any]:
519
+ """Ensure a stable `metas` dict (keys always present)."""
520
+ meta = meta or {}
521
+ out: Dict[str, Any] = dict(meta)
522
+
523
+ # Normalize key aliases
524
+ if "keyscale" not in out and "key_scale" in out:
525
+ out["keyscale"] = out.get("key_scale")
526
+ if "timesignature" not in out and "time_signature" in out:
527
+ out["timesignature"] = out.get("time_signature")
528
+
529
+ # Ensure required keys exist
530
+ for k in ["bpm", "duration", "genres", "keyscale", "timesignature"]:
531
+ if out.get(k) in (None, ""):
532
+ out[k] = "N/A"
533
+ return out
534
+
535
+ # Optional: generate 5Hz LM codes server-side
536
+ audio_code_string = req.audio_code_string
537
+ bpm_val = req.bpm
538
+ key_scale_val = req.key_scale
539
+ time_sig_val = req.time_signature
540
+ audio_duration_val = req.audio_duration
541
+
542
+ # Infer type semantics: `dit` => metas only, `llm_dit` => metas + audio codes.
543
+ # Default to llm_dit only when we actually have (or will generate) codes.
544
+ explicit_infer = (req.infer_type or "").strip().lower() in {"dit", "llm_dit"}
545
+ infer_type = (req.infer_type or "").strip().lower()
546
+ if infer_type not in {"dit", "llm_dit"}:
547
+ has_codes = bool(audio_code_string and str(audio_code_string).strip())
548
+ infer_type = "llm_dit" if (req.use_5hz_lm or has_codes) else "dit"
549
+
550
+ # If LM-generated code hints are used, a too-strong cover strength can suppress lyric/vocal conditioning.
551
+ # We keep backward compatibility: only auto-adjust when user didn't override (still at default 1.0).
552
+ audio_cover_strength_val = float(req.audio_cover_strength)
553
+
554
+ lm_fields: Dict[str, Any] = {}
555
+
556
+ # Determine effective batch size (used for per-sample LM code diversity)
557
+ effective_batch_size = req.batch_size
558
+ if effective_batch_size is None:
559
+ try:
560
+ effective_batch_size = int(getattr(h, "batch_size", 1))
561
+ except Exception:
562
+ effective_batch_size = 1
563
+ effective_batch_size = max(1, int(effective_batch_size))
564
+
565
+ if req.use_5hz_lm and not (audio_code_string and str(audio_code_string).strip()):
566
+ # Lazy init 5Hz LM once
567
+ with app.state._llm_init_lock:
568
+ if getattr(app.state, "_llm_initialized", False) is False and getattr(app.state, "_llm_init_error", None) is None:
569
+ project_root = _get_project_root()
570
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
571
+ lm_model_path = (req.lm_model_path or os.getenv("ACESTEP_LM_MODEL_PATH") or "acestep-5Hz-lm-0.6B").strip()
572
+ backend = (req.lm_backend or os.getenv("ACESTEP_LM_BACKEND") or "vllm").strip().lower()
573
+ if backend not in {"vllm", "pt"}:
574
+ backend = "vllm"
575
+
576
+ lm_device = os.getenv("ACESTEP_LM_DEVICE", os.getenv("ACESTEP_DEVICE", "auto"))
577
+ lm_offload = _env_bool("ACESTEP_LM_OFFLOAD_TO_CPU", False)
578
+
579
+ status, ok = llm.initialize(
580
+ checkpoint_dir=checkpoint_dir,
581
+ lm_model_path=lm_model_path,
582
+ backend=backend,
583
+ device=lm_device,
584
+ offload_to_cpu=lm_offload,
585
+ dtype=h.dtype,
586
+ )
587
+ if not ok:
588
+ app.state._llm_init_error = status
589
+ else:
590
+ app.state._llm_initialized = True
591
+
592
+ if getattr(app.state, "_llm_init_error", None):
593
+ raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
594
+
595
+ def _lm_call() -> tuple[Dict[str, Any], str, str]:
596
+ return llm.generate_with_stop_condition(
597
+ caption=req.caption,
598
+ lyrics=req.lyrics,
599
+ infer_type=infer_type,
600
+ temperature=float(req.lm_temperature),
601
+ cfg_scale=max(1.0, float(req.lm_cfg_scale)),
602
+ negative_prompt=str(req.lm_negative_prompt or "NO USER INPUT"),
603
+ top_k=_normalize_optional_int(req.lm_top_k),
604
+ top_p=_normalize_optional_float(req.lm_top_p),
605
+ repetition_penalty=float(req.lm_repetition_penalty),
606
+ )
607
+
608
+ meta, codes, status = _lm_call()
609
+
610
+ if infer_type == "llm_dit":
611
+ if not codes:
612
+ raise RuntimeError(f"5Hz LM generation failed: {status}")
613
+
614
+ # LM once per job; rely on DiT seeds for batch diversity.
615
+ # For convenience, replicate the same codes across the batch.
616
+ if effective_batch_size > 1:
617
+ # use the same codes for all in the batch
618
+ audio_code_string = [codes] * effective_batch_size
619
+
620
+ # If needed in future: call LM multiple times for more diverse codes.
621
+ # codes_list: list[str] = [codes]
622
+ # for _ in range(effective_batch_size - 1):
623
+ # _m2, _c2, _s2 = _lm_call()
624
+ # if not _c2:
625
+ # raise RuntimeError(f"5Hz LM generation failed: {_s2}")
626
+ # codes_list.append(_c2)
627
+ # audio_code_string = codes_list
628
+ else:
629
+ audio_code_string = codes
630
+
631
+ lm_fields = {
632
+ "metas": _normalize_metas(meta),
633
+ **_extract_lm_fields(meta),
634
+ }
635
+ bpm_val, key_scale_val, time_sig_val, audio_duration_val = _maybe_fill_from_metadata(req, meta)
636
+
637
+ # If user provided long lyrics but LM didn't provide a usable duration, estimate a longer duration.
638
+ if infer_type == "llm_dit" and audio_duration_val is None and (req.audio_duration is None):
639
+ est = _estimate_duration_from_lyrics(req.lyrics)
640
+ if est is not None:
641
+ audio_duration_val = est
642
+
643
+ # Optional: auto-tune LM cover strength (opt-in) to avoid suppressing lyric/vocal conditioning.
644
+ if infer_type == "llm_dit" and audio_cover_strength_val >= 0.999 and (req.lyrics or "").strip():
645
+ tuned = os.getenv("ACESTEP_LM_COVER_STRENGTH")
646
+ if tuned is not None and tuned.strip() != "":
647
+ audio_cover_strength_val = float(tuned)
648
+
649
+ # Align behavior with feishu bot:
650
+ # - dit: metas only (ignore audio codes), keep text2music.
651
+ # - llm_dit: metas + audio codes, run in cover mode with LM instruction.
652
+ instruction_val = req.instruction
653
+ task_type_val = (req.task_type or "").strip() or "text2music"
654
+
655
+ if infer_type == "dit":
656
+ audio_code_string = ""
657
+ if task_type_val == "cover":
658
+ task_type_val = "text2music"
659
+ if (instruction_val or "").strip() in {"", _DEFAULT_LM_INSTRUCTION}:
660
+ instruction_val = _DEFAULT_DIT_INSTRUCTION
661
+
662
+ if infer_type == "llm_dit":
663
+ task_type_val = "cover"
664
+ if (instruction_val or "").strip() in {"", _DEFAULT_DIT_INSTRUCTION}:
665
+ instruction_val = _DEFAULT_LM_INSTRUCTION
666
+
667
+ if not (audio_code_string and str(audio_code_string).strip()):
668
+ if explicit_infer or req.use_5hz_lm:
669
+ raise RuntimeError("llm_dit requires non-empty audio codes: provide 'audio_code_string' or set 'use_5hz_lm=true'.")
670
+ # If not explicitly requested, fall back to dit semantics.
671
+ infer_type = "dit"
672
+ task_type_val = "text2music"
673
+ instruction_val = _DEFAULT_DIT_INSTRUCTION
674
+
675
  first, second, paths, gen_info, status_msg, seed_value, *_ = h.generate_music(
676
  captions=req.caption,
677
  lyrics=req.lyrics,
678
+ bpm=bpm_val,
679
+ key_scale=key_scale_val,
680
+ time_signature=time_sig_val,
681
  vocal_language=req.vocal_language,
682
  inference_steps=req.inference_steps,
683
  guidance_scale=req.guidance_scale,
684
  use_random_seed=req.use_random_seed,
685
+ seed=("-1" if (req.use_random_seed and int(req.seed) < 0) else str(req.seed)),
686
  reference_audio=req.reference_audio_path,
687
+ audio_duration=audio_duration_val,
688
  batch_size=req.batch_size,
689
  src_audio=req.src_audio_path,
690
+ audio_code_string=audio_code_string,
691
  repainting_start=req.repainting_start,
692
  repainting_end=req.repainting_end,
693
+ instruction=instruction_val,
694
+ audio_cover_strength=audio_cover_strength_val,
695
+ task_type=task_type_val,
696
  use_adg=req.use_adg,
697
  cfg_interval_start=req.cfg_interval_start,
698
  cfg_interval_end=req.cfg_interval_end,
 
707
  "generation_info": gen_info,
708
  "status_message": status_msg,
709
  "seed_value": seed_value,
710
+ **lm_fields,
711
  }
712
 
713
  t0 = time.time()
 
779
  return GenerateMusicRequest(
780
  caption=str(get("caption", "") or ""),
781
  lyrics=str(get("lyrics", "") or ""),
782
+ infer_type=_normalize_infer_type(get("infer_type")),
783
  bpm=_to_int(get("bpm"), None),
784
  key_scale=str(get("key_scale", "") or ""),
785
  time_signature=str(get("time_signature", "") or ""),
 
795
  audio_code_string=str(get("audio_code_string", "") or ""),
796
  repainting_start=_to_float(get("repainting_start"), 0.0) or 0.0,
797
  repainting_end=_to_float(get("repainting_end"), None),
798
+ instruction=str(get("instruction", _DEFAULT_DIT_INSTRUCTION) or ""),
799
  audio_cover_strength=_to_float(get("audio_cover_strength"), 1.0) or 1.0,
800
  task_type=str(get("task_type", "text2music") or "text2music"),
801
  use_adg=_to_bool(get("use_adg"), False),
 
803
  cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
804
  audio_format=str(get("audio_format", "mp3") or "mp3"),
805
  use_tiled_decode=_to_bool(get("use_tiled_decode"), True),
806
+
807
+ use_5hz_lm=_to_bool(get("use_5hz_lm"), False),
808
+ lm_model_path=str(get("lm_model_path") or "").strip() or None,
809
+ lm_backend=str(get("lm_backend", "vllm") or "vllm"),
810
+ lm_temperature=_to_float(get("lm_temperature"), _LM_DEFAULT_TEMPERATURE) or _LM_DEFAULT_TEMPERATURE,
811
+ lm_cfg_scale=_to_float(get("lm_cfg_scale"), _LM_DEFAULT_CFG_SCALE) or _LM_DEFAULT_CFG_SCALE,
812
+ lm_top_k=_to_int(get("lm_top_k"), None),
813
+ lm_top_p=_to_float(get("lm_top_p"), _LM_DEFAULT_TOP_P),
814
+ lm_repetition_penalty=_to_float(get("lm_repetition_penalty"), 1.0) or 1.0,
815
+ lm_negative_prompt=str(get("lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
816
  )
817
 
818
  def _first_value(v: Any) -> Any: