Spaces:
Running
on
A100
Running
on
A100
Merge pull request #1 from ace-step/refact_add_inference
Browse files- acestep/api_server.py +204 -458
- acestep/audio_utils.py +320 -0
- acestep/constrained_logits_processor.py +76 -97
- acestep/gradio_ui/event.py +0 -0
- acestep/gradio_ui/events/__init__.py +78 -41
- acestep/gradio_ui/events/results_handlers.py +355 -461
- acestep/gradio_ui/interfaces/result.py +16 -8
- acestep/handler.py +328 -318
- acestep/inference.py +477 -785
- acestep/llm_inference.py +640 -603
- acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py +92 -64
- acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py +87 -47
- acestep/third_parts/nano-vllm/pyproject.toml +0 -2
- profile_inference.py +682 -0
acestep/api_server.py
CHANGED
|
@@ -44,6 +44,12 @@ from acestep.constants import (
|
|
| 44 |
DEFAULT_DIT_INSTRUCTION,
|
| 45 |
DEFAULT_LM_INSTRUCTION,
|
| 46 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
JobStatus = Literal["queued", "running", "succeeded", "failed"]
|
|
@@ -387,6 +393,10 @@ def create_app() -> FastAPI:
|
|
| 387 |
app.state.executor = executor
|
| 388 |
app.state.job_store = store
|
| 389 |
app.state._python_executable = sys.executable
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
async def _ensure_initialized() -> None:
|
| 392 |
h: AceStepHandler = app.state.handler
|
|
@@ -443,131 +453,10 @@ def create_app() -> FastAPI:
|
|
| 443 |
job_store.mark_running(job_id)
|
| 444 |
|
| 445 |
def _blocking_generate() -> Dict[str, Any]:
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
return None
|
| 449 |
-
try:
|
| 450 |
-
iv = int(v)
|
| 451 |
-
except Exception:
|
| 452 |
-
return None
|
| 453 |
-
return None if iv == 0 else iv
|
| 454 |
-
|
| 455 |
-
def _normalize_optional_float(v: Any) -> Optional[float]:
|
| 456 |
-
if v is None:
|
| 457 |
-
return None
|
| 458 |
-
try:
|
| 459 |
-
fv = float(v)
|
| 460 |
-
except Exception:
|
| 461 |
-
return None
|
| 462 |
-
# gradio treats 1.0 as disabled for top_p
|
| 463 |
-
return None if fv >= 1.0 else fv
|
| 464 |
-
|
| 465 |
-
def _maybe_fill_from_metadata(current: GenerateMusicRequest, meta: Dict[str, Any]) -> tuple[Optional[int], str, str, Optional[float]]:
|
| 466 |
-
def _parse_first_float(v: Any) -> Optional[float]:
|
| 467 |
-
if v is None:
|
| 468 |
-
return None
|
| 469 |
-
if isinstance(v, (int, float)):
|
| 470 |
-
return float(v)
|
| 471 |
-
s = str(v).strip()
|
| 472 |
-
if not s or s.upper() == "N/A":
|
| 473 |
-
return None
|
| 474 |
-
try:
|
| 475 |
-
return float(s)
|
| 476 |
-
except Exception:
|
| 477 |
-
pass
|
| 478 |
-
m = re.search(r"[-+]?\d*\.?\d+", s)
|
| 479 |
-
if not m:
|
| 480 |
-
return None
|
| 481 |
-
try:
|
| 482 |
-
return float(m.group(0))
|
| 483 |
-
except Exception:
|
| 484 |
-
return None
|
| 485 |
-
|
| 486 |
-
def _parse_first_int(v: Any) -> Optional[int]:
|
| 487 |
-
fv = _parse_first_float(v)
|
| 488 |
-
if fv is None:
|
| 489 |
-
return None
|
| 490 |
-
try:
|
| 491 |
-
return int(round(fv))
|
| 492 |
-
except Exception:
|
| 493 |
-
return None
|
| 494 |
-
|
| 495 |
-
# Fill only when user did not provide values
|
| 496 |
-
bpm_val = current.bpm
|
| 497 |
-
if bpm_val is None:
|
| 498 |
-
m = meta.get("bpm")
|
| 499 |
-
parsed = _parse_first_int(m)
|
| 500 |
-
if parsed is not None and parsed > 0:
|
| 501 |
-
bpm_val = parsed
|
| 502 |
-
|
| 503 |
-
key_scale_val = current.key_scale
|
| 504 |
-
if not key_scale_val:
|
| 505 |
-
m = meta.get("keyscale", meta.get("key_scale", ""))
|
| 506 |
-
if m not in (None, "", "N/A"):
|
| 507 |
-
key_scale_val = str(m)
|
| 508 |
-
|
| 509 |
-
time_sig_val = current.time_signature
|
| 510 |
-
if not time_sig_val:
|
| 511 |
-
m = meta.get("timesignature", meta.get("time_signature", ""))
|
| 512 |
-
if m not in (None, "", "N/A"):
|
| 513 |
-
time_sig_val = str(m)
|
| 514 |
-
|
| 515 |
-
dur_val = current.audio_duration
|
| 516 |
-
if dur_val is None:
|
| 517 |
-
m = meta.get("duration", meta.get("audio_duration"))
|
| 518 |
-
parsed = _parse_first_float(m)
|
| 519 |
-
if parsed is not None:
|
| 520 |
-
dur_val = float(parsed)
|
| 521 |
-
if dur_val <= 0:
|
| 522 |
-
dur_val = None
|
| 523 |
-
|
| 524 |
-
# Avoid truncating lyrical songs when LM predicts a very short duration.
|
| 525 |
-
# (Users can still force a short duration by explicitly setting `audio_duration`.)
|
| 526 |
-
if dur_val is not None and (current.lyrics or "").strip():
|
| 527 |
-
min_dur = float(os.getenv("ACESTEP_LM_MIN_DURATION_SECONDS", "30"))
|
| 528 |
-
if dur_val < min_dur:
|
| 529 |
-
dur_val = None
|
| 530 |
-
|
| 531 |
-
return bpm_val, key_scale_val, time_sig_val, dur_val
|
| 532 |
-
|
| 533 |
-
def _estimate_duration_from_lyrics(lyrics: str) -> Optional[float]:
|
| 534 |
-
lyrics = (lyrics or "").strip()
|
| 535 |
-
if not lyrics:
|
| 536 |
-
return None
|
| 537 |
-
|
| 538 |
-
# Best-effort heuristic: singing rate ~ 2.2 words/sec for English-like lyrics.
|
| 539 |
-
# For languages without spaces, fall back to non-space char count.
|
| 540 |
-
words = re.findall(r"[A-Za-z0-9']+", lyrics)
|
| 541 |
-
if len(words) >= 8:
|
| 542 |
-
words_per_sec = float(os.getenv("ACESTEP_LYRICS_WORDS_PER_SEC", "2.2"))
|
| 543 |
-
est = len(words) / max(0.5, words_per_sec)
|
| 544 |
-
else:
|
| 545 |
-
non_space = len(re.sub(r"\s+", "", lyrics))
|
| 546 |
-
chars_per_sec = float(os.getenv("ACESTEP_LYRICS_CHARS_PER_SEC", "12"))
|
| 547 |
-
est = non_space / max(4.0, chars_per_sec)
|
| 548 |
-
|
| 549 |
-
min_dur = float(os.getenv("ACESTEP_LYRICS_MIN_DURATION_SECONDS", "45"))
|
| 550 |
-
max_dur = float(os.getenv("ACESTEP_LYRICS_MAX_DURATION_SECONDS", "180"))
|
| 551 |
-
return float(min(max(est, min_dur), max_dur))
|
| 552 |
-
|
| 553 |
-
def _normalize_metas(meta: Dict[str, Any]) -> Dict[str, Any]:
|
| 554 |
-
"""Ensure a stable `metas` dict (keys always present)."""
|
| 555 |
-
meta = meta or {}
|
| 556 |
-
out: Dict[str, Any] = dict(meta)
|
| 557 |
-
|
| 558 |
-
# Normalize key aliases
|
| 559 |
-
if "keyscale" not in out and "key_scale" in out:
|
| 560 |
-
out["keyscale"] = out.get("key_scale")
|
| 561 |
-
if "timesignature" not in out and "time_signature" in out:
|
| 562 |
-
out["timesignature"] = out.get("time_signature")
|
| 563 |
-
|
| 564 |
-
# Ensure required keys exist
|
| 565 |
-
for k in ["bpm", "duration", "genres", "keyscale", "timesignature"]:
|
| 566 |
-
if out.get(k) in (None, ""):
|
| 567 |
-
out[k] = "N/A"
|
| 568 |
-
return out
|
| 569 |
-
|
| 570 |
def _ensure_llm_ready() -> None:
|
|
|
|
| 571 |
with app.state._llm_init_lock:
|
| 572 |
initialized = getattr(app.state, "_llm_initialized", False)
|
| 573 |
had_error = getattr(app.state, "_llm_init_error", None)
|
|
@@ -597,269 +486,207 @@ def create_app() -> FastAPI:
|
|
| 597 |
else:
|
| 598 |
app.state._llm_initialized = True
|
| 599 |
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
time_sig_val = req.time_signature
|
| 605 |
-
audio_duration_val = req.audio_duration
|
| 606 |
-
|
| 607 |
-
thinking = bool(getattr(req, "thinking", False))
|
| 608 |
-
|
| 609 |
-
print(
|
| 610 |
-
"[api_server] parsed req: "
|
| 611 |
-
f"thinking={thinking}, caption_len={len((req.caption or '').strip())}, lyrics_len={len((req.lyrics or '').strip())}, "
|
| 612 |
-
f"bpm={req.bpm}, audio_duration={req.audio_duration}, key_scale={req.key_scale!r}, time_signature={req.time_signature!r}"
|
| 613 |
-
)
|
| 614 |
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
|
|
|
|
|
|
| 618 |
|
| 619 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
|
| 621 |
-
|
| 622 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
_ensure_llm_ready()
|
| 624 |
if getattr(app.state, "_llm_init_error", None):
|
| 625 |
raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
|
| 626 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
sample_metadata, sample_status = llm.understand_audio_from_codes(
|
| 628 |
audio_codes="NO USER INPUT",
|
| 629 |
-
temperature=
|
| 630 |
-
cfg_scale=max(1.0,
|
| 631 |
-
negative_prompt=
|
| 632 |
-
top_k=
|
| 633 |
-
top_p=
|
| 634 |
-
repetition_penalty=
|
| 635 |
-
use_constrained_decoding=
|
| 636 |
-
constrained_decoding_debug=
|
| 637 |
)
|
| 638 |
|
| 639 |
if not sample_metadata or str(sample_status).startswith("❌"):
|
| 640 |
raise RuntimeError(f"Sample generation failed: {sample_status}")
|
| 641 |
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
if
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
if
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
"audio_duration": req.audio_duration,
|
| 692 |
-
"key_scale": req.key_scale,
|
| 693 |
-
"time_signature": req.time_signature,
|
| 694 |
-
},
|
| 695 |
-
)
|
| 696 |
-
|
| 697 |
-
# Determine effective batch size (used for per-sample LM code diversity)
|
| 698 |
-
effective_batch_size = req.batch_size
|
| 699 |
-
if effective_batch_size is None:
|
| 700 |
-
try:
|
| 701 |
-
effective_batch_size = int(getattr(h, "batch_size", 1))
|
| 702 |
-
except Exception:
|
| 703 |
-
effective_batch_size = 1
|
| 704 |
-
effective_batch_size = max(1, int(effective_batch_size))
|
| 705 |
-
|
| 706 |
-
has_codes = bool(audio_code_string and str(audio_code_string).strip())
|
| 707 |
-
need_lm_codes = bool(thinking) and (not has_codes)
|
| 708 |
-
|
| 709 |
-
use_constrained_decoding = bool(getattr(req, "constrained_decoding", True))
|
| 710 |
-
constrained_decoding_debug = bool(getattr(req, "constrained_decoding_debug", False))
|
| 711 |
-
use_cot_caption = bool(getattr(req, "use_cot_caption", True))
|
| 712 |
-
use_cot_language = bool(getattr(req, "use_cot_language", True))
|
| 713 |
-
is_format_caption = bool(getattr(req, "is_format_caption", False))
|
| 714 |
-
|
| 715 |
-
# pass them into constrained decoding so LM injects them directly
|
| 716 |
-
# (i.e. does not re-infer / override those fields).
|
| 717 |
-
user_metadata: Dict[str, Optional[str]] = {}
|
| 718 |
-
|
| 719 |
-
def _set_user_meta(field: str, value: Optional[Any]) -> None:
|
| 720 |
-
if value is None:
|
| 721 |
-
return
|
| 722 |
-
s = str(value).strip()
|
| 723 |
-
if not s or s.upper() == "N/A":
|
| 724 |
-
return
|
| 725 |
-
user_metadata[field] = s
|
| 726 |
-
|
| 727 |
-
_set_user_meta("bpm", int(bpm_val) if bpm_val is not None else None)
|
| 728 |
-
_set_user_meta("duration", float(audio_duration_val) if audio_duration_val is not None else None)
|
| 729 |
-
_set_user_meta("keyscale", key_scale_val if (key_scale_val or "").strip() else None)
|
| 730 |
-
_set_user_meta("timesignature", time_sig_val if (time_sig_val or "").strip() else None)
|
| 731 |
-
|
| 732 |
-
def _has_meta(field: str) -> bool:
|
| 733 |
-
v = user_metadata.get(field)
|
| 734 |
-
return bool((v or "").strip())
|
| 735 |
-
|
| 736 |
-
need_lm_metas = not (
|
| 737 |
-
_has_meta("bpm")
|
| 738 |
-
and _has_meta("duration")
|
| 739 |
-
and _has_meta("keyscale")
|
| 740 |
-
and _has_meta("timesignature")
|
| 741 |
)
|
| 742 |
|
| 743 |
-
|
| 744 |
-
if
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
f"user_metadata_keys={sorted(user_metadata.keys())}, target_duration={lm_target_duration}, "
|
| 752 |
-
f"need_lm_codes={need_lm_codes}, need_lm_metas={need_lm_metas}, "
|
| 753 |
-
f"use_constrained_decoding={use_constrained_decoding}, use_cot_caption={use_cot_caption}, "
|
| 754 |
-
f"use_cot_language={use_cot_language}, is_format_caption={is_format_caption}"
|
| 755 |
)
|
| 756 |
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
#
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
raise RuntimeError("thinking=true requires non-empty audio codes (LM generation failed).")
|
| 838 |
-
|
| 839 |
-
# Response metas MUST reflect the actual values used by DiT.
|
| 840 |
-
metas_out = _normalize_metas(lm_meta or {})
|
| 841 |
-
if bpm_val is not None and int(bpm_val) > 0:
|
| 842 |
-
metas_out["bpm"] = int(bpm_val)
|
| 843 |
-
if audio_duration_val is not None and float(audio_duration_val) > 0:
|
| 844 |
-
metas_out["duration"] = float(audio_duration_val)
|
| 845 |
-
if (key_scale_val or "").strip():
|
| 846 |
-
metas_out["keyscale"] = str(key_scale_val)
|
| 847 |
-
if (time_sig_val or "").strip():
|
| 848 |
-
metas_out["timesignature"] = str(time_sig_val)
|
| 849 |
-
|
| 850 |
-
def _ensure_text_meta(field: str, fallback: Optional[str]) -> None:
|
| 851 |
-
existing = metas_out.get(field)
|
| 852 |
-
if isinstance(existing, str):
|
| 853 |
-
stripped = existing.strip()
|
| 854 |
-
if stripped and stripped.upper() != "N/A":
|
| 855 |
-
return
|
| 856 |
-
if fallback is None:
|
| 857 |
-
return
|
| 858 |
-
if fallback.strip():
|
| 859 |
-
metas_out[field] = fallback
|
| 860 |
-
|
| 861 |
-
_ensure_text_meta("caption", req.caption)
|
| 862 |
-
_ensure_text_meta("lyrics", req.lyrics)
|
| 863 |
|
| 864 |
def _none_if_na_str(v: Any) -> Optional[str]:
|
| 865 |
if v is None:
|
|
@@ -868,44 +695,17 @@ def create_app() -> FastAPI:
|
|
| 868 |
if s in {"", "N/A"}:
|
| 869 |
return None
|
| 870 |
return s
|
| 871 |
-
|
| 872 |
-
captions=req.caption,
|
| 873 |
-
lyrics=req.lyrics,
|
| 874 |
-
bpm=bpm_val,
|
| 875 |
-
key_scale=key_scale_val,
|
| 876 |
-
time_signature=time_sig_val,
|
| 877 |
-
vocal_language=req.vocal_language,
|
| 878 |
-
inference_steps=req.inference_steps,
|
| 879 |
-
guidance_scale=req.guidance_scale,
|
| 880 |
-
use_random_seed=req.use_random_seed,
|
| 881 |
-
seed=("-1" if (req.use_random_seed and int(req.seed) < 0) else str(req.seed)),
|
| 882 |
-
reference_audio=req.reference_audio_path,
|
| 883 |
-
audio_duration=audio_duration_val,
|
| 884 |
-
batch_size=req.batch_size,
|
| 885 |
-
src_audio=req.src_audio_path,
|
| 886 |
-
audio_code_string=audio_code_string,
|
| 887 |
-
repainting_start=req.repainting_start,
|
| 888 |
-
repainting_end=req.repainting_end,
|
| 889 |
-
instruction=instruction_val,
|
| 890 |
-
audio_cover_strength=audio_cover_strength_val,
|
| 891 |
-
task_type=task_type_val,
|
| 892 |
-
use_adg=req.use_adg,
|
| 893 |
-
cfg_interval_start=req.cfg_interval_start,
|
| 894 |
-
cfg_interval_end=req.cfg_interval_end,
|
| 895 |
-
audio_format=req.audio_format,
|
| 896 |
-
use_tiled_decode=req.use_tiled_decode,
|
| 897 |
-
progress=None,
|
| 898 |
-
)
|
| 899 |
return {
|
| 900 |
-
"first_audio_path": _path_to_audio_url(
|
| 901 |
-
"second_audio_path": _path_to_audio_url(
|
| 902 |
-
"audio_paths": [_path_to_audio_url(p) for p in
|
| 903 |
-
"generation_info":
|
| 904 |
-
"status_message":
|
| 905 |
"seed_value": seed_value,
|
| 906 |
"metas": metas_out,
|
| 907 |
-
"bpm":
|
| 908 |
-
"duration":
|
| 909 |
"genres": _none_if_na_str(metas_out.get("genres")),
|
| 910 |
"keyscale": _none_if_na_str(metas_out.get("keyscale")),
|
| 911 |
"timesignature": _none_if_na_str(metas_out.get("timesignature")),
|
|
@@ -1010,53 +810,6 @@ def create_app() -> FastAPI:
|
|
| 1010 |
|
| 1011 |
return default
|
| 1012 |
|
| 1013 |
-
# Debug: print what keys we actually received (helps explain empty parsed values)
|
| 1014 |
-
try:
|
| 1015 |
-
top_keys = list(getattr(mapping, "keys", lambda: [])())
|
| 1016 |
-
except Exception:
|
| 1017 |
-
top_keys = []
|
| 1018 |
-
try:
|
| 1019 |
-
nested_probe = (
|
| 1020 |
-
get("metas", None)
|
| 1021 |
-
or get("meta", None)
|
| 1022 |
-
or get("metadata", None)
|
| 1023 |
-
or get("user_metadata", None)
|
| 1024 |
-
or get("userMetadata", None)
|
| 1025 |
-
)
|
| 1026 |
-
if isinstance(nested_probe, str):
|
| 1027 |
-
sp = nested_probe.strip()
|
| 1028 |
-
if sp.startswith("{") and sp.endswith("}"):
|
| 1029 |
-
try:
|
| 1030 |
-
nested_probe = json.loads(sp)
|
| 1031 |
-
except Exception:
|
| 1032 |
-
nested_probe = None
|
| 1033 |
-
nested_keys = list(nested_probe.keys()) if isinstance(nested_probe, dict) else []
|
| 1034 |
-
except Exception:
|
| 1035 |
-
nested_keys = []
|
| 1036 |
-
print(f"[api_server] request keys: top={sorted(top_keys)}, nested={sorted(nested_keys)}")
|
| 1037 |
-
|
| 1038 |
-
# Debug: print raw values/types for common meta fields (top-level + common aliases)
|
| 1039 |
-
try:
|
| 1040 |
-
probe_keys = [
|
| 1041 |
-
"thinking",
|
| 1042 |
-
"bpm",
|
| 1043 |
-
"audio_duration",
|
| 1044 |
-
"duration",
|
| 1045 |
-
"audioDuration",
|
| 1046 |
-
"key_scale",
|
| 1047 |
-
"keyscale",
|
| 1048 |
-
"keyScale",
|
| 1049 |
-
"time_signature",
|
| 1050 |
-
"timesignature",
|
| 1051 |
-
"timeSignature",
|
| 1052 |
-
]
|
| 1053 |
-
raw = {k: get(k, None) for k in probe_keys}
|
| 1054 |
-
raw_types = {k: (type(v).__name__ if v is not None else None) for k, v in raw.items()}
|
| 1055 |
-
print(f"[api_server] request raw: {raw}")
|
| 1056 |
-
print(f"[api_server] request raw types: {raw_types}")
|
| 1057 |
-
except Exception:
|
| 1058 |
-
pass
|
| 1059 |
-
|
| 1060 |
normalized_audio_duration = _to_float(_get_any("audio_duration", "duration", "audioDuration"), None)
|
| 1061 |
normalized_bpm = _to_int(_get_any("bpm"), None)
|
| 1062 |
normalized_keyscale = str(_get_any("key_scale", "keyscale", "keyScale", default="") or "")
|
|
@@ -1066,12 +819,6 @@ def create_app() -> FastAPI:
|
|
| 1066 |
if normalized_audio_duration is None:
|
| 1067 |
normalized_audio_duration = _to_float(_get_any("target_duration", "targetDuration"), None)
|
| 1068 |
|
| 1069 |
-
print(
|
| 1070 |
-
"[api_server] normalized: "
|
| 1071 |
-
f"thinking={_to_bool(get('thinking'), False)}, bpm={normalized_bpm}, "
|
| 1072 |
-
f"audio_duration={normalized_audio_duration}, key_scale={normalized_keyscale!r}, time_signature={normalized_timesig!r}"
|
| 1073 |
-
)
|
| 1074 |
-
|
| 1075 |
return GenerateMusicRequest(
|
| 1076 |
caption=str(get("caption", "") or ""),
|
| 1077 |
lyrics=str(get("lyrics", "") or ""),
|
|
@@ -1110,7 +857,6 @@ def create_app() -> FastAPI:
|
|
| 1110 |
lm_negative_prompt=str(get("lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
|
| 1111 |
constrained_decoding=_to_bool(_get_any("constrained_decoding", "constrainedDecoding", "constrained"), True),
|
| 1112 |
constrained_decoding_debug=_to_bool(_get_any("constrained_decoding_debug", "constrainedDecodingDebug"), False),
|
| 1113 |
-
# Accept common aliases, including hyphenated keys from some clients.
|
| 1114 |
use_cot_caption=_to_bool(_get_any("use_cot_caption", "cot_caption", "cot-caption"), True),
|
| 1115 |
use_cot_language=_to_bool(_get_any("use_cot_language", "cot_language", "cot-language"), True),
|
| 1116 |
is_format_caption=_to_bool(_get_any("is_format_caption", "isFormatCaption"), False),
|
|
|
|
| 44 |
DEFAULT_DIT_INSTRUCTION,
|
| 45 |
DEFAULT_LM_INSTRUCTION,
|
| 46 |
)
|
| 47 |
+
from acestep.inference import (
|
| 48 |
+
GenerationParams,
|
| 49 |
+
GenerationConfig,
|
| 50 |
+
generate_music,
|
| 51 |
+
)
|
| 52 |
+
from acestep.gradio_ui.events.results_handlers import _build_generation_info
|
| 53 |
|
| 54 |
|
| 55 |
JobStatus = Literal["queued", "running", "succeeded", "failed"]
|
|
|
|
| 393 |
app.state.executor = executor
|
| 394 |
app.state.job_store = store
|
| 395 |
app.state._python_executable = sys.executable
|
| 396 |
+
|
| 397 |
+
# Temporary directory for saving generated audio files
|
| 398 |
+
app.state.temp_audio_dir = os.path.join(tmp_root, "api_audio")
|
| 399 |
+
os.makedirs(app.state.temp_audio_dir, exist_ok=True)
|
| 400 |
|
| 401 |
async def _ensure_initialized() -> None:
|
| 402 |
h: AceStepHandler = app.state.handler
|
|
|
|
| 453 |
job_store.mark_running(job_id)
|
| 454 |
|
| 455 |
def _blocking_generate() -> Dict[str, Any]:
|
| 456 |
+
"""Generate music using unified inference logic from acestep.inference"""
|
| 457 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
def _ensure_llm_ready() -> None:
|
| 459 |
+
"""Ensure LLM handler is initialized when needed"""
|
| 460 |
with app.state._llm_init_lock:
|
| 461 |
initialized = getattr(app.state, "_llm_initialized", False)
|
| 462 |
had_error = getattr(app.state, "_llm_init_error", None)
|
|
|
|
| 486 |
else:
|
| 487 |
app.state._llm_initialized = True
|
| 488 |
|
| 489 |
+
def _normalize_metas(meta: Dict[str, Any]) -> Dict[str, Any]:
|
| 490 |
+
"""Ensure a stable `metas` dict (keys always present)."""
|
| 491 |
+
meta = meta or {}
|
| 492 |
+
out: Dict[str, Any] = dict(meta)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
|
| 494 |
+
# Normalize key aliases
|
| 495 |
+
if "keyscale" not in out and "key_scale" in out:
|
| 496 |
+
out["keyscale"] = out.get("key_scale")
|
| 497 |
+
if "timesignature" not in out and "time_signature" in out:
|
| 498 |
+
out["timesignature"] = out.get("time_signature")
|
| 499 |
|
| 500 |
+
# Ensure required keys exist
|
| 501 |
+
for k in ["bpm", "duration", "genres", "keyscale", "timesignature"]:
|
| 502 |
+
if out.get(k) in (None, ""):
|
| 503 |
+
out[k] = "N/A"
|
| 504 |
+
return out
|
| 505 |
|
| 506 |
+
# Normalize LM sampling parameters
|
| 507 |
+
lm_top_k = req.lm_top_k if req.lm_top_k and req.lm_top_k > 0 else 0
|
| 508 |
+
lm_top_p = req.lm_top_p if req.lm_top_p and req.lm_top_p < 1.0 else 0.9
|
| 509 |
+
|
| 510 |
+
# Determine if LLM is needed
|
| 511 |
+
thinking = bool(req.thinking)
|
| 512 |
+
sample_mode = bool(req.sample_mode)
|
| 513 |
+
need_llm = thinking or sample_mode
|
| 514 |
+
|
| 515 |
+
print(f"[api_server] Request params: req.thinking={req.thinking}, req.sample_mode={req.sample_mode}")
|
| 516 |
+
print(f"[api_server] Determined: thinking={thinking}, sample_mode={sample_mode}, need_llm={need_llm}")
|
| 517 |
+
|
| 518 |
+
# Ensure LLM is ready if needed
|
| 519 |
+
if need_llm:
|
| 520 |
_ensure_llm_ready()
|
| 521 |
if getattr(app.state, "_llm_init_error", None):
|
| 522 |
raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
|
| 523 |
|
| 524 |
+
# Handle sample mode: generate random caption/lyrics first
|
| 525 |
+
caption = req.caption
|
| 526 |
+
lyrics = req.lyrics
|
| 527 |
+
bpm = req.bpm
|
| 528 |
+
key_scale = req.key_scale
|
| 529 |
+
time_signature = req.time_signature
|
| 530 |
+
audio_duration = req.audio_duration
|
| 531 |
+
|
| 532 |
+
if sample_mode:
|
| 533 |
+
print("[api_server] Sample mode: generating random caption/lyrics via LM")
|
| 534 |
sample_metadata, sample_status = llm.understand_audio_from_codes(
|
| 535 |
audio_codes="NO USER INPUT",
|
| 536 |
+
temperature=req.lm_temperature,
|
| 537 |
+
cfg_scale=max(1.0, req.lm_cfg_scale),
|
| 538 |
+
negative_prompt=req.lm_negative_prompt,
|
| 539 |
+
top_k=lm_top_k if lm_top_k > 0 else None,
|
| 540 |
+
top_p=lm_top_p if lm_top_p < 1.0 else None,
|
| 541 |
+
repetition_penalty=req.lm_repetition_penalty,
|
| 542 |
+
use_constrained_decoding=req.constrained_decoding,
|
| 543 |
+
constrained_decoding_debug=req.constrained_decoding_debug,
|
| 544 |
)
|
| 545 |
|
| 546 |
if not sample_metadata or str(sample_status).startswith("❌"):
|
| 547 |
raise RuntimeError(f"Sample generation failed: {sample_status}")
|
| 548 |
|
| 549 |
+
# Use generated values with fallback defaults
|
| 550 |
+
caption = sample_metadata.get("caption", "")
|
| 551 |
+
lyrics = sample_metadata.get("lyrics", "")
|
| 552 |
+
bpm = _to_int(sample_metadata.get("bpm"), None) or _to_int(os.getenv("ACESTEP_SAMPLE_DEFAULT_BPM", "120"), 120)
|
| 553 |
+
key_scale = sample_metadata.get("keyscale", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_KEY", "C Major")
|
| 554 |
+
time_signature = sample_metadata.get("timesignature", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_TIMESIGNATURE", "4/4")
|
| 555 |
+
audio_duration = _to_float(sample_metadata.get("duration"), None) or _to_float(os.getenv("ACESTEP_SAMPLE_DEFAULT_DURATION_SECONDS", "120"), 120.0)
|
| 556 |
+
|
| 557 |
+
print(f"[api_server] Sample generated: caption_len={len(caption)}, lyrics_len={len(lyrics)}, bpm={bpm}, duration={audio_duration}")
|
| 558 |
+
|
| 559 |
+
print(f"[api_server] Before GenerationParams: thinking={thinking}, sample_mode={sample_mode}")
|
| 560 |
+
print(f"[api_server] Caption/Lyrics to use: caption_len={len(caption)}, lyrics_len={len(lyrics)}")
|
| 561 |
+
|
| 562 |
+
# Build GenerationParams using unified interface
|
| 563 |
+
# Note: thinking controls LM code generation, sample_mode only affects CoT metas
|
| 564 |
+
params = GenerationParams(
|
| 565 |
+
task_type=req.task_type,
|
| 566 |
+
instruction=req.instruction,
|
| 567 |
+
reference_audio=req.reference_audio_path,
|
| 568 |
+
src_audio=req.src_audio_path,
|
| 569 |
+
audio_codes=req.audio_code_string,
|
| 570 |
+
caption=caption,
|
| 571 |
+
lyrics=lyrics,
|
| 572 |
+
instrumental=False,
|
| 573 |
+
vocal_language=req.vocal_language,
|
| 574 |
+
bpm=bpm,
|
| 575 |
+
keyscale=key_scale,
|
| 576 |
+
timesignature=time_signature,
|
| 577 |
+
duration=audio_duration if audio_duration else -1.0,
|
| 578 |
+
inference_steps=req.inference_steps,
|
| 579 |
+
seed=req.seed,
|
| 580 |
+
guidance_scale=req.guidance_scale,
|
| 581 |
+
use_adg=req.use_adg,
|
| 582 |
+
cfg_interval_start=req.cfg_interval_start,
|
| 583 |
+
cfg_interval_end=req.cfg_interval_end,
|
| 584 |
+
repainting_start=req.repainting_start,
|
| 585 |
+
repainting_end=req.repainting_end if req.repainting_end else -1,
|
| 586 |
+
audio_cover_strength=req.audio_cover_strength,
|
| 587 |
+
# LM parameters
|
| 588 |
+
thinking=thinking, # Use LM for code generation when thinking=True
|
| 589 |
+
lm_temperature=req.lm_temperature,
|
| 590 |
+
lm_cfg_scale=req.lm_cfg_scale,
|
| 591 |
+
lm_top_k=lm_top_k,
|
| 592 |
+
lm_top_p=lm_top_p,
|
| 593 |
+
lm_negative_prompt=req.lm_negative_prompt,
|
| 594 |
+
use_cot_metas=not sample_mode, # Sample mode already generated metas, don't regenerate
|
| 595 |
+
use_cot_caption=req.use_cot_caption,
|
| 596 |
+
use_cot_language=req.use_cot_language,
|
| 597 |
+
use_constrained_decoding=req.constrained_decoding,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 598 |
)
|
| 599 |
|
| 600 |
+
# Build GenerationConfig - default to 2 audios like gradio_ui
|
| 601 |
+
batch_size = req.batch_size if req.batch_size is not None else 2
|
| 602 |
+
config = GenerationConfig(
|
| 603 |
+
batch_size=batch_size,
|
| 604 |
+
use_random_seed=req.use_random_seed,
|
| 605 |
+
seeds=None, # Let unified logic handle seed generation
|
| 606 |
+
audio_format=req.audio_format,
|
| 607 |
+
constrained_decoding_debug=req.constrained_decoding_debug,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
)
|
| 609 |
|
| 610 |
+
# Check LLM initialization status
|
| 611 |
+
llm_is_initialized = getattr(app.state, "_llm_initialized", False)
|
| 612 |
+
llm_to_pass = llm if llm_is_initialized else None
|
| 613 |
+
|
| 614 |
+
print(f"[api_server] Generating music with unified interface:")
|
| 615 |
+
print(f" - thinking={params.thinking}")
|
| 616 |
+
print(f" - batch_size={batch_size}")
|
| 617 |
+
print(f" - llm_initialized={llm_is_initialized}")
|
| 618 |
+
print(f" - llm_handler={'Available' if llm_to_pass else 'None'}")
|
| 619 |
+
|
| 620 |
+
# Generate music using unified interface
|
| 621 |
+
result = generate_music(
|
| 622 |
+
dit_handler=h,
|
| 623 |
+
llm_handler=llm_to_pass,
|
| 624 |
+
params=params,
|
| 625 |
+
config=config,
|
| 626 |
+
save_dir=app.state.temp_audio_dir,
|
| 627 |
+
progress=None,
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
print(f"[api_server] Generation completed. Success={result.success}, Audios={len(result.audios)}")
|
| 631 |
+
print(f"[api_server] Time costs keys: {list(result.extra_outputs.get('time_costs', {}).keys())}")
|
| 632 |
+
|
| 633 |
+
if not result.success:
|
| 634 |
+
raise RuntimeError(f"Music generation failed: {result.error or result.status_message}")
|
| 635 |
+
|
| 636 |
+
# Extract results
|
| 637 |
+
audio_paths = [audio["path"] for audio in result.audios if audio.get("path")]
|
| 638 |
+
first_audio = audio_paths[0] if len(audio_paths) > 0 else None
|
| 639 |
+
second_audio = audio_paths[1] if len(audio_paths) > 1 else None
|
| 640 |
+
|
| 641 |
+
# Get metadata from LM or CoT results
|
| 642 |
+
lm_metadata = result.extra_outputs.get("lm_metadata", {})
|
| 643 |
+
metas_out = _normalize_metas(lm_metadata)
|
| 644 |
+
|
| 645 |
+
# Update metas with actual values used
|
| 646 |
+
if params.cot_bpm:
|
| 647 |
+
metas_out["bpm"] = params.cot_bpm
|
| 648 |
+
elif bpm:
|
| 649 |
+
metas_out["bpm"] = bpm
|
| 650 |
+
|
| 651 |
+
if params.cot_duration:
|
| 652 |
+
metas_out["duration"] = params.cot_duration
|
| 653 |
+
elif audio_duration:
|
| 654 |
+
metas_out["duration"] = audio_duration
|
| 655 |
+
|
| 656 |
+
if params.cot_keyscale:
|
| 657 |
+
metas_out["keyscale"] = params.cot_keyscale
|
| 658 |
+
elif key_scale:
|
| 659 |
+
metas_out["keyscale"] = key_scale
|
| 660 |
+
|
| 661 |
+
if params.cot_timesignature:
|
| 662 |
+
metas_out["timesignature"] = params.cot_timesignature
|
| 663 |
+
elif time_signature:
|
| 664 |
+
metas_out["timesignature"] = time_signature
|
| 665 |
+
|
| 666 |
+
# Ensure caption and lyrics are in metas
|
| 667 |
+
if caption:
|
| 668 |
+
metas_out["caption"] = caption
|
| 669 |
+
if lyrics:
|
| 670 |
+
metas_out["lyrics"] = lyrics
|
| 671 |
+
|
| 672 |
+
# Extract seed values for response (comma-separated for multiple audios)
|
| 673 |
+
seed_values = []
|
| 674 |
+
for audio in result.audios:
|
| 675 |
+
audio_params = audio.get("params", {})
|
| 676 |
+
seed = audio_params.get("seed")
|
| 677 |
+
if seed is not None:
|
| 678 |
+
seed_values.append(str(seed))
|
| 679 |
+
seed_value = ",".join(seed_values) if seed_values else ""
|
| 680 |
+
|
| 681 |
+
# Build generation_info using the helper function (like gradio_ui)
|
| 682 |
+
time_costs = result.extra_outputs.get("time_costs", {})
|
| 683 |
+
generation_info = _build_generation_info(
|
| 684 |
+
lm_metadata=lm_metadata,
|
| 685 |
+
time_costs=time_costs,
|
| 686 |
+
seed_value=seed_value,
|
| 687 |
+
inference_steps=req.inference_steps,
|
| 688 |
+
num_audios=len(result.audios),
|
| 689 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
|
| 691 |
def _none_if_na_str(v: Any) -> Optional[str]:
|
| 692 |
if v is None:
|
|
|
|
| 695 |
if s in {"", "N/A"}:
|
| 696 |
return None
|
| 697 |
return s
|
| 698 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 699 |
return {
|
| 700 |
+
"first_audio_path": _path_to_audio_url(first_audio) if first_audio else None,
|
| 701 |
+
"second_audio_path": _path_to_audio_url(second_audio) if second_audio else None,
|
| 702 |
+
"audio_paths": [_path_to_audio_url(p) for p in audio_paths],
|
| 703 |
+
"generation_info": generation_info,
|
| 704 |
+
"status_message": result.status_message,
|
| 705 |
"seed_value": seed_value,
|
| 706 |
"metas": metas_out,
|
| 707 |
+
"bpm": metas_out.get("bpm") if isinstance(metas_out.get("bpm"), int) else None,
|
| 708 |
+
"duration": metas_out.get("duration") if isinstance(metas_out.get("duration"), (int, float)) else None,
|
| 709 |
"genres": _none_if_na_str(metas_out.get("genres")),
|
| 710 |
"keyscale": _none_if_na_str(metas_out.get("keyscale")),
|
| 711 |
"timesignature": _none_if_na_str(metas_out.get("timesignature")),
|
|
|
|
| 810 |
|
| 811 |
return default
|
| 812 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
normalized_audio_duration = _to_float(_get_any("audio_duration", "duration", "audioDuration"), None)
|
| 814 |
normalized_bpm = _to_int(_get_any("bpm"), None)
|
| 815 |
normalized_keyscale = str(_get_any("key_scale", "keyscale", "keyScale", default="") or "")
|
|
|
|
| 819 |
if normalized_audio_duration is None:
|
| 820 |
normalized_audio_duration = _to_float(_get_any("target_duration", "targetDuration"), None)
|
| 821 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 822 |
return GenerateMusicRequest(
|
| 823 |
caption=str(get("caption", "") or ""),
|
| 824 |
lyrics=str(get("lyrics", "") or ""),
|
|
|
|
| 857 |
lm_negative_prompt=str(get("lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
|
| 858 |
constrained_decoding=_to_bool(_get_any("constrained_decoding", "constrainedDecoding", "constrained"), True),
|
| 859 |
constrained_decoding_debug=_to_bool(_get_any("constrained_decoding_debug", "constrainedDecodingDebug"), False),
|
|
|
|
| 860 |
use_cot_caption=_to_bool(_get_any("use_cot_caption", "cot_caption", "cot-caption"), True),
|
| 861 |
use_cot_language=_to_bool(_get_any("use_cot_language", "cot_language", "cot-language"), True),
|
| 862 |
is_format_caption=_to_bool(_get_any("is_format_caption", "isFormatCaption"), False),
|
acestep/audio_utils.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio saving and transcoding utility module
|
| 3 |
+
|
| 4 |
+
Independent audio file operations outside of handler, supporting:
|
| 5 |
+
- Save audio tensor/numpy to files (default FLAC format, fast)
|
| 6 |
+
- Format conversion (FLAC/WAV/MP3)
|
| 7 |
+
- Batch processing
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import hashlib
|
| 12 |
+
import json
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Union, Optional, List, Tuple
|
| 15 |
+
import torch
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torchaudio
|
| 18 |
+
from loguru import logger
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AudioSaver:
|
| 22 |
+
"""Audio saving and transcoding utility class"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, default_format: str = "flac"):
|
| 25 |
+
"""
|
| 26 |
+
Initialize audio saver
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
default_format: Default save format ('flac', 'wav', 'mp3')
|
| 30 |
+
"""
|
| 31 |
+
self.default_format = default_format.lower()
|
| 32 |
+
if self.default_format not in ["flac", "wav", "mp3"]:
|
| 33 |
+
logger.warning(f"Unsupported format {default_format}, using 'flac'")
|
| 34 |
+
self.default_format = "flac"
|
| 35 |
+
|
| 36 |
+
def save_audio(
|
| 37 |
+
self,
|
| 38 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 39 |
+
output_path: Union[str, Path],
|
| 40 |
+
sample_rate: int = 48000,
|
| 41 |
+
format: Optional[str] = None,
|
| 42 |
+
channels_first: bool = True,
|
| 43 |
+
) -> str:
|
| 44 |
+
"""
|
| 45 |
+
Save audio data to file
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
|
| 49 |
+
output_path: Output file path (extension can be omitted)
|
| 50 |
+
sample_rate: Sample rate
|
| 51 |
+
format: Audio format ('flac', 'wav', 'mp3'), defaults to default_format
|
| 52 |
+
channels_first: If True, tensor format is [channels, samples], else [samples, channels]
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Actual saved file path
|
| 56 |
+
"""
|
| 57 |
+
format = (format or self.default_format).lower()
|
| 58 |
+
if format not in ["flac", "wav", "mp3"]:
|
| 59 |
+
logger.warning(f"Unsupported format {format}, using {self.default_format}")
|
| 60 |
+
format = self.default_format
|
| 61 |
+
|
| 62 |
+
# Ensure output path has correct extension
|
| 63 |
+
output_path = Path(output_path)
|
| 64 |
+
if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
|
| 65 |
+
output_path = output_path.with_suffix(f'.{format}')
|
| 66 |
+
|
| 67 |
+
# Convert to torch tensor
|
| 68 |
+
if isinstance(audio_data, np.ndarray):
|
| 69 |
+
if channels_first:
|
| 70 |
+
# numpy [samples, channels] -> tensor [channels, samples]
|
| 71 |
+
audio_tensor = torch.from_numpy(audio_data.T).float()
|
| 72 |
+
else:
|
| 73 |
+
# numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples]
|
| 74 |
+
audio_tensor = torch.from_numpy(audio_data).float()
|
| 75 |
+
if audio_tensor.dim() == 2 and audio_tensor.shape[0] < audio_tensor.shape[1]:
|
| 76 |
+
audio_tensor = audio_tensor.T
|
| 77 |
+
else:
|
| 78 |
+
# torch tensor
|
| 79 |
+
audio_tensor = audio_data.cpu().float()
|
| 80 |
+
if not channels_first and audio_tensor.dim() == 2:
|
| 81 |
+
# [samples, channels] -> [channels, samples]
|
| 82 |
+
if audio_tensor.shape[0] > audio_tensor.shape[1]:
|
| 83 |
+
audio_tensor = audio_tensor.T
|
| 84 |
+
|
| 85 |
+
# Ensure memory is contiguous
|
| 86 |
+
audio_tensor = audio_tensor.contiguous()
|
| 87 |
+
|
| 88 |
+
# Select backend and save
|
| 89 |
+
try:
|
| 90 |
+
if format == "mp3":
|
| 91 |
+
# MP3 uses ffmpeg backend
|
| 92 |
+
torchaudio.save(
|
| 93 |
+
str(output_path),
|
| 94 |
+
audio_tensor,
|
| 95 |
+
sample_rate,
|
| 96 |
+
channels_first=True,
|
| 97 |
+
backend='ffmpeg',
|
| 98 |
+
)
|
| 99 |
+
elif format in ["flac", "wav"]:
|
| 100 |
+
# FLAC and WAV use soundfile backend (fastest)
|
| 101 |
+
torchaudio.save(
|
| 102 |
+
str(output_path),
|
| 103 |
+
audio_tensor,
|
| 104 |
+
sample_rate,
|
| 105 |
+
channels_first=True,
|
| 106 |
+
backend='soundfile',
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
# Other formats use default backend
|
| 110 |
+
torchaudio.save(
|
| 111 |
+
str(output_path),
|
| 112 |
+
audio_tensor,
|
| 113 |
+
sample_rate,
|
| 114 |
+
channels_first=True,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
|
| 118 |
+
return str(output_path)
|
| 119 |
+
|
| 120 |
+
except Exception as e:
|
| 121 |
+
logger.error(f"[AudioSaver] Failed to save audio: {e}")
|
| 122 |
+
raise
|
| 123 |
+
|
| 124 |
+
def convert_audio(
|
| 125 |
+
self,
|
| 126 |
+
input_path: Union[str, Path],
|
| 127 |
+
output_path: Union[str, Path],
|
| 128 |
+
output_format: str,
|
| 129 |
+
remove_input: bool = False,
|
| 130 |
+
) -> str:
|
| 131 |
+
"""
|
| 132 |
+
Convert audio format
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
input_path: Input audio file path
|
| 136 |
+
output_path: Output audio file path
|
| 137 |
+
output_format: Target format ('flac', 'wav', 'mp3')
|
| 138 |
+
remove_input: Whether to delete input file
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
Output file path
|
| 142 |
+
"""
|
| 143 |
+
input_path = Path(input_path)
|
| 144 |
+
output_path = Path(output_path)
|
| 145 |
+
|
| 146 |
+
if not input_path.exists():
|
| 147 |
+
raise FileNotFoundError(f"Input file not found: {input_path}")
|
| 148 |
+
|
| 149 |
+
# Load audio
|
| 150 |
+
audio_tensor, sample_rate = torchaudio.load(str(input_path))
|
| 151 |
+
|
| 152 |
+
# Save as new format
|
| 153 |
+
output_path = self.save_audio(
|
| 154 |
+
audio_tensor,
|
| 155 |
+
output_path,
|
| 156 |
+
sample_rate=sample_rate,
|
| 157 |
+
format=output_format,
|
| 158 |
+
channels_first=True
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Delete input file if needed
|
| 162 |
+
if remove_input:
|
| 163 |
+
input_path.unlink()
|
| 164 |
+
logger.debug(f"[AudioSaver] Removed input file: {input_path}")
|
| 165 |
+
|
| 166 |
+
return output_path
|
| 167 |
+
|
| 168 |
+
def save_batch(
|
| 169 |
+
self,
|
| 170 |
+
audio_batch: Union[List[torch.Tensor], torch.Tensor],
|
| 171 |
+
output_dir: Union[str, Path],
|
| 172 |
+
file_prefix: str = "audio",
|
| 173 |
+
sample_rate: int = 48000,
|
| 174 |
+
format: Optional[str] = None,
|
| 175 |
+
channels_first: bool = True,
|
| 176 |
+
) -> List[str]:
|
| 177 |
+
"""
|
| 178 |
+
Save audio batch
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
|
| 182 |
+
output_dir: Output directory
|
| 183 |
+
file_prefix: File prefix
|
| 184 |
+
sample_rate: Sample rate
|
| 185 |
+
format: Audio format
|
| 186 |
+
channels_first: Tensor format flag
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
List of saved file paths
|
| 190 |
+
"""
|
| 191 |
+
output_dir = Path(output_dir)
|
| 192 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 193 |
+
|
| 194 |
+
# Process batch
|
| 195 |
+
if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
|
| 196 |
+
# [batch, channels, samples]
|
| 197 |
+
audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
|
| 198 |
+
elif isinstance(audio_batch, list):
|
| 199 |
+
audio_list = audio_batch
|
| 200 |
+
else:
|
| 201 |
+
audio_list = [audio_batch]
|
| 202 |
+
|
| 203 |
+
saved_paths = []
|
| 204 |
+
for i, audio in enumerate(audio_list):
|
| 205 |
+
output_path = output_dir / f"{file_prefix}_{i:04d}"
|
| 206 |
+
saved_path = self.save_audio(
|
| 207 |
+
audio,
|
| 208 |
+
output_path,
|
| 209 |
+
sample_rate=sample_rate,
|
| 210 |
+
format=format,
|
| 211 |
+
channels_first=channels_first
|
| 212 |
+
)
|
| 213 |
+
saved_paths.append(saved_path)
|
| 214 |
+
|
| 215 |
+
return saved_paths
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def get_audio_file_hash(audio_file) -> str:
|
| 219 |
+
"""
|
| 220 |
+
Get hash identifier for an audio file.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
audio_file: Path to audio file (str) or file-like object
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
Hash string or empty string
|
| 227 |
+
"""
|
| 228 |
+
if audio_file is None:
|
| 229 |
+
return ""
|
| 230 |
+
|
| 231 |
+
try:
|
| 232 |
+
if isinstance(audio_file, str):
|
| 233 |
+
if os.path.exists(audio_file):
|
| 234 |
+
with open(audio_file, 'rb') as f:
|
| 235 |
+
return hashlib.md5(f.read()).hexdigest()
|
| 236 |
+
return hashlib.md5(audio_file.encode('utf-8')).hexdigest()
|
| 237 |
+
elif hasattr(audio_file, 'name'):
|
| 238 |
+
return hashlib.md5(str(audio_file.name).encode('utf-8')).hexdigest()
|
| 239 |
+
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
|
| 240 |
+
except Exception:
|
| 241 |
+
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def generate_uuid_from_params(params_dict) -> str:
|
| 245 |
+
"""
|
| 246 |
+
Generate deterministic UUID from generation parameters.
|
| 247 |
+
Same parameters will always generate the same UUID.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
params_dict: Dictionary of parameters
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
UUID string
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
|
| 257 |
+
hash_obj = hashlib.sha256(params_json.encode('utf-8'))
|
| 258 |
+
hash_hex = hash_obj.hexdigest()
|
| 259 |
+
uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
|
| 260 |
+
return uuid_str
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def generate_uuid_from_audio_data(
|
| 264 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 265 |
+
seed: Optional[int] = None
|
| 266 |
+
) -> str:
|
| 267 |
+
"""
|
| 268 |
+
Generate UUID from audio data (for caching/deduplication)
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
audio_data: Audio data
|
| 272 |
+
seed: Optional seed value
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
UUID string
|
| 276 |
+
"""
|
| 277 |
+
if isinstance(audio_data, torch.Tensor):
|
| 278 |
+
# Convert to numpy and calculate hash
|
| 279 |
+
audio_np = audio_data.cpu().numpy()
|
| 280 |
+
else:
|
| 281 |
+
audio_np = audio_data
|
| 282 |
+
|
| 283 |
+
# Calculate data hash
|
| 284 |
+
data_hash = hashlib.md5(audio_np.tobytes()).hexdigest()
|
| 285 |
+
|
| 286 |
+
if seed is not None:
|
| 287 |
+
combined = f"{data_hash}_{seed}"
|
| 288 |
+
return hashlib.md5(combined.encode()).hexdigest()
|
| 289 |
+
|
| 290 |
+
return data_hash
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# Global default instance
|
| 294 |
+
_default_saver = AudioSaver(default_format="flac")
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def save_audio(
|
| 298 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 299 |
+
output_path: Union[str, Path],
|
| 300 |
+
sample_rate: int = 48000,
|
| 301 |
+
format: Optional[str] = None,
|
| 302 |
+
channels_first: bool = True,
|
| 303 |
+
) -> str:
|
| 304 |
+
"""
|
| 305 |
+
Convenience function: save audio (using default configuration)
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
audio_data: Audio data
|
| 309 |
+
output_path: Output path
|
| 310 |
+
sample_rate: Sample rate
|
| 311 |
+
format: Format (default flac)
|
| 312 |
+
channels_first: Tensor format flag
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
Saved file path
|
| 316 |
+
"""
|
| 317 |
+
return _default_saver.save_audio(
|
| 318 |
+
audio_data, output_path, sample_rate, format, channels_first
|
| 319 |
+
)
|
| 320 |
+
|
acestep/constrained_logits_processor.py
CHANGED
|
@@ -571,6 +571,33 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 571 |
if self.debug:
|
| 572 |
logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens")
|
| 573 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
|
| 575 |
"""
|
| 576 |
Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
|
|
@@ -1484,10 +1511,10 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1484 |
if self.debug:
|
| 1485 |
logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS")
|
| 1486 |
else:
|
| 1487 |
-
# Force EOS token when target codes count is reached
|
| 1488 |
-
|
| 1489 |
-
|
| 1490 |
-
scores =
|
| 1491 |
if self.debug:
|
| 1492 |
logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS")
|
| 1493 |
return self._apply_temperature_scaling(scores)
|
|
@@ -1609,20 +1636,15 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1609 |
input_ids: torch.LongTensor,
|
| 1610 |
scores: torch.FloatTensor,
|
| 1611 |
) -> torch.FloatTensor:
|
| 1612 |
-
"""Process a single sequence and return modified scores."""
|
| 1613 |
|
| 1614 |
# Check if we have tokens in queue for user-provided field
|
| 1615 |
# If so, inject the next token directly
|
| 1616 |
if self.user_field_token_queue:
|
| 1617 |
-
mask = torch.full_like(scores, float('-inf'))
|
| 1618 |
next_token = self.user_field_token_queue[0]
|
| 1619 |
-
|
| 1620 |
-
scores = scores + mask
|
| 1621 |
return scores
|
| 1622 |
|
| 1623 |
-
# Create mask (all -inf initially)
|
| 1624 |
-
mask = torch.full_like(scores, float('-inf'))
|
| 1625 |
-
|
| 1626 |
if self.state in self.fixed_strings:
|
| 1627 |
# Fixed string state: force specific tokens
|
| 1628 |
fixed_str = self.fixed_strings[self.state]
|
|
@@ -1633,28 +1655,18 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1633 |
# This happens when we're about to complete the </think> tag
|
| 1634 |
if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
|
| 1635 |
# Check if the next token would complete the fixed string
|
| 1636 |
-
# We check if position_in_state + length of next token would complete it
|
| 1637 |
-
# Since we don't know which token will be selected, we check if we're close to completion
|
| 1638 |
-
# Actually, a better approach: check if this is the last character(s) of the fixed string
|
| 1639 |
remaining_chars = len(fixed_str) - self.position_in_state
|
| 1640 |
# If remaining is small (<= 10 chars, which is typically 1-2 tokens), force EOS
|
| 1641 |
if remaining_chars <= 10:
|
| 1642 |
# Force EOS token to stop generation
|
| 1643 |
if self.eos_token_id is not None:
|
| 1644 |
-
|
| 1645 |
-
scores = scores + mask
|
| 1646 |
if self.debug:
|
| 1647 |
logger.debug(f"stop_at_reasoning=True: forcing EOS near end of </think> tag (remaining: {remaining_chars} chars)")
|
| 1648 |
return scores
|
| 1649 |
|
| 1650 |
-
|
| 1651 |
-
|
| 1652 |
-
# Apply mask
|
| 1653 |
-
scores = scores + mask
|
| 1654 |
-
|
| 1655 |
-
# Update position tracking
|
| 1656 |
-
# We need to check if the selected token completes the fixed string
|
| 1657 |
-
# This will be done in update_state() after token selection
|
| 1658 |
else:
|
| 1659 |
# Position exceeds string, move to next state
|
| 1660 |
# If stop_at_reasoning is True and we're transitioning from THINK_END_TAG,
|
|
@@ -1662,8 +1674,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1662 |
if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
|
| 1663 |
# Force EOS token to stop generation
|
| 1664 |
if self.eos_token_id is not None:
|
| 1665 |
-
|
| 1666 |
-
scores = scores + mask
|
| 1667 |
if self.debug:
|
| 1668 |
logger.debug(f"stop_at_reasoning=True: forcing EOS after completing </think> tag")
|
| 1669 |
return scores
|
|
@@ -1676,7 +1687,9 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1676 |
if self.debug:
|
| 1677 |
logger.warning(f"State transition from {old_state.name} to {self.state.name} still in fixed_strings, avoiding recursion")
|
| 1678 |
return scores
|
| 1679 |
-
|
|
|
|
|
|
|
| 1680 |
|
| 1681 |
elif self.state == FSMState.BPM_VALUE:
|
| 1682 |
# Check if field is user-provided and we haven't started injecting yet
|
|
@@ -1690,22 +1703,18 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1690 |
self.user_field_token_queue = value_tokens
|
| 1691 |
self.current_user_field = "bpm"
|
| 1692 |
# Inject first token
|
| 1693 |
-
|
| 1694 |
-
scores = scores + mask
|
| 1695 |
return scores
|
| 1696 |
|
| 1697 |
# Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "120")
|
| 1698 |
allowed = self._get_allowed_numeric_tokens(self.bpm_prefix_tree)
|
| 1699 |
-
for t in allowed:
|
| 1700 |
-
mask[0, t] = 0
|
| 1701 |
|
| 1702 |
# Also allow newline if current token sequence prefix allows it
|
| 1703 |
-
# Check if current token sequence is in prefix tree and allows newline
|
| 1704 |
token_prefix = tuple(self.accumulated_token_ids)
|
| 1705 |
if token_prefix in self.bpm_prefix_tree and self.newline_token in self.bpm_prefix_tree[token_prefix]:
|
| 1706 |
-
|
| 1707 |
|
| 1708 |
-
scores
|
| 1709 |
|
| 1710 |
elif self.state == FSMState.CAPTION_VALUE:
|
| 1711 |
# Caption field generation with YAML format support:
|
|
@@ -1724,8 +1733,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1724 |
self.user_field_token_queue = value_tokens
|
| 1725 |
self.current_user_field = "caption"
|
| 1726 |
# Inject first token
|
| 1727 |
-
|
| 1728 |
-
scores = scores + mask
|
| 1729 |
return scores
|
| 1730 |
|
| 1731 |
# Check if we should transition after a newline (non-indented line = new field)
|
|
@@ -1757,7 +1765,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1757 |
# The field name detection will happen in update_state()
|
| 1758 |
return scores
|
| 1759 |
|
| 1760 |
-
# Block backticks (code blocks)
|
| 1761 |
if self.backtick_token is not None:
|
| 1762 |
scores[0, self.backtick_token] = float('-inf')
|
| 1763 |
|
|
@@ -1773,8 +1781,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1773 |
if self.caption_token_count >= 512:
|
| 1774 |
# Force end by only allowing newline
|
| 1775 |
if self.newline_token is not None:
|
| 1776 |
-
|
| 1777 |
-
scores = scores + mask
|
| 1778 |
return scores
|
| 1779 |
|
| 1780 |
# Allow natural generation (with blocked audio codes and backticks)
|
|
@@ -1791,8 +1798,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1791 |
self.user_field_token_queue = value_tokens
|
| 1792 |
self.current_user_field = "duration"
|
| 1793 |
# Inject first token
|
| 1794 |
-
|
| 1795 |
-
scores = scores + mask
|
| 1796 |
return scores
|
| 1797 |
|
| 1798 |
# If target_duration is set, force generate that exact value
|
|
@@ -1804,26 +1810,22 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1804 |
# Force the next digit
|
| 1805 |
next_digit = int(target_str[current_pos])
|
| 1806 |
if next_digit in self.digit_tokens:
|
| 1807 |
-
|
| 1808 |
else:
|
| 1809 |
# All digits generated, force newline
|
| 1810 |
if self.newline_token:
|
| 1811 |
-
|
| 1812 |
-
|
| 1813 |
-
scores = scores + mask
|
| 1814 |
else:
|
| 1815 |
# Normal duration generation with range constraint
|
| 1816 |
# Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "60", "120")
|
| 1817 |
allowed = self._get_allowed_numeric_tokens(self.duration_prefix_tree)
|
| 1818 |
-
for t in allowed:
|
| 1819 |
-
mask[0, t] = 0
|
| 1820 |
|
| 1821 |
# Also allow newline if current token sequence prefix allows it
|
| 1822 |
token_prefix = tuple(self.accumulated_token_ids)
|
| 1823 |
if token_prefix in self.duration_prefix_tree and self.newline_token in self.duration_prefix_tree[token_prefix]:
|
| 1824 |
-
|
| 1825 |
|
| 1826 |
-
scores
|
| 1827 |
|
| 1828 |
elif self.state == FSMState.GENRES_VALUE:
|
| 1829 |
# Check if field is user-provided and we haven't started injecting yet
|
|
@@ -1836,8 +1838,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1836 |
self.user_field_token_queue = value_tokens
|
| 1837 |
self.current_user_field = "genres"
|
| 1838 |
# Inject first token
|
| 1839 |
-
|
| 1840 |
-
scores = scores + mask
|
| 1841 |
return scores
|
| 1842 |
|
| 1843 |
# Try to hot-reload genres vocab if file has changed
|
|
@@ -1848,24 +1849,20 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1848 |
|
| 1849 |
if allowed:
|
| 1850 |
# Use vocabulary-constrained decoding
|
| 1851 |
-
|
| 1852 |
-
mask[0, t] = 0
|
| 1853 |
-
scores = scores + mask
|
| 1854 |
elif self.genres_vocab:
|
| 1855 |
# Vocab is loaded but no valid continuation found
|
| 1856 |
# Force newline to end the field
|
| 1857 |
if self.newline_token:
|
| 1858 |
-
mask[0, self.newline_token] = 0
|
| 1859 |
if self.debug:
|
| 1860 |
logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
|
| 1861 |
-
|
| 1862 |
else:
|
| 1863 |
# Fallback: no vocab loaded, use probability-based ending
|
| 1864 |
if self._should_end_text_field(scores):
|
| 1865 |
if self.newline_token:
|
| 1866 |
-
|
| 1867 |
self._transition_to_next_state()
|
| 1868 |
-
scores = scores + mask
|
| 1869 |
else:
|
| 1870 |
# Allow any token except newline if we don't have content yet
|
| 1871 |
if not self.accumulated_value.strip():
|
|
@@ -1884,8 +1881,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1884 |
self.user_field_token_queue = value_tokens
|
| 1885 |
self.current_user_field = "keyscale"
|
| 1886 |
# Inject first token
|
| 1887 |
-
|
| 1888 |
-
scores = scores + mask
|
| 1889 |
return scores
|
| 1890 |
|
| 1891 |
# Check if current token sequence is complete (allows newline)
|
|
@@ -1893,21 +1889,17 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1893 |
if token_prefix in self.keyscale_prefix_tree and self.newline_token in self.keyscale_prefix_tree[token_prefix]:
|
| 1894 |
# Complete keyscale, allow newline
|
| 1895 |
if self.newline_token:
|
| 1896 |
-
|
| 1897 |
-
scores = scores + mask
|
| 1898 |
else:
|
| 1899 |
# Not complete, allow valid continuation tokens
|
| 1900 |
allowed = self._get_allowed_keyscale_tokens()
|
| 1901 |
if allowed:
|
| 1902 |
-
|
| 1903 |
-
mask[0, t] = 0
|
| 1904 |
-
scores = scores + mask
|
| 1905 |
else:
|
| 1906 |
# No valid tokens found - force newline to end field
|
| 1907 |
# This handles edge cases where keyscale format is unexpected
|
| 1908 |
if self.newline_token:
|
| 1909 |
-
|
| 1910 |
-
scores = scores + mask
|
| 1911 |
|
| 1912 |
elif self.state == FSMState.LANGUAGE_VALUE:
|
| 1913 |
# Language field: Use top-1 probability language (greedy selection)
|
|
@@ -1925,8 +1917,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1925 |
self.user_field_token_queue = value_tokens
|
| 1926 |
self.current_user_field = "language"
|
| 1927 |
# Inject first token
|
| 1928 |
-
|
| 1929 |
-
scores = scores + mask
|
| 1930 |
return scores
|
| 1931 |
|
| 1932 |
# If we haven't started generating language yet (empty accumulated_token_ids),
|
|
@@ -1938,19 +1929,17 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1938 |
candidate_tokens = list(self.language_prefix_tree[empty_prefix])
|
| 1939 |
|
| 1940 |
if candidate_tokens:
|
| 1941 |
-
# Find the token with highest probability (top-1)
|
| 1942 |
-
#
|
| 1943 |
-
|
| 1944 |
-
|
| 1945 |
-
temp_mask[0, t] = 0
|
| 1946 |
-
temp_scores = scores + temp_mask
|
| 1947 |
|
| 1948 |
# Get the highest probability token among candidates
|
| 1949 |
-
|
|
|
|
| 1950 |
|
| 1951 |
-
# Only allow this top-1 token, block all others
|
| 1952 |
-
|
| 1953 |
-
scores = scores + mask
|
| 1954 |
|
| 1955 |
if self.debug:
|
| 1956 |
top_token_text = self.tokenizer.decode([top_token_id])
|
|
@@ -1958,13 +1947,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1958 |
else:
|
| 1959 |
# No valid first tokens found - force newline
|
| 1960 |
if self.newline_token:
|
| 1961 |
-
|
| 1962 |
-
scores = scores + mask
|
| 1963 |
else:
|
| 1964 |
# Empty prefix not in tree - force newline
|
| 1965 |
if self.newline_token:
|
| 1966 |
-
|
| 1967 |
-
scores = scores + mask
|
| 1968 |
else:
|
| 1969 |
# We've started generating a language, continue with prefix tree constraints
|
| 1970 |
# Check if current token sequence is complete (allows newline)
|
|
@@ -1972,20 +1959,16 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1972 |
if token_prefix in self.language_prefix_tree and self.newline_token in self.language_prefix_tree[token_prefix]:
|
| 1973 |
# Complete language, allow newline
|
| 1974 |
if self.newline_token:
|
| 1975 |
-
|
| 1976 |
-
scores = scores + mask
|
| 1977 |
else:
|
| 1978 |
# Not complete, allow valid continuation tokens
|
| 1979 |
allowed = self._get_allowed_language_tokens()
|
| 1980 |
if allowed:
|
| 1981 |
-
|
| 1982 |
-
mask[0, t] = 0
|
| 1983 |
-
scores = scores + mask
|
| 1984 |
else:
|
| 1985 |
# No valid tokens found - force newline to end field
|
| 1986 |
if self.newline_token:
|
| 1987 |
-
|
| 1988 |
-
scores = scores + mask
|
| 1989 |
|
| 1990 |
elif self.state == FSMState.TIMESIG_VALUE:
|
| 1991 |
# Check if field is user-provided and we haven't started injecting yet
|
|
@@ -1998,8 +1981,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1998 |
self.user_field_token_queue = value_tokens
|
| 1999 |
self.current_user_field = "timesignature"
|
| 2000 |
# Inject first token
|
| 2001 |
-
|
| 2002 |
-
scores = scores + mask
|
| 2003 |
return scores
|
| 2004 |
|
| 2005 |
# Check if current token sequence is complete (allows newline)
|
|
@@ -2007,14 +1989,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 2007 |
if token_prefix in self.timesig_prefix_tree and self.newline_token in self.timesig_prefix_tree[token_prefix]:
|
| 2008 |
# Complete value, allow newline
|
| 2009 |
if self.newline_token:
|
| 2010 |
-
|
| 2011 |
-
scores = scores + mask
|
| 2012 |
else:
|
| 2013 |
# Not complete, allow valid continuation tokens
|
| 2014 |
allowed = self._get_allowed_timesig_tokens()
|
| 2015 |
-
|
| 2016 |
-
mask[0, t] = 0
|
| 2017 |
-
scores = scores + mask
|
| 2018 |
|
| 2019 |
return scores
|
| 2020 |
|
|
|
|
| 571 |
if self.debug:
|
| 572 |
logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens")
|
| 573 |
|
| 574 |
+
def _apply_whitelist_inplace(self, scores: torch.Tensor, allowed_tokens: List[int]) -> None:
|
| 575 |
+
"""
|
| 576 |
+
Apply whitelist constraint inplace: only allow specified tokens, block all others.
|
| 577 |
+
|
| 578 |
+
This is more efficient than creating a mask tensor because:
|
| 579 |
+
1. No memory allocation for mask
|
| 580 |
+
2. No tensor addition operation
|
| 581 |
+
|
| 582 |
+
Args:
|
| 583 |
+
scores: [1, vocab_size] scores tensor to modify inplace
|
| 584 |
+
allowed_tokens: List of token IDs to allow (all others will be set to -inf)
|
| 585 |
+
"""
|
| 586 |
+
if not allowed_tokens:
|
| 587 |
+
# No tokens allowed, set all to -inf
|
| 588 |
+
scores.fill_(float('-inf'))
|
| 589 |
+
return
|
| 590 |
+
|
| 591 |
+
# Save the original values of allowed tokens
|
| 592 |
+
allowed_indices = torch.tensor(allowed_tokens, device=scores.device, dtype=torch.long)
|
| 593 |
+
saved_values = scores[0, allowed_indices].clone()
|
| 594 |
+
|
| 595 |
+
# Set all scores to -inf
|
| 596 |
+
scores.fill_(float('-inf'))
|
| 597 |
+
|
| 598 |
+
# Restore allowed token values
|
| 599 |
+
scores[0, allowed_indices] = saved_values
|
| 600 |
+
|
| 601 |
def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
|
| 602 |
"""
|
| 603 |
Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
|
|
|
|
| 1511 |
if self.debug:
|
| 1512 |
logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS")
|
| 1513 |
else:
|
| 1514 |
+
# Force EOS token when target codes count is reached - inplace
|
| 1515 |
+
eos_scores = scores[:, self.eos_token_id].clone()
|
| 1516 |
+
scores.fill_(float('-inf'))
|
| 1517 |
+
scores[:, self.eos_token_id] = eos_scores
|
| 1518 |
if self.debug:
|
| 1519 |
logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS")
|
| 1520 |
return self._apply_temperature_scaling(scores)
|
|
|
|
| 1636 |
input_ids: torch.LongTensor,
|
| 1637 |
scores: torch.FloatTensor,
|
| 1638 |
) -> torch.FloatTensor:
|
| 1639 |
+
"""Process a single sequence and return modified scores (inplace when possible)."""
|
| 1640 |
|
| 1641 |
# Check if we have tokens in queue for user-provided field
|
| 1642 |
# If so, inject the next token directly
|
| 1643 |
if self.user_field_token_queue:
|
|
|
|
| 1644 |
next_token = self.user_field_token_queue[0]
|
| 1645 |
+
self._apply_whitelist_inplace(scores, [next_token])
|
|
|
|
| 1646 |
return scores
|
| 1647 |
|
|
|
|
|
|
|
|
|
|
| 1648 |
if self.state in self.fixed_strings:
|
| 1649 |
# Fixed string state: force specific tokens
|
| 1650 |
fixed_str = self.fixed_strings[self.state]
|
|
|
|
| 1655 |
# This happens when we're about to complete the </think> tag
|
| 1656 |
if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
|
| 1657 |
# Check if the next token would complete the fixed string
|
|
|
|
|
|
|
|
|
|
| 1658 |
remaining_chars = len(fixed_str) - self.position_in_state
|
| 1659 |
# If remaining is small (<= 10 chars, which is typically 1-2 tokens), force EOS
|
| 1660 |
if remaining_chars <= 10:
|
| 1661 |
# Force EOS token to stop generation
|
| 1662 |
if self.eos_token_id is not None:
|
| 1663 |
+
self._apply_whitelist_inplace(scores, [self.eos_token_id])
|
|
|
|
| 1664 |
if self.debug:
|
| 1665 |
logger.debug(f"stop_at_reasoning=True: forcing EOS near end of </think> tag (remaining: {remaining_chars} chars)")
|
| 1666 |
return scores
|
| 1667 |
|
| 1668 |
+
# Apply whitelist constraint inplace
|
| 1669 |
+
self._apply_whitelist_inplace(scores, allowed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1670 |
else:
|
| 1671 |
# Position exceeds string, move to next state
|
| 1672 |
# If stop_at_reasoning is True and we're transitioning from THINK_END_TAG,
|
|
|
|
| 1674 |
if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
|
| 1675 |
# Force EOS token to stop generation
|
| 1676 |
if self.eos_token_id is not None:
|
| 1677 |
+
self._apply_whitelist_inplace(scores, [self.eos_token_id])
|
|
|
|
| 1678 |
if self.debug:
|
| 1679 |
logger.debug(f"stop_at_reasoning=True: forcing EOS after completing </think> tag")
|
| 1680 |
return scores
|
|
|
|
| 1687 |
if self.debug:
|
| 1688 |
logger.warning(f"State transition from {old_state.name} to {self.state.name} still in fixed_strings, avoiding recursion")
|
| 1689 |
return scores
|
| 1690 |
+
# For recursion, reset scores to zero (no constraints from previous state)
|
| 1691 |
+
scores.zero_()
|
| 1692 |
+
return self._process_single_sequence(input_ids, scores)
|
| 1693 |
|
| 1694 |
elif self.state == FSMState.BPM_VALUE:
|
| 1695 |
# Check if field is user-provided and we haven't started injecting yet
|
|
|
|
| 1703 |
self.user_field_token_queue = value_tokens
|
| 1704 |
self.current_user_field = "bpm"
|
| 1705 |
# Inject first token
|
| 1706 |
+
self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
|
|
|
| 1707 |
return scores
|
| 1708 |
|
| 1709 |
# Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "120")
|
| 1710 |
allowed = self._get_allowed_numeric_tokens(self.bpm_prefix_tree)
|
|
|
|
|
|
|
| 1711 |
|
| 1712 |
# Also allow newline if current token sequence prefix allows it
|
|
|
|
| 1713 |
token_prefix = tuple(self.accumulated_token_ids)
|
| 1714 |
if token_prefix in self.bpm_prefix_tree and self.newline_token in self.bpm_prefix_tree[token_prefix]:
|
| 1715 |
+
allowed = allowed + [self.newline_token]
|
| 1716 |
|
| 1717 |
+
self._apply_whitelist_inplace(scores, allowed)
|
| 1718 |
|
| 1719 |
elif self.state == FSMState.CAPTION_VALUE:
|
| 1720 |
# Caption field generation with YAML format support:
|
|
|
|
| 1733 |
self.user_field_token_queue = value_tokens
|
| 1734 |
self.current_user_field = "caption"
|
| 1735 |
# Inject first token
|
| 1736 |
+
self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
|
|
|
| 1737 |
return scores
|
| 1738 |
|
| 1739 |
# Check if we should transition after a newline (non-indented line = new field)
|
|
|
|
| 1765 |
# The field name detection will happen in update_state()
|
| 1766 |
return scores
|
| 1767 |
|
| 1768 |
+
# Block backticks (code blocks) - inplace
|
| 1769 |
if self.backtick_token is not None:
|
| 1770 |
scores[0, self.backtick_token] = float('-inf')
|
| 1771 |
|
|
|
|
| 1781 |
if self.caption_token_count >= 512:
|
| 1782 |
# Force end by only allowing newline
|
| 1783 |
if self.newline_token is not None:
|
| 1784 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
| 1785 |
return scores
|
| 1786 |
|
| 1787 |
# Allow natural generation (with blocked audio codes and backticks)
|
|
|
|
| 1798 |
self.user_field_token_queue = value_tokens
|
| 1799 |
self.current_user_field = "duration"
|
| 1800 |
# Inject first token
|
| 1801 |
+
self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
|
|
|
| 1802 |
return scores
|
| 1803 |
|
| 1804 |
# If target_duration is set, force generate that exact value
|
|
|
|
| 1810 |
# Force the next digit
|
| 1811 |
next_digit = int(target_str[current_pos])
|
| 1812 |
if next_digit in self.digit_tokens:
|
| 1813 |
+
self._apply_whitelist_inplace(scores, [self.digit_tokens[next_digit]])
|
| 1814 |
else:
|
| 1815 |
# All digits generated, force newline
|
| 1816 |
if self.newline_token:
|
| 1817 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
|
|
|
| 1818 |
else:
|
| 1819 |
# Normal duration generation with range constraint
|
| 1820 |
# Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "60", "120")
|
| 1821 |
allowed = self._get_allowed_numeric_tokens(self.duration_prefix_tree)
|
|
|
|
|
|
|
| 1822 |
|
| 1823 |
# Also allow newline if current token sequence prefix allows it
|
| 1824 |
token_prefix = tuple(self.accumulated_token_ids)
|
| 1825 |
if token_prefix in self.duration_prefix_tree and self.newline_token in self.duration_prefix_tree[token_prefix]:
|
| 1826 |
+
allowed = allowed + [self.newline_token]
|
| 1827 |
|
| 1828 |
+
self._apply_whitelist_inplace(scores, allowed)
|
| 1829 |
|
| 1830 |
elif self.state == FSMState.GENRES_VALUE:
|
| 1831 |
# Check if field is user-provided and we haven't started injecting yet
|
|
|
|
| 1838 |
self.user_field_token_queue = value_tokens
|
| 1839 |
self.current_user_field = "genres"
|
| 1840 |
# Inject first token
|
| 1841 |
+
self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
|
|
|
| 1842 |
return scores
|
| 1843 |
|
| 1844 |
# Try to hot-reload genres vocab if file has changed
|
|
|
|
| 1849 |
|
| 1850 |
if allowed:
|
| 1851 |
# Use vocabulary-constrained decoding
|
| 1852 |
+
self._apply_whitelist_inplace(scores, allowed)
|
|
|
|
|
|
|
| 1853 |
elif self.genres_vocab:
|
| 1854 |
# Vocab is loaded but no valid continuation found
|
| 1855 |
# Force newline to end the field
|
| 1856 |
if self.newline_token:
|
|
|
|
| 1857 |
if self.debug:
|
| 1858 |
logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
|
| 1859 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
| 1860 |
else:
|
| 1861 |
# Fallback: no vocab loaded, use probability-based ending
|
| 1862 |
if self._should_end_text_field(scores):
|
| 1863 |
if self.newline_token:
|
| 1864 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
| 1865 |
self._transition_to_next_state()
|
|
|
|
| 1866 |
else:
|
| 1867 |
# Allow any token except newline if we don't have content yet
|
| 1868 |
if not self.accumulated_value.strip():
|
|
|
|
| 1881 |
self.user_field_token_queue = value_tokens
|
| 1882 |
self.current_user_field = "keyscale"
|
| 1883 |
# Inject first token
|
| 1884 |
+
self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
|
|
|
| 1885 |
return scores
|
| 1886 |
|
| 1887 |
# Check if current token sequence is complete (allows newline)
|
|
|
|
| 1889 |
if token_prefix in self.keyscale_prefix_tree and self.newline_token in self.keyscale_prefix_tree[token_prefix]:
|
| 1890 |
# Complete keyscale, allow newline
|
| 1891 |
if self.newline_token:
|
| 1892 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
| 1893 |
else:
|
| 1894 |
# Not complete, allow valid continuation tokens
|
| 1895 |
allowed = self._get_allowed_keyscale_tokens()
|
| 1896 |
if allowed:
|
| 1897 |
+
self._apply_whitelist_inplace(scores, allowed)
|
|
|
|
|
|
|
| 1898 |
else:
|
| 1899 |
# No valid tokens found - force newline to end field
|
| 1900 |
# This handles edge cases where keyscale format is unexpected
|
| 1901 |
if self.newline_token:
|
| 1902 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
| 1903 |
|
| 1904 |
elif self.state == FSMState.LANGUAGE_VALUE:
|
| 1905 |
# Language field: Use top-1 probability language (greedy selection)
|
|
|
|
| 1917 |
self.user_field_token_queue = value_tokens
|
| 1918 |
self.current_user_field = "language"
|
| 1919 |
# Inject first token
|
| 1920 |
+
self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
|
|
|
| 1921 |
return scores
|
| 1922 |
|
| 1923 |
# If we haven't started generating language yet (empty accumulated_token_ids),
|
|
|
|
| 1929 |
candidate_tokens = list(self.language_prefix_tree[empty_prefix])
|
| 1930 |
|
| 1931 |
if candidate_tokens:
|
| 1932 |
+
# Find the token with highest probability (top-1) among candidates
|
| 1933 |
+
# Use tensor indexing to get scores of candidate tokens directly
|
| 1934 |
+
candidate_indices = torch.tensor(candidate_tokens, device=scores.device, dtype=torch.long)
|
| 1935 |
+
candidate_scores = scores[0, candidate_indices]
|
|
|
|
|
|
|
| 1936 |
|
| 1937 |
# Get the highest probability token among candidates
|
| 1938 |
+
best_idx = torch.argmax(candidate_scores).item()
|
| 1939 |
+
top_token_id = candidate_tokens[best_idx]
|
| 1940 |
|
| 1941 |
+
# Only allow this top-1 token, block all others
|
| 1942 |
+
self._apply_whitelist_inplace(scores, [top_token_id])
|
|
|
|
| 1943 |
|
| 1944 |
if self.debug:
|
| 1945 |
top_token_text = self.tokenizer.decode([top_token_id])
|
|
|
|
| 1947 |
else:
|
| 1948 |
# No valid first tokens found - force newline
|
| 1949 |
if self.newline_token:
|
| 1950 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
| 1951 |
else:
|
| 1952 |
# Empty prefix not in tree - force newline
|
| 1953 |
if self.newline_token:
|
| 1954 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
| 1955 |
else:
|
| 1956 |
# We've started generating a language, continue with prefix tree constraints
|
| 1957 |
# Check if current token sequence is complete (allows newline)
|
|
|
|
| 1959 |
if token_prefix in self.language_prefix_tree and self.newline_token in self.language_prefix_tree[token_prefix]:
|
| 1960 |
# Complete language, allow newline
|
| 1961 |
if self.newline_token:
|
| 1962 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
| 1963 |
else:
|
| 1964 |
# Not complete, allow valid continuation tokens
|
| 1965 |
allowed = self._get_allowed_language_tokens()
|
| 1966 |
if allowed:
|
| 1967 |
+
self._apply_whitelist_inplace(scores, allowed)
|
|
|
|
|
|
|
| 1968 |
else:
|
| 1969 |
# No valid tokens found - force newline to end field
|
| 1970 |
if self.newline_token:
|
| 1971 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
| 1972 |
|
| 1973 |
elif self.state == FSMState.TIMESIG_VALUE:
|
| 1974 |
# Check if field is user-provided and we haven't started injecting yet
|
|
|
|
| 1981 |
self.user_field_token_queue = value_tokens
|
| 1982 |
self.current_user_field = "timesignature"
|
| 1983 |
# Inject first token
|
| 1984 |
+
self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
|
|
|
| 1985 |
return scores
|
| 1986 |
|
| 1987 |
# Check if current token sequence is complete (allows newline)
|
|
|
|
| 1989 |
if token_prefix in self.timesig_prefix_tree and self.newline_token in self.timesig_prefix_tree[token_prefix]:
|
| 1990 |
# Complete value, allow newline
|
| 1991 |
if self.newline_token:
|
| 1992 |
+
self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
|
| 1993 |
else:
|
| 1994 |
# Not complete, allow valid continuation tokens
|
| 1995 |
allowed = self._get_allowed_timesig_tokens()
|
| 1996 |
+
self._apply_whitelist_inplace(scores, allowed)
|
|
|
|
|
|
|
| 1997 |
|
| 1998 |
return scores
|
| 1999 |
|
acestep/gradio_ui/event.py
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/gradio_ui/events/__init__.py
CHANGED
|
@@ -254,48 +254,84 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 254 |
]
|
| 255 |
)
|
| 256 |
|
| 257 |
-
# Save buttons for
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
inputs=[
|
| 262 |
results_section[f"generated_audio_{btn_idx}"],
|
| 263 |
-
|
| 264 |
-
generation_section["captions"],
|
| 265 |
-
generation_section["lyrics"],
|
| 266 |
-
generation_section["vocal_language"],
|
| 267 |
-
generation_section["bpm"],
|
| 268 |
-
generation_section["key_scale"],
|
| 269 |
-
generation_section["time_signature"],
|
| 270 |
-
generation_section["audio_duration"],
|
| 271 |
-
generation_section["batch_size_input"],
|
| 272 |
-
generation_section["inference_steps"],
|
| 273 |
-
generation_section["guidance_scale"],
|
| 274 |
-
generation_section["seed"],
|
| 275 |
-
generation_section["random_seed_checkbox"],
|
| 276 |
-
generation_section["use_adg"],
|
| 277 |
-
generation_section["cfg_interval_start"],
|
| 278 |
-
generation_section["cfg_interval_end"],
|
| 279 |
-
generation_section["audio_format"],
|
| 280 |
-
generation_section["lm_temperature"],
|
| 281 |
-
generation_section["lm_cfg_scale"],
|
| 282 |
-
generation_section["lm_top_k"],
|
| 283 |
-
generation_section["lm_top_p"],
|
| 284 |
-
generation_section["lm_negative_prompt"],
|
| 285 |
-
generation_section["use_cot_caption"],
|
| 286 |
-
generation_section["use_cot_language"],
|
| 287 |
-
generation_section["audio_cover_strength"],
|
| 288 |
-
generation_section["think_checkbox"],
|
| 289 |
-
generation_section["text2music_audio_code_string"],
|
| 290 |
-
generation_section["repainting_start"],
|
| 291 |
-
generation_section["repainting_end"],
|
| 292 |
-
generation_section["track_name"],
|
| 293 |
-
generation_section["complete_track_classes"],
|
| 294 |
-
results_section["lm_metadata_state"],
|
| 295 |
],
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
# ========== Send to SRC Handlers ==========
|
| 300 |
for btn_idx in range(1, 9):
|
| 301 |
results_section[f"send_to_src_btn_{btn_idx}"].click(
|
|
@@ -331,10 +367,11 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 331 |
],
|
| 332 |
outputs=[results_section[f"score_display_{btn_idx}"], results_section["batch_queue"]]
|
| 333 |
)
|
| 334 |
-
|
|
|
|
| 335 |
# ========== Generation Handler ==========
|
| 336 |
generation_section["generate_btn"].click(
|
| 337 |
-
fn=
|
| 338 |
inputs=[
|
| 339 |
generation_section["captions"],
|
| 340 |
generation_section["lyrics"],
|
|
|
|
| 254 |
]
|
| 255 |
)
|
| 256 |
|
| 257 |
+
# Save buttons for all 8 audio outputs
|
| 258 |
+
download_existing_js = """(current_audio, batch_files) => {
|
| 259 |
+
// Debug: print what the input actually is
|
| 260 |
+
console.log("👉 [Debug] Current Audio Input:", current_audio);
|
| 261 |
+
|
| 262 |
+
// 1. Safety check
|
| 263 |
+
if (!current_audio) {
|
| 264 |
+
console.warn("⚠️ No audio selected or audio is empty.");
|
| 265 |
+
return;
|
| 266 |
+
}
|
| 267 |
+
if (!batch_files || !Array.isArray(batch_files)) {
|
| 268 |
+
console.warn("⚠️ Batch file list is empty/not ready.");
|
| 269 |
+
return;
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
// 2. Smartly extract path string
|
| 273 |
+
let pathString = "";
|
| 274 |
+
|
| 275 |
+
if (typeof current_audio === "string") {
|
| 276 |
+
// Case A: direct path string received
|
| 277 |
+
pathString = current_audio;
|
| 278 |
+
} else if (typeof current_audio === "object") {
|
| 279 |
+
// Case B: an object is received, try common properties
|
| 280 |
+
// Gradio file objects usually have path, url, or name
|
| 281 |
+
pathString = current_audio.path || current_audio.name || current_audio.url || "";
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
if (!pathString) {
|
| 285 |
+
console.error("❌ Error: Could not extract a valid path string from input.", current_audio);
|
| 286 |
+
return;
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
// 3. Extract Key (UUID)
|
| 290 |
+
// Path could be /tmp/.../uuid.mp3 or url like /file=.../uuid.mp3
|
| 291 |
+
let filename = pathString.split(/[\\\\/]/).pop(); // get the filename
|
| 292 |
+
let key = filename.split('.')[0]; // get UUID without extension
|
| 293 |
+
|
| 294 |
+
console.log(`🔑 Key extracted: ${key}`);
|
| 295 |
+
|
| 296 |
+
// 4. Find matching file(s) in the list
|
| 297 |
+
let targets = batch_files.filter(f => {
|
| 298 |
+
// Also extract names from batch_files objects
|
| 299 |
+
// f usually contains name (backend path) and orig_name (download name)
|
| 300 |
+
const fPath = f.name || f.path || "";
|
| 301 |
+
return fPath.includes(key);
|
| 302 |
+
});
|
| 303 |
+
|
| 304 |
+
if (targets.length === 0) {
|
| 305 |
+
console.warn("❌ No matching files found in batch list for key:", key);
|
| 306 |
+
alert("Batch list does not contain this file yet. Please wait for generation to finish.");
|
| 307 |
+
return;
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
// 5. Trigger download(s)
|
| 311 |
+
console.log(`🎯 Found ${targets.length} files to download.`);
|
| 312 |
+
targets.forEach((f, index) => {
|
| 313 |
+
setTimeout(() => {
|
| 314 |
+
const a = document.createElement('a');
|
| 315 |
+
// Prefer url (frontend-accessible link), otherwise try data
|
| 316 |
+
a.href = f.url || f.data;
|
| 317 |
+
a.download = f.orig_name || "download";
|
| 318 |
+
a.style.display = 'none';
|
| 319 |
+
document.body.appendChild(a);
|
| 320 |
+
a.click();
|
| 321 |
+
document.body.removeChild(a);
|
| 322 |
+
}, index * 1000); // 300ms interval to avoid browser blocking
|
| 323 |
+
});
|
| 324 |
+
}
|
| 325 |
+
"""
|
| 326 |
+
for btn_idx in range(1, 9):
|
| 327 |
+
results_section[f"save_btn_{btn_idx}"].click(
|
| 328 |
+
fn=None,
|
| 329 |
inputs=[
|
| 330 |
results_section[f"generated_audio_{btn_idx}"],
|
| 331 |
+
results_section["generated_audio_batch"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
],
|
| 333 |
+
js=download_existing_js # Run the above JS
|
| 334 |
+
)
|
|
|
|
| 335 |
# ========== Send to SRC Handlers ==========
|
| 336 |
for btn_idx in range(1, 9):
|
| 337 |
results_section[f"send_to_src_btn_{btn_idx}"].click(
|
|
|
|
| 367 |
],
|
| 368 |
outputs=[results_section[f"score_display_{btn_idx}"], results_section["batch_queue"]]
|
| 369 |
)
|
| 370 |
+
def generation_wrapper(*args):
|
| 371 |
+
yield from res_h.generate_with_batch_management(dit_handler, llm_handler, *args)
|
| 372 |
# ========== Generation Handler ==========
|
| 373 |
generation_section["generate_btn"].click(
|
| 374 |
+
fn=generation_wrapper,
|
| 375 |
inputs=[
|
| 376 |
generation_section["captions"],
|
| 377 |
generation_section["lyrics"],
|
acestep/gradio_ui/events/results_handlers.py
CHANGED
|
@@ -10,9 +10,123 @@ import tempfile
|
|
| 10 |
import shutil
|
| 11 |
import zipfile
|
| 12 |
import time as time_module
|
|
|
|
| 13 |
import gradio as gr
|
| 14 |
from loguru import logger
|
| 15 |
from acestep.gradio_ui.i18n import t
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def store_batch_in_queue(
|
|
@@ -66,99 +180,6 @@ def update_navigation_buttons(current_batch, total_batches):
|
|
| 66 |
can_go_next = current_batch < total_batches - 1
|
| 67 |
return can_go_previous, can_go_next
|
| 68 |
|
| 69 |
-
|
| 70 |
-
def save_audio_and_metadata(
|
| 71 |
-
audio_path, task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature, audio_duration,
|
| 72 |
-
batch_size_input, inference_steps, guidance_scale, seed, random_seed_checkbox,
|
| 73 |
-
use_adg, cfg_interval_start, cfg_interval_end, audio_format,
|
| 74 |
-
lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 75 |
-
use_cot_caption, use_cot_language, audio_cover_strength,
|
| 76 |
-
think_checkbox, text2music_audio_code_string, repainting_start, repainting_end,
|
| 77 |
-
track_name, complete_track_classes, lm_metadata
|
| 78 |
-
):
|
| 79 |
-
"""Save audio file and its metadata as a zip package"""
|
| 80 |
-
if audio_path is None:
|
| 81 |
-
gr.Warning(t("messages.no_audio_to_save"))
|
| 82 |
-
return None
|
| 83 |
-
|
| 84 |
-
try:
|
| 85 |
-
# Create metadata dictionary
|
| 86 |
-
metadata = {
|
| 87 |
-
"saved_at": datetime.datetime.now().isoformat(),
|
| 88 |
-
"task_type": task_type,
|
| 89 |
-
"caption": captions or "",
|
| 90 |
-
"lyrics": lyrics or "",
|
| 91 |
-
"vocal_language": vocal_language,
|
| 92 |
-
"bpm": bpm if bpm is not None else None,
|
| 93 |
-
"keyscale": key_scale or "",
|
| 94 |
-
"timesignature": time_signature or "",
|
| 95 |
-
"duration": audio_duration if audio_duration is not None else -1,
|
| 96 |
-
"batch_size": batch_size_input,
|
| 97 |
-
"inference_steps": inference_steps,
|
| 98 |
-
"guidance_scale": guidance_scale,
|
| 99 |
-
"seed": seed,
|
| 100 |
-
"random_seed": False, # Disable random seed for reproducibility
|
| 101 |
-
"use_adg": use_adg,
|
| 102 |
-
"cfg_interval_start": cfg_interval_start,
|
| 103 |
-
"cfg_interval_end": cfg_interval_end,
|
| 104 |
-
"audio_format": audio_format,
|
| 105 |
-
"lm_temperature": lm_temperature,
|
| 106 |
-
"lm_cfg_scale": lm_cfg_scale,
|
| 107 |
-
"lm_top_k": lm_top_k,
|
| 108 |
-
"lm_top_p": lm_top_p,
|
| 109 |
-
"lm_negative_prompt": lm_negative_prompt,
|
| 110 |
-
"use_cot_caption": use_cot_caption,
|
| 111 |
-
"use_cot_language": use_cot_language,
|
| 112 |
-
"audio_cover_strength": audio_cover_strength,
|
| 113 |
-
"think": think_checkbox,
|
| 114 |
-
"audio_codes": text2music_audio_code_string or "",
|
| 115 |
-
"repainting_start": repainting_start,
|
| 116 |
-
"repainting_end": repainting_end,
|
| 117 |
-
"track_name": track_name,
|
| 118 |
-
"complete_track_classes": complete_track_classes or [],
|
| 119 |
-
}
|
| 120 |
-
|
| 121 |
-
# Add LM-generated metadata if available
|
| 122 |
-
if lm_metadata:
|
| 123 |
-
metadata["lm_generated_metadata"] = lm_metadata
|
| 124 |
-
|
| 125 |
-
# Generate timestamp and base name
|
| 126 |
-
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 127 |
-
|
| 128 |
-
# Extract audio filename extension
|
| 129 |
-
audio_ext = os.path.splitext(audio_path)[1]
|
| 130 |
-
|
| 131 |
-
# Create temporary directory for packaging
|
| 132 |
-
temp_dir = tempfile.mkdtemp()
|
| 133 |
-
|
| 134 |
-
# Save JSON metadata
|
| 135 |
-
json_path = os.path.join(temp_dir, f"metadata_{timestamp}.json")
|
| 136 |
-
with open(json_path, 'w', encoding='utf-8') as f:
|
| 137 |
-
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
| 138 |
-
|
| 139 |
-
# Copy audio file
|
| 140 |
-
audio_copy_path = os.path.join(temp_dir, f"audio_{timestamp}{audio_ext}")
|
| 141 |
-
shutil.copy2(audio_path, audio_copy_path)
|
| 142 |
-
|
| 143 |
-
# Create zip file
|
| 144 |
-
zip_path = os.path.join(tempfile.gettempdir(), f"music_package_{timestamp}.zip")
|
| 145 |
-
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 146 |
-
zipf.write(audio_copy_path, os.path.basename(audio_copy_path))
|
| 147 |
-
zipf.write(json_path, os.path.basename(json_path))
|
| 148 |
-
|
| 149 |
-
# Clean up temp directory
|
| 150 |
-
shutil.rmtree(temp_dir)
|
| 151 |
-
|
| 152 |
-
gr.Info(t("messages.save_success", filename=os.path.basename(zip_path)))
|
| 153 |
-
return zip_path
|
| 154 |
-
|
| 155 |
-
except Exception as e:
|
| 156 |
-
gr.Warning(t("messages.save_failed", error=str(e)))
|
| 157 |
-
import traceback
|
| 158 |
-
traceback.print_exc()
|
| 159 |
-
return None
|
| 160 |
-
|
| 161 |
-
|
| 162 |
def send_audio_to_src_with_metadata(audio_file, lm_metadata):
|
| 163 |
"""Send generated audio file to src_audio input and populate metadata fields
|
| 164 |
|
|
@@ -254,366 +275,209 @@ def generate_with_progress(
|
|
| 254 |
auto_score,
|
| 255 |
score_scale,
|
| 256 |
lm_batch_chunk_size,
|
| 257 |
-
progress=gr.Progress(track_tqdm=True)
|
| 258 |
):
|
| 259 |
"""Generate audio with progress tracking"""
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
think_checkbox
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
)
|
| 274 |
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
chunk_size = chunk_end - chunk_start
|
| 330 |
-
chunk_seeds = actual_seed_list[chunk_start:chunk_end]
|
| 331 |
-
|
| 332 |
-
logger.info(f"Generating LM batch chunk {chunk_idx+1}/{num_chunks} (size: {chunk_size}, seeds: {chunk_seeds})...")
|
| 333 |
-
|
| 334 |
-
# Generate batch
|
| 335 |
-
metadata_list, audio_codes_list, status = llm_handler.generate_with_stop_condition_batch(
|
| 336 |
-
caption=captions or "",
|
| 337 |
-
lyrics=lyrics or "",
|
| 338 |
-
batch_size=chunk_size,
|
| 339 |
-
infer_type="llm_dit",
|
| 340 |
-
temperature=lm_temperature,
|
| 341 |
-
cfg_scale=lm_cfg_scale,
|
| 342 |
-
negative_prompt=lm_negative_prompt,
|
| 343 |
-
top_k=top_k_value,
|
| 344 |
-
top_p=top_p_value,
|
| 345 |
-
user_metadata=user_metadata_to_pass,
|
| 346 |
-
use_cot_caption=use_cot_caption,
|
| 347 |
-
use_cot_language=use_cot_language,
|
| 348 |
-
is_format_caption=is_format_caption,
|
| 349 |
-
constrained_decoding_debug=constrained_decoding_debug,
|
| 350 |
-
seeds=chunk_seeds,
|
| 351 |
-
)
|
| 352 |
-
|
| 353 |
-
all_metadata_list.extend(metadata_list)
|
| 354 |
-
all_audio_codes_list.extend(audio_codes_list)
|
| 355 |
-
|
| 356 |
-
# Use first metadata as representative (all are same)
|
| 357 |
-
lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
|
| 358 |
|
| 359 |
-
|
| 360 |
-
|
| 361 |
|
| 362 |
-
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
pass
|
| 389 |
-
else:
|
| 390 |
-
# SEQUENTIAL LM GENERATION (current behavior, when allow_lm_batch is False)
|
| 391 |
-
# Phase 1: Generate CoT metadata
|
| 392 |
-
phase1_start = time_module.time()
|
| 393 |
-
metadata, _, status = llm_handler.generate_with_stop_condition(
|
| 394 |
-
caption=captions or "",
|
| 395 |
-
lyrics=lyrics or "",
|
| 396 |
-
infer_type="dit", # Only generate metadata in Phase 1
|
| 397 |
-
temperature=lm_temperature,
|
| 398 |
-
cfg_scale=lm_cfg_scale,
|
| 399 |
-
negative_prompt=lm_negative_prompt,
|
| 400 |
-
top_k=top_k_value,
|
| 401 |
-
top_p=top_p_value,
|
| 402 |
-
user_metadata=user_metadata_to_pass,
|
| 403 |
-
use_cot_caption=use_cot_caption,
|
| 404 |
-
use_cot_language=use_cot_language,
|
| 405 |
-
is_format_caption=is_format_caption,
|
| 406 |
-
constrained_decoding_debug=constrained_decoding_debug,
|
| 407 |
-
)
|
| 408 |
-
lm_phase1_time = time_module.time() - phase1_start
|
| 409 |
-
logger.info(f"LM Phase 1 (CoT) completed in {lm_phase1_time:.2f}s")
|
| 410 |
-
|
| 411 |
-
# Phase 2: Generate audio codes
|
| 412 |
-
phase2_start = time_module.time()
|
| 413 |
-
metadata, audio_codes, status = llm_handler.generate_with_stop_condition(
|
| 414 |
-
caption=captions or "",
|
| 415 |
-
lyrics=lyrics or "",
|
| 416 |
-
infer_type="llm_dit", # Generate both metadata and codes
|
| 417 |
-
temperature=lm_temperature,
|
| 418 |
-
cfg_scale=lm_cfg_scale,
|
| 419 |
-
negative_prompt=lm_negative_prompt,
|
| 420 |
-
top_k=top_k_value,
|
| 421 |
-
top_p=top_p_value,
|
| 422 |
-
user_metadata=user_metadata_to_pass,
|
| 423 |
-
use_cot_caption=use_cot_caption,
|
| 424 |
-
use_cot_language=use_cot_language,
|
| 425 |
-
is_format_caption=is_format_caption,
|
| 426 |
-
constrained_decoding_debug=constrained_decoding_debug,
|
| 427 |
)
|
| 428 |
-
lm_phase2_time = time_module.time() - phase2_start
|
| 429 |
-
logger.info(f"LM Phase 2 (Codes) completed in {lm_phase2_time:.2f}s")
|
| 430 |
-
|
| 431 |
-
# Store LM-generated metadata and audio codes for display
|
| 432 |
-
lm_generated_metadata = metadata
|
| 433 |
-
if audio_codes:
|
| 434 |
-
audio_code_string_to_use = audio_codes
|
| 435 |
-
lm_generated_audio_codes = audio_codes
|
| 436 |
-
# Update metadata fields only if they are empty/None (user didn't provide them)
|
| 437 |
-
if bpm is None and metadata.get('bpm'):
|
| 438 |
-
bpm_value = metadata.get('bpm')
|
| 439 |
-
if bpm_value != "N/A" and bpm_value != "":
|
| 440 |
-
try:
|
| 441 |
-
bpm = int(bpm_value)
|
| 442 |
-
except:
|
| 443 |
-
pass
|
| 444 |
-
if not key_scale and metadata.get('keyscale'):
|
| 445 |
-
key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
|
| 446 |
-
if key_scale_value != "N/A":
|
| 447 |
-
key_scale = key_scale_value
|
| 448 |
-
if not time_signature and metadata.get('timesignature'):
|
| 449 |
-
time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
|
| 450 |
-
if time_signature_value != "N/A":
|
| 451 |
-
time_signature = time_signature_value
|
| 452 |
-
if audio_duration is None or audio_duration <= 0:
|
| 453 |
-
audio_duration_value = metadata.get('duration', -1)
|
| 454 |
-
if audio_duration_value != "N/A" and audio_duration_value != "":
|
| 455 |
-
try:
|
| 456 |
-
audio_duration = float(audio_duration_value)
|
| 457 |
-
except:
|
| 458 |
-
pass
|
| 459 |
-
|
| 460 |
-
# Call generate_music and get results
|
| 461 |
-
result = dit_handler.generate_music(
|
| 462 |
-
captions=captions, lyrics=lyrics, bpm=bpm, key_scale=key_scale,
|
| 463 |
-
time_signature=time_signature, vocal_language=vocal_language,
|
| 464 |
-
inference_steps=inference_steps, guidance_scale=guidance_scale,
|
| 465 |
-
use_random_seed=random_seed_checkbox, seed=seed,
|
| 466 |
-
reference_audio=reference_audio, audio_duration=audio_duration,
|
| 467 |
-
batch_size=batch_size_input, src_audio=src_audio,
|
| 468 |
-
audio_code_string=audio_code_string_to_use,
|
| 469 |
-
repainting_start=repainting_start, repainting_end=repainting_end,
|
| 470 |
-
instruction=instruction_display_gen, audio_cover_strength=audio_cover_strength,
|
| 471 |
-
task_type=task_type, use_adg=use_adg,
|
| 472 |
-
cfg_interval_start=cfg_interval_start, cfg_interval_end=cfg_interval_end,
|
| 473 |
-
audio_format=audio_format, lm_temperature=lm_temperature,
|
| 474 |
-
progress=progress
|
| 475 |
-
)
|
| 476 |
-
|
| 477 |
-
# Extract results
|
| 478 |
-
first_audio, second_audio, all_audio_paths, generation_info, status_message, seed_value_for_ui, \
|
| 479 |
-
align_score_1, align_text_1, align_plot_1, align_score_2, align_text_2, align_plot_2 = result
|
| 480 |
-
|
| 481 |
-
# Extract LM timing from status if available and prepend to generation_info
|
| 482 |
-
if status:
|
| 483 |
-
import re
|
| 484 |
-
# Try to extract timing info from status using regex
|
| 485 |
-
# Expected format: "Phase1: X.XXs" and "Phase2: X.XXs"
|
| 486 |
-
phase1_match = re.search(r'Phase1:\s*([\d.]+)s', status)
|
| 487 |
-
phase2_match = re.search(r'Phase2:\s*([\d.]+)s', status)
|
| 488 |
-
|
| 489 |
-
if phase1_match or phase2_match:
|
| 490 |
-
lm_timing_section = "\n\n**🤖 LM Timing:**\n"
|
| 491 |
-
lm_total = 0.0
|
| 492 |
-
if phase1_match:
|
| 493 |
-
phase1_time = float(phase1_match.group(1))
|
| 494 |
-
lm_timing_section += f" - Phase 1 (CoT Metadata): {phase1_time:.2f}s\n"
|
| 495 |
-
lm_total += phase1_time
|
| 496 |
-
if phase2_match:
|
| 497 |
-
phase2_time = float(phase2_match.group(1))
|
| 498 |
-
lm_timing_section += f" - Phase 2 (Audio Codes): {phase2_time:.2f}s\n"
|
| 499 |
-
lm_total += phase2_time
|
| 500 |
-
if lm_total > 0:
|
| 501 |
-
lm_timing_section += f" - Total LM Time: {lm_total:.2f}s\n"
|
| 502 |
-
generation_info = lm_timing_section + "\n" + generation_info
|
| 503 |
-
|
| 504 |
-
# Append LM-generated metadata to generation_info if available
|
| 505 |
-
if lm_generated_metadata:
|
| 506 |
-
metadata_lines = []
|
| 507 |
-
if lm_generated_metadata.get('bpm'):
|
| 508 |
-
metadata_lines.append(f"- **BPM:** {lm_generated_metadata['bpm']}")
|
| 509 |
-
if lm_generated_metadata.get('caption'):
|
| 510 |
-
metadata_lines.append(f"- **User Query Rewritten Caption:** {lm_generated_metadata['caption']}")
|
| 511 |
-
if lm_generated_metadata.get('duration'):
|
| 512 |
-
metadata_lines.append(f"- **Duration:** {lm_generated_metadata['duration']} seconds")
|
| 513 |
-
if lm_generated_metadata.get('keyscale'):
|
| 514 |
-
metadata_lines.append(f"- **KeyScale:** {lm_generated_metadata['keyscale']}")
|
| 515 |
-
if lm_generated_metadata.get('language'):
|
| 516 |
-
metadata_lines.append(f"- **Language:** {lm_generated_metadata['language']}")
|
| 517 |
-
if lm_generated_metadata.get('timesignature'):
|
| 518 |
-
metadata_lines.append(f"- **Time Signature:** {lm_generated_metadata['timesignature']}")
|
| 519 |
-
|
| 520 |
-
if metadata_lines:
|
| 521 |
-
metadata_section = "\n\n**🤖 LM-Generated Metadata:**\n" + "\n\n".join(metadata_lines)
|
| 522 |
-
generation_info = metadata_section + "\n\n" + generation_info
|
| 523 |
-
|
| 524 |
-
# Update audio codes in UI if LM generated them
|
| 525 |
-
codes_outputs = [""] * 8 # Codes for 8 components
|
| 526 |
-
if should_use_lm_batch and lm_generated_audio_codes_list:
|
| 527 |
-
# Batch mode: update individual codes inputs
|
| 528 |
-
for idx in range(min(len(lm_generated_audio_codes_list), 8)):
|
| 529 |
-
codes_outputs[idx] = lm_generated_audio_codes_list[idx]
|
| 530 |
-
# For single codes input, show first one
|
| 531 |
-
updated_audio_codes = lm_generated_audio_codes_list[0] if lm_generated_audio_codes_list else text2music_audio_code_string
|
| 532 |
-
else:
|
| 533 |
-
# Single mode: update main codes input
|
| 534 |
-
updated_audio_codes = lm_generated_audio_codes if lm_generated_audio_codes else text2music_audio_code_string
|
| 535 |
-
|
| 536 |
-
# AUTO-SCORING
|
| 537 |
-
score_displays = [""] * 8 # Scores for 8 components
|
| 538 |
-
if auto_score and all_audio_paths:
|
| 539 |
-
logger.info(f"Auto-scoring enabled, calculating quality scores for {batch_size_input} generated audios...")
|
| 540 |
-
|
| 541 |
-
# Determine which audio codes to use for scoring
|
| 542 |
-
if should_use_lm_batch and lm_generated_audio_codes_list:
|
| 543 |
-
codes_list = lm_generated_audio_codes_list
|
| 544 |
-
elif audio_code_string_to_use and isinstance(audio_code_string_to_use, list):
|
| 545 |
-
codes_list = audio_code_string_to_use
|
| 546 |
else:
|
| 547 |
-
#
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
for idx in range(min(len(all_audio_paths), 8)):
|
| 574 |
-
audio_outputs[idx] = all_audio_paths[idx]
|
| 575 |
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
audio_outputs[3], # generated_audio_4
|
| 581 |
-
audio_outputs[4], # generated_audio_5
|
| 582 |
-
audio_outputs[5], # generated_audio_6
|
| 583 |
-
audio_outputs[6], # generated_audio_7
|
| 584 |
-
audio_outputs[7], # generated_audio_8
|
| 585 |
-
all_audio_paths, # generated_audio_batch
|
| 586 |
generation_info,
|
| 587 |
-
|
| 588 |
seed_value_for_ui,
|
| 589 |
-
align_score_1,
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
score_displays[2], # score_display_3
|
| 598 |
-
score_displays[3], # score_display_4
|
| 599 |
-
score_displays[4], # score_display_5
|
| 600 |
-
score_displays[5], # score_display_6
|
| 601 |
-
score_displays[6], # score_display_7
|
| 602 |
-
score_displays[7], # score_display_8
|
| 603 |
-
updated_audio_codes, # Update main audio codes in UI
|
| 604 |
-
codes_outputs[0], # text2music_audio_code_string_1
|
| 605 |
-
codes_outputs[1], # text2music_audio_code_string_2
|
| 606 |
-
codes_outputs[2], # text2music_audio_code_string_3
|
| 607 |
-
codes_outputs[3], # text2music_audio_code_string_4
|
| 608 |
-
codes_outputs[4], # text2music_audio_code_string_5
|
| 609 |
-
codes_outputs[5], # text2music_audio_code_string_6
|
| 610 |
-
codes_outputs[6], # text2music_audio_code_string_7
|
| 611 |
-
codes_outputs[7], # text2music_audio_code_string_8
|
| 612 |
-
lm_generated_metadata, # Store metadata for "Send to src audio" buttons
|
| 613 |
-
is_format_caption, # Keep is_format_caption unchanged
|
| 614 |
)
|
| 615 |
|
| 616 |
|
|
|
|
| 617 |
def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale):
|
| 618 |
"""
|
| 619 |
Calculate PMI-based quality score for generated audio.
|
|
@@ -756,7 +620,9 @@ def calculate_score_handler_with_selection(llm_handler, sample_idx, score_scale,
|
|
| 756 |
if stored_allow_lm_batch and isinstance(stored_codes, list):
|
| 757 |
# Batch mode: use specific sample's codes
|
| 758 |
if 0 <= sample_idx - 1 < len(stored_codes):
|
| 759 |
-
|
|
|
|
|
|
|
| 760 |
else:
|
| 761 |
# Single mode: all samples use same codes
|
| 762 |
audio_codes_str = stored_codes if isinstance(stored_codes, str) else ""
|
|
@@ -868,7 +734,7 @@ def generate_with_batch_management(
|
|
| 868 |
Wrapper for generate_with_progress that adds batch queue management
|
| 869 |
"""
|
| 870 |
# Call the original generation function
|
| 871 |
-
|
| 872 |
dit_handler, llm_handler,
|
| 873 |
captions, lyrics, bpm, key_scale, time_signature, vocal_language,
|
| 874 |
inference_steps, guidance_scale, random_seed_checkbox, seed,
|
|
@@ -885,23 +751,41 @@ def generate_with_batch_management(
|
|
| 885 |
lm_batch_chunk_size,
|
| 886 |
progress
|
| 887 |
)
|
| 888 |
-
|
| 889 |
-
|
| 890 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 891 |
generation_info = result[9]
|
| 892 |
seed_value_for_ui = result[11]
|
| 893 |
-
lm_generated_metadata = result[
|
| 894 |
|
| 895 |
# Extract codes
|
| 896 |
generated_codes_single = result[26]
|
| 897 |
generated_codes_batch = [result[27], result[28], result[29], result[30], result[31], result[32], result[33], result[34]]
|
| 898 |
-
|
| 899 |
# Determine which codes to store based on mode
|
| 900 |
if allow_lm_batch and batch_size_input >= 2:
|
| 901 |
codes_to_store = generated_codes_batch[:int(batch_size_input)]
|
| 902 |
else:
|
| 903 |
codes_to_store = generated_codes_single
|
| 904 |
-
|
| 905 |
# Save parameters for history
|
| 906 |
saved_params = {
|
| 907 |
"captions": captions,
|
|
@@ -947,6 +831,7 @@ def generate_with_batch_management(
|
|
| 947 |
}
|
| 948 |
|
| 949 |
# Next batch parameters (with cleared codes & random seed)
|
|
|
|
| 950 |
next_params = saved_params.copy()
|
| 951 |
next_params["text2music_audio_code_string"] = ""
|
| 952 |
next_params["random_seed_checkbox"] = True
|
|
@@ -979,9 +864,10 @@ def generate_with_batch_management(
|
|
| 979 |
next_batch_status_text = ""
|
| 980 |
if autogen_checkbox:
|
| 981 |
next_batch_status_text = t("messages.autogen_enabled")
|
| 982 |
-
|
| 983 |
-
#
|
| 984 |
-
|
|
|
|
| 985 |
current_batch_index,
|
| 986 |
total_batches,
|
| 987 |
batch_queue,
|
|
@@ -1097,7 +983,8 @@ def generate_next_batch_background(
|
|
| 1097 |
params.setdefault("complete_track_classes", [])
|
| 1098 |
|
| 1099 |
# Call generate_with_progress with the saved parameters
|
| 1100 |
-
|
|
|
|
| 1101 |
dit_handler,
|
| 1102 |
llm_handler,
|
| 1103 |
captions=params.get("captions"),
|
|
@@ -1142,15 +1029,20 @@ def generate_next_batch_background(
|
|
| 1142 |
progress=progress
|
| 1143 |
)
|
| 1144 |
|
| 1145 |
-
#
|
| 1146 |
-
|
| 1147 |
-
|
| 1148 |
-
|
| 1149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1150 |
|
| 1151 |
# Extract codes
|
| 1152 |
-
generated_codes_single =
|
| 1153 |
-
generated_codes_batch = [
|
| 1154 |
|
| 1155 |
# Determine which codes to store
|
| 1156 |
batch_size = params.get("batch_size_input", 2)
|
|
@@ -1240,8 +1132,9 @@ def navigate_to_previous_batch(current_batch_index, batch_queue):
|
|
| 1240 |
|
| 1241 |
# Prepare audio outputs (up to 8)
|
| 1242 |
audio_outputs = [None] * 8
|
| 1243 |
-
for
|
| 1244 |
-
|
|
|
|
| 1245 |
|
| 1246 |
# Update batch indicator
|
| 1247 |
total_batches = len(batch_queue)
|
|
@@ -1286,8 +1179,9 @@ def navigate_to_next_batch(autogen_enabled, current_batch_index, total_batches,
|
|
| 1286 |
|
| 1287 |
# Prepare audio outputs (up to 8)
|
| 1288 |
audio_outputs = [None] * 8
|
| 1289 |
-
for
|
| 1290 |
-
|
|
|
|
| 1291 |
|
| 1292 |
# Update batch indicator
|
| 1293 |
batch_indicator_text = update_batch_indicator(new_batch_index, total_batches)
|
|
|
|
| 10 |
import shutil
|
| 11 |
import zipfile
|
| 12 |
import time as time_module
|
| 13 |
+
from typing import Dict, Any, Optional
|
| 14 |
import gradio as gr
|
| 15 |
from loguru import logger
|
| 16 |
from acestep.gradio_ui.i18n import t
|
| 17 |
+
from acestep.inference import generate_music, GenerationParams, GenerationConfig
|
| 18 |
+
from acestep.audio_utils import save_audio
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _build_generation_info(
|
| 22 |
+
lm_metadata: Optional[Dict[str, Any]],
|
| 23 |
+
time_costs: Dict[str, float],
|
| 24 |
+
seed_value: str,
|
| 25 |
+
inference_steps: int,
|
| 26 |
+
num_audios: int,
|
| 27 |
+
) -> str:
|
| 28 |
+
"""Build generation info string from result data.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
lm_metadata: LM-generated metadata dictionary
|
| 32 |
+
time_costs: Unified time costs dictionary
|
| 33 |
+
seed_value: Seed value string
|
| 34 |
+
inference_steps: Number of inference steps
|
| 35 |
+
num_audios: Number of generated audios
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Formatted generation info string
|
| 39 |
+
"""
|
| 40 |
+
info_parts = []
|
| 41 |
+
|
| 42 |
+
# Part 1: LM-generated metadata (if available)
|
| 43 |
+
if lm_metadata:
|
| 44 |
+
metadata_lines = []
|
| 45 |
+
if lm_metadata.get('bpm'):
|
| 46 |
+
metadata_lines.append(f"- **BPM:** {lm_metadata['bpm']}")
|
| 47 |
+
if lm_metadata.get('caption'):
|
| 48 |
+
metadata_lines.append(f"- **Refined Caption:** {lm_metadata['caption']}")
|
| 49 |
+
if lm_metadata.get('lyrics'):
|
| 50 |
+
metadata_lines.append(f"- **Refined Lyrics:** {lm_metadata['lyrics']}")
|
| 51 |
+
if lm_metadata.get('duration'):
|
| 52 |
+
metadata_lines.append(f"- **Duration:** {lm_metadata['duration']} seconds")
|
| 53 |
+
if lm_metadata.get('keyscale'):
|
| 54 |
+
metadata_lines.append(f"- **Key Scale:** {lm_metadata['keyscale']}")
|
| 55 |
+
if lm_metadata.get('language'):
|
| 56 |
+
metadata_lines.append(f"- **Language:** {lm_metadata['language']}")
|
| 57 |
+
if lm_metadata.get('timesignature'):
|
| 58 |
+
metadata_lines.append(f"- **Time Signature:** {lm_metadata['timesignature']}")
|
| 59 |
+
|
| 60 |
+
if metadata_lines:
|
| 61 |
+
metadata_section = "**🤖 LM-Generated Metadata:**\n" + "\n".join(metadata_lines)
|
| 62 |
+
info_parts.append(metadata_section)
|
| 63 |
+
|
| 64 |
+
# Part 2: Time costs (formatted and beautified)
|
| 65 |
+
if time_costs:
|
| 66 |
+
time_lines = []
|
| 67 |
+
|
| 68 |
+
# LM time costs
|
| 69 |
+
lm_phase1 = time_costs.get('lm_phase1_time', 0.0)
|
| 70 |
+
lm_phase2 = time_costs.get('lm_phase2_time', 0.0)
|
| 71 |
+
lm_total = time_costs.get('lm_total_time', 0.0)
|
| 72 |
+
|
| 73 |
+
if lm_total > 0:
|
| 74 |
+
time_lines.append("**🧠 LM Time:**")
|
| 75 |
+
if lm_phase1 > 0:
|
| 76 |
+
time_lines.append(f" - Phase 1 (CoT): {lm_phase1:.2f}s")
|
| 77 |
+
if lm_phase2 > 0:
|
| 78 |
+
time_lines.append(f" - Phase 2 (Codes): {lm_phase2:.2f}s")
|
| 79 |
+
time_lines.append(f" - Total: {lm_total:.2f}s")
|
| 80 |
+
|
| 81 |
+
# DiT time costs
|
| 82 |
+
dit_encoder = time_costs.get('dit_encoder_time_cost', 0.0)
|
| 83 |
+
dit_model = time_costs.get('dit_model_time_cost', 0.0)
|
| 84 |
+
dit_vae_decode = time_costs.get('dit_vae_decode_time_cost', 0.0)
|
| 85 |
+
dit_offload = time_costs.get('dit_offload_time_cost', 0.0)
|
| 86 |
+
dit_total = time_costs.get('dit_total_time_cost', 0.0)
|
| 87 |
+
if dit_total > 0:
|
| 88 |
+
time_lines.append("\n**🎵 DiT Time:**")
|
| 89 |
+
if dit_encoder > 0:
|
| 90 |
+
time_lines.append(f" - Encoder: {dit_encoder:.2f}s")
|
| 91 |
+
if dit_model > 0:
|
| 92 |
+
time_lines.append(f" - Model: {dit_model:.2f}s")
|
| 93 |
+
if dit_vae_decode > 0:
|
| 94 |
+
time_lines.append(f" - VAE Decode: {dit_vae_decode:.2f}s")
|
| 95 |
+
if dit_offload > 0:
|
| 96 |
+
time_lines.append(f" - Offload: {dit_offload:.2f}s")
|
| 97 |
+
time_lines.append(f" - Total: {dit_total:.2f}s")
|
| 98 |
+
|
| 99 |
+
# Post-processing time costs
|
| 100 |
+
audio_conversion_time = time_costs.get('audio_conversion_time', 0.0)
|
| 101 |
+
auto_score_time = time_costs.get('auto_score_time', 0.0)
|
| 102 |
+
|
| 103 |
+
if audio_conversion_time > 0 or auto_score_time > 0:
|
| 104 |
+
time_lines.append("\n**🔧 Post-processing Time:**")
|
| 105 |
+
if audio_conversion_time > 0:
|
| 106 |
+
time_lines.append(f" - Audio Conversion: {audio_conversion_time:.2f}s")
|
| 107 |
+
if auto_score_time > 0:
|
| 108 |
+
time_lines.append(f" - Auto Score: {auto_score_time:.2f}s")
|
| 109 |
+
|
| 110 |
+
# Pipeline total
|
| 111 |
+
pipeline_total = time_costs.get('pipeline_total_time', 0.0)
|
| 112 |
+
if pipeline_total > 0:
|
| 113 |
+
time_lines.append(f"\n**⏱️ Pipeline Total: {pipeline_total:.2f}s**")
|
| 114 |
+
|
| 115 |
+
if time_lines:
|
| 116 |
+
time_section = "\n".join(time_lines)
|
| 117 |
+
info_parts.append(time_section)
|
| 118 |
+
|
| 119 |
+
# Part 3: Generation summary
|
| 120 |
+
summary_lines = [
|
| 121 |
+
"**🎵 Generation Complete**",
|
| 122 |
+
f" - **Seeds:** {seed_value}",
|
| 123 |
+
f" - **Steps:** {inference_steps}",
|
| 124 |
+
f" - **Audio Count:** {num_audios} audio(s)",
|
| 125 |
+
]
|
| 126 |
+
info_parts.append("\n".join(summary_lines))
|
| 127 |
+
|
| 128 |
+
# Combine all parts
|
| 129 |
+
return "\n\n".join(info_parts)
|
| 130 |
|
| 131 |
|
| 132 |
def store_batch_in_queue(
|
|
|
|
| 180 |
can_go_next = current_batch < total_batches - 1
|
| 181 |
return can_go_previous, can_go_next
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
def send_audio_to_src_with_metadata(audio_file, lm_metadata):
|
| 184 |
"""Send generated audio file to src_audio input and populate metadata fields
|
| 185 |
|
|
|
|
| 275 |
auto_score,
|
| 276 |
score_scale,
|
| 277 |
lm_batch_chunk_size,
|
| 278 |
+
progress=gr.Progress(track_tqdm=True),
|
| 279 |
):
|
| 280 |
"""Generate audio with progress tracking"""
|
| 281 |
+
|
| 282 |
+
# step 1: prepare inputs
|
| 283 |
+
# generate_music, GenerationParams, GenerationConfig
|
| 284 |
+
gen_params = GenerationParams(
|
| 285 |
+
task_type=task_type,
|
| 286 |
+
instruction=instruction_display_gen,
|
| 287 |
+
reference_audio=reference_audio,
|
| 288 |
+
src_audio=src_audio,
|
| 289 |
+
audio_codes=text2music_audio_code_string if not think_checkbox else "",
|
| 290 |
+
caption=captions or "",
|
| 291 |
+
lyrics=lyrics or "",
|
| 292 |
+
instrumental=False,
|
| 293 |
+
vocal_language=vocal_language,
|
| 294 |
+
bpm=bpm,
|
| 295 |
+
keyscale=key_scale,
|
| 296 |
+
timesignature=time_signature,
|
| 297 |
+
duration=audio_duration,
|
| 298 |
+
inference_steps=inference_steps,
|
| 299 |
+
guidance_scale=guidance_scale,
|
| 300 |
+
use_adg=use_adg,
|
| 301 |
+
cfg_interval_start=cfg_interval_start,
|
| 302 |
+
cfg_interval_end=cfg_interval_end,
|
| 303 |
+
repainting_start=repainting_start,
|
| 304 |
+
repainting_end=repainting_end,
|
| 305 |
+
audio_cover_strength=audio_cover_strength,
|
| 306 |
+
thinking=think_checkbox,
|
| 307 |
+
lm_temperature=lm_temperature,
|
| 308 |
+
lm_cfg_scale=lm_cfg_scale,
|
| 309 |
+
lm_top_k=lm_top_k,
|
| 310 |
+
lm_top_p=lm_top_p,
|
| 311 |
+
lm_negative_prompt=lm_negative_prompt,
|
| 312 |
+
use_cot_metas=use_cot_metas,
|
| 313 |
+
use_cot_caption=use_cot_caption,
|
| 314 |
+
use_cot_language=use_cot_language,
|
| 315 |
+
use_constrained_decoding=True,
|
| 316 |
+
)
|
| 317 |
+
# seed string to list
|
| 318 |
+
if isinstance(seed, str) and seed.strip():
|
| 319 |
+
if "," in seed:
|
| 320 |
+
seed_list = [int(s.strip()) for s in seed.split(",")]
|
| 321 |
+
else:
|
| 322 |
+
seed_list = [int(seed.strip())]
|
| 323 |
+
else:
|
| 324 |
+
seed_list = None
|
| 325 |
+
gen_config = GenerationConfig(
|
| 326 |
+
batch_size=batch_size_input,
|
| 327 |
+
allow_lm_batch=allow_lm_batch,
|
| 328 |
+
use_random_seed=random_seed_checkbox,
|
| 329 |
+
seeds=seed_list,
|
| 330 |
+
lm_batch_chunk_size=lm_batch_chunk_size,
|
| 331 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 332 |
+
audio_format=audio_format,
|
| 333 |
+
)
|
| 334 |
+
result = generate_music(
|
| 335 |
+
dit_handler,
|
| 336 |
+
llm_handler,
|
| 337 |
+
params=gen_params,
|
| 338 |
+
config=gen_config,
|
| 339 |
+
progress=progress,
|
| 340 |
)
|
| 341 |
|
| 342 |
+
audio_outputs = [None] * 8
|
| 343 |
+
all_audio_paths = []
|
| 344 |
+
final_codes_list = [""] * 8
|
| 345 |
+
final_scores_list = [""] * 8
|
| 346 |
+
|
| 347 |
+
# Build generation_info from result data
|
| 348 |
+
status_message = result.status_message
|
| 349 |
+
seed_value_for_ui = result.extra_outputs.get("seed_value", "")
|
| 350 |
+
lm_generated_metadata = result.extra_outputs.get("lm_metadata", {})
|
| 351 |
+
time_costs = result.extra_outputs.get("time_costs", {}).copy()
|
| 352 |
+
|
| 353 |
+
# Initialize post-processing timing
|
| 354 |
+
audio_conversion_start_time = time_module.time()
|
| 355 |
+
total_auto_score_time = 0.0
|
| 356 |
+
|
| 357 |
+
align_score_1 = ""
|
| 358 |
+
align_text_1 = ""
|
| 359 |
+
align_plot_1 = None
|
| 360 |
+
align_score_2 = ""
|
| 361 |
+
align_text_2 = ""
|
| 362 |
+
align_plot_2 = None
|
| 363 |
+
updated_audio_codes = text2music_audio_code_string if not think_checkbox else ""
|
| 364 |
+
|
| 365 |
+
# Build initial generation_info (will be updated with post-processing times at the end)
|
| 366 |
+
generation_info = _build_generation_info(
|
| 367 |
+
lm_metadata=lm_generated_metadata,
|
| 368 |
+
time_costs=time_costs,
|
| 369 |
+
seed_value=seed_value_for_ui,
|
| 370 |
+
inference_steps=inference_steps,
|
| 371 |
+
num_audios=len(result.audios) if result.success else 0,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
if not result.success:
|
| 375 |
+
yield (None,) * 8 + (None, generation_info, result.status_message) + (gr.skip(),) * 26
|
| 376 |
+
return
|
| 377 |
+
|
| 378 |
+
audios = result.audios
|
| 379 |
+
progress(0.99, "Converting audio to mp3...")
|
| 380 |
+
for i in range(8):
|
| 381 |
+
if i < len(audios):
|
| 382 |
+
key = audios[i]["key"]
|
| 383 |
+
audio_tensor = audios[i]["tensor"]
|
| 384 |
+
sample_rate = audios[i]["sample_rate"]
|
| 385 |
+
audio_params = audios[i]["params"]
|
| 386 |
+
temp_dir = tempfile.mkdtemp(f"acestep_gradio_results/")
|
| 387 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 388 |
+
json_path = os.path.join(temp_dir, f"{key}.json")
|
| 389 |
+
audio_path = os.path.join(temp_dir, f"{key}.{audio_format}")
|
| 390 |
+
save_audio(audio_data=audio_tensor, output_path=audio_path, sample_rate=sample_rate, format=audio_format, channels_first=True)
|
| 391 |
+
with open(json_path, 'w', encoding='utf-8') as f:
|
| 392 |
+
json.dump(audio_params, f, indent=2, ensure_ascii=False)
|
| 393 |
+
audio_outputs[i] = audio_path
|
| 394 |
+
all_audio_paths.append(audio_path)
|
| 395 |
+
all_audio_paths.append(json_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
|
| 397 |
+
code_str = audio_params.get("audio_codes", "")
|
| 398 |
+
final_codes_list[i] = code_str
|
| 399 |
|
| 400 |
+
scores_ui_updates = [gr.skip()] * 8
|
| 401 |
+
score_str = "Done!"
|
| 402 |
+
if auto_score:
|
| 403 |
+
auto_score_start = time_module.time()
|
| 404 |
+
score_str = calculate_score_handler(llm_handler, code_str, captions, lyrics, lm_generated_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale)
|
| 405 |
+
auto_score_end = time_module.time()
|
| 406 |
+
total_auto_score_time += (auto_score_end - auto_score_start)
|
| 407 |
+
scores_ui_updates[i] = score_str
|
| 408 |
+
final_scores_list[i] = score_str
|
| 409 |
|
| 410 |
+
status_message = f"Encoding & Ready: {i+1}/{len(audios)}"
|
| 411 |
+
current_audio_updates = [gr.skip()] * 8
|
| 412 |
+
current_audio_updates[i] = audio_path
|
| 413 |
+
|
| 414 |
+
audio_codes_ui_updates = [gr.skip()] * 8
|
| 415 |
+
audio_codes_ui_updates[i] = code_str
|
| 416 |
+
yield (
|
| 417 |
+
current_audio_updates[0], current_audio_updates[1], current_audio_updates[2], current_audio_updates[3],
|
| 418 |
+
current_audio_updates[4], current_audio_updates[5], current_audio_updates[6], current_audio_updates[7],
|
| 419 |
+
all_audio_paths, # Real-time update of Batch File list
|
| 420 |
+
generation_info,
|
| 421 |
+
status_message,
|
| 422 |
+
seed_value_for_ui,
|
| 423 |
+
# Align plot placeholders (assume no need to update in real time)
|
| 424 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
| 425 |
+
# Scores
|
| 426 |
+
scores_ui_updates[0], scores_ui_updates[1], scores_ui_updates[2], scores_ui_updates[3], scores_ui_updates[4], scores_ui_updates[5], scores_ui_updates[6], scores_ui_updates[7],
|
| 427 |
+
updated_audio_codes,
|
| 428 |
+
# Codes
|
| 429 |
+
audio_codes_ui_updates[0], audio_codes_ui_updates[1], audio_codes_ui_updates[2], audio_codes_ui_updates[3],
|
| 430 |
+
audio_codes_ui_updates[4], audio_codes_ui_updates[5], audio_codes_ui_updates[6], audio_codes_ui_updates[7],
|
| 431 |
+
lm_generated_metadata,
|
| 432 |
+
is_format_caption,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
else:
|
| 435 |
+
# If i exceeds the generated count (e.g., batch=2, i=2..7), do not yield
|
| 436 |
+
pass
|
| 437 |
+
time_module.sleep(0.1)
|
| 438 |
+
|
| 439 |
+
# Record audio conversion time
|
| 440 |
+
audio_conversion_end_time = time_module.time()
|
| 441 |
+
audio_conversion_time = audio_conversion_end_time - audio_conversion_start_time
|
| 442 |
+
|
| 443 |
+
# Add post-processing times to time_costs
|
| 444 |
+
if audio_conversion_time > 0:
|
| 445 |
+
time_costs['audio_conversion_time'] = audio_conversion_time
|
| 446 |
+
if total_auto_score_time > 0:
|
| 447 |
+
time_costs['auto_score_time'] = total_auto_score_time
|
| 448 |
+
|
| 449 |
+
# Update pipeline total time to include post-processing
|
| 450 |
+
if 'pipeline_total_time' in time_costs:
|
| 451 |
+
time_costs['pipeline_total_time'] += audio_conversion_time + total_auto_score_time
|
| 452 |
+
|
| 453 |
+
# Rebuild generation_info with complete timing information
|
| 454 |
+
generation_info = _build_generation_info(
|
| 455 |
+
lm_metadata=lm_generated_metadata,
|
| 456 |
+
time_costs=time_costs,
|
| 457 |
+
seed_value=seed_value_for_ui,
|
| 458 |
+
inference_steps=inference_steps,
|
| 459 |
+
num_audios=len(result.audios),
|
| 460 |
+
)
|
|
|
|
|
|
|
| 461 |
|
| 462 |
+
yield (
|
| 463 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(), # Audio 1-4: SKIP
|
| 464 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(), # Audio 5-8: SKIP
|
| 465 |
+
all_audio_paths,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
generation_info,
|
| 467 |
+
"Generation Complete",
|
| 468 |
seed_value_for_ui,
|
| 469 |
+
align_score_1, align_text_1, align_plot_1, align_score_2, align_text_2, align_plot_2,
|
| 470 |
+
final_scores_list[0], final_scores_list[1], final_scores_list[2], final_scores_list[3],
|
| 471 |
+
final_scores_list[4], final_scores_list[5], final_scores_list[6], final_scores_list[7],
|
| 472 |
+
updated_audio_codes,
|
| 473 |
+
final_codes_list[0], final_codes_list[1], final_codes_list[2], final_codes_list[3],
|
| 474 |
+
final_codes_list[4], final_codes_list[5], final_codes_list[6], final_codes_list[7],
|
| 475 |
+
lm_generated_metadata,
|
| 476 |
+
is_format_caption,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
)
|
| 478 |
|
| 479 |
|
| 480 |
+
|
| 481 |
def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale):
|
| 482 |
"""
|
| 483 |
Calculate PMI-based quality score for generated audio.
|
|
|
|
| 620 |
if stored_allow_lm_batch and isinstance(stored_codes, list):
|
| 621 |
# Batch mode: use specific sample's codes
|
| 622 |
if 0 <= sample_idx - 1 < len(stored_codes):
|
| 623 |
+
code_item = stored_codes[sample_idx - 1]
|
| 624 |
+
# Ensure it's a string (handle cases where dict was mistakenly stored)
|
| 625 |
+
audio_codes_str = code_item if isinstance(code_item, str) else ""
|
| 626 |
else:
|
| 627 |
# Single mode: all samples use same codes
|
| 628 |
audio_codes_str = stored_codes if isinstance(stored_codes, str) else ""
|
|
|
|
| 734 |
Wrapper for generate_with_progress that adds batch queue management
|
| 735 |
"""
|
| 736 |
# Call the original generation function
|
| 737 |
+
generator = generate_with_progress(
|
| 738 |
dit_handler, llm_handler,
|
| 739 |
captions, lyrics, bpm, key_scale, time_signature, vocal_language,
|
| 740 |
inference_steps, guidance_scale, random_seed_checkbox, seed,
|
|
|
|
| 751 |
lm_batch_chunk_size,
|
| 752 |
progress
|
| 753 |
)
|
| 754 |
+
final_result_from_inner = None
|
| 755 |
+
for partial_result in generator:
|
| 756 |
+
final_result_from_inner = partial_result
|
| 757 |
+
# current_batch_index, total_batches, batch_queue, next_params,
|
| 758 |
+
# batch_indicator_text, prev_btn, next_btn, next_status, restore_btn
|
| 759 |
+
yield partial_result + (
|
| 760 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
| 761 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
| 762 |
+
)
|
| 763 |
+
result = final_result_from_inner
|
| 764 |
+
all_audio_paths = result[8]
|
| 765 |
+
|
| 766 |
+
if all_audio_paths is None:
|
| 767 |
+
|
| 768 |
+
yield result + (
|
| 769 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
| 770 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
| 771 |
+
)
|
| 772 |
+
return
|
| 773 |
+
|
| 774 |
+
# Extract results from generation (使用 result 下标访问)
|
| 775 |
generation_info = result[9]
|
| 776 |
seed_value_for_ui = result[11]
|
| 777 |
+
lm_generated_metadata = result[35] # Fixed: lm_metadata is at index 35, not 34
|
| 778 |
|
| 779 |
# Extract codes
|
| 780 |
generated_codes_single = result[26]
|
| 781 |
generated_codes_batch = [result[27], result[28], result[29], result[30], result[31], result[32], result[33], result[34]]
|
| 782 |
+
|
| 783 |
# Determine which codes to store based on mode
|
| 784 |
if allow_lm_batch and batch_size_input >= 2:
|
| 785 |
codes_to_store = generated_codes_batch[:int(batch_size_input)]
|
| 786 |
else:
|
| 787 |
codes_to_store = generated_codes_single
|
| 788 |
+
|
| 789 |
# Save parameters for history
|
| 790 |
saved_params = {
|
| 791 |
"captions": captions,
|
|
|
|
| 831 |
}
|
| 832 |
|
| 833 |
# Next batch parameters (with cleared codes & random seed)
|
| 834 |
+
# Next batch parameters
|
| 835 |
next_params = saved_params.copy()
|
| 836 |
next_params["text2music_audio_code_string"] = ""
|
| 837 |
next_params["random_seed_checkbox"] = True
|
|
|
|
| 864 |
next_batch_status_text = ""
|
| 865 |
if autogen_checkbox:
|
| 866 |
next_batch_status_text = t("messages.autogen_enabled")
|
| 867 |
+
|
| 868 |
+
# 4. Yield final result (includes Batch UI updates)
|
| 869 |
+
# The result here is already a tuple structure
|
| 870 |
+
yield result + (
|
| 871 |
current_batch_index,
|
| 872 |
total_batches,
|
| 873 |
batch_queue,
|
|
|
|
| 983 |
params.setdefault("complete_track_classes", [])
|
| 984 |
|
| 985 |
# Call generate_with_progress with the saved parameters
|
| 986 |
+
# Note: generate_with_progress is a generator, need to iterate through it
|
| 987 |
+
generator = generate_with_progress(
|
| 988 |
dit_handler,
|
| 989 |
llm_handler,
|
| 990 |
captions=params.get("captions"),
|
|
|
|
| 1029 |
progress=progress
|
| 1030 |
)
|
| 1031 |
|
| 1032 |
+
# Consume generator to get final result (similar to generate_with_batch_management)
|
| 1033 |
+
final_result = None
|
| 1034 |
+
for partial_result in generator:
|
| 1035 |
+
final_result = partial_result
|
| 1036 |
+
|
| 1037 |
+
# Extract results from final_result
|
| 1038 |
+
all_audio_paths = final_result[8] # generated_audio_batch
|
| 1039 |
+
generation_info = final_result[9]
|
| 1040 |
+
seed_value_for_ui = final_result[11]
|
| 1041 |
+
lm_generated_metadata = final_result[35] # Fixed: lm_metadata is at index 35, not 34
|
| 1042 |
|
| 1043 |
# Extract codes
|
| 1044 |
+
generated_codes_single = final_result[26]
|
| 1045 |
+
generated_codes_batch = [final_result[27], final_result[28], final_result[29], final_result[30], final_result[31], final_result[32], final_result[33], final_result[34]]
|
| 1046 |
|
| 1047 |
# Determine which codes to store
|
| 1048 |
batch_size = params.get("batch_size_input", 2)
|
|
|
|
| 1132 |
|
| 1133 |
# Prepare audio outputs (up to 8)
|
| 1134 |
audio_outputs = [None] * 8
|
| 1135 |
+
real_audio_paths = [p for p in audio_paths if not p.lower().endswith('.json')]
|
| 1136 |
+
for idx in range(min(len(real_audio_paths), 8)):
|
| 1137 |
+
audio_outputs[idx] = real_audio_paths[idx]
|
| 1138 |
|
| 1139 |
# Update batch indicator
|
| 1140 |
total_batches = len(batch_queue)
|
|
|
|
| 1179 |
|
| 1180 |
# Prepare audio outputs (up to 8)
|
| 1181 |
audio_outputs = [None] * 8
|
| 1182 |
+
real_audio_paths = [p for p in audio_paths if not p.lower().endswith('.json')]
|
| 1183 |
+
for idx in range(min(len(real_audio_paths), 8)):
|
| 1184 |
+
audio_outputs[idx] = real_audio_paths[idx]
|
| 1185 |
|
| 1186 |
# Update batch indicator
|
| 1187 |
batch_indicator_text = update_batch_indicator(new_batch_index, total_batches)
|
acestep/gradio_ui/interfaces/result.py
CHANGED
|
@@ -28,7 +28,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 28 |
generated_audio_1 = gr.Audio(
|
| 29 |
label=t("results.generated_music", n=1),
|
| 30 |
type="filepath",
|
| 31 |
-
interactive=False
|
|
|
|
| 32 |
)
|
| 33 |
with gr.Row(equal_height=True):
|
| 34 |
send_to_src_btn_1 = gr.Button(
|
|
@@ -58,7 +59,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 58 |
generated_audio_2 = gr.Audio(
|
| 59 |
label=t("results.generated_music", n=2),
|
| 60 |
type="filepath",
|
| 61 |
-
interactive=False
|
|
|
|
| 62 |
)
|
| 63 |
with gr.Row(equal_height=True):
|
| 64 |
send_to_src_btn_2 = gr.Button(
|
|
@@ -88,7 +90,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 88 |
generated_audio_3 = gr.Audio(
|
| 89 |
label=t("results.generated_music", n=3),
|
| 90 |
type="filepath",
|
| 91 |
-
interactive=False
|
|
|
|
| 92 |
)
|
| 93 |
with gr.Row(equal_height=True):
|
| 94 |
send_to_src_btn_3 = gr.Button(
|
|
@@ -118,7 +121,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 118 |
generated_audio_4 = gr.Audio(
|
| 119 |
label=t("results.generated_music", n=4),
|
| 120 |
type="filepath",
|
| 121 |
-
interactive=False
|
|
|
|
| 122 |
)
|
| 123 |
with gr.Row(equal_height=True):
|
| 124 |
send_to_src_btn_4 = gr.Button(
|
|
@@ -151,7 +155,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 151 |
generated_audio_5 = gr.Audio(
|
| 152 |
label=t("results.generated_music", n=5),
|
| 153 |
type="filepath",
|
| 154 |
-
interactive=False
|
|
|
|
| 155 |
)
|
| 156 |
with gr.Row(equal_height=True):
|
| 157 |
send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
|
@@ -166,7 +171,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 166 |
generated_audio_6 = gr.Audio(
|
| 167 |
label=t("results.generated_music", n=6),
|
| 168 |
type="filepath",
|
| 169 |
-
interactive=False
|
|
|
|
| 170 |
)
|
| 171 |
with gr.Row(equal_height=True):
|
| 172 |
send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
|
@@ -181,7 +187,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 181 |
generated_audio_7 = gr.Audio(
|
| 182 |
label=t("results.generated_music", n=7),
|
| 183 |
type="filepath",
|
| 184 |
-
interactive=False
|
|
|
|
| 185 |
)
|
| 186 |
with gr.Row(equal_height=True):
|
| 187 |
send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
|
@@ -196,7 +203,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 196 |
generated_audio_8 = gr.Audio(
|
| 197 |
label=t("results.generated_music", n=8),
|
| 198 |
type="filepath",
|
| 199 |
-
interactive=False
|
|
|
|
| 200 |
)
|
| 201 |
with gr.Row(equal_height=True):
|
| 202 |
send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
|
|
|
| 28 |
generated_audio_1 = gr.Audio(
|
| 29 |
label=t("results.generated_music", n=1),
|
| 30 |
type="filepath",
|
| 31 |
+
interactive=False,
|
| 32 |
+
show_download_button=False
|
| 33 |
)
|
| 34 |
with gr.Row(equal_height=True):
|
| 35 |
send_to_src_btn_1 = gr.Button(
|
|
|
|
| 59 |
generated_audio_2 = gr.Audio(
|
| 60 |
label=t("results.generated_music", n=2),
|
| 61 |
type="filepath",
|
| 62 |
+
interactive=False,
|
| 63 |
+
show_download_button=False
|
| 64 |
)
|
| 65 |
with gr.Row(equal_height=True):
|
| 66 |
send_to_src_btn_2 = gr.Button(
|
|
|
|
| 90 |
generated_audio_3 = gr.Audio(
|
| 91 |
label=t("results.generated_music", n=3),
|
| 92 |
type="filepath",
|
| 93 |
+
interactive=False,
|
| 94 |
+
show_download_button=False
|
| 95 |
)
|
| 96 |
with gr.Row(equal_height=True):
|
| 97 |
send_to_src_btn_3 = gr.Button(
|
|
|
|
| 121 |
generated_audio_4 = gr.Audio(
|
| 122 |
label=t("results.generated_music", n=4),
|
| 123 |
type="filepath",
|
| 124 |
+
interactive=False,
|
| 125 |
+
show_download_button=False
|
| 126 |
)
|
| 127 |
with gr.Row(equal_height=True):
|
| 128 |
send_to_src_btn_4 = gr.Button(
|
|
|
|
| 155 |
generated_audio_5 = gr.Audio(
|
| 156 |
label=t("results.generated_music", n=5),
|
| 157 |
type="filepath",
|
| 158 |
+
interactive=False,
|
| 159 |
+
show_download_button=False
|
| 160 |
)
|
| 161 |
with gr.Row(equal_height=True):
|
| 162 |
send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
|
|
|
| 171 |
generated_audio_6 = gr.Audio(
|
| 172 |
label=t("results.generated_music", n=6),
|
| 173 |
type="filepath",
|
| 174 |
+
interactive=False,
|
| 175 |
+
show_download_button=False
|
| 176 |
)
|
| 177 |
with gr.Row(equal_height=True):
|
| 178 |
send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
|
|
|
| 187 |
generated_audio_7 = gr.Audio(
|
| 188 |
label=t("results.generated_music", n=7),
|
| 189 |
type="filepath",
|
| 190 |
+
interactive=False,
|
| 191 |
+
show_download_button=False
|
| 192 |
)
|
| 193 |
with gr.Row(equal_height=True):
|
| 194 |
send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
|
|
|
| 203 |
generated_audio_8 = gr.Audio(
|
| 204 |
label=t("results.generated_music", n=8),
|
| 205 |
type="filepath",
|
| 206 |
+
interactive=False,
|
| 207 |
+
show_download_button=False
|
| 208 |
)
|
| 209 |
with gr.Row(equal_height=True):
|
| 210 |
send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
acestep/handler.py
CHANGED
|
@@ -10,6 +10,8 @@ import traceback
|
|
| 10 |
import re
|
| 11 |
import random
|
| 12 |
import uuid
|
|
|
|
|
|
|
| 13 |
from contextlib import contextmanager
|
| 14 |
from typing import Optional, Dict, Any, Tuple, List, Union
|
| 15 |
|
|
@@ -37,16 +39,12 @@ warnings.filterwarnings("ignore")
|
|
| 37 |
class AceStepHandler:
|
| 38 |
"""ACE-Step Business Logic Handler"""
|
| 39 |
|
| 40 |
-
def __init__(self
|
| 41 |
self.model = None
|
| 42 |
self.config = None
|
| 43 |
self.device = "cpu"
|
| 44 |
self.dtype = torch.float32 # Will be set based on device in initialize_service
|
| 45 |
-
|
| 46 |
-
self.temp_dir = tempfile.mkdtemp()
|
| 47 |
-
else:
|
| 48 |
-
self.temp_dir = save_root
|
| 49 |
-
|
| 50 |
# VAE for audio encoding/decoding
|
| 51 |
self.vae = None
|
| 52 |
|
|
@@ -81,8 +79,7 @@ class AceStepHandler:
|
|
| 81 |
def get_available_checkpoints(self) -> str:
|
| 82 |
"""Return project root directory path"""
|
| 83 |
# Get project root (handler.py is in acestep/, so go up two levels to project root)
|
| 84 |
-
|
| 85 |
-
project_root = os.path.dirname(os.path.dirname(current_file))
|
| 86 |
# default checkpoints
|
| 87 |
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 88 |
if os.path.exists(checkpoint_dir):
|
|
@@ -93,8 +90,7 @@ class AceStepHandler:
|
|
| 93 |
def get_available_acestep_v15_models(self) -> List[str]:
|
| 94 |
"""Scan and return all model directory names starting with 'acestep-v15-'"""
|
| 95 |
# Get project root
|
| 96 |
-
|
| 97 |
-
project_root = os.path.dirname(os.path.dirname(current_file))
|
| 98 |
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 99 |
|
| 100 |
models = []
|
|
@@ -171,8 +167,7 @@ class AceStepHandler:
|
|
| 171 |
|
| 172 |
|
| 173 |
# Auto-detect project root (independent of passed project_root parameter)
|
| 174 |
-
|
| 175 |
-
actual_project_root = os.path.dirname(os.path.dirname(current_file))
|
| 176 |
checkpoint_dir = os.path.join(actual_project_root, "checkpoints")
|
| 177 |
|
| 178 |
# 1. Load main model
|
|
@@ -187,7 +182,7 @@ class AceStepHandler:
|
|
| 187 |
attn_implementation = "sdpa"
|
| 188 |
|
| 189 |
try:
|
| 190 |
-
logger.info(f"Attempting to load model with attention implementation: {attn_implementation}")
|
| 191 |
self.model = AutoModel.from_pretrained(
|
| 192 |
acestep_v15_checkpoint_path,
|
| 193 |
trust_remote_code=True,
|
|
@@ -195,9 +190,9 @@ class AceStepHandler:
|
|
| 195 |
dtype="bfloat16"
|
| 196 |
)
|
| 197 |
except Exception as e:
|
| 198 |
-
logger.warning(f"Failed to load model with {attn_implementation}: {e}")
|
| 199 |
if attn_implementation == "sdpa":
|
| 200 |
-
logger.info("Falling back to eager attention")
|
| 201 |
attn_implementation = "eager"
|
| 202 |
self.model = AutoModel.from_pretrained(
|
| 203 |
acestep_v15_checkpoint_path,
|
|
@@ -215,7 +210,7 @@ class AceStepHandler:
|
|
| 215 |
else:
|
| 216 |
# If offload_to_cpu is True, check if we should keep DiT on GPU
|
| 217 |
if not self.offload_dit_to_cpu:
|
| 218 |
-
logger.info(f"Keeping main model on {device} (persistent)")
|
| 219 |
self.model = self.model.to(device).to(self.dtype)
|
| 220 |
else:
|
| 221 |
self.model = self.model.to("cpu").to(self.dtype)
|
|
@@ -239,7 +234,7 @@ class AceStepHandler:
|
|
| 239 |
raise ValueError(f"Unsupported quantization type: {self.quantization}")
|
| 240 |
|
| 241 |
quantize_(self.model, quant_config)
|
| 242 |
-
logger.info(f"DiT quantized with: {self.quantization}")
|
| 243 |
|
| 244 |
|
| 245 |
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
|
|
@@ -260,7 +255,7 @@ class AceStepHandler:
|
|
| 260 |
if os.path.exists(vae_checkpoint_path):
|
| 261 |
self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
|
| 262 |
# Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
|
| 263 |
-
vae_dtype =
|
| 264 |
if not self.offload_to_cpu:
|
| 265 |
self.vae = self.vae.to(device).to(vae_dtype)
|
| 266 |
else:
|
|
@@ -302,6 +297,7 @@ class AceStepHandler:
|
|
| 302 |
|
| 303 |
except Exception as e:
|
| 304 |
error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
|
|
|
| 305 |
return error_msg, False
|
| 306 |
|
| 307 |
@contextmanager
|
|
@@ -326,7 +322,7 @@ class AceStepHandler:
|
|
| 326 |
try:
|
| 327 |
param = next(model.parameters())
|
| 328 |
if param.device.type == "cpu":
|
| 329 |
-
logger.info(f"Moving {model_name} to {self.device} (persistent)")
|
| 330 |
model.to(self.device).to(self.dtype)
|
| 331 |
if hasattr(self, "silence_latent"):
|
| 332 |
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
|
|
@@ -341,10 +337,10 @@ class AceStepHandler:
|
|
| 341 |
return
|
| 342 |
|
| 343 |
# Load to GPU
|
| 344 |
-
logger.info(f"Loading {model_name} to {self.device}")
|
| 345 |
start_time = time.time()
|
| 346 |
if model_name == "vae":
|
| 347 |
-
vae_dtype =
|
| 348 |
model.to(self.device).to(vae_dtype)
|
| 349 |
else:
|
| 350 |
model.to(self.device).to(self.dtype)
|
|
@@ -354,13 +350,13 @@ class AceStepHandler:
|
|
| 354 |
|
| 355 |
load_time = time.time() - start_time
|
| 356 |
self.current_offload_cost += load_time
|
| 357 |
-
logger.info(f"Loaded {model_name} to {self.device} in {load_time:.4f}s")
|
| 358 |
|
| 359 |
try:
|
| 360 |
yield
|
| 361 |
finally:
|
| 362 |
# Offload to CPU
|
| 363 |
-
logger.info(f"Offloading {model_name} to CPU")
|
| 364 |
start_time = time.time()
|
| 365 |
model.to("cpu")
|
| 366 |
|
|
@@ -370,7 +366,7 @@ class AceStepHandler:
|
|
| 370 |
torch.cuda.empty_cache()
|
| 371 |
offload_time = time.time() - start_time
|
| 372 |
self.current_offload_cost += offload_time
|
| 373 |
-
logger.info(f"Offloaded {model_name} to CPU in {offload_time:.4f}s")
|
| 374 |
|
| 375 |
def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
|
| 376 |
"""Process target audio"""
|
|
@@ -386,23 +382,12 @@ class AceStepHandler:
|
|
| 386 |
else:
|
| 387 |
audio = torch.from_numpy(audio_np.T)
|
| 388 |
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
audio = audio[:2]
|
| 393 |
-
|
| 394 |
-
# Resample if needed
|
| 395 |
-
if sr != 48000:
|
| 396 |
-
import torch.nn.functional as F
|
| 397 |
-
ratio = 48000 / sr
|
| 398 |
-
new_length = int(audio.shape[-1] * ratio)
|
| 399 |
-
audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0)
|
| 400 |
-
|
| 401 |
-
audio = torch.clamp(audio, -1.0, 1.0)
|
| 402 |
|
| 403 |
return audio
|
| 404 |
except Exception as e:
|
| 405 |
-
logger.
|
| 406 |
return None
|
| 407 |
|
| 408 |
def _parse_audio_code_string(self, code_str: str) -> List[int]:
|
|
@@ -411,7 +396,8 @@ class AceStepHandler:
|
|
| 411 |
return []
|
| 412 |
try:
|
| 413 |
return [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
|
| 414 |
-
except Exception:
|
|
|
|
| 415 |
return []
|
| 416 |
|
| 417 |
def _decode_audio_codes_to_latents(self, code_str: str) -> Optional[torch.Tensor]:
|
|
@@ -538,9 +524,7 @@ class AceStepHandler:
|
|
| 538 |
)
|
| 539 |
"""
|
| 540 |
# Align instruction formatting with _prepare_batch
|
| 541 |
-
final_instruction = instruction or DEFAULT_DIT_INSTRUCTION
|
| 542 |
-
if not final_instruction.endswith(":"):
|
| 543 |
-
final_instruction = final_instruction + ":"
|
| 544 |
|
| 545 |
# Extract caption and language from metas if available (from LM CoT output)
|
| 546 |
# Fallback to user-provided values if not in metas
|
|
@@ -571,7 +555,7 @@ class AceStepHandler:
|
|
| 571 |
|
| 572 |
parsed_meta = self._parse_metas([metas])[0]
|
| 573 |
caption_input = SFT_GEN_PROMPT.format(final_instruction, actual_caption, parsed_meta)
|
| 574 |
-
lyrics_input =
|
| 575 |
return caption_input, lyrics_input
|
| 576 |
|
| 577 |
def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
@@ -614,7 +598,7 @@ class AceStepHandler:
|
|
| 614 |
return match.group(1).strip()
|
| 615 |
return caption
|
| 616 |
except Exception as e:
|
| 617 |
-
logger.
|
| 618 |
return caption
|
| 619 |
|
| 620 |
def prepare_seeds(self, actual_batch_size, seed, use_random_seed):
|
|
@@ -638,7 +622,8 @@ class AceStepHandler:
|
|
| 638 |
else:
|
| 639 |
try:
|
| 640 |
seed_list.append(int(float(s)))
|
| 641 |
-
except (ValueError, TypeError):
|
|
|
|
| 642 |
seed_list.append(-1)
|
| 643 |
elif seed is None or (isinstance(seed, (int, float)) and seed < 0):
|
| 644 |
# If seed is None or negative, use -1 for all items
|
|
@@ -679,7 +664,176 @@ class AceStepHandler:
|
|
| 679 |
return actual_seed_list, seed_value_for_ui
|
| 680 |
|
| 681 |
def prepare_metadata(self, bpm, key_scale, time_signature):
|
| 682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
metadata_dict = {}
|
| 684 |
if bpm:
|
| 685 |
metadata_dict["bpm"] = bpm
|
|
@@ -695,10 +849,12 @@ class AceStepHandler:
|
|
| 695 |
metadata_dict["timesignature"] = time_signature
|
| 696 |
else:
|
| 697 |
metadata_dict["timesignature"] = "N/A"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 698 |
return metadata_dict
|
| 699 |
-
|
| 700 |
-
def is_silence(self, audio):
|
| 701 |
-
return torch.all(audio.abs() < 1e-6)
|
| 702 |
|
| 703 |
def generate_instruction(
|
| 704 |
self,
|
|
@@ -745,23 +901,12 @@ class AceStepHandler:
|
|
| 745 |
# Load audio file
|
| 746 |
audio, sr = torchaudio.load(audio_file)
|
| 747 |
|
| 748 |
-
logger.
|
| 749 |
-
logger.
|
| 750 |
-
logger.
|
| 751 |
-
|
| 752 |
-
# Convert to stereo (duplicate channel if mono)
|
| 753 |
-
if audio.shape[0] == 1:
|
| 754 |
-
audio = torch.cat([audio, audio], dim=0)
|
| 755 |
-
|
| 756 |
-
# Keep only first 2 channels
|
| 757 |
-
audio = audio[:2]
|
| 758 |
|
| 759 |
-
#
|
| 760 |
-
|
| 761 |
-
audio = torchaudio.transforms.Resample(sr, 48000)(audio)
|
| 762 |
-
|
| 763 |
-
# Clamp values to [-1.0, 1.0]
|
| 764 |
-
audio = torch.clamp(audio, -1.0, 1.0)
|
| 765 |
|
| 766 |
is_silence = self.is_silence(audio)
|
| 767 |
if is_silence:
|
|
@@ -800,7 +945,7 @@ class AceStepHandler:
|
|
| 800 |
return audio
|
| 801 |
|
| 802 |
except Exception as e:
|
| 803 |
-
logger.
|
| 804 |
return None
|
| 805 |
|
| 806 |
def process_src_audio(self, audio_file) -> Optional[torch.Tensor]:
|
|
@@ -811,24 +956,13 @@ class AceStepHandler:
|
|
| 811 |
# Load audio file
|
| 812 |
audio, sr = torchaudio.load(audio_file)
|
| 813 |
|
| 814 |
-
#
|
| 815 |
-
|
| 816 |
-
audio = torch.cat([audio, audio], dim=0)
|
| 817 |
-
|
| 818 |
-
# Keep only first 2 channels
|
| 819 |
-
audio = audio[:2]
|
| 820 |
-
|
| 821 |
-
# Resample to 48kHz if needed
|
| 822 |
-
if sr != 48000:
|
| 823 |
-
audio = torchaudio.transforms.Resample(sr, 48000)(audio)
|
| 824 |
-
|
| 825 |
-
# Clamp values to [-1.0, 1.0]
|
| 826 |
-
audio = torch.clamp(audio, -1.0, 1.0)
|
| 827 |
|
| 828 |
return audio
|
| 829 |
|
| 830 |
except Exception as e:
|
| 831 |
-
logger.
|
| 832 |
return None
|
| 833 |
|
| 834 |
def convert_src_audio_to_codes(self, audio_file) -> str:
|
|
@@ -856,19 +990,12 @@ class AceStepHandler:
|
|
| 856 |
# Encode audio to latents using VAE
|
| 857 |
with torch.no_grad():
|
| 858 |
with self._load_model_context("vae"):
|
| 859 |
-
# Prepare audio for VAE: [channels, samples] -> [1, channels, samples]
|
| 860 |
-
vae_input = processed_audio.unsqueeze(0).to(self.device).to(self.vae.dtype)
|
| 861 |
-
|
| 862 |
# Check if audio is silence
|
| 863 |
-
if self.is_silence(
|
| 864 |
return "❌ Audio file appears to be silent"
|
| 865 |
|
| 866 |
-
# Encode to latents
|
| 867 |
-
latents = self.
|
| 868 |
-
# Cast back to model dtype
|
| 869 |
-
latents = latents.to(self.dtype)
|
| 870 |
-
# Transpose: [1, d, T] -> [1, T, d] -> [T, d]
|
| 871 |
-
latents = latents.squeeze(0).transpose(0, 1) # [T, d]
|
| 872 |
|
| 873 |
# Create attention mask for latents
|
| 874 |
attention_mask = torch.ones(latents.shape[0], dtype=torch.bool, device=self.device)
|
|
@@ -893,7 +1020,7 @@ class AceStepHandler:
|
|
| 893 |
|
| 894 |
except Exception as e:
|
| 895 |
error_msg = f"❌ Error converting audio to codes: {str(e)}\n{traceback.format_exc()}"
|
| 896 |
-
logger.
|
| 897 |
return error_msg
|
| 898 |
|
| 899 |
def prepare_batch_data(
|
|
@@ -922,26 +1049,7 @@ class AceStepHandler:
|
|
| 922 |
calculated_duration = audio_duration
|
| 923 |
|
| 924 |
# Build metadata dict - use "N/A" as default for empty fields
|
| 925 |
-
metadata_dict =
|
| 926 |
-
if bpm:
|
| 927 |
-
metadata_dict["bpm"] = bpm
|
| 928 |
-
else:
|
| 929 |
-
metadata_dict["bpm"] = "N/A"
|
| 930 |
-
|
| 931 |
-
if key_scale.strip():
|
| 932 |
-
metadata_dict["keyscale"] = key_scale
|
| 933 |
-
else:
|
| 934 |
-
metadata_dict["keyscale"] = "N/A"
|
| 935 |
-
|
| 936 |
-
if time_signature.strip() and time_signature != "N/A" and time_signature:
|
| 937 |
-
metadata_dict["timesignature"] = time_signature
|
| 938 |
-
else:
|
| 939 |
-
metadata_dict["timesignature"] = "N/A"
|
| 940 |
-
|
| 941 |
-
# Add duration to metadata if available (inference service format: "30 seconds")
|
| 942 |
-
if calculated_duration is not None:
|
| 943 |
-
metadata_dict["duration"] = f"{int(calculated_duration)} seconds"
|
| 944 |
-
# If duration not set, inference service will use default (30 seconds)
|
| 945 |
|
| 946 |
# Format metadata - inference service accepts dict and will convert to string
|
| 947 |
# Create a copy for each batch item (in case we modify it)
|
|
@@ -977,7 +1085,7 @@ class AceStepHandler:
|
|
| 977 |
target_wavs = torch.zeros(2, frames)
|
| 978 |
return target_wavs
|
| 979 |
except Exception as e:
|
| 980 |
-
logger.
|
| 981 |
# Fallback to 30 seconds if error
|
| 982 |
return torch.zeros(2, 30 * 48000)
|
| 983 |
|
|
@@ -1158,16 +1266,8 @@ class AceStepHandler:
|
|
| 1158 |
"""
|
| 1159 |
batch_size = len(captions)
|
| 1160 |
|
| 1161 |
-
#
|
| 1162 |
-
|
| 1163 |
-
audio_code_hints = [None] * batch_size
|
| 1164 |
-
elif len(audio_code_hints) != batch_size:
|
| 1165 |
-
if len(audio_code_hints) == 1:
|
| 1166 |
-
audio_code_hints = audio_code_hints * batch_size
|
| 1167 |
-
else:
|
| 1168 |
-
audio_code_hints = audio_code_hints[:batch_size]
|
| 1169 |
-
while len(audio_code_hints) < batch_size:
|
| 1170 |
-
audio_code_hints.append(None)
|
| 1171 |
|
| 1172 |
for ii, refer_audio_list in enumerate(refer_audios):
|
| 1173 |
if isinstance(refer_audio_list, list):
|
|
@@ -1179,17 +1279,6 @@ class AceStepHandler:
|
|
| 1179 |
if vocal_languages is None:
|
| 1180 |
vocal_languages = self._create_fallback_vocal_languages(batch_size)
|
| 1181 |
|
| 1182 |
-
# Normalize audio_code_hints to batch list
|
| 1183 |
-
if audio_code_hints is None:
|
| 1184 |
-
audio_code_hints = [None] * batch_size
|
| 1185 |
-
elif not isinstance(audio_code_hints, list):
|
| 1186 |
-
audio_code_hints = [audio_code_hints] * batch_size
|
| 1187 |
-
elif len(audio_code_hints) == 1 and batch_size > 1:
|
| 1188 |
-
audio_code_hints = audio_code_hints * batch_size
|
| 1189 |
-
else:
|
| 1190 |
-
audio_code_hints = (audio_code_hints + [None] * batch_size)[:batch_size]
|
| 1191 |
-
audio_code_hints = [hint if isinstance(hint, str) and hint.strip() else None for hint in audio_code_hints]
|
| 1192 |
-
|
| 1193 |
# Parse metas with fallbacks
|
| 1194 |
parsed_metas = self._parse_metas(metas)
|
| 1195 |
|
|
@@ -1223,13 +1312,9 @@ class AceStepHandler:
|
|
| 1223 |
expected_latent_length = current_wav.shape[-1] // 1920
|
| 1224 |
target_latent = self.silence_latent[0, :expected_latent_length, :]
|
| 1225 |
else:
|
| 1226 |
-
#
|
| 1227 |
logger.info(f"[generate_music] Encoding target audio to latents for item {i}...")
|
| 1228 |
-
|
| 1229 |
-
target_latent = self.vae.encode(vae_input).latent_dist.sample()
|
| 1230 |
-
# Cast back to model dtype
|
| 1231 |
-
target_latent = target_latent.to(self.dtype)
|
| 1232 |
-
target_latent = target_latent.squeeze(0).transpose(0, 1)
|
| 1233 |
target_latents_list.append(target_latent)
|
| 1234 |
latent_lengths.append(target_latent.shape[0])
|
| 1235 |
|
|
@@ -1268,18 +1353,7 @@ class AceStepHandler:
|
|
| 1268 |
|
| 1269 |
# Process instructions early so we can use them for task type detection
|
| 1270 |
# Use custom instructions if provided, otherwise use default
|
| 1271 |
-
|
| 1272 |
-
instructions = [DEFAULT_DIT_INSTRUCTION] * batch_size
|
| 1273 |
-
|
| 1274 |
-
# Ensure instructions list has the same length as batch_size
|
| 1275 |
-
if len(instructions) != batch_size:
|
| 1276 |
-
if len(instructions) == 1:
|
| 1277 |
-
instructions = instructions * batch_size
|
| 1278 |
-
else:
|
| 1279 |
-
# Pad or truncate to match batch_size
|
| 1280 |
-
instructions = instructions[:batch_size]
|
| 1281 |
-
while len(instructions) < batch_size:
|
| 1282 |
-
instructions.append(DEFAULT_DIT_INSTRUCTION)
|
| 1283 |
|
| 1284 |
# Generate chunk_masks and spans based on repainting parameters
|
| 1285 |
# Also determine if this is a cover task (target audio provided without repainting)
|
|
@@ -1428,6 +1502,10 @@ class AceStepHandler:
|
|
| 1428 |
else:
|
| 1429 |
precomputed_lm_hints_25Hz = None
|
| 1430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1431 |
# Format text_inputs
|
| 1432 |
text_inputs = []
|
| 1433 |
text_token_idss = []
|
|
@@ -1437,26 +1515,10 @@ class AceStepHandler:
|
|
| 1437 |
|
| 1438 |
for i in range(batch_size):
|
| 1439 |
# Use custom instruction for this batch item
|
| 1440 |
-
instruction = instructions[i] if i < len(instructions) else DEFAULT_DIT_INSTRUCTION
|
| 1441 |
-
|
| 1442 |
-
|
| 1443 |
-
|
| 1444 |
-
|
| 1445 |
-
# Extract caption and language from metas if available (from LM CoT output)
|
| 1446 |
-
# Fallback to user-provided values if not in metas
|
| 1447 |
-
actual_caption = captions[i]
|
| 1448 |
-
actual_language = vocal_languages[i]
|
| 1449 |
-
|
| 1450 |
-
# Check if metas contains caption/language from LM CoT
|
| 1451 |
-
if i < len(parsed_metas) and parsed_metas[i]:
|
| 1452 |
-
meta_dict = parsed_metas[i]
|
| 1453 |
-
if isinstance(meta_dict, dict):
|
| 1454 |
-
# Extract caption from metas if available
|
| 1455 |
-
if 'caption' in meta_dict and meta_dict['caption']:
|
| 1456 |
-
actual_caption = str(meta_dict['caption'])
|
| 1457 |
-
# Extract language from metas if available
|
| 1458 |
-
if 'language' in meta_dict and meta_dict['language']:
|
| 1459 |
-
actual_language = str(meta_dict['language'])
|
| 1460 |
|
| 1461 |
# Format text prompt with custom instruction (using LM-generated caption if available)
|
| 1462 |
text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
|
|
@@ -1473,7 +1535,7 @@ class AceStepHandler:
|
|
| 1473 |
text_attention_mask = text_inputs_dict.attention_mask[0].bool()
|
| 1474 |
|
| 1475 |
# Format and tokenize lyrics (using LM-generated language if available)
|
| 1476 |
-
lyrics_text =
|
| 1477 |
lyrics_inputs_dict = self.text_tokenizer(
|
| 1478 |
lyrics_text,
|
| 1479 |
padding="longest",
|
|
@@ -1495,36 +1557,12 @@ class AceStepHandler:
|
|
| 1495 |
|
| 1496 |
# Pad tokenized sequences
|
| 1497 |
max_text_length = max(len(seq) for seq in text_token_idss)
|
| 1498 |
-
padded_text_token_idss =
|
| 1499 |
-
|
| 1500 |
-
seq, (0, max_text_length - len(seq)), 'constant',
|
| 1501 |
-
self.text_tokenizer.pad_token_id
|
| 1502 |
-
)
|
| 1503 |
-
for seq in text_token_idss
|
| 1504 |
-
])
|
| 1505 |
-
|
| 1506 |
-
padded_text_attention_masks = torch.stack([
|
| 1507 |
-
torch.nn.functional.pad(
|
| 1508 |
-
seq, (0, max_text_length - len(seq)), 'constant', 0
|
| 1509 |
-
)
|
| 1510 |
-
for seq in text_attention_masks
|
| 1511 |
-
])
|
| 1512 |
|
| 1513 |
max_lyric_length = max(len(seq) for seq in lyric_token_idss)
|
| 1514 |
-
padded_lyric_token_idss =
|
| 1515 |
-
|
| 1516 |
-
seq, (0, max_lyric_length - len(seq)), 'constant',
|
| 1517 |
-
self.text_tokenizer.pad_token_id
|
| 1518 |
-
)
|
| 1519 |
-
for seq in lyric_token_idss
|
| 1520 |
-
])
|
| 1521 |
-
|
| 1522 |
-
padded_lyric_attention_masks = torch.stack([
|
| 1523 |
-
torch.nn.functional.pad(
|
| 1524 |
-
seq, (0, max_lyric_length - len(seq)), 'constant', 0
|
| 1525 |
-
)
|
| 1526 |
-
for seq in lyric_attention_masks
|
| 1527 |
-
])
|
| 1528 |
|
| 1529 |
padded_non_cover_text_input_ids = None
|
| 1530 |
padded_non_cover_text_attention_masks = None
|
|
@@ -1533,14 +1571,10 @@ class AceStepHandler:
|
|
| 1533 |
non_cover_text_attention_masks = []
|
| 1534 |
for i in range(batch_size):
|
| 1535 |
# Use custom instruction for this batch item
|
| 1536 |
-
instruction = DEFAULT_DIT_INSTRUCTION
|
| 1537 |
|
| 1538 |
# Extract caption from metas if available (from LM CoT output)
|
| 1539 |
-
actual_caption =
|
| 1540 |
-
if i < len(parsed_metas) and parsed_metas[i]:
|
| 1541 |
-
meta_dict = parsed_metas[i]
|
| 1542 |
-
if isinstance(meta_dict, dict) and 'caption' in meta_dict and meta_dict['caption']:
|
| 1543 |
-
actual_caption = str(meta_dict['caption'])
|
| 1544 |
|
| 1545 |
# Format text prompt with custom instruction (using LM-generated caption if available)
|
| 1546 |
text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
|
|
@@ -1558,19 +1592,8 @@ class AceStepHandler:
|
|
| 1558 |
non_cover_text_input_ids.append(text_token_ids)
|
| 1559 |
non_cover_text_attention_masks.append(non_cover_text_attention_mask)
|
| 1560 |
|
| 1561 |
-
padded_non_cover_text_input_ids =
|
| 1562 |
-
|
| 1563 |
-
seq, (0, max_text_length - len(seq)), 'constant',
|
| 1564 |
-
self.text_tokenizer.pad_token_id
|
| 1565 |
-
)
|
| 1566 |
-
for seq in non_cover_text_input_ids
|
| 1567 |
-
])
|
| 1568 |
-
padded_non_cover_text_attention_masks = torch.stack([
|
| 1569 |
-
torch.nn.functional.pad(
|
| 1570 |
-
seq, (0, max_text_length - len(seq)), 'constant', 0
|
| 1571 |
-
)
|
| 1572 |
-
for seq in non_cover_text_attention_masks
|
| 1573 |
-
])
|
| 1574 |
|
| 1575 |
if audio_cover_strength < 1.0:
|
| 1576 |
assert padded_non_cover_text_input_ids is not None, "When audio_cover_strength < 1.0, padded_non_cover_text_input_ids must not be None"
|
|
@@ -1804,7 +1827,7 @@ class AceStepHandler:
|
|
| 1804 |
if self.config.is_turbo:
|
| 1805 |
# Limit inference steps to maximum 8
|
| 1806 |
if infer_steps > 8:
|
| 1807 |
-
logger.warning(f"dmd_gan version: infer_steps {infer_steps} exceeds maximum 8, clamping to 8")
|
| 1808 |
infer_steps = 8
|
| 1809 |
# CFG parameters are not adjustable for dmd_gan (they will be ignored)
|
| 1810 |
# Note: guidance_scale, cfg_interval_start, cfg_interval_end are still passed but may be ignored by the model
|
|
@@ -1827,30 +1850,12 @@ class AceStepHandler:
|
|
| 1827 |
if isinstance(repainting_end, (int, float)):
|
| 1828 |
repainting_end = [repainting_end]
|
| 1829 |
|
| 1830 |
-
# Convert instructions to list
|
| 1831 |
-
if isinstance(instructions, str):
|
| 1832 |
-
instructions = [instructions]
|
| 1833 |
-
elif instructions is None:
|
| 1834 |
-
instructions = None
|
| 1835 |
-
|
| 1836 |
-
# Convert audio_code_hints to list
|
| 1837 |
-
if isinstance(audio_code_hints, str):
|
| 1838 |
-
audio_code_hints = [audio_code_hints]
|
| 1839 |
-
elif audio_code_hints is None:
|
| 1840 |
-
audio_code_hints = None
|
| 1841 |
-
|
| 1842 |
# Get batch size from captions
|
| 1843 |
batch_size = len(captions)
|
| 1844 |
|
| 1845 |
-
#
|
| 1846 |
-
if
|
| 1847 |
-
|
| 1848 |
-
if len(audio_code_hints) == 1:
|
| 1849 |
-
audio_code_hints = audio_code_hints * batch_size
|
| 1850 |
-
else:
|
| 1851 |
-
audio_code_hints = audio_code_hints[:batch_size]
|
| 1852 |
-
while len(audio_code_hints) < batch_size:
|
| 1853 |
-
audio_code_hints.append(None)
|
| 1854 |
|
| 1855 |
# Convert seed to list format
|
| 1856 |
if seed is None:
|
|
@@ -1947,6 +1952,14 @@ class AceStepHandler:
|
|
| 1947 |
logger.info("[service_generate] Generating audio...")
|
| 1948 |
with self._load_model_context("model"):
|
| 1949 |
outputs = self.model.generate_audio(**generate_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1950 |
return outputs
|
| 1951 |
|
| 1952 |
def tiled_decode(self, latents, chunk_size=512, overlap=64):
|
|
@@ -2042,25 +2055,33 @@ class AceStepHandler:
|
|
| 2042 |
use_adg: bool = False,
|
| 2043 |
cfg_interval_start: float = 0.0,
|
| 2044 |
cfg_interval_end: float = 1.0,
|
| 2045 |
-
audio_format: str = "mp3",
|
| 2046 |
-
lm_temperature: float = 0.6,
|
| 2047 |
use_tiled_decode: bool = True,
|
| 2048 |
progress=None
|
| 2049 |
-
) ->
|
| 2050 |
"""
|
| 2051 |
Main interface for music generation
|
| 2052 |
|
| 2053 |
Returns:
|
| 2054 |
-
|
| 2055 |
-
|
| 2056 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2057 |
"""
|
| 2058 |
if progress is None:
|
| 2059 |
def progress(*args, **kwargs):
|
| 2060 |
pass
|
| 2061 |
|
| 2062 |
if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
|
| 2063 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2064 |
|
| 2065 |
def _has_audio_codes(v: Union[str, List[str]]) -> bool:
|
| 2066 |
if isinstance(v, list):
|
|
@@ -2079,7 +2100,7 @@ class AceStepHandler:
|
|
| 2079 |
|
| 2080 |
logger.info("[generate_music] Starting generation...")
|
| 2081 |
if progress:
|
| 2082 |
-
progress(0.
|
| 2083 |
logger.info("[generate_music] Preparing inputs...")
|
| 2084 |
|
| 2085 |
# Reset offload cost
|
|
@@ -2101,8 +2122,6 @@ class AceStepHandler:
|
|
| 2101 |
repainting_end = None
|
| 2102 |
|
| 2103 |
try:
|
| 2104 |
-
progress(0.1, desc="Preparing inputs...")
|
| 2105 |
-
|
| 2106 |
# 1. Process reference audio
|
| 2107 |
refer_audios = None
|
| 2108 |
if reference_audio is not None:
|
|
@@ -2154,7 +2173,7 @@ class AceStepHandler:
|
|
| 2154 |
can_use_repainting
|
| 2155 |
)
|
| 2156 |
|
| 2157 |
-
progress(0.
|
| 2158 |
|
| 2159 |
# Prepare audio_code_hints - use if audio_code_string is provided
|
| 2160 |
# This works for both text2music (auto-switched to cover) and cover tasks
|
|
@@ -2191,8 +2210,8 @@ class AceStepHandler:
|
|
| 2191 |
pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
|
| 2192 |
time_costs = outputs["time_costs"]
|
| 2193 |
time_costs["offload_time_cost"] = self.current_offload_cost
|
| 2194 |
-
logger.
|
| 2195 |
-
logger.
|
| 2196 |
if progress:
|
| 2197 |
progress(0.8, desc="Decoding audio...")
|
| 2198 |
logger.info("[generate_music] Decoding latents with VAE...")
|
|
@@ -2221,75 +2240,66 @@ class AceStepHandler:
|
|
| 2221 |
# Update offload cost one last time to include VAE offloading
|
| 2222 |
time_costs["offload_time_cost"] = self.current_offload_cost
|
| 2223 |
|
| 2224 |
-
logger.info("[generate_music] VAE decode completed.
|
| 2225 |
if progress:
|
| 2226 |
-
progress(0.
|
| 2227 |
|
| 2228 |
-
#
|
| 2229 |
-
|
| 2230 |
-
|
| 2231 |
-
|
| 2232 |
|
| 2233 |
-
saved_files = []
|
| 2234 |
-
saved_uuids = [] # Store UUIDs for each file
|
| 2235 |
for i in range(actual_batch_size):
|
| 2236 |
-
#
|
| 2237 |
-
|
| 2238 |
-
|
| 2239 |
-
|
| 2240 |
-
audio_np = pred_wavs[i].cpu().float().numpy().T
|
| 2241 |
-
sf.write(audio_file, audio_np, self.sample_rate)
|
| 2242 |
-
saved_files.append(audio_file)
|
| 2243 |
-
saved_uuids.append(file_uuid)
|
| 2244 |
-
|
| 2245 |
-
# Prepare return values
|
| 2246 |
-
first_audio = saved_files[0] if len(saved_files) > 0 else None
|
| 2247 |
-
second_audio = saved_files[1] if len(saved_files) > 1 else None
|
| 2248 |
-
|
| 2249 |
-
# Format time costs if available
|
| 2250 |
-
time_costs_str = ""
|
| 2251 |
-
if time_costs:
|
| 2252 |
-
if isinstance(time_costs, dict):
|
| 2253 |
-
time_costs_str = "\n\n**⏱️ Time Costs:**\n"
|
| 2254 |
-
for key, value in time_costs.items():
|
| 2255 |
-
# Format key: encoder_time_cost -> Encoder
|
| 2256 |
-
formatted_key = key.replace("_time_cost", "").replace("_", " ").title()
|
| 2257 |
-
time_costs_str += f" - {formatted_key}: {value:.2f}s\n"
|
| 2258 |
-
elif isinstance(time_costs, (int, float)):
|
| 2259 |
-
time_costs_str = f"\n\n**⏱️ Time Cost:** {time_costs:.2f}s"
|
| 2260 |
-
|
| 2261 |
-
generation_info = f"""**🎵 Generation Complete**
|
| 2262 |
-
|
| 2263 |
-
**Seeds:** {seed_value_for_ui}
|
| 2264 |
-
**Steps:** {inference_steps}
|
| 2265 |
-
**Files:** {len(saved_files)} audio(s){time_costs_str}"""
|
| 2266 |
status_message = f"✅ Generation completed successfully!"
|
| 2267 |
-
logger.info(f"[generate_music] Done! Generated {len(
|
| 2268 |
-
|
| 2269 |
-
#
|
| 2270 |
-
|
| 2271 |
-
|
| 2272 |
-
|
| 2273 |
-
|
| 2274 |
-
|
| 2275 |
-
|
| 2276 |
-
|
| 2277 |
-
|
| 2278 |
-
|
| 2279 |
-
|
| 2280 |
-
|
| 2281 |
-
|
| 2282 |
-
|
| 2283 |
-
|
| 2284 |
-
|
| 2285 |
-
|
| 2286 |
-
|
| 2287 |
-
|
| 2288 |
-
|
| 2289 |
-
|
| 2290 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2291 |
|
| 2292 |
except Exception as e:
|
| 2293 |
error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
|
| 2294 |
-
|
| 2295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
import re
|
| 11 |
import random
|
| 12 |
import uuid
|
| 13 |
+
import hashlib
|
| 14 |
+
import json
|
| 15 |
from contextlib import contextmanager
|
| 16 |
from typing import Optional, Dict, Any, Tuple, List, Union
|
| 17 |
|
|
|
|
| 39 |
class AceStepHandler:
|
| 40 |
"""ACE-Step Business Logic Handler"""
|
| 41 |
|
| 42 |
+
def __init__(self):
|
| 43 |
self.model = None
|
| 44 |
self.config = None
|
| 45 |
self.device = "cpu"
|
| 46 |
self.dtype = torch.float32 # Will be set based on device in initialize_service
|
| 47 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
# VAE for audio encoding/decoding
|
| 49 |
self.vae = None
|
| 50 |
|
|
|
|
| 79 |
def get_available_checkpoints(self) -> str:
|
| 80 |
"""Return project root directory path"""
|
| 81 |
# Get project root (handler.py is in acestep/, so go up two levels to project root)
|
| 82 |
+
project_root = self._get_project_root()
|
|
|
|
| 83 |
# default checkpoints
|
| 84 |
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 85 |
if os.path.exists(checkpoint_dir):
|
|
|
|
| 90 |
def get_available_acestep_v15_models(self) -> List[str]:
|
| 91 |
"""Scan and return all model directory names starting with 'acestep-v15-'"""
|
| 92 |
# Get project root
|
| 93 |
+
project_root = self._get_project_root()
|
|
|
|
| 94 |
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 95 |
|
| 96 |
models = []
|
|
|
|
| 167 |
|
| 168 |
|
| 169 |
# Auto-detect project root (independent of passed project_root parameter)
|
| 170 |
+
actual_project_root = self._get_project_root()
|
|
|
|
| 171 |
checkpoint_dir = os.path.join(actual_project_root, "checkpoints")
|
| 172 |
|
| 173 |
# 1. Load main model
|
|
|
|
| 182 |
attn_implementation = "sdpa"
|
| 183 |
|
| 184 |
try:
|
| 185 |
+
logger.info(f"[initialize_service] Attempting to load model with attention implementation: {attn_implementation}")
|
| 186 |
self.model = AutoModel.from_pretrained(
|
| 187 |
acestep_v15_checkpoint_path,
|
| 188 |
trust_remote_code=True,
|
|
|
|
| 190 |
dtype="bfloat16"
|
| 191 |
)
|
| 192 |
except Exception as e:
|
| 193 |
+
logger.warning(f"[initialize_service] Failed to load model with {attn_implementation}: {e}")
|
| 194 |
if attn_implementation == "sdpa":
|
| 195 |
+
logger.info("[initialize_service] Falling back to eager attention")
|
| 196 |
attn_implementation = "eager"
|
| 197 |
self.model = AutoModel.from_pretrained(
|
| 198 |
acestep_v15_checkpoint_path,
|
|
|
|
| 210 |
else:
|
| 211 |
# If offload_to_cpu is True, check if we should keep DiT on GPU
|
| 212 |
if not self.offload_dit_to_cpu:
|
| 213 |
+
logger.info(f"[initialize_service] Keeping main model on {device} (persistent)")
|
| 214 |
self.model = self.model.to(device).to(self.dtype)
|
| 215 |
else:
|
| 216 |
self.model = self.model.to("cpu").to(self.dtype)
|
|
|
|
| 234 |
raise ValueError(f"Unsupported quantization type: {self.quantization}")
|
| 235 |
|
| 236 |
quantize_(self.model, quant_config)
|
| 237 |
+
logger.info(f"[initialize_service] DiT quantized with: {self.quantization}")
|
| 238 |
|
| 239 |
|
| 240 |
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
|
|
|
|
| 255 |
if os.path.exists(vae_checkpoint_path):
|
| 256 |
self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
|
| 257 |
# Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
|
| 258 |
+
vae_dtype = self._get_vae_dtype(device)
|
| 259 |
if not self.offload_to_cpu:
|
| 260 |
self.vae = self.vae.to(device).to(vae_dtype)
|
| 261 |
else:
|
|
|
|
| 297 |
|
| 298 |
except Exception as e:
|
| 299 |
error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 300 |
+
logger.exception("[initialize_service] Error initializing model")
|
| 301 |
return error_msg, False
|
| 302 |
|
| 303 |
@contextmanager
|
|
|
|
| 322 |
try:
|
| 323 |
param = next(model.parameters())
|
| 324 |
if param.device.type == "cpu":
|
| 325 |
+
logger.info(f"[_load_model_context] Moving {model_name} to {self.device} (persistent)")
|
| 326 |
model.to(self.device).to(self.dtype)
|
| 327 |
if hasattr(self, "silence_latent"):
|
| 328 |
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
|
|
|
|
| 337 |
return
|
| 338 |
|
| 339 |
# Load to GPU
|
| 340 |
+
logger.info(f"[_load_model_context] Loading {model_name} to {self.device}")
|
| 341 |
start_time = time.time()
|
| 342 |
if model_name == "vae":
|
| 343 |
+
vae_dtype = self._get_vae_dtype()
|
| 344 |
model.to(self.device).to(vae_dtype)
|
| 345 |
else:
|
| 346 |
model.to(self.device).to(self.dtype)
|
|
|
|
| 350 |
|
| 351 |
load_time = time.time() - start_time
|
| 352 |
self.current_offload_cost += load_time
|
| 353 |
+
logger.info(f"[_load_model_context] Loaded {model_name} to {self.device} in {load_time:.4f}s")
|
| 354 |
|
| 355 |
try:
|
| 356 |
yield
|
| 357 |
finally:
|
| 358 |
# Offload to CPU
|
| 359 |
+
logger.info(f"[_load_model_context] Offloading {model_name} to CPU")
|
| 360 |
start_time = time.time()
|
| 361 |
model.to("cpu")
|
| 362 |
|
|
|
|
| 366 |
torch.cuda.empty_cache()
|
| 367 |
offload_time = time.time() - start_time
|
| 368 |
self.current_offload_cost += offload_time
|
| 369 |
+
logger.info(f"[_load_model_context] Offloaded {model_name} to CPU in {offload_time:.4f}s")
|
| 370 |
|
| 371 |
def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
|
| 372 |
"""Process target audio"""
|
|
|
|
| 382 |
else:
|
| 383 |
audio = torch.from_numpy(audio_np.T)
|
| 384 |
|
| 385 |
+
# Normalize to stereo 48kHz
|
| 386 |
+
audio = self._normalize_audio_to_stereo_48k(audio, sr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
return audio
|
| 389 |
except Exception as e:
|
| 390 |
+
logger.exception("[process_target_audio] Error processing target audio")
|
| 391 |
return None
|
| 392 |
|
| 393 |
def _parse_audio_code_string(self, code_str: str) -> List[int]:
|
|
|
|
| 396 |
return []
|
| 397 |
try:
|
| 398 |
return [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
|
| 399 |
+
except Exception as e:
|
| 400 |
+
logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
|
| 401 |
return []
|
| 402 |
|
| 403 |
def _decode_audio_codes_to_latents(self, code_str: str) -> Optional[torch.Tensor]:
|
|
|
|
| 524 |
)
|
| 525 |
"""
|
| 526 |
# Align instruction formatting with _prepare_batch
|
| 527 |
+
final_instruction = self._format_instruction(instruction or DEFAULT_DIT_INSTRUCTION)
|
|
|
|
|
|
|
| 528 |
|
| 529 |
# Extract caption and language from metas if available (from LM CoT output)
|
| 530 |
# Fallback to user-provided values if not in metas
|
|
|
|
| 555 |
|
| 556 |
parsed_meta = self._parse_metas([metas])[0]
|
| 557 |
caption_input = SFT_GEN_PROMPT.format(final_instruction, actual_caption, parsed_meta)
|
| 558 |
+
lyrics_input = self._format_lyrics(lyrics, actual_language)
|
| 559 |
return caption_input, lyrics_input
|
| 560 |
|
| 561 |
def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
| 598 |
return match.group(1).strip()
|
| 599 |
return caption
|
| 600 |
except Exception as e:
|
| 601 |
+
logger.exception("[extract_caption_from_sft_format] Error extracting caption")
|
| 602 |
return caption
|
| 603 |
|
| 604 |
def prepare_seeds(self, actual_batch_size, seed, use_random_seed):
|
|
|
|
| 622 |
else:
|
| 623 |
try:
|
| 624 |
seed_list.append(int(float(s)))
|
| 625 |
+
except (ValueError, TypeError) as e:
|
| 626 |
+
logger.debug(f"[prepare_seeds] Failed to parse seed value '{s}': {e}")
|
| 627 |
seed_list.append(-1)
|
| 628 |
elif seed is None or (isinstance(seed, (int, float)) and seed < 0):
|
| 629 |
# If seed is None or negative, use -1 for all items
|
|
|
|
| 664 |
return actual_seed_list, seed_value_for_ui
|
| 665 |
|
| 666 |
def prepare_metadata(self, bpm, key_scale, time_signature):
|
| 667 |
+
"""Build metadata dict - use "N/A" as default for empty fields."""
|
| 668 |
+
return self._build_metadata_dict(bpm, key_scale, time_signature)
|
| 669 |
+
|
| 670 |
+
def is_silence(self, audio):
|
| 671 |
+
return torch.all(audio.abs() < 1e-6)
|
| 672 |
+
|
| 673 |
+
def _get_project_root(self) -> str:
|
| 674 |
+
"""Get project root directory path."""
|
| 675 |
+
current_file = os.path.abspath(__file__)
|
| 676 |
+
return os.path.dirname(os.path.dirname(current_file))
|
| 677 |
+
|
| 678 |
+
def _get_vae_dtype(self, device: Optional[str] = None) -> torch.dtype:
|
| 679 |
+
"""Get VAE dtype based on device."""
|
| 680 |
+
device = device or self.device
|
| 681 |
+
return torch.bfloat16 if device in ["cuda", "xpu"] else self.dtype
|
| 682 |
+
|
| 683 |
+
def _format_instruction(self, instruction: str) -> str:
|
| 684 |
+
"""Format instruction to ensure it ends with colon."""
|
| 685 |
+
if not instruction.endswith(":"):
|
| 686 |
+
instruction = instruction + ":"
|
| 687 |
+
return instruction
|
| 688 |
+
|
| 689 |
+
def _normalize_audio_to_stereo_48k(self, audio: torch.Tensor, sr: int) -> torch.Tensor:
|
| 690 |
+
"""
|
| 691 |
+
Normalize audio to stereo 48kHz format.
|
| 692 |
+
|
| 693 |
+
Args:
|
| 694 |
+
audio: Audio tensor [channels, samples] or [samples]
|
| 695 |
+
sr: Sample rate
|
| 696 |
+
|
| 697 |
+
Returns:
|
| 698 |
+
Normalized audio tensor [2, samples] at 48kHz
|
| 699 |
+
"""
|
| 700 |
+
# Convert to stereo (duplicate channel if mono)
|
| 701 |
+
if audio.shape[0] == 1:
|
| 702 |
+
audio = torch.cat([audio, audio], dim=0)
|
| 703 |
+
|
| 704 |
+
# Keep only first 2 channels
|
| 705 |
+
audio = audio[:2]
|
| 706 |
+
|
| 707 |
+
# Resample to 48kHz if needed
|
| 708 |
+
if sr != 48000:
|
| 709 |
+
audio = torchaudio.transforms.Resample(sr, 48000)(audio)
|
| 710 |
+
|
| 711 |
+
# Clamp values to [-1.0, 1.0]
|
| 712 |
+
audio = torch.clamp(audio, -1.0, 1.0)
|
| 713 |
+
|
| 714 |
+
return audio
|
| 715 |
+
|
| 716 |
+
def _normalize_audio_code_hints(self, audio_code_hints: Optional[Union[str, List[str]]], batch_size: int) -> List[Optional[str]]:
|
| 717 |
+
"""Normalize audio_code_hints to list of correct length."""
|
| 718 |
+
if audio_code_hints is None:
|
| 719 |
+
normalized = [None] * batch_size
|
| 720 |
+
elif isinstance(audio_code_hints, str):
|
| 721 |
+
normalized = [audio_code_hints] * batch_size
|
| 722 |
+
elif len(audio_code_hints) == 1 and batch_size > 1:
|
| 723 |
+
normalized = audio_code_hints * batch_size
|
| 724 |
+
elif len(audio_code_hints) != batch_size:
|
| 725 |
+
# Pad or truncate to match batch_size
|
| 726 |
+
normalized = list(audio_code_hints[:batch_size])
|
| 727 |
+
while len(normalized) < batch_size:
|
| 728 |
+
normalized.append(None)
|
| 729 |
+
else:
|
| 730 |
+
normalized = list(audio_code_hints)
|
| 731 |
+
|
| 732 |
+
# Clean up: convert empty strings to None
|
| 733 |
+
normalized = [hint if isinstance(hint, str) and hint.strip() else None for hint in normalized]
|
| 734 |
+
return normalized
|
| 735 |
+
|
| 736 |
+
def _normalize_instructions(self, instructions: Optional[Union[str, List[str]]], batch_size: int, default: Optional[str] = None) -> List[str]:
|
| 737 |
+
"""Normalize instructions to list of correct length."""
|
| 738 |
+
if instructions is None:
|
| 739 |
+
default_instruction = default or DEFAULT_DIT_INSTRUCTION
|
| 740 |
+
return [default_instruction] * batch_size
|
| 741 |
+
elif isinstance(instructions, str):
|
| 742 |
+
return [instructions] * batch_size
|
| 743 |
+
elif len(instructions) == 1:
|
| 744 |
+
return instructions * batch_size
|
| 745 |
+
elif len(instructions) != batch_size:
|
| 746 |
+
# Pad or truncate to match batch_size
|
| 747 |
+
normalized = list(instructions[:batch_size])
|
| 748 |
+
default_instruction = default or DEFAULT_DIT_INSTRUCTION
|
| 749 |
+
while len(normalized) < batch_size:
|
| 750 |
+
normalized.append(default_instruction)
|
| 751 |
+
return normalized
|
| 752 |
+
else:
|
| 753 |
+
return list(instructions)
|
| 754 |
+
|
| 755 |
+
def _format_lyrics(self, lyrics: str, language: str) -> str:
|
| 756 |
+
"""Format lyrics text with language header."""
|
| 757 |
+
return f"# Languages\n{language}\n\n# Lyric\n{lyrics}<|endoftext|>"
|
| 758 |
+
|
| 759 |
+
def _pad_sequences(self, sequences: List[torch.Tensor], max_length: int, pad_value: int = 0) -> torch.Tensor:
|
| 760 |
+
"""Pad sequences to same length."""
|
| 761 |
+
return torch.stack([
|
| 762 |
+
torch.nn.functional.pad(seq, (0, max_length - len(seq)), 'constant', pad_value)
|
| 763 |
+
for seq in sequences
|
| 764 |
+
])
|
| 765 |
+
|
| 766 |
+
def _extract_caption_and_language(self, metas: List[Union[str, Dict[str, Any]]], captions: List[str], vocal_languages: List[str]) -> Tuple[List[str], List[str]]:
|
| 767 |
+
"""Extract caption and language from metas with fallback to provided values."""
|
| 768 |
+
actual_captions = list(captions)
|
| 769 |
+
actual_languages = list(vocal_languages)
|
| 770 |
+
|
| 771 |
+
for i, meta in enumerate(metas):
|
| 772 |
+
if i >= len(actual_captions):
|
| 773 |
+
break
|
| 774 |
+
|
| 775 |
+
meta_dict = None
|
| 776 |
+
if isinstance(meta, str):
|
| 777 |
+
parsed = self._parse_metas([meta])
|
| 778 |
+
if parsed and isinstance(parsed[0], dict):
|
| 779 |
+
meta_dict = parsed[0]
|
| 780 |
+
elif isinstance(meta, dict):
|
| 781 |
+
meta_dict = meta
|
| 782 |
+
|
| 783 |
+
if meta_dict:
|
| 784 |
+
if 'caption' in meta_dict and meta_dict['caption']:
|
| 785 |
+
actual_captions[i] = str(meta_dict['caption'])
|
| 786 |
+
if 'language' in meta_dict and meta_dict['language']:
|
| 787 |
+
actual_languages[i] = str(meta_dict['language'])
|
| 788 |
+
|
| 789 |
+
return actual_captions, actual_languages
|
| 790 |
+
|
| 791 |
+
def _encode_audio_to_latents(self, audio: torch.Tensor) -> torch.Tensor:
|
| 792 |
+
"""
|
| 793 |
+
Encode audio to latents using VAE.
|
| 794 |
+
|
| 795 |
+
Args:
|
| 796 |
+
audio: Audio tensor [channels, samples] or [batch, channels, samples]
|
| 797 |
+
|
| 798 |
+
Returns:
|
| 799 |
+
Latents tensor [T, D] or [batch, T, D]
|
| 800 |
+
"""
|
| 801 |
+
# Ensure batch dimension
|
| 802 |
+
if audio.dim() == 2:
|
| 803 |
+
audio = audio.unsqueeze(0)
|
| 804 |
+
|
| 805 |
+
# Ensure input is in VAE's dtype
|
| 806 |
+
vae_input = audio.to(self.device).to(self.vae.dtype)
|
| 807 |
+
|
| 808 |
+
# Encode to latents
|
| 809 |
+
with torch.no_grad():
|
| 810 |
+
latents = self.vae.encode(vae_input).latent_dist.sample()
|
| 811 |
+
|
| 812 |
+
# Cast back to model dtype
|
| 813 |
+
latents = latents.to(self.dtype)
|
| 814 |
+
|
| 815 |
+
# Transpose: [batch, d, T] -> [batch, T, d]
|
| 816 |
+
latents = latents.transpose(1, 2)
|
| 817 |
+
|
| 818 |
+
# Remove batch dimension if input didn't have it
|
| 819 |
+
if audio.dim() == 2:
|
| 820 |
+
latents = latents.squeeze(0)
|
| 821 |
+
|
| 822 |
+
return latents
|
| 823 |
+
|
| 824 |
+
def _build_metadata_dict(self, bpm: Optional[Union[int, str]], key_scale: str, time_signature: str, duration: Optional[float] = None) -> Dict[str, Any]:
|
| 825 |
+
"""
|
| 826 |
+
Build metadata dictionary with default values.
|
| 827 |
+
|
| 828 |
+
Args:
|
| 829 |
+
bpm: BPM value (optional)
|
| 830 |
+
key_scale: Key/scale string
|
| 831 |
+
time_signature: Time signature string
|
| 832 |
+
duration: Duration in seconds (optional)
|
| 833 |
+
|
| 834 |
+
Returns:
|
| 835 |
+
Metadata dictionary
|
| 836 |
+
"""
|
| 837 |
metadata_dict = {}
|
| 838 |
if bpm:
|
| 839 |
metadata_dict["bpm"] = bpm
|
|
|
|
| 849 |
metadata_dict["timesignature"] = time_signature
|
| 850 |
else:
|
| 851 |
metadata_dict["timesignature"] = "N/A"
|
| 852 |
+
|
| 853 |
+
# Add duration if provided
|
| 854 |
+
if duration is not None:
|
| 855 |
+
metadata_dict["duration"] = f"{int(duration)} seconds"
|
| 856 |
+
|
| 857 |
return metadata_dict
|
|
|
|
|
|
|
|
|
|
| 858 |
|
| 859 |
def generate_instruction(
|
| 860 |
self,
|
|
|
|
| 901 |
# Load audio file
|
| 902 |
audio, sr = torchaudio.load(audio_file)
|
| 903 |
|
| 904 |
+
logger.debug(f"[process_reference_audio] Reference audio shape: {audio.shape}")
|
| 905 |
+
logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
|
| 906 |
+
logger.debug(f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / 48000.0} seconds")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 907 |
|
| 908 |
+
# Normalize to stereo 48kHz
|
| 909 |
+
audio = self._normalize_audio_to_stereo_48k(audio, sr)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 910 |
|
| 911 |
is_silence = self.is_silence(audio)
|
| 912 |
if is_silence:
|
|
|
|
| 945 |
return audio
|
| 946 |
|
| 947 |
except Exception as e:
|
| 948 |
+
logger.exception("[process_reference_audio] Error processing reference audio")
|
| 949 |
return None
|
| 950 |
|
| 951 |
def process_src_audio(self, audio_file) -> Optional[torch.Tensor]:
|
|
|
|
| 956 |
# Load audio file
|
| 957 |
audio, sr = torchaudio.load(audio_file)
|
| 958 |
|
| 959 |
+
# Normalize to stereo 48kHz
|
| 960 |
+
audio = self._normalize_audio_to_stereo_48k(audio, sr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 961 |
|
| 962 |
return audio
|
| 963 |
|
| 964 |
except Exception as e:
|
| 965 |
+
logger.exception("[process_src_audio] Error processing source audio")
|
| 966 |
return None
|
| 967 |
|
| 968 |
def convert_src_audio_to_codes(self, audio_file) -> str:
|
|
|
|
| 990 |
# Encode audio to latents using VAE
|
| 991 |
with torch.no_grad():
|
| 992 |
with self._load_model_context("vae"):
|
|
|
|
|
|
|
|
|
|
| 993 |
# Check if audio is silence
|
| 994 |
+
if self.is_silence(processed_audio.unsqueeze(0)):
|
| 995 |
return "❌ Audio file appears to be silent"
|
| 996 |
|
| 997 |
+
# Encode to latents using helper method
|
| 998 |
+
latents = self._encode_audio_to_latents(processed_audio) # [T, d]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 999 |
|
| 1000 |
# Create attention mask for latents
|
| 1001 |
attention_mask = torch.ones(latents.shape[0], dtype=torch.bool, device=self.device)
|
|
|
|
| 1020 |
|
| 1021 |
except Exception as e:
|
| 1022 |
error_msg = f"❌ Error converting audio to codes: {str(e)}\n{traceback.format_exc()}"
|
| 1023 |
+
logger.exception("[convert_src_audio_to_codes] Error converting audio to codes")
|
| 1024 |
return error_msg
|
| 1025 |
|
| 1026 |
def prepare_batch_data(
|
|
|
|
| 1049 |
calculated_duration = audio_duration
|
| 1050 |
|
| 1051 |
# Build metadata dict - use "N/A" as default for empty fields
|
| 1052 |
+
metadata_dict = self._build_metadata_dict(bpm, key_scale, time_signature, calculated_duration)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1053 |
|
| 1054 |
# Format metadata - inference service accepts dict and will convert to string
|
| 1055 |
# Create a copy for each batch item (in case we modify it)
|
|
|
|
| 1085 |
target_wavs = torch.zeros(2, frames)
|
| 1086 |
return target_wavs
|
| 1087 |
except Exception as e:
|
| 1088 |
+
logger.exception("[create_target_wavs] Error creating target audio")
|
| 1089 |
# Fallback to 30 seconds if error
|
| 1090 |
return torch.zeros(2, 30 * 48000)
|
| 1091 |
|
|
|
|
| 1266 |
"""
|
| 1267 |
batch_size = len(captions)
|
| 1268 |
|
| 1269 |
+
# Normalize audio_code_hints to batch list
|
| 1270 |
+
audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1271 |
|
| 1272 |
for ii, refer_audio_list in enumerate(refer_audios):
|
| 1273 |
if isinstance(refer_audio_list, list):
|
|
|
|
| 1279 |
if vocal_languages is None:
|
| 1280 |
vocal_languages = self._create_fallback_vocal_languages(batch_size)
|
| 1281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1282 |
# Parse metas with fallbacks
|
| 1283 |
parsed_metas = self._parse_metas(metas)
|
| 1284 |
|
|
|
|
| 1312 |
expected_latent_length = current_wav.shape[-1] // 1920
|
| 1313 |
target_latent = self.silence_latent[0, :expected_latent_length, :]
|
| 1314 |
else:
|
| 1315 |
+
# Encode using helper method
|
| 1316 |
logger.info(f"[generate_music] Encoding target audio to latents for item {i}...")
|
| 1317 |
+
target_latent = self._encode_audio_to_latents(current_wav.squeeze(0)) # Remove batch dim for helper
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1318 |
target_latents_list.append(target_latent)
|
| 1319 |
latent_lengths.append(target_latent.shape[0])
|
| 1320 |
|
|
|
|
| 1353 |
|
| 1354 |
# Process instructions early so we can use them for task type detection
|
| 1355 |
# Use custom instructions if provided, otherwise use default
|
| 1356 |
+
instructions = self._normalize_instructions(instructions, batch_size, DEFAULT_DIT_INSTRUCTION)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1357 |
|
| 1358 |
# Generate chunk_masks and spans based on repainting parameters
|
| 1359 |
# Also determine if this is a cover task (target audio provided without repainting)
|
|
|
|
| 1502 |
else:
|
| 1503 |
precomputed_lm_hints_25Hz = None
|
| 1504 |
|
| 1505 |
+
# Extract caption and language from metas if available (from LM CoT output)
|
| 1506 |
+
# Fallback to user-provided values if not in metas
|
| 1507 |
+
actual_captions, actual_languages = self._extract_caption_and_language(parsed_metas, captions, vocal_languages)
|
| 1508 |
+
|
| 1509 |
# Format text_inputs
|
| 1510 |
text_inputs = []
|
| 1511 |
text_token_idss = []
|
|
|
|
| 1515 |
|
| 1516 |
for i in range(batch_size):
|
| 1517 |
# Use custom instruction for this batch item
|
| 1518 |
+
instruction = self._format_instruction(instructions[i] if i < len(instructions) else DEFAULT_DIT_INSTRUCTION)
|
| 1519 |
+
|
| 1520 |
+
actual_caption = actual_captions[i]
|
| 1521 |
+
actual_language = actual_languages[i]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1522 |
|
| 1523 |
# Format text prompt with custom instruction (using LM-generated caption if available)
|
| 1524 |
text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
|
|
|
|
| 1535 |
text_attention_mask = text_inputs_dict.attention_mask[0].bool()
|
| 1536 |
|
| 1537 |
# Format and tokenize lyrics (using LM-generated language if available)
|
| 1538 |
+
lyrics_text = self._format_lyrics(lyrics[i], actual_language)
|
| 1539 |
lyrics_inputs_dict = self.text_tokenizer(
|
| 1540 |
lyrics_text,
|
| 1541 |
padding="longest",
|
|
|
|
| 1557 |
|
| 1558 |
# Pad tokenized sequences
|
| 1559 |
max_text_length = max(len(seq) for seq in text_token_idss)
|
| 1560 |
+
padded_text_token_idss = self._pad_sequences(text_token_idss, max_text_length, self.text_tokenizer.pad_token_id)
|
| 1561 |
+
padded_text_attention_masks = self._pad_sequences(text_attention_masks, max_text_length, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1562 |
|
| 1563 |
max_lyric_length = max(len(seq) for seq in lyric_token_idss)
|
| 1564 |
+
padded_lyric_token_idss = self._pad_sequences(lyric_token_idss, max_lyric_length, self.text_tokenizer.pad_token_id)
|
| 1565 |
+
padded_lyric_attention_masks = self._pad_sequences(lyric_attention_masks, max_lyric_length, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1566 |
|
| 1567 |
padded_non_cover_text_input_ids = None
|
| 1568 |
padded_non_cover_text_attention_masks = None
|
|
|
|
| 1571 |
non_cover_text_attention_masks = []
|
| 1572 |
for i in range(batch_size):
|
| 1573 |
# Use custom instruction for this batch item
|
| 1574 |
+
instruction = self._format_instruction(DEFAULT_DIT_INSTRUCTION)
|
| 1575 |
|
| 1576 |
# Extract caption from metas if available (from LM CoT output)
|
| 1577 |
+
actual_caption = actual_captions[i]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1578 |
|
| 1579 |
# Format text prompt with custom instruction (using LM-generated caption if available)
|
| 1580 |
text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
|
|
|
|
| 1592 |
non_cover_text_input_ids.append(text_token_ids)
|
| 1593 |
non_cover_text_attention_masks.append(non_cover_text_attention_mask)
|
| 1594 |
|
| 1595 |
+
padded_non_cover_text_input_ids = self._pad_sequences(non_cover_text_input_ids, max_text_length, self.text_tokenizer.pad_token_id)
|
| 1596 |
+
padded_non_cover_text_attention_masks = self._pad_sequences(non_cover_text_attention_masks, max_text_length, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1597 |
|
| 1598 |
if audio_cover_strength < 1.0:
|
| 1599 |
assert padded_non_cover_text_input_ids is not None, "When audio_cover_strength < 1.0, padded_non_cover_text_input_ids must not be None"
|
|
|
|
| 1827 |
if self.config.is_turbo:
|
| 1828 |
# Limit inference steps to maximum 8
|
| 1829 |
if infer_steps > 8:
|
| 1830 |
+
logger.warning(f"[service_generate] dmd_gan version: infer_steps {infer_steps} exceeds maximum 8, clamping to 8")
|
| 1831 |
infer_steps = 8
|
| 1832 |
# CFG parameters are not adjustable for dmd_gan (they will be ignored)
|
| 1833 |
# Note: guidance_scale, cfg_interval_start, cfg_interval_end are still passed but may be ignored by the model
|
|
|
|
| 1850 |
if isinstance(repainting_end, (int, float)):
|
| 1851 |
repainting_end = [repainting_end]
|
| 1852 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1853 |
# Get batch size from captions
|
| 1854 |
batch_size = len(captions)
|
| 1855 |
|
| 1856 |
+
# Normalize instructions and audio_code_hints to match batch size
|
| 1857 |
+
instructions = self._normalize_instructions(instructions, batch_size, DEFAULT_DIT_INSTRUCTION) if instructions is not None else None
|
| 1858 |
+
audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size) if audio_code_hints is not None else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1859 |
|
| 1860 |
# Convert seed to list format
|
| 1861 |
if seed is None:
|
|
|
|
| 1952 |
logger.info("[service_generate] Generating audio...")
|
| 1953 |
with self._load_model_context("model"):
|
| 1954 |
outputs = self.model.generate_audio(**generate_kwargs)
|
| 1955 |
+
|
| 1956 |
+
# Add intermediate information to outputs for extra_outputs
|
| 1957 |
+
outputs["src_latents"] = src_latents
|
| 1958 |
+
outputs["target_latents_input"] = target_latents # Input target latents (before generation)
|
| 1959 |
+
outputs["chunk_masks"] = chunk_mask
|
| 1960 |
+
outputs["spans"] = spans
|
| 1961 |
+
outputs["latent_masks"] = batch.get("latent_masks") # Latent masks for valid length
|
| 1962 |
+
|
| 1963 |
return outputs
|
| 1964 |
|
| 1965 |
def tiled_decode(self, latents, chunk_size=512, overlap=64):
|
|
|
|
| 2055 |
use_adg: bool = False,
|
| 2056 |
cfg_interval_start: float = 0.0,
|
| 2057 |
cfg_interval_end: float = 1.0,
|
|
|
|
|
|
|
| 2058 |
use_tiled_decode: bool = True,
|
| 2059 |
progress=None
|
| 2060 |
+
) -> Dict[str, Any]:
|
| 2061 |
"""
|
| 2062 |
Main interface for music generation
|
| 2063 |
|
| 2064 |
Returns:
|
| 2065 |
+
Dictionary containing:
|
| 2066 |
+
- audios: List of audio dictionaries with path, key, params
|
| 2067 |
+
- generation_info: Markdown-formatted generation information
|
| 2068 |
+
- status_message: Status message
|
| 2069 |
+
- extra_outputs: Dictionary with latents, masks, time_costs, etc.
|
| 2070 |
+
- success: Whether generation completed successfully
|
| 2071 |
+
- error: Error message if generation failed
|
| 2072 |
"""
|
| 2073 |
if progress is None:
|
| 2074 |
def progress(*args, **kwargs):
|
| 2075 |
pass
|
| 2076 |
|
| 2077 |
if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
|
| 2078 |
+
return {
|
| 2079 |
+
"audios": [],
|
| 2080 |
+
"status_message": "❌ Model not fully initialized. Please initialize all components first.",
|
| 2081 |
+
"extra_outputs": {},
|
| 2082 |
+
"success": False,
|
| 2083 |
+
"error": "Model not fully initialized",
|
| 2084 |
+
}
|
| 2085 |
|
| 2086 |
def _has_audio_codes(v: Union[str, List[str]]) -> bool:
|
| 2087 |
if isinstance(v, list):
|
|
|
|
| 2100 |
|
| 2101 |
logger.info("[generate_music] Starting generation...")
|
| 2102 |
if progress:
|
| 2103 |
+
progress(0.51, desc="Preparing inputs...")
|
| 2104 |
logger.info("[generate_music] Preparing inputs...")
|
| 2105 |
|
| 2106 |
# Reset offload cost
|
|
|
|
| 2122 |
repainting_end = None
|
| 2123 |
|
| 2124 |
try:
|
|
|
|
|
|
|
| 2125 |
# 1. Process reference audio
|
| 2126 |
refer_audios = None
|
| 2127 |
if reference_audio is not None:
|
|
|
|
| 2173 |
can_use_repainting
|
| 2174 |
)
|
| 2175 |
|
| 2176 |
+
progress(0.52, desc=f"Generating music (batch size: {actual_batch_size})...")
|
| 2177 |
|
| 2178 |
# Prepare audio_code_hints - use if audio_code_string is provided
|
| 2179 |
# This works for both text2music (auto-switched to cover) and cover tasks
|
|
|
|
| 2210 |
pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
|
| 2211 |
time_costs = outputs["time_costs"]
|
| 2212 |
time_costs["offload_time_cost"] = self.current_offload_cost
|
| 2213 |
+
logger.debug(f"[generate_music] pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}")
|
| 2214 |
+
logger.debug(f"[generate_music] time_costs: {time_costs}")
|
| 2215 |
if progress:
|
| 2216 |
progress(0.8, desc="Decoding audio...")
|
| 2217 |
logger.info("[generate_music] Decoding latents with VAE...")
|
|
|
|
| 2240 |
# Update offload cost one last time to include VAE offloading
|
| 2241 |
time_costs["offload_time_cost"] = self.current_offload_cost
|
| 2242 |
|
| 2243 |
+
logger.info("[generate_music] VAE decode completed. Preparing audio tensors...")
|
| 2244 |
if progress:
|
| 2245 |
+
progress(0.99, desc="Preparing audio data...")
|
| 2246 |
|
| 2247 |
+
# Prepare audio tensors (no file I/O here, no UUID generation)
|
| 2248 |
+
# pred_wavs is already [batch, channels, samples] format
|
| 2249 |
+
# Move to CPU and convert to float32 for return
|
| 2250 |
+
audio_tensors = []
|
| 2251 |
|
|
|
|
|
|
|
| 2252 |
for i in range(actual_batch_size):
|
| 2253 |
+
# Extract audio tensor: [channels, samples] format, CPU, float32
|
| 2254 |
+
audio_tensor = pred_wavs[i].cpu().float()
|
| 2255 |
+
audio_tensors.append(audio_tensor)
|
| 2256 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2257 |
status_message = f"✅ Generation completed successfully!"
|
| 2258 |
+
logger.info(f"[generate_music] Done! Generated {len(audio_tensors)} audio tensors.")
|
| 2259 |
+
|
| 2260 |
+
# Extract intermediate information from outputs
|
| 2261 |
+
src_latents = outputs.get("src_latents") # [batch, T, D]
|
| 2262 |
+
target_latents_input = outputs.get("target_latents_input") # [batch, T, D]
|
| 2263 |
+
chunk_masks = outputs.get("chunk_masks") # [batch, T]
|
| 2264 |
+
spans = outputs.get("spans", []) # List of tuples
|
| 2265 |
+
latent_masks = outputs.get("latent_masks") # [batch, T]
|
| 2266 |
+
|
| 2267 |
+
# Move latents to CPU to save memory (they can be large)
|
| 2268 |
+
extra_outputs = {
|
| 2269 |
+
"pred_latents": pred_latents.cpu() if pred_latents is not None else None,
|
| 2270 |
+
"target_latents": target_latents_input.cpu() if target_latents_input is not None else None,
|
| 2271 |
+
"src_latents": src_latents.cpu() if src_latents is not None else None,
|
| 2272 |
+
"chunk_masks": chunk_masks.cpu() if chunk_masks is not None else None,
|
| 2273 |
+
"latent_masks": latent_masks.cpu() if latent_masks is not None else None,
|
| 2274 |
+
"spans": spans,
|
| 2275 |
+
"time_costs": time_costs,
|
| 2276 |
+
"seed_value": seed_value_for_ui,
|
| 2277 |
+
}
|
| 2278 |
+
|
| 2279 |
+
# Build audios list with tensor data (no file paths, no UUIDs, handled outside)
|
| 2280 |
+
audios = []
|
| 2281 |
+
for idx, audio_tensor in enumerate(audio_tensors):
|
| 2282 |
+
audio_dict = {
|
| 2283 |
+
"tensor": audio_tensor, # torch.Tensor [channels, samples], CPU, float32
|
| 2284 |
+
"sample_rate": self.sample_rate,
|
| 2285 |
+
}
|
| 2286 |
+
audios.append(audio_dict)
|
| 2287 |
+
|
| 2288 |
+
return {
|
| 2289 |
+
"audios": audios,
|
| 2290 |
+
"status_message": status_message,
|
| 2291 |
+
"extra_outputs": extra_outputs,
|
| 2292 |
+
"success": True,
|
| 2293 |
+
"error": None,
|
| 2294 |
+
}
|
| 2295 |
|
| 2296 |
except Exception as e:
|
| 2297 |
error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
|
| 2298 |
+
logger.exception("[generate_music] Generation failed")
|
| 2299 |
+
return {
|
| 2300 |
+
"audios": [],
|
| 2301 |
+
"status_message": error_msg,
|
| 2302 |
+
"extra_outputs": {},
|
| 2303 |
+
"success": False,
|
| 2304 |
+
"error": str(e),
|
| 2305 |
+
}
|
acestep/inference.py
CHANGED
|
@@ -7,105 +7,100 @@ backward-compatible Gradio UI support.
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
import math
|
|
|
|
|
|
|
| 10 |
from typing import Optional, Union, List, Dict, Any, Tuple
|
| 11 |
from dataclasses import dataclass, field, asdict
|
| 12 |
from loguru import logger
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
@dataclass
|
| 17 |
-
class
|
| 18 |
-
"""Configuration for music generation.
|
| 19 |
|
| 20 |
Attributes:
|
| 21 |
# Text Inputs
|
| 22 |
-
caption:
|
| 23 |
-
lyrics: Lyrics
|
|
|
|
| 24 |
|
| 25 |
# Music Metadata
|
| 26 |
-
bpm:
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
vocal_language: Language code for vocals
|
| 30 |
-
|
| 31 |
|
| 32 |
# Generation Parameters
|
| 33 |
-
inference_steps: Number of
|
| 34 |
-
guidance_scale:
|
| 35 |
-
|
| 36 |
-
seed: Random seed for reproducibility (-1 for random)
|
| 37 |
-
batch_size: Number of samples to generate (1-8)
|
| 38 |
|
| 39 |
# Advanced DiT Parameters
|
| 40 |
-
use_adg:
|
| 41 |
-
cfg_interval_start:
|
| 42 |
-
cfg_interval_end:
|
| 43 |
-
audio_format: Output audio format ("mp3", "wav", "flac")
|
| 44 |
|
| 45 |
# Task-Specific Parameters
|
| 46 |
-
task_type:
|
| 47 |
-
reference_audio: Path to reference audio file
|
| 48 |
-
src_audio: Path to source audio file
|
| 49 |
-
|
| 50 |
-
repainting_start:
|
| 51 |
-
repainting_end:
|
| 52 |
-
audio_cover_strength: Strength of audio
|
| 53 |
-
instruction:
|
| 54 |
|
| 55 |
-
# 5Hz Language Model Parameters
|
| 56 |
-
|
| 57 |
-
lm_temperature:
|
| 58 |
-
lm_cfg_scale:
|
| 59 |
-
lm_top_k:
|
| 60 |
-
lm_top_p:
|
| 61 |
-
lm_negative_prompt: Negative prompt for
|
| 62 |
-
use_cot_metas:
|
| 63 |
-
use_cot_caption:
|
| 64 |
-
use_cot_language:
|
| 65 |
-
is_format_caption: Whether caption is already formatted
|
| 66 |
-
constrained_decoding_debug: Enable debug logging for constrained decoding
|
| 67 |
-
|
| 68 |
-
# Batch LM Generation
|
| 69 |
-
allow_lm_batch: Allow batch LM code generation (faster for batch_size >= 2)
|
| 70 |
-
lm_batch_chunk_size: Maximum batch size per LM inference chunk (GPU memory constraint)
|
| 71 |
"""
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
# Text Inputs
|
| 74 |
caption: str = ""
|
| 75 |
lyrics: str = ""
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
key_scale: str = ""
|
| 80 |
-
time_signature: str = ""
|
| 81 |
vocal_language: str = "unknown"
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
| 85 |
inference_steps: int = 8
|
| 86 |
-
guidance_scale: float = 7.0
|
| 87 |
-
use_random_seed: bool = True
|
| 88 |
seed: int = -1
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
# Advanced DiT Parameters
|
| 92 |
use_adg: bool = False
|
| 93 |
cfg_interval_start: float = 0.0
|
| 94 |
cfg_interval_end: float = 1.0
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
# Task-Specific Parameters
|
| 98 |
-
task_type: str = "text2music"
|
| 99 |
-
reference_audio: Optional[str] = None
|
| 100 |
-
src_audio: Optional[str] = None
|
| 101 |
-
audio_code_string: Union[str, List[str]] = ""
|
| 102 |
repainting_start: float = 0.0
|
| 103 |
repainting_end: float = -1
|
| 104 |
audio_cover_strength: float = 1.0
|
| 105 |
-
|
| 106 |
-
|
| 107 |
# 5Hz Language Model Parameters
|
| 108 |
-
|
| 109 |
lm_temperature: float = 0.85
|
| 110 |
lm_cfg_scale: float = 2.0
|
| 111 |
lm_top_k: int = 0
|
|
@@ -113,13 +108,50 @@ class GenerationConfig:
|
|
| 113 |
lm_negative_prompt: str = "NO USER INPUT"
|
| 114 |
use_cot_metas: bool = True
|
| 115 |
use_cot_caption: bool = True
|
|
|
|
| 116 |
use_cot_language: bool = True
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
allow_lm_batch: bool = False
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
|
| 125 |
@dataclass
|
|
@@ -128,801 +160,461 @@ class GenerationResult:
|
|
| 128 |
|
| 129 |
Attributes:
|
| 130 |
# Audio Outputs
|
| 131 |
-
|
| 132 |
-
first_audio: Path to first generated audio (backward compatibility)
|
| 133 |
-
second_audio: Path to second generated audio (backward compatibility)
|
| 134 |
-
|
| 135 |
-
# Generation Information
|
| 136 |
-
generation_info: Markdown-formatted generation information
|
| 137 |
status_message: Status message from generation
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
# LM-Generated Metadata (if applicable)
|
| 141 |
-
lm_metadata: Metadata generated by language model (dict or None)
|
| 142 |
-
|
| 143 |
-
# Audio-Text Alignment Scores (if available)
|
| 144 |
-
align_score_1: First alignment score
|
| 145 |
-
align_text_1: First alignment text description
|
| 146 |
-
align_plot_1: First alignment plot image
|
| 147 |
-
align_score_2: Second alignment score
|
| 148 |
-
align_text_2: Second alignment text description
|
| 149 |
-
align_plot_2: Second alignment plot image
|
| 150 |
-
|
| 151 |
-
# Success Status
|
| 152 |
success: Whether generation completed successfully
|
| 153 |
error: Error message if generation failed
|
| 154 |
"""
|
| 155 |
-
|
| 156 |
# Audio Outputs
|
| 157 |
-
|
| 158 |
-
first_audio: Optional[str] = None
|
| 159 |
-
second_audio: Optional[str] = None
|
| 160 |
-
|
| 161 |
# Generation Information
|
| 162 |
-
generation_info: str = ""
|
| 163 |
status_message: str = ""
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
# LM-Generated Metadata
|
| 167 |
-
lm_metadata: Optional[Dict[str, Any]] = None
|
| 168 |
-
|
| 169 |
-
# Audio-Text Alignment Scores
|
| 170 |
-
align_score_1: Optional[float] = None
|
| 171 |
-
align_text_1: Optional[str] = None
|
| 172 |
-
align_plot_1: Optional[Any] = None
|
| 173 |
-
align_score_2: Optional[float] = None
|
| 174 |
-
align_text_2: Optional[str] = None
|
| 175 |
-
align_plot_2: Optional[Any] = None
|
| 176 |
-
|
| 177 |
# Success Status
|
| 178 |
success: bool = True
|
| 179 |
error: Optional[str] = None
|
| 180 |
-
|
| 181 |
def to_dict(self) -> Dict[str, Any]:
|
| 182 |
"""Convert result to dictionary for JSON serialization."""
|
| 183 |
return asdict(self)
|
| 184 |
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
def generate_music(
|
| 187 |
dit_handler,
|
| 188 |
llm_handler,
|
|
|
|
| 189 |
config: GenerationConfig,
|
|
|
|
|
|
|
| 190 |
) -> GenerationResult:
|
| 191 |
"""Generate music using ACE-Step model with optional LM reasoning.
|
| 192 |
|
| 193 |
-
This is the main inference API for music generation. It supports various task types
|
| 194 |
-
(text2music, cover, repaint, etc.) and can optionally use a 5Hz Language Model for
|
| 195 |
-
Chain-of-Thought reasoning to generate metadata and audio codes.
|
| 196 |
-
|
| 197 |
Args:
|
| 198 |
dit_handler: Initialized DiT model handler (AceStepHandler instance)
|
| 199 |
llm_handler: Initialized LLM handler (LLMHandler instance)
|
|
|
|
| 200 |
config: Generation configuration (GenerationConfig instance)
|
| 201 |
|
| 202 |
Returns:
|
| 203 |
-
GenerationResult
|
| 204 |
-
|
| 205 |
-
Example:
|
| 206 |
-
>>> from acestep.handler import AceStepHandler
|
| 207 |
-
>>> from acestep.llm_inference import LLMHandler
|
| 208 |
-
>>> from acestep.inference import GenerationConfig, generate_music
|
| 209 |
-
>>>
|
| 210 |
-
>>> # Initialize handlers
|
| 211 |
-
>>> dit_handler = AceStepHandler()
|
| 212 |
-
>>> llm_handler = LLMHandler()
|
| 213 |
-
>>> dit_handler.initialize_service(...)
|
| 214 |
-
>>> llm_handler.initialize(...)
|
| 215 |
-
>>>
|
| 216 |
-
>>> # Configure generation
|
| 217 |
-
>>> config = GenerationConfig(
|
| 218 |
-
... caption="upbeat electronic dance music",
|
| 219 |
-
... bpm=128,
|
| 220 |
-
... audio_duration=30,
|
| 221 |
-
... batch_size=2,
|
| 222 |
-
... )
|
| 223 |
-
>>>
|
| 224 |
-
>>> # Generate music
|
| 225 |
-
>>> result = generate_music(dit_handler, llm_handler, config)
|
| 226 |
-
>>> print(f"Generated {len(result.audio_paths)} audio files")
|
| 227 |
-
>>> for path in result.audio_paths:
|
| 228 |
-
... print(f"Audio: {path}")
|
| 229 |
"""
|
| 230 |
-
|
| 231 |
try:
|
| 232 |
# Phase 1: LM-based metadata and code generation (if enabled)
|
| 233 |
-
audio_code_string_to_use =
|
| 234 |
lm_generated_metadata = None
|
| 235 |
-
lm_generated_audio_codes = None
|
| 236 |
lm_generated_audio_codes_list = []
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
# Extract mutable copies of metadata (will be updated by LM if needed)
|
| 239 |
-
bpm =
|
| 240 |
-
key_scale =
|
| 241 |
-
time_signature =
|
| 242 |
-
audio_duration =
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
# LM-based Chain-of-Thought reasoning
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
|
|
|
|
|
|
| 259 |
# Build user_metadata from user-provided values
|
| 260 |
user_metadata = {}
|
| 261 |
if bpm is not None:
|
| 262 |
try:
|
| 263 |
bpm_value = float(bpm)
|
| 264 |
if bpm_value > 0:
|
| 265 |
-
user_metadata['bpm'] =
|
| 266 |
except (ValueError, TypeError):
|
| 267 |
pass
|
| 268 |
-
|
| 269 |
if key_scale and key_scale.strip():
|
| 270 |
key_scale_clean = key_scale.strip()
|
| 271 |
if key_scale_clean.lower() not in ["n/a", ""]:
|
| 272 |
user_metadata['keyscale'] = key_scale_clean
|
| 273 |
-
|
| 274 |
if time_signature and time_signature.strip():
|
| 275 |
time_sig_clean = time_signature.strip()
|
| 276 |
if time_sig_clean.lower() not in ["n/a", ""]:
|
| 277 |
user_metadata['timesignature'] = time_sig_clean
|
| 278 |
-
|
| 279 |
if audio_duration is not None:
|
| 280 |
try:
|
| 281 |
duration_value = float(audio_duration)
|
| 282 |
if duration_value > 0:
|
| 283 |
-
user_metadata['duration'] =
|
| 284 |
except (ValueError, TypeError):
|
| 285 |
pass
|
| 286 |
-
|
| 287 |
user_metadata_to_pass = user_metadata if user_metadata else None
|
| 288 |
-
|
| 289 |
-
#
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
top_k=top_k_value,
|
| 321 |
-
top_p=top_p_value,
|
| 322 |
-
user_metadata=user_metadata_to_pass,
|
| 323 |
-
use_cot_caption=config.use_cot_caption,
|
| 324 |
-
use_cot_language=config.use_cot_language,
|
| 325 |
-
is_format_caption=config.is_format_caption,
|
| 326 |
-
constrained_decoding_debug=config.constrained_decoding_debug,
|
| 327 |
-
seeds=chunk_seeds,
|
| 328 |
-
)
|
| 329 |
-
|
| 330 |
-
all_metadata_list.extend(metadata_list)
|
| 331 |
-
all_audio_codes_list.extend(audio_codes_list)
|
| 332 |
-
|
| 333 |
-
lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
|
| 334 |
-
lm_generated_audio_codes_list = all_audio_codes_list
|
| 335 |
-
audio_code_string_to_use = all_audio_codes_list
|
| 336 |
-
|
| 337 |
-
# Update metadata from LM if not provided by user
|
| 338 |
-
if lm_generated_metadata:
|
| 339 |
-
bpm, key_scale, time_signature, audio_duration = _update_metadata_from_lm(
|
| 340 |
-
lm_generated_metadata, bpm, key_scale, time_signature, audio_duration
|
| 341 |
-
)
|
| 342 |
-
|
| 343 |
-
else:
|
| 344 |
-
# Sequential LM generation (current behavior)
|
| 345 |
-
# Phase 1: Generate CoT metadata
|
| 346 |
-
phase1_start = time_module.time()
|
| 347 |
-
metadata, _, status = llm_handler.generate_with_stop_condition(
|
| 348 |
-
caption=config.caption or "",
|
| 349 |
-
lyrics=config.lyrics or "",
|
| 350 |
-
infer_type="dit",
|
| 351 |
-
temperature=config.lm_temperature,
|
| 352 |
-
cfg_scale=config.lm_cfg_scale,
|
| 353 |
-
negative_prompt=config.lm_negative_prompt,
|
| 354 |
-
top_k=top_k_value,
|
| 355 |
-
top_p=top_p_value,
|
| 356 |
-
user_metadata=user_metadata_to_pass,
|
| 357 |
-
use_cot_caption=config.use_cot_caption,
|
| 358 |
-
use_cot_language=config.use_cot_language,
|
| 359 |
-
is_format_caption=config.is_format_caption,
|
| 360 |
-
constrained_decoding_debug=config.constrained_decoding_debug,
|
| 361 |
-
)
|
| 362 |
-
lm_phase1_time = time_module.time() - phase1_start
|
| 363 |
-
logger.info(f"LM Phase 1 (CoT) completed in {lm_phase1_time:.2f}s")
|
| 364 |
-
|
| 365 |
-
# Phase 2: Generate audio codes
|
| 366 |
-
phase2_start = time_module.time()
|
| 367 |
-
metadata, audio_codes, status = llm_handler.generate_with_stop_condition(
|
| 368 |
-
caption=config.caption or "",
|
| 369 |
-
lyrics=config.lyrics or "",
|
| 370 |
-
infer_type="llm_dit",
|
| 371 |
-
temperature=config.lm_temperature,
|
| 372 |
-
cfg_scale=config.lm_cfg_scale,
|
| 373 |
-
negative_prompt=config.lm_negative_prompt,
|
| 374 |
top_k=top_k_value,
|
| 375 |
top_p=top_p_value,
|
| 376 |
user_metadata=user_metadata_to_pass,
|
| 377 |
-
use_cot_caption=
|
| 378 |
-
use_cot_language=
|
| 379 |
-
|
|
|
|
| 380 |
constrained_decoding_debug=config.constrained_decoding_debug,
|
|
|
|
|
|
|
|
|
|
| 381 |
)
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
|
|
|
| 393 |
)
|
| 394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
# Phase 2: DiT music generation
|
|
|
|
| 396 |
result = dit_handler.generate_music(
|
| 397 |
-
captions=
|
| 398 |
-
lyrics=
|
| 399 |
bpm=bpm,
|
| 400 |
key_scale=key_scale,
|
| 401 |
time_signature=time_signature,
|
| 402 |
-
vocal_language=
|
| 403 |
-
inference_steps=
|
| 404 |
-
guidance_scale=
|
| 405 |
use_random_seed=config.use_random_seed,
|
| 406 |
-
seed=config.seed
|
| 407 |
-
reference_audio=
|
| 408 |
audio_duration=audio_duration,
|
| 409 |
-
batch_size=config.batch_size,
|
| 410 |
-
src_audio=
|
| 411 |
audio_code_string=audio_code_string_to_use,
|
| 412 |
-
repainting_start=
|
| 413 |
-
repainting_end=
|
| 414 |
-
instruction=
|
| 415 |
-
audio_cover_strength=
|
| 416 |
-
task_type=
|
| 417 |
-
use_adg=
|
| 418 |
-
cfg_interval_start=
|
| 419 |
-
cfg_interval_end=
|
| 420 |
-
|
| 421 |
-
lm_temperature=config.lm_temperature,
|
| 422 |
)
|
| 423 |
-
|
| 424 |
-
#
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
return GenerationResult(
|
| 435 |
-
|
| 436 |
-
first_audio=first_audio,
|
| 437 |
-
second_audio=second_audio,
|
| 438 |
-
generation_info=generation_info,
|
| 439 |
status_message=status_message,
|
| 440 |
-
|
| 441 |
-
lm_metadata=lm_generated_metadata,
|
| 442 |
-
align_score_1=align_score_1,
|
| 443 |
-
align_text_1=align_text_1,
|
| 444 |
-
align_plot_1=align_plot_1,
|
| 445 |
-
align_score_2=align_score_2,
|
| 446 |
-
align_text_2=align_text_2,
|
| 447 |
-
align_plot_2=align_plot_2,
|
| 448 |
success=True,
|
| 449 |
error=None,
|
| 450 |
)
|
| 451 |
-
|
| 452 |
except Exception as e:
|
| 453 |
logger.exception("Music generation failed")
|
| 454 |
return GenerationResult(
|
|
|
|
|
|
|
|
|
|
| 455 |
success=False,
|
| 456 |
error=str(e),
|
| 457 |
-
generation_info=f"❌ Generation failed: {str(e)}",
|
| 458 |
-
status_message=f"Error: {str(e)}",
|
| 459 |
)
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
def _update_metadata_from_lm(
|
| 463 |
-
metadata: Dict[str, Any],
|
| 464 |
-
bpm: Optional[int],
|
| 465 |
-
key_scale: str,
|
| 466 |
-
time_signature: str,
|
| 467 |
-
audio_duration: Optional[float],
|
| 468 |
-
) -> Tuple[Optional[int], str, str, Optional[float]]:
|
| 469 |
-
"""Update metadata fields from LM output if not provided by user."""
|
| 470 |
-
|
| 471 |
-
if bpm is None and metadata.get('bpm'):
|
| 472 |
-
bpm_value = metadata.get('bpm')
|
| 473 |
-
if bpm_value not in ["N/A", ""]:
|
| 474 |
-
try:
|
| 475 |
-
bpm = int(bpm_value)
|
| 476 |
-
except (ValueError, TypeError):
|
| 477 |
-
pass
|
| 478 |
-
|
| 479 |
-
if not key_scale and metadata.get('keyscale'):
|
| 480 |
-
key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
|
| 481 |
-
if key_scale_value != "N/A":
|
| 482 |
-
key_scale = key_scale_value
|
| 483 |
-
|
| 484 |
-
if not time_signature and metadata.get('timesignature'):
|
| 485 |
-
time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
|
| 486 |
-
if time_signature_value != "N/A":
|
| 487 |
-
time_signature = time_signature_value
|
| 488 |
-
|
| 489 |
-
if audio_duration is None or audio_duration <= 0:
|
| 490 |
-
audio_duration_value = metadata.get('duration', -1)
|
| 491 |
-
if audio_duration_value not in ["N/A", ""]:
|
| 492 |
-
try:
|
| 493 |
-
audio_duration = float(audio_duration_value)
|
| 494 |
-
except (ValueError, TypeError):
|
| 495 |
-
pass
|
| 496 |
-
|
| 497 |
-
return bpm, key_scale, time_signature, audio_duration
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
def _append_lm_metadata_to_info(generation_info: str, metadata: Dict[str, Any]) -> str:
|
| 501 |
-
"""Append LM-generated metadata to generation info string."""
|
| 502 |
-
|
| 503 |
-
metadata_lines = []
|
| 504 |
-
if metadata.get('bpm'):
|
| 505 |
-
metadata_lines.append(f"- **BPM:** {metadata['bpm']}")
|
| 506 |
-
if metadata.get('caption'):
|
| 507 |
-
metadata_lines.append(f"- **Refined Caption:** {metadata['caption']}")
|
| 508 |
-
if metadata.get('duration'):
|
| 509 |
-
metadata_lines.append(f"- **Duration:** {metadata['duration']} seconds")
|
| 510 |
-
if metadata.get('keyscale'):
|
| 511 |
-
metadata_lines.append(f"- **Key Scale:** {metadata['keyscale']}")
|
| 512 |
-
if metadata.get('language'):
|
| 513 |
-
metadata_lines.append(f"- **Language:** {metadata['language']}")
|
| 514 |
-
if metadata.get('timesignature'):
|
| 515 |
-
metadata_lines.append(f"- **Time Signature:** {metadata['timesignature']}")
|
| 516 |
-
|
| 517 |
-
if metadata_lines:
|
| 518 |
-
metadata_section = "\n\n**🤖 LM-Generated Metadata:**\n" + "\n\n".join(metadata_lines)
|
| 519 |
-
return metadata_section + "\n\n" + generation_info
|
| 520 |
-
|
| 521 |
-
return generation_info
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
# ============================================================================
|
| 525 |
-
# LEGACY GRADIO UI COMPATIBILITY LAYER
|
| 526 |
-
# ============================================================================
|
| 527 |
-
|
| 528 |
-
def generate(
|
| 529 |
-
dit_handler,
|
| 530 |
-
llm_handler,
|
| 531 |
-
captions,
|
| 532 |
-
lyrics,
|
| 533 |
-
bpm,
|
| 534 |
-
key_scale,
|
| 535 |
-
time_signature,
|
| 536 |
-
vocal_language,
|
| 537 |
-
inference_steps,
|
| 538 |
-
guidance_scale,
|
| 539 |
-
random_seed_checkbox,
|
| 540 |
-
seed,
|
| 541 |
-
reference_audio,
|
| 542 |
-
audio_duration,
|
| 543 |
-
batch_size_input,
|
| 544 |
-
src_audio,
|
| 545 |
-
text2music_audio_code_string,
|
| 546 |
-
repainting_start,
|
| 547 |
-
repainting_end,
|
| 548 |
-
instruction_display_gen,
|
| 549 |
-
audio_cover_strength,
|
| 550 |
-
task_type,
|
| 551 |
-
use_adg,
|
| 552 |
-
cfg_interval_start,
|
| 553 |
-
cfg_interval_end,
|
| 554 |
-
audio_format,
|
| 555 |
-
lm_temperature,
|
| 556 |
-
think_checkbox,
|
| 557 |
-
lm_cfg_scale,
|
| 558 |
-
lm_top_k,
|
| 559 |
-
lm_top_p,
|
| 560 |
-
lm_negative_prompt,
|
| 561 |
-
use_cot_metas,
|
| 562 |
-
use_cot_caption,
|
| 563 |
-
use_cot_language,
|
| 564 |
-
is_format_caption,
|
| 565 |
-
constrained_decoding_debug,
|
| 566 |
-
allow_lm_batch,
|
| 567 |
-
lm_batch_chunk_size,
|
| 568 |
-
):
|
| 569 |
-
"""Legacy Gradio UI compatibility wrapper.
|
| 570 |
-
|
| 571 |
-
This function maintains backward compatibility with the Gradio UI.
|
| 572 |
-
For new integrations, use generate_music() with GenerationConfig instead.
|
| 573 |
-
|
| 574 |
-
Returns:
|
| 575 |
-
Tuple with 28 elements for Gradio UI component updates
|
| 576 |
-
"""
|
| 577 |
-
|
| 578 |
-
# Convert legacy parameters to new config
|
| 579 |
-
config = GenerationConfig(
|
| 580 |
-
caption=captions,
|
| 581 |
-
lyrics=lyrics,
|
| 582 |
-
bpm=bpm,
|
| 583 |
-
key_scale=key_scale,
|
| 584 |
-
time_signature=time_signature,
|
| 585 |
-
vocal_language=vocal_language,
|
| 586 |
-
audio_duration=audio_duration,
|
| 587 |
-
inference_steps=inference_steps,
|
| 588 |
-
guidance_scale=guidance_scale,
|
| 589 |
-
use_random_seed=random_seed_checkbox,
|
| 590 |
-
seed=seed,
|
| 591 |
-
batch_size=batch_size_input,
|
| 592 |
-
use_adg=use_adg,
|
| 593 |
-
cfg_interval_start=cfg_interval_start,
|
| 594 |
-
cfg_interval_end=cfg_interval_end,
|
| 595 |
-
audio_format=audio_format,
|
| 596 |
-
task_type=task_type,
|
| 597 |
-
reference_audio=reference_audio,
|
| 598 |
-
src_audio=src_audio,
|
| 599 |
-
audio_code_string=text2music_audio_code_string,
|
| 600 |
-
repainting_start=repainting_start,
|
| 601 |
-
repainting_end=repainting_end,
|
| 602 |
-
audio_cover_strength=audio_cover_strength,
|
| 603 |
-
instruction=instruction_display_gen,
|
| 604 |
-
use_llm_thinking=think_checkbox,
|
| 605 |
-
lm_temperature=lm_temperature,
|
| 606 |
-
lm_cfg_scale=lm_cfg_scale,
|
| 607 |
-
lm_top_k=lm_top_k,
|
| 608 |
-
lm_top_p=lm_top_p,
|
| 609 |
-
lm_negative_prompt=lm_negative_prompt,
|
| 610 |
-
use_cot_metas=use_cot_metas,
|
| 611 |
-
use_cot_caption=use_cot_caption,
|
| 612 |
-
use_cot_language=use_cot_language,
|
| 613 |
-
is_format_caption=is_format_caption,
|
| 614 |
-
constrained_decoding_debug=constrained_decoding_debug,
|
| 615 |
-
allow_lm_batch=allow_lm_batch,
|
| 616 |
-
lm_batch_chunk_size=lm_batch_chunk_size,
|
| 617 |
-
)
|
| 618 |
-
|
| 619 |
-
# Call new API
|
| 620 |
-
result = generate_music(dit_handler, llm_handler, config)
|
| 621 |
-
|
| 622 |
-
# Determine which codes to update in UI
|
| 623 |
-
if config.allow_lm_batch and result.lm_metadata:
|
| 624 |
-
# Batch mode: extract codes from metadata if available
|
| 625 |
-
lm_codes_list = result.lm_metadata.get('audio_codes_list', [])
|
| 626 |
-
updated_audio_codes = lm_codes_list[0] if lm_codes_list else text2music_audio_code_string
|
| 627 |
-
codes_outputs = (lm_codes_list + [""] * 8)[:8]
|
| 628 |
-
else:
|
| 629 |
-
# Single mode
|
| 630 |
-
lm_codes = result.lm_metadata.get('audio_codes', '') if result.lm_metadata else ''
|
| 631 |
-
updated_audio_codes = lm_codes if lm_codes else text2music_audio_code_string
|
| 632 |
-
codes_outputs = [""] * 8
|
| 633 |
-
|
| 634 |
-
# Prepare audio outputs (up to 8)
|
| 635 |
-
audio_outputs = (result.audio_paths + [None] * 8)[:8]
|
| 636 |
-
|
| 637 |
-
# Return tuple for Gradio UI (28 elements)
|
| 638 |
-
return (
|
| 639 |
-
audio_outputs[0], # generated_audio_1
|
| 640 |
-
audio_outputs[1], # generated_audio_2
|
| 641 |
-
audio_outputs[2], # generated_audio_3
|
| 642 |
-
audio_outputs[3], # generated_audio_4
|
| 643 |
-
audio_outputs[4], # generated_audio_5
|
| 644 |
-
audio_outputs[5], # generated_audio_6
|
| 645 |
-
audio_outputs[6], # generated_audio_7
|
| 646 |
-
audio_outputs[7], # generated_audio_8
|
| 647 |
-
result.audio_paths, # generated_audio_batch
|
| 648 |
-
result.generation_info,
|
| 649 |
-
result.status_message,
|
| 650 |
-
result.seed_value,
|
| 651 |
-
result.align_score_1,
|
| 652 |
-
result.align_text_1,
|
| 653 |
-
result.align_plot_1,
|
| 654 |
-
result.align_score_2,
|
| 655 |
-
result.align_text_2,
|
| 656 |
-
result.align_plot_2,
|
| 657 |
-
updated_audio_codes, # Update main audio codes in UI
|
| 658 |
-
codes_outputs[0], # text2music_audio_code_string_1
|
| 659 |
-
codes_outputs[1], # text2music_audio_code_string_2
|
| 660 |
-
codes_outputs[2], # text2music_audio_code_string_3
|
| 661 |
-
codes_outputs[3], # text2music_audio_code_string_4
|
| 662 |
-
codes_outputs[4], # text2music_audio_code_string_5
|
| 663 |
-
codes_outputs[5], # text2music_audio_code_string_6
|
| 664 |
-
codes_outputs[6], # text2music_audio_code_string_7
|
| 665 |
-
codes_outputs[7], # text2music_audio_code_string_8
|
| 666 |
-
result.lm_metadata, # Store metadata for "Send to src audio" buttons
|
| 667 |
-
is_format_caption, # Keep is_format_caption unchanged
|
| 668 |
-
)
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
# ============================================================================
|
| 672 |
-
# TESTING & EXAMPLES
|
| 673 |
-
# ============================================================================
|
| 674 |
-
|
| 675 |
-
if __name__ == "__main__":
|
| 676 |
-
"""
|
| 677 |
-
Test suite for the inference API.
|
| 678 |
-
Demonstrates various usage scenarios and validates functionality.
|
| 679 |
-
|
| 680 |
-
Usage:
|
| 681 |
-
python -m acestep.inference
|
| 682 |
-
"""
|
| 683 |
-
|
| 684 |
-
import os
|
| 685 |
-
import json
|
| 686 |
-
from acestep.handler import AceStepHandler
|
| 687 |
-
from acestep.llm_inference import LLMHandler
|
| 688 |
-
|
| 689 |
-
# Initialize paths
|
| 690 |
-
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 691 |
-
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 692 |
-
|
| 693 |
-
print("=" * 80)
|
| 694 |
-
print("ACE-Step Inference API Test Suite")
|
| 695 |
-
print("=" * 80)
|
| 696 |
-
|
| 697 |
-
# ========================================================================
|
| 698 |
-
# Initialize Handlers
|
| 699 |
-
# ========================================================================
|
| 700 |
-
print("\n[1/3] Initializing handlers...")
|
| 701 |
-
dit_handler = AceStepHandler(save_root="./")
|
| 702 |
-
llm_handler = LLMHandler()
|
| 703 |
-
|
| 704 |
-
try:
|
| 705 |
-
# Initialize DiT handler
|
| 706 |
-
print(" - Initializing DiT model...")
|
| 707 |
-
status_dit, success_dit = dit_handler.initialize_service(
|
| 708 |
-
project_root=project_root,
|
| 709 |
-
config_path="acestep-v15-turbo-rl",
|
| 710 |
-
device="cuda",
|
| 711 |
-
)
|
| 712 |
-
if not success_dit:
|
| 713 |
-
print(f" ❌ DiT initialization failed: {status_dit}")
|
| 714 |
-
exit(1)
|
| 715 |
-
print(f" ✓ DiT model initialized successfully")
|
| 716 |
-
|
| 717 |
-
# Initialize LLM handler
|
| 718 |
-
print(" - Initializing 5Hz LM model...")
|
| 719 |
-
status_llm, success_llm = llm_handler.initialize(
|
| 720 |
-
checkpoint_dir=checkpoint_dir,
|
| 721 |
-
lm_model_path="acestep-5Hz-lm-0.6B-v3",
|
| 722 |
-
backend="vllm",
|
| 723 |
-
device="cuda",
|
| 724 |
-
)
|
| 725 |
-
if success_llm:
|
| 726 |
-
print(f" ✓ LM model initialized successfully")
|
| 727 |
-
else:
|
| 728 |
-
print(f" ⚠ LM initialization failed (will skip LM tests): {status_llm}")
|
| 729 |
-
|
| 730 |
-
except Exception as e:
|
| 731 |
-
print(f" ❌ Initialization error: {e}")
|
| 732 |
-
exit(1)
|
| 733 |
-
|
| 734 |
-
# ========================================================================
|
| 735 |
-
# Helper Functions
|
| 736 |
-
# ========================================================================
|
| 737 |
-
def load_example_config(example_file: str) -> GenerationConfig:
|
| 738 |
-
"""Load configuration from an example JSON file."""
|
| 739 |
-
try:
|
| 740 |
-
with open(example_file, 'r', encoding='utf-8') as f:
|
| 741 |
-
data = json.load(f)
|
| 742 |
-
|
| 743 |
-
# Convert example format to GenerationConfig
|
| 744 |
-
# Handle time signature format (example uses "4" instead of "4/4")
|
| 745 |
-
time_sig = data.get('timesignature', '')
|
| 746 |
-
if time_sig and '/' not in time_sig:
|
| 747 |
-
time_sig = f"{time_sig}/4" # Default to /4 if only numerator given
|
| 748 |
-
|
| 749 |
-
config = GenerationConfig(
|
| 750 |
-
caption=data.get('caption', ''),
|
| 751 |
-
lyrics=data.get('lyrics', ''),
|
| 752 |
-
bpm=data.get('bpm'),
|
| 753 |
-
key_scale=data.get('keyscale', ''),
|
| 754 |
-
time_signature=time_sig,
|
| 755 |
-
vocal_language=data.get('language', 'unknown'),
|
| 756 |
-
audio_duration=data.get('duration'),
|
| 757 |
-
use_llm_thinking=data.get('think', False),
|
| 758 |
-
batch_size=data.get('batch_size', 1),
|
| 759 |
-
inference_steps=data.get('inference_steps', 8),
|
| 760 |
-
)
|
| 761 |
-
return config
|
| 762 |
-
|
| 763 |
-
except Exception as e:
|
| 764 |
-
print(f" ⚠ Failed to load example file: {e}")
|
| 765 |
-
return None
|
| 766 |
-
|
| 767 |
-
# ========================================================================
|
| 768 |
-
# Test Cases
|
| 769 |
-
# ========================================================================
|
| 770 |
-
test_results = []
|
| 771 |
-
|
| 772 |
-
def run_test(test_name: str, config: GenerationConfig, expected_outputs: int = 1):
|
| 773 |
-
"""Run a single test case and collect results."""
|
| 774 |
-
print(f"\n{'=' * 80}")
|
| 775 |
-
print(f"Test: {test_name}")
|
| 776 |
-
print(f"{'=' * 80}")
|
| 777 |
-
|
| 778 |
-
# Display configuration
|
| 779 |
-
print("\nConfiguration:")
|
| 780 |
-
print(f" Task Type: {config.task_type}")
|
| 781 |
-
print(f" Caption: {config.caption[:60]}..." if len(config.caption) > 60 else f" Caption: {config.caption}")
|
| 782 |
-
if config.lyrics:
|
| 783 |
-
print(f" Lyrics: {config.lyrics[:60]}..." if len(config.lyrics) > 60 else f" Lyrics: {config.lyrics}")
|
| 784 |
-
if config.bpm:
|
| 785 |
-
print(f" BPM: {config.bpm}")
|
| 786 |
-
if config.key_scale:
|
| 787 |
-
print(f" Key Scale: {config.key_scale}")
|
| 788 |
-
if config.time_signature:
|
| 789 |
-
print(f" Time Signature: {config.time_signature}")
|
| 790 |
-
if config.audio_duration:
|
| 791 |
-
print(f" Duration: {config.audio_duration}s")
|
| 792 |
-
print(f" Batch Size: {config.batch_size}")
|
| 793 |
-
print(f" Inference Steps: {config.inference_steps}")
|
| 794 |
-
print(f" Use LLM Thinking: {config.use_llm_thinking}")
|
| 795 |
-
|
| 796 |
-
# Run generation
|
| 797 |
-
print("\nGenerating...")
|
| 798 |
-
import time
|
| 799 |
-
start_time = time.time()
|
| 800 |
-
|
| 801 |
-
result = generate_music(dit_handler, llm_handler, config)
|
| 802 |
-
|
| 803 |
-
elapsed_time = time.time() - start_time
|
| 804 |
-
|
| 805 |
-
# Display results
|
| 806 |
-
print("\nResults:")
|
| 807 |
-
print(f" Success: {'✓' if result.success else '✗'}")
|
| 808 |
-
|
| 809 |
-
if result.success:
|
| 810 |
-
print(f" Generated Files: {len(result.audio_paths)}")
|
| 811 |
-
for i, path in enumerate(result.audio_paths, 1):
|
| 812 |
-
if os.path.exists(path):
|
| 813 |
-
file_size = os.path.getsize(path) / (1024 * 1024) # MB
|
| 814 |
-
print(f" [{i}] {os.path.basename(path)} ({file_size:.2f} MB)")
|
| 815 |
-
else:
|
| 816 |
-
print(f" [{i}] {os.path.basename(path)} (file not found)")
|
| 817 |
-
|
| 818 |
-
print(f" Seed: {result.seed_value}")
|
| 819 |
-
print(f" Generation Time: {elapsed_time:.2f}s")
|
| 820 |
-
|
| 821 |
-
# Display LM metadata if available
|
| 822 |
-
if result.lm_metadata:
|
| 823 |
-
print(f"\n LM-Generated Metadata:")
|
| 824 |
-
for key, value in result.lm_metadata.items():
|
| 825 |
-
if key not in ['audio_codes', 'audio_codes_list']: # Skip large code strings
|
| 826 |
-
print(f" {key}: {value}")
|
| 827 |
-
|
| 828 |
-
# Validate outputs
|
| 829 |
-
if len(result.audio_paths) != expected_outputs:
|
| 830 |
-
print(f" ⚠ Warning: Expected {expected_outputs} outputs, got {len(result.audio_paths)}")
|
| 831 |
-
success = False
|
| 832 |
-
else:
|
| 833 |
-
success = True
|
| 834 |
-
|
| 835 |
-
else:
|
| 836 |
-
print(f" Error: {result.error}")
|
| 837 |
-
success = False
|
| 838 |
-
|
| 839 |
-
# Store test result
|
| 840 |
-
test_results.append({
|
| 841 |
-
"test_name": test_name,
|
| 842 |
-
"success": success,
|
| 843 |
-
"generation_success": result.success,
|
| 844 |
-
"num_outputs": len(result.audio_paths) if result.success else 0,
|
| 845 |
-
"expected_outputs": expected_outputs,
|
| 846 |
-
"elapsed_time": elapsed_time,
|
| 847 |
-
"error": result.error if not result.success else None,
|
| 848 |
-
})
|
| 849 |
-
|
| 850 |
-
return result
|
| 851 |
-
|
| 852 |
-
# ========================================================================
|
| 853 |
-
# Test: Production Example (from examples directory)
|
| 854 |
-
# ========================================================================
|
| 855 |
-
print("\n[2/3] Running Test...")
|
| 856 |
-
|
| 857 |
-
# Load production example (J-Rock song from examples/text2music/example_05.json)
|
| 858 |
-
example_file = os.path.join(project_root, "examples", "text2music", "example_05.json")
|
| 859 |
-
|
| 860 |
-
if not os.path.exists(example_file):
|
| 861 |
-
print(f"\n ❌ Example file not found: {example_file}")
|
| 862 |
-
print(" Please ensure the examples directory exists.")
|
| 863 |
-
exit(1)
|
| 864 |
-
|
| 865 |
-
print(f" Loading example: {os.path.basename(example_file)}")
|
| 866 |
-
config = load_example_config(example_file)
|
| 867 |
-
|
| 868 |
-
if not config:
|
| 869 |
-
print(" ❌ Failed to load example configuration")
|
| 870 |
-
exit(1)
|
| 871 |
-
|
| 872 |
-
# Reduce duration for faster testing (original is 200s)
|
| 873 |
-
print(f" Original duration: {config.audio_duration}s")
|
| 874 |
-
config.audio_duration = 30
|
| 875 |
-
config.use_random_seed = False
|
| 876 |
-
config.seed = 42
|
| 877 |
-
print(f" Test duration: {config.audio_duration}s (reduced for testing)")
|
| 878 |
-
|
| 879 |
-
run_test("Production Example (J-Rock Song)", config, expected_outputs=1)
|
| 880 |
-
|
| 881 |
-
# ========================================================================
|
| 882 |
-
# Test Summary
|
| 883 |
-
# ========================================================================
|
| 884 |
-
print("\n[3/3] Test Summary")
|
| 885 |
-
print("=" * 80)
|
| 886 |
-
|
| 887 |
-
if len(test_results) == 0:
|
| 888 |
-
print("No tests were run.")
|
| 889 |
-
exit(1)
|
| 890 |
-
|
| 891 |
-
result = test_results[0]
|
| 892 |
-
|
| 893 |
-
print(f"\nTest: {result['test_name']}")
|
| 894 |
-
print(f"Status: {'✓ PASS' if result['success'] else '✗ FAIL'}")
|
| 895 |
-
print(f"Generation: {'Success' if result['generation_success'] else 'Failed'}")
|
| 896 |
-
print(f"Outputs: {result['num_outputs']}/{result['expected_outputs']}")
|
| 897 |
-
print(f"Time: {result['elapsed_time']:.2f}s")
|
| 898 |
-
|
| 899 |
-
if result["error"]:
|
| 900 |
-
print(f"Error: {result['error']}")
|
| 901 |
-
|
| 902 |
-
# Save test results to JSON
|
| 903 |
-
results_file = os.path.join(project_root, "test_results.json")
|
| 904 |
-
try:
|
| 905 |
-
with open(results_file, "w") as f:
|
| 906 |
-
json.dump({
|
| 907 |
-
"test_name": result['test_name'],
|
| 908 |
-
"success": result['success'],
|
| 909 |
-
"generation_success": result['generation_success'],
|
| 910 |
-
"num_outputs": result['num_outputs'],
|
| 911 |
-
"expected_outputs": result['expected_outputs'],
|
| 912 |
-
"elapsed_time": result['elapsed_time'],
|
| 913 |
-
"error": result['error'],
|
| 914 |
-
}, f, indent=2)
|
| 915 |
-
print(f"\n✓ Test results saved to: {results_file}")
|
| 916 |
-
except Exception as e:
|
| 917 |
-
print(f"\n⚠ Failed to save test results: {e}")
|
| 918 |
-
|
| 919 |
-
# Exit with appropriate code
|
| 920 |
-
print("\n" + "=" * 80)
|
| 921 |
-
if result['success']:
|
| 922 |
-
print("Test passed! ✓")
|
| 923 |
-
print("=" * 80)
|
| 924 |
-
exit(0)
|
| 925 |
-
else:
|
| 926 |
-
print("Test failed! ✗")
|
| 927 |
-
print("=" * 80)
|
| 928 |
-
exit(1)
|
|
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
import math
|
| 10 |
+
import os
|
| 11 |
+
import tempfile
|
| 12 |
from typing import Optional, Union, List, Dict, Any, Tuple
|
| 13 |
from dataclasses import dataclass, field, asdict
|
| 14 |
from loguru import logger
|
| 15 |
+
|
| 16 |
+
from acestep.audio_utils import AudioSaver, generate_uuid_from_params
|
| 17 |
|
| 18 |
|
| 19 |
@dataclass
|
| 20 |
+
class GenerationParams:
|
| 21 |
+
"""Configuration for music generation parameters.
|
| 22 |
|
| 23 |
Attributes:
|
| 24 |
# Text Inputs
|
| 25 |
+
caption: A short text prompt describing the desired music (main prompt). < 512 characters
|
| 26 |
+
lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters
|
| 27 |
+
instrumental: If True, generate instrumental music regardless of lyrics.
|
| 28 |
|
| 29 |
# Music Metadata
|
| 30 |
+
bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300
|
| 31 |
+
keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor
|
| 32 |
+
timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection.
|
| 33 |
+
vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES
|
| 34 |
+
duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600
|
| 35 |
|
| 36 |
# Generation Parameters
|
| 37 |
+
inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model).
|
| 38 |
+
guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model.
|
| 39 |
+
seed: Integer seed for reproducibility. -1 means use random seed each time.
|
|
|
|
|
|
|
| 40 |
|
| 41 |
# Advanced DiT Parameters
|
| 42 |
+
use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
|
| 43 |
+
cfg_interval_start: Start ratio (0.0–1.0) to apply CFG.
|
| 44 |
+
cfg_interval_end: End ratio (0.0–1.0) to apply CFG.
|
|
|
|
| 45 |
|
| 46 |
# Task-Specific Parameters
|
| 47 |
+
task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
|
| 48 |
+
reference_audio: Path to a reference audio file for style transfer or cover tasks.
|
| 49 |
+
src_audio: Path to a source audio file for audio-to-audio tasks.
|
| 50 |
+
audio_codes: Audio semantic codes as a string (advanced use, for code-control generation).
|
| 51 |
+
repainting_start: For repaint/lego tasks: start time in seconds for region to repaint.
|
| 52 |
+
repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end).
|
| 53 |
+
audio_cover_strength: Strength of reference audio/codes influence (range 0.0–1.0). set smaller (0.2) for style transfer tasks.
|
| 54 |
+
instruction: Optional task instruction prompt. If empty, auto-generated by system.
|
| 55 |
|
| 56 |
+
# 5Hz Language Model Parameters for CoT reasoning
|
| 57 |
+
thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes.
|
| 58 |
+
lm_temperature: Sampling temperature for the LLM (0.0–2.0). Higher = more creative/varied results.
|
| 59 |
+
lm_cfg_scale: Classifier-free guidance scale for the LLM.
|
| 60 |
+
lm_top_k: LLM top-k sampling (0 = disabled).
|
| 61 |
+
lm_top_p: LLM top-p nucleus sampling (1.0 = disabled).
|
| 62 |
+
lm_negative_prompt: Negative prompt to use for LLM (for control).
|
| 63 |
+
use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning.
|
| 64 |
+
use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning.
|
| 65 |
+
use_cot_language: Whether to let LLM detect vocal language via CoT.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
"""
|
| 67 |
+
# Required Inputs
|
| 68 |
+
task_type: str = "text2music"
|
| 69 |
+
instruction: str = "Fill the audio semantic mask based on the given conditions:"
|
| 70 |
+
|
| 71 |
+
# Audio Uploads
|
| 72 |
+
reference_audio: Optional[str] = None
|
| 73 |
+
src_audio: Optional[str] = None
|
| 74 |
+
|
| 75 |
+
# LM Codes Hints
|
| 76 |
+
audio_codes: str = ""
|
| 77 |
+
|
| 78 |
# Text Inputs
|
| 79 |
caption: str = ""
|
| 80 |
lyrics: str = ""
|
| 81 |
+
instrumental: bool = False
|
| 82 |
+
|
| 83 |
+
# Metadata
|
|
|
|
|
|
|
| 84 |
vocal_language: str = "unknown"
|
| 85 |
+
bpm: Optional[int] = None
|
| 86 |
+
keyscale: str = ""
|
| 87 |
+
timesignature: str = ""
|
| 88 |
+
duration: float = -1.0
|
| 89 |
+
|
| 90 |
+
# Advanced Settings
|
| 91 |
inference_steps: int = 8
|
|
|
|
|
|
|
| 92 |
seed: int = -1
|
| 93 |
+
guidance_scale: float = 7.0
|
|
|
|
|
|
|
| 94 |
use_adg: bool = False
|
| 95 |
cfg_interval_start: float = 0.0
|
| 96 |
cfg_interval_end: float = 1.0
|
| 97 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
repainting_start: float = 0.0
|
| 99 |
repainting_end: float = -1
|
| 100 |
audio_cover_strength: float = 1.0
|
| 101 |
+
|
|
|
|
| 102 |
# 5Hz Language Model Parameters
|
| 103 |
+
thinking: bool = True
|
| 104 |
lm_temperature: float = 0.85
|
| 105 |
lm_cfg_scale: float = 2.0
|
| 106 |
lm_top_k: int = 0
|
|
|
|
| 108 |
lm_negative_prompt: str = "NO USER INPUT"
|
| 109 |
use_cot_metas: bool = True
|
| 110 |
use_cot_caption: bool = True
|
| 111 |
+
use_cot_lyrics: bool = False # TODO: not used yet
|
| 112 |
use_cot_language: bool = True
|
| 113 |
+
use_constrained_decoding: bool = True
|
| 114 |
+
|
| 115 |
+
cot_bpm: Optional[int] = None
|
| 116 |
+
cot_keyscale: str = ""
|
| 117 |
+
cot_timesignature: str = ""
|
| 118 |
+
cot_duration: Optional[float] = None
|
| 119 |
+
cot_vocal_language: str = "unknown"
|
| 120 |
+
cot_caption: str = ""
|
| 121 |
+
cot_lyrics: str = ""
|
| 122 |
+
|
| 123 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 124 |
+
"""Convert config to dictionary for JSON serialization."""
|
| 125 |
+
return asdict(self)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@dataclass
|
| 129 |
+
class GenerationConfig:
|
| 130 |
+
"""Configuration for music generation.
|
| 131 |
|
| 132 |
+
Attributes:
|
| 133 |
+
batch_size: Number of audio samples to generate
|
| 134 |
+
allow_lm_batch: Whether to allow batch processing in LM
|
| 135 |
+
use_random_seed: Whether to use random seed
|
| 136 |
+
seeds: Seed(s) for batch generation. Can be:
|
| 137 |
+
- None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
|
| 138 |
+
- List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
|
| 139 |
+
- int: Single seed value (will be converted to list and padded)
|
| 140 |
+
lm_batch_chunk_size: Batch chunk size for LM processing
|
| 141 |
+
constrained_decoding_debug: Whether to enable constrained decoding debug
|
| 142 |
+
audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
|
| 143 |
+
"""
|
| 144 |
+
batch_size: int = 2
|
| 145 |
allow_lm_batch: bool = False
|
| 146 |
+
use_random_seed: bool = True
|
| 147 |
+
seeds: Optional[List[int]] = None
|
| 148 |
+
lm_batch_chunk_size: int = 8
|
| 149 |
+
constrained_decoding_debug: bool = False
|
| 150 |
+
audio_format: str = "flac" # Default to FLAC for fast saving
|
| 151 |
+
|
| 152 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 153 |
+
"""Convert config to dictionary for JSON serialization."""
|
| 154 |
+
return asdict(self)
|
| 155 |
|
| 156 |
|
| 157 |
@dataclass
|
|
|
|
| 160 |
|
| 161 |
Attributes:
|
| 162 |
# Audio Outputs
|
| 163 |
+
audios: List of audio dictionaries with paths, keys, params
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
status_message: Status message from generation
|
| 165 |
+
extra_outputs: Extra outputs from generation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
success: Whether generation completed successfully
|
| 167 |
error: Error message if generation failed
|
| 168 |
"""
|
| 169 |
+
|
| 170 |
# Audio Outputs
|
| 171 |
+
audios: List[Dict[str, Any]] = field(default_factory=list)
|
|
|
|
|
|
|
|
|
|
| 172 |
# Generation Information
|
|
|
|
| 173 |
status_message: str = ""
|
| 174 |
+
extra_outputs: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
# Success Status
|
| 176 |
success: bool = True
|
| 177 |
error: Optional[str] = None
|
| 178 |
+
|
| 179 |
def to_dict(self) -> Dict[str, Any]:
|
| 180 |
"""Convert result to dictionary for JSON serialization."""
|
| 181 |
return asdict(self)
|
| 182 |
|
| 183 |
|
| 184 |
+
def _update_metadata_from_lm(
|
| 185 |
+
metadata: Dict[str, Any],
|
| 186 |
+
bpm: Optional[int],
|
| 187 |
+
key_scale: str,
|
| 188 |
+
time_signature: str,
|
| 189 |
+
audio_duration: Optional[float],
|
| 190 |
+
vocal_language: str,
|
| 191 |
+
caption: str,
|
| 192 |
+
lyrics: str,
|
| 193 |
+
) -> Tuple[Optional[int], str, str, Optional[float]]:
|
| 194 |
+
"""Update metadata fields from LM output if not provided by user."""
|
| 195 |
+
|
| 196 |
+
if bpm is None and metadata.get('bpm'):
|
| 197 |
+
bpm_value = metadata.get('bpm')
|
| 198 |
+
if bpm_value not in ["N/A", ""]:
|
| 199 |
+
try:
|
| 200 |
+
bpm = int(bpm_value)
|
| 201 |
+
except (ValueError, TypeError):
|
| 202 |
+
pass
|
| 203 |
+
|
| 204 |
+
if not key_scale and metadata.get('keyscale'):
|
| 205 |
+
key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
|
| 206 |
+
if key_scale_value != "N/A":
|
| 207 |
+
key_scale = key_scale_value
|
| 208 |
+
|
| 209 |
+
if not time_signature and metadata.get('timesignature'):
|
| 210 |
+
time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
|
| 211 |
+
if time_signature_value != "N/A":
|
| 212 |
+
time_signature = time_signature_value
|
| 213 |
+
|
| 214 |
+
if audio_duration is None or audio_duration <= 0:
|
| 215 |
+
audio_duration_value = metadata.get('duration', -1)
|
| 216 |
+
if audio_duration_value not in ["N/A", ""]:
|
| 217 |
+
try:
|
| 218 |
+
audio_duration = float(audio_duration_value)
|
| 219 |
+
except (ValueError, TypeError):
|
| 220 |
+
pass
|
| 221 |
+
|
| 222 |
+
if not vocal_language and metadata.get('vocal_language'):
|
| 223 |
+
vocal_language = metadata.get('vocal_language')
|
| 224 |
+
if not caption and metadata.get('caption'):
|
| 225 |
+
caption = metadata.get('caption')
|
| 226 |
+
if not lyrics and metadata.get('lyrics'):
|
| 227 |
+
lyrics = metadata.get('lyrics')
|
| 228 |
+
return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
|
| 229 |
+
|
| 230 |
+
|
| 231 |
def generate_music(
|
| 232 |
dit_handler,
|
| 233 |
llm_handler,
|
| 234 |
+
params: GenerationParams,
|
| 235 |
config: GenerationConfig,
|
| 236 |
+
save_dir: Optional[str] = None,
|
| 237 |
+
progress=None,
|
| 238 |
) -> GenerationResult:
|
| 239 |
"""Generate music using ACE-Step model with optional LM reasoning.
|
| 240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
Args:
|
| 242 |
dit_handler: Initialized DiT model handler (AceStepHandler instance)
|
| 243 |
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 244 |
+
params: Generation parameters (GenerationParams instance)
|
| 245 |
config: Generation configuration (GenerationConfig instance)
|
| 246 |
|
| 247 |
Returns:
|
| 248 |
+
GenerationResult with generated audio files and metadata
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
"""
|
|
|
|
| 250 |
try:
|
| 251 |
# Phase 1: LM-based metadata and code generation (if enabled)
|
| 252 |
+
audio_code_string_to_use = params.audio_codes
|
| 253 |
lm_generated_metadata = None
|
|
|
|
| 254 |
lm_generated_audio_codes_list = []
|
| 255 |
+
lm_total_time_costs = {
|
| 256 |
+
"phase1_time": 0.0,
|
| 257 |
+
"phase2_time": 0.0,
|
| 258 |
+
"total_time": 0.0,
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
# Extract mutable copies of metadata (will be updated by LM if needed)
|
| 262 |
+
bpm = params.bpm
|
| 263 |
+
key_scale = params.keyscale
|
| 264 |
+
time_signature = params.timesignature
|
| 265 |
+
audio_duration = params.duration
|
| 266 |
+
dit_input_caption = params.caption
|
| 267 |
+
dit_input_vocal_language = params.vocal_language
|
| 268 |
+
dit_input_lyrics = params.lyrics
|
| 269 |
+
# Determine if we need to generate audio codes
|
| 270 |
+
# If user has provided audio_codes, we don't need to generate them
|
| 271 |
+
# Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
|
| 272 |
+
user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
|
| 273 |
+
|
| 274 |
+
# Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
|
| 275 |
+
# For now, we use "llm_dit" if batch mode or if user hasn't provided codes
|
| 276 |
+
# Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
|
| 277 |
+
# Note: This logic can be refined based on specific requirements
|
| 278 |
+
need_audio_codes = not user_provided_audio_codes
|
| 279 |
+
|
| 280 |
+
# Determine if we should use chunk-based LM generation (always use chunks for consistency)
|
| 281 |
+
# Determine actual batch size for chunk processing
|
| 282 |
+
actual_batch_size = config.batch_size if config.batch_size is not None else 1
|
| 283 |
+
|
| 284 |
+
# Prepare seeds for batch generation
|
| 285 |
+
# Use config.seed if provided, otherwise fallback to params.seed
|
| 286 |
+
# Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
|
| 287 |
+
seed_for_generation = ""
|
| 288 |
+
if config.seeds is not None and len(config.seeds) > 0:
|
| 289 |
+
if isinstance(config.seeds, list):
|
| 290 |
+
# Convert List[int] to comma-separated string
|
| 291 |
+
seed_for_generation = ",".join(str(s) for s in config.seeds)
|
| 292 |
+
|
| 293 |
+
# Use dit_handler.prepare_seeds to handle seed list generation and padding
|
| 294 |
+
# This will handle all the logic: padding with random seeds if needed, etc.
|
| 295 |
+
actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
|
| 296 |
+
|
| 297 |
# LM-based Chain-of-Thought reasoning
|
| 298 |
+
use_lm = params.thinking and llm_handler.llm_initialized
|
| 299 |
+
lm_status = []
|
| 300 |
+
if use_lm:
|
| 301 |
+
# Convert sampling parameters - handle None values safely
|
| 302 |
+
top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
|
| 303 |
+
top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p
|
| 304 |
+
|
| 305 |
# Build user_metadata from user-provided values
|
| 306 |
user_metadata = {}
|
| 307 |
if bpm is not None:
|
| 308 |
try:
|
| 309 |
bpm_value = float(bpm)
|
| 310 |
if bpm_value > 0:
|
| 311 |
+
user_metadata['bpm'] = int(bpm_value)
|
| 312 |
except (ValueError, TypeError):
|
| 313 |
pass
|
| 314 |
+
|
| 315 |
if key_scale and key_scale.strip():
|
| 316 |
key_scale_clean = key_scale.strip()
|
| 317 |
if key_scale_clean.lower() not in ["n/a", ""]:
|
| 318 |
user_metadata['keyscale'] = key_scale_clean
|
| 319 |
+
|
| 320 |
if time_signature and time_signature.strip():
|
| 321 |
time_sig_clean = time_signature.strip()
|
| 322 |
if time_sig_clean.lower() not in ["n/a", ""]:
|
| 323 |
user_metadata['timesignature'] = time_sig_clean
|
| 324 |
+
|
| 325 |
if audio_duration is not None:
|
| 326 |
try:
|
| 327 |
duration_value = float(audio_duration)
|
| 328 |
if duration_value > 0:
|
| 329 |
+
user_metadata['duration'] = int(duration_value)
|
| 330 |
except (ValueError, TypeError):
|
| 331 |
pass
|
| 332 |
+
|
| 333 |
user_metadata_to_pass = user_metadata if user_metadata else None
|
| 334 |
+
|
| 335 |
+
# Determine infer_type based on whether we need audio codes
|
| 336 |
+
# - "llm_dit": generates both metas and audio codes (two-phase internally)
|
| 337 |
+
# - "dit": generates only metas (single phase)
|
| 338 |
+
infer_type = "llm_dit" if need_audio_codes else "dit"
|
| 339 |
+
|
| 340 |
+
# Use chunk size from config, or default to batch_size if not set
|
| 341 |
+
max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
|
| 342 |
+
num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
|
| 343 |
+
|
| 344 |
+
all_metadata_list = []
|
| 345 |
+
all_audio_codes_list = []
|
| 346 |
+
|
| 347 |
+
for chunk_idx in range(num_chunks):
|
| 348 |
+
chunk_start = chunk_idx * max_inference_batch_size
|
| 349 |
+
chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
|
| 350 |
+
chunk_size = chunk_end - chunk_start
|
| 351 |
+
chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
|
| 352 |
+
|
| 353 |
+
logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
|
| 354 |
+
f"(size: {chunk_size}, seeds: {chunk_seeds})")
|
| 355 |
+
|
| 356 |
+
# Use the determined infer_type
|
| 357 |
+
# - "llm_dit" will internally run two phases (metas + codes)
|
| 358 |
+
# - "dit" will only run phase 1 (metas only)
|
| 359 |
+
result = llm_handler.generate_with_stop_condition(
|
| 360 |
+
caption=params.caption or "",
|
| 361 |
+
lyrics=params.lyrics or "",
|
| 362 |
+
infer_type=infer_type,
|
| 363 |
+
temperature=params.lm_temperature,
|
| 364 |
+
cfg_scale=params.lm_cfg_scale,
|
| 365 |
+
negative_prompt=params.lm_negative_prompt,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
top_k=top_k_value,
|
| 367 |
top_p=top_p_value,
|
| 368 |
user_metadata=user_metadata_to_pass,
|
| 369 |
+
use_cot_caption=params.use_cot_caption,
|
| 370 |
+
use_cot_language=params.use_cot_language,
|
| 371 |
+
use_cot_metas=params.use_cot_metas,
|
| 372 |
+
use_constrained_decoding=params.use_constrained_decoding,
|
| 373 |
constrained_decoding_debug=config.constrained_decoding_debug,
|
| 374 |
+
batch_size=chunk_size,
|
| 375 |
+
seeds=chunk_seeds,
|
| 376 |
+
progress=progress,
|
| 377 |
)
|
| 378 |
+
|
| 379 |
+
# Check if LM generation failed
|
| 380 |
+
if not result.get("success", False):
|
| 381 |
+
error_msg = result.get("error", "Unknown LM error")
|
| 382 |
+
lm_status.append(f"❌ LM Error: {error_msg}")
|
| 383 |
+
# Return early with error
|
| 384 |
+
return GenerationResult(
|
| 385 |
+
audios=[],
|
| 386 |
+
status_message=f"❌ LM generation failed: {error_msg}",
|
| 387 |
+
extra_outputs={},
|
| 388 |
+
success=False,
|
| 389 |
+
error=error_msg,
|
| 390 |
)
|
| 391 |
+
|
| 392 |
+
# Extract metadata and audio_codes from result dict
|
| 393 |
+
if chunk_size > 1:
|
| 394 |
+
metadata_list = result.get("metadata", [])
|
| 395 |
+
audio_codes_list = result.get("audio_codes", [])
|
| 396 |
+
all_metadata_list.extend(metadata_list)
|
| 397 |
+
all_audio_codes_list.extend(audio_codes_list)
|
| 398 |
+
else:
|
| 399 |
+
metadata = result.get("metadata", {})
|
| 400 |
+
audio_codes = result.get("audio_codes", "")
|
| 401 |
+
all_metadata_list.append(metadata)
|
| 402 |
+
all_audio_codes_list.append(audio_codes)
|
| 403 |
+
|
| 404 |
+
# Collect time costs from LM extra_outputs
|
| 405 |
+
lm_extra = result.get("extra_outputs", {})
|
| 406 |
+
lm_chunk_time_costs = lm_extra.get("time_costs", {})
|
| 407 |
+
if lm_chunk_time_costs:
|
| 408 |
+
# Accumulate time costs from all chunks
|
| 409 |
+
for key in ["phase1_time", "phase2_time", "total_time"]:
|
| 410 |
+
if key in lm_chunk_time_costs:
|
| 411 |
+
lm_total_time_costs[key] += lm_chunk_time_costs[key]
|
| 412 |
+
|
| 413 |
+
time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()])
|
| 414 |
+
lm_status.append(f"✅ LM chunk {chunk_idx+1}: {time_str}")
|
| 415 |
+
|
| 416 |
+
lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
|
| 417 |
+
lm_generated_audio_codes_list = all_audio_codes_list
|
| 418 |
+
|
| 419 |
+
# Set audio_code_string_to_use based on infer_type
|
| 420 |
+
if infer_type == "llm_dit":
|
| 421 |
+
# If batch mode, use list; otherwise use single string
|
| 422 |
+
if actual_batch_size > 1:
|
| 423 |
+
audio_code_string_to_use = all_audio_codes_list
|
| 424 |
+
else:
|
| 425 |
+
audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else ""
|
| 426 |
+
else:
|
| 427 |
+
# For "dit" mode, keep user-provided codes or empty
|
| 428 |
+
audio_code_string_to_use = params.audio_codes
|
| 429 |
+
|
| 430 |
+
# Update metadata from LM if not provided by user
|
| 431 |
+
if lm_generated_metadata:
|
| 432 |
+
bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm(
|
| 433 |
+
metadata=lm_generated_metadata,
|
| 434 |
+
bpm=bpm,
|
| 435 |
+
key_scale=key_scale,
|
| 436 |
+
time_signature=time_signature,
|
| 437 |
+
audio_duration=audio_duration,
|
| 438 |
+
vocal_language=dit_input_vocal_language,
|
| 439 |
+
caption=dit_input_caption,
|
| 440 |
+
lyrics=dit_input_lyrics)
|
| 441 |
+
if not params.bpm:
|
| 442 |
+
params.cot_bpm = bpm
|
| 443 |
+
if not params.keyscale:
|
| 444 |
+
params.cot_keyscale = key_scale
|
| 445 |
+
if not params.timesignature:
|
| 446 |
+
params.cot_timesignature = time_signature
|
| 447 |
+
if not params.duration:
|
| 448 |
+
params.cot_duration = audio_duration
|
| 449 |
+
if not params.vocal_language:
|
| 450 |
+
params.cot_vocal_language = vocal_language
|
| 451 |
+
if not params.caption:
|
| 452 |
+
params.cot_caption = caption
|
| 453 |
+
if not params.lyrics:
|
| 454 |
+
params.cot_lyrics = lyrics
|
| 455 |
+
|
| 456 |
+
# set cot caption and language if needed
|
| 457 |
+
if params.use_cot_caption:
|
| 458 |
+
dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption)
|
| 459 |
+
if params.use_cot_language:
|
| 460 |
+
dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language)
|
| 461 |
+
|
| 462 |
# Phase 2: DiT music generation
|
| 463 |
+
# Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
|
| 464 |
result = dit_handler.generate_music(
|
| 465 |
+
captions=dit_input_caption,
|
| 466 |
+
lyrics=dit_input_lyrics,
|
| 467 |
bpm=bpm,
|
| 468 |
key_scale=key_scale,
|
| 469 |
time_signature=time_signature,
|
| 470 |
+
vocal_language=dit_input_vocal_language,
|
| 471 |
+
inference_steps=params.inference_steps,
|
| 472 |
+
guidance_scale=params.guidance_scale,
|
| 473 |
use_random_seed=config.use_random_seed,
|
| 474 |
+
seed=seed_for_generation, # Use config.seed (or params.seed fallback) instead of params.seed directly
|
| 475 |
+
reference_audio=params.reference_audio,
|
| 476 |
audio_duration=audio_duration,
|
| 477 |
+
batch_size=config.batch_size if config.batch_size is not None else 1,
|
| 478 |
+
src_audio=params.src_audio,
|
| 479 |
audio_code_string=audio_code_string_to_use,
|
| 480 |
+
repainting_start=params.repainting_start,
|
| 481 |
+
repainting_end=params.repainting_end,
|
| 482 |
+
instruction=params.instruction,
|
| 483 |
+
audio_cover_strength=params.audio_cover_strength,
|
| 484 |
+
task_type=params.task_type,
|
| 485 |
+
use_adg=params.use_adg,
|
| 486 |
+
cfg_interval_start=params.cfg_interval_start,
|
| 487 |
+
cfg_interval_end=params.cfg_interval_end,
|
| 488 |
+
progress=progress,
|
|
|
|
| 489 |
)
|
| 490 |
+
|
| 491 |
+
# Check if generation failed
|
| 492 |
+
if not result.get("success", False):
|
| 493 |
+
return GenerationResult(
|
| 494 |
+
audios=[],
|
| 495 |
+
status_message=result.get("status_message", ""),
|
| 496 |
+
extra_outputs={},
|
| 497 |
+
success=False,
|
| 498 |
+
error=result.get("error"),
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
# Extract results from dit_handler.generate_music dict
|
| 502 |
+
dit_audios = result.get("audios", [])
|
| 503 |
+
status_message = result.get("status_message", "")
|
| 504 |
+
dit_extra_outputs = result.get("extra_outputs", {})
|
| 505 |
+
|
| 506 |
+
# Use the seed list already prepared above (from config.seed or params.seed fallback)
|
| 507 |
+
# actual_seed_list was computed earlier using dit_handler.prepare_seeds
|
| 508 |
+
seed_list = actual_seed_list
|
| 509 |
+
|
| 510 |
+
# Get base params dictionary
|
| 511 |
+
base_params_dict = params.to_dict()
|
| 512 |
+
|
| 513 |
+
# Save audio files using AudioSaver (format from config)
|
| 514 |
+
audio_format = config.audio_format if config.audio_format else "flac"
|
| 515 |
+
audio_saver = AudioSaver(default_format=audio_format)
|
| 516 |
+
|
| 517 |
+
# Use handler's temp_dir for saving files
|
| 518 |
+
if save_dir is not None:
|
| 519 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 520 |
+
|
| 521 |
+
# Build audios list for GenerationResult with params and save files
|
| 522 |
+
# Audio saving and UUID generation handled here, outside of handler
|
| 523 |
+
audios = []
|
| 524 |
+
for idx, dit_audio in enumerate(dit_audios):
|
| 525 |
+
# Create a copy of params dict for this audio
|
| 526 |
+
audio_params = base_params_dict.copy()
|
| 527 |
+
|
| 528 |
+
# Update audio-specific values
|
| 529 |
+
audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
|
| 530 |
+
|
| 531 |
+
# Add audio codes if batch mode
|
| 532 |
+
if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
|
| 533 |
+
audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
|
| 534 |
+
|
| 535 |
+
# Get audio tensor and metadata
|
| 536 |
+
audio_tensor = dit_audio.get("tensor")
|
| 537 |
+
sample_rate = dit_audio.get("sample_rate", 48000)
|
| 538 |
+
|
| 539 |
+
# Generate UUID for this audio (moved from handler)
|
| 540 |
+
batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
|
| 541 |
+
audio_code_str = lm_generated_audio_codes_list[idx] if (
|
| 542 |
+
lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
|
| 543 |
+
if isinstance(audio_code_str, list):
|
| 544 |
+
audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
|
| 545 |
+
|
| 546 |
+
audio_key = generate_uuid_from_params(audio_params)
|
| 547 |
+
|
| 548 |
+
# Save audio file (handled outside handler)
|
| 549 |
+
audio_path = None
|
| 550 |
+
if audio_tensor is not None and save_dir is not None:
|
| 551 |
+
try:
|
| 552 |
+
audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
|
| 553 |
+
audio_path = audio_saver.save_audio(audio_tensor,
|
| 554 |
+
audio_file,
|
| 555 |
+
sample_rate=sample_rate,
|
| 556 |
+
format=audio_format,
|
| 557 |
+
channels_first=True)
|
| 558 |
+
except Exception as e:
|
| 559 |
+
logger.error(f"[generate_music] Failed to save audio file: {e}")
|
| 560 |
+
audio_path = "" # Fallback to empty path
|
| 561 |
+
|
| 562 |
+
audio_dict = {
|
| 563 |
+
"path": audio_path or "", # File path (saved here, not in handler)
|
| 564 |
+
"tensor": audio_tensor, # Audio tensor [channels, samples], CPU, float32
|
| 565 |
+
"key": audio_key,
|
| 566 |
+
"sample_rate": sample_rate,
|
| 567 |
+
"params": audio_params,
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
audios.append(audio_dict)
|
| 571 |
+
|
| 572 |
+
# Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
|
| 573 |
+
extra_outputs = dit_extra_outputs.copy()
|
| 574 |
+
extra_outputs["lm_metadata"] = lm_generated_metadata
|
| 575 |
+
|
| 576 |
+
# Merge time_costs from both LM and DiT into a unified dictionary
|
| 577 |
+
unified_time_costs = {}
|
| 578 |
+
|
| 579 |
+
# Add LM time costs (if LM was used)
|
| 580 |
+
if use_lm and lm_total_time_costs:
|
| 581 |
+
for key, value in lm_total_time_costs.items():
|
| 582 |
+
unified_time_costs[f"lm_{key}"] = value
|
| 583 |
+
|
| 584 |
+
# Add DiT time costs (if available)
|
| 585 |
+
dit_time_costs = dit_extra_outputs.get("time_costs", {})
|
| 586 |
+
if dit_time_costs:
|
| 587 |
+
for key, value in dit_time_costs.items():
|
| 588 |
+
unified_time_costs[f"dit_{key}"] = value
|
| 589 |
+
|
| 590 |
+
# Calculate total pipeline time
|
| 591 |
+
if unified_time_costs:
|
| 592 |
+
lm_total = unified_time_costs.get("lm_total_time", 0.0)
|
| 593 |
+
dit_total = unified_time_costs.get("dit_total_time_cost", 0.0)
|
| 594 |
+
unified_time_costs["pipeline_total_time"] = lm_total + dit_total
|
| 595 |
+
|
| 596 |
+
# Update extra_outputs with unified time_costs
|
| 597 |
+
extra_outputs["time_costs"] = unified_time_costs
|
| 598 |
+
|
| 599 |
+
if lm_status:
|
| 600 |
+
status_message = "\n".join(lm_status) + "\n" + status_message
|
| 601 |
+
else:
|
| 602 |
+
status_message = status_message
|
| 603 |
+
# Create and return GenerationResult
|
| 604 |
return GenerationResult(
|
| 605 |
+
audios=audios,
|
|
|
|
|
|
|
|
|
|
| 606 |
status_message=status_message,
|
| 607 |
+
extra_outputs=extra_outputs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
success=True,
|
| 609 |
error=None,
|
| 610 |
)
|
| 611 |
+
|
| 612 |
except Exception as e:
|
| 613 |
logger.exception("Music generation failed")
|
| 614 |
return GenerationResult(
|
| 615 |
+
audios=[],
|
| 616 |
+
status_message=f"Error: {str(e)}",
|
| 617 |
+
extra_outputs={},
|
| 618 |
success=False,
|
| 619 |
error=str(e),
|
|
|
|
|
|
|
| 620 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
acestep/llm_inference.py
CHANGED
|
@@ -5,7 +5,8 @@ Handles all LM-related operations including initialization and generation
|
|
| 5 |
import os
|
| 6 |
import traceback
|
| 7 |
import time
|
| 8 |
-
|
|
|
|
| 9 |
from contextlib import contextmanager
|
| 10 |
|
| 11 |
import yaml
|
|
@@ -85,6 +86,189 @@ class LLMHandler:
|
|
| 85 |
except Exception as e:
|
| 86 |
return 0.9, False
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
def initialize(
|
| 89 |
self,
|
| 90 |
checkpoint_dir: str,
|
|
@@ -126,6 +310,7 @@ class LLMHandler:
|
|
| 126 |
|
| 127 |
logger.info("loading 5Hz LM tokenizer...")
|
| 128 |
start_time = time.time()
|
|
|
|
| 129 |
llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True)
|
| 130 |
logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
|
| 131 |
self.llm_tokenizer = llm_tokenizer
|
|
@@ -150,41 +335,21 @@ class LLMHandler:
|
|
| 150 |
# vllm initialization failed, fallback to PyTorch
|
| 151 |
if not self.llm_initialized:
|
| 152 |
logger.warning("vllm initialization failed, falling back to PyTorch backend")
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
else:
|
| 158 |
-
self.llm = self.llm.to("cpu").to(self.dtype)
|
| 159 |
-
self.llm.eval()
|
| 160 |
-
self.llm_backend = "pt"
|
| 161 |
-
self.llm_initialized = True
|
| 162 |
-
logger.info("5Hz LM initialized successfully using PyTorch backend (fallback)")
|
| 163 |
-
status_msg = f"✅ 5Hz LM initialized successfully (PyTorch fallback)\nModel: {full_lm_model_path}\nBackend: PyTorch"
|
| 164 |
-
except Exception as e:
|
| 165 |
-
return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
|
| 166 |
# If vllm initialization succeeded, self.llm_initialized should already be True
|
| 167 |
else:
|
| 168 |
# Use PyTorch backend (pt)
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
self.llm = self.llm.to(device).to(self.dtype)
|
| 173 |
-
else:
|
| 174 |
-
self.llm = self.llm.to("cpu").to(self.dtype)
|
| 175 |
-
self.llm.eval()
|
| 176 |
-
self.llm_backend = "pt"
|
| 177 |
-
self.llm_initialized = True
|
| 178 |
-
logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}")
|
| 179 |
-
status_msg = f"✅ 5Hz LM initialized successfully\nModel: {full_lm_model_path}\nBackend: PyTorch\nDevice: {device}"
|
| 180 |
-
except Exception as e:
|
| 181 |
-
return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
|
| 182 |
|
| 183 |
return status_msg, True
|
| 184 |
|
| 185 |
except Exception as e:
|
| 186 |
-
|
| 187 |
-
return error_msg, False
|
| 188 |
|
| 189 |
def _initialize_5hz_lm_vllm(self, model_path: str) -> str:
|
| 190 |
"""Initialize 5Hz LM model using vllm backend"""
|
|
@@ -230,12 +395,11 @@ class LLMHandler:
|
|
| 230 |
return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
|
| 231 |
except Exception as e:
|
| 232 |
self.llm_initialized = False
|
| 233 |
-
|
| 234 |
-
return error_msg
|
| 235 |
|
| 236 |
-
def
|
| 237 |
self,
|
| 238 |
-
|
| 239 |
temperature: float,
|
| 240 |
cfg_scale: float,
|
| 241 |
negative_prompt: str,
|
|
@@ -244,7 +408,7 @@ class LLMHandler:
|
|
| 244 |
repetition_penalty: float,
|
| 245 |
use_constrained_decoding: bool = True,
|
| 246 |
constrained_decoding_debug: bool = False,
|
| 247 |
-
metadata_temperature: Optional[float] =
|
| 248 |
codes_temperature: Optional[float] = None,
|
| 249 |
target_duration: Optional[float] = None,
|
| 250 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
|
@@ -256,37 +420,40 @@ class LLMHandler:
|
|
| 256 |
caption: str = "",
|
| 257 |
lyrics: str = "",
|
| 258 |
cot_text: str = "",
|
| 259 |
-
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
from nanovllm import SamplingParams
|
| 262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
# Determine effective temperature for sampler
|
| 264 |
-
|
|
|
|
|
|
|
| 265 |
effective_sampler_temp = 1.0 if use_phase_temperatures else temperature
|
| 266 |
|
| 267 |
-
#
|
| 268 |
-
constrained_processor =
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
# Set skip_caption and skip_language based on flags
|
| 283 |
-
self.constrained_processor.set_skip_genres(skip_genres)
|
| 284 |
-
self.constrained_processor.set_skip_caption(skip_caption)
|
| 285 |
-
self.constrained_processor.set_skip_language(skip_language)
|
| 286 |
-
# Set generation phase for phase-aware processing
|
| 287 |
-
self.constrained_processor.set_generation_phase(generation_phase)
|
| 288 |
-
|
| 289 |
-
constrained_processor = self.constrained_processor
|
| 290 |
|
| 291 |
sampling_params = SamplingParams(
|
| 292 |
max_tokens=self.max_model_len - 64,
|
|
@@ -301,119 +468,25 @@ class LLMHandler:
|
|
| 301 |
|
| 302 |
if cfg_scale > 1.0:
|
| 303 |
# Build unconditional prompt based on generation phase
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
else:
|
| 312 |
-
# CoT phase: unconditional prompt
|
| 313 |
-
# If negative_prompt is provided, use it as caption; otherwise remove caption and keep only lyrics
|
| 314 |
-
formatted_unconditional_prompt = self.build_formatted_prompt(
|
| 315 |
-
caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
|
| 316 |
-
)
|
| 317 |
-
|
| 318 |
-
outputs = self.llm.generate(
|
| 319 |
-
[formatted_prompt],
|
| 320 |
-
sampling_params,
|
| 321 |
-
unconditional_prompts=[formatted_unconditional_prompt],
|
| 322 |
-
)
|
| 323 |
-
else:
|
| 324 |
-
outputs = self.llm.generate([formatted_prompt], sampling_params)
|
| 325 |
-
|
| 326 |
-
# Extract text (retain original selection order/logic)
|
| 327 |
-
if isinstance(outputs, list) and len(outputs) > 0:
|
| 328 |
-
if hasattr(outputs[0], "outputs") and len(outputs[0].outputs) > 0:
|
| 329 |
-
output_text = outputs[0].outputs[0].text
|
| 330 |
-
elif hasattr(outputs[0], "text"):
|
| 331 |
-
output_text = outputs[0].text
|
| 332 |
-
elif isinstance(outputs[0], dict) and "text" in outputs[0]:
|
| 333 |
-
output_text = outputs[0]["text"]
|
| 334 |
-
else:
|
| 335 |
-
output_text = str(outputs[0])
|
| 336 |
-
else:
|
| 337 |
-
output_text = str(outputs)
|
| 338 |
-
|
| 339 |
-
return output_text
|
| 340 |
-
|
| 341 |
-
def _run_vllm_batch(
|
| 342 |
-
self,
|
| 343 |
-
formatted_prompts: List[str],
|
| 344 |
-
temperature: float,
|
| 345 |
-
cfg_scale: float,
|
| 346 |
-
negative_prompt: str,
|
| 347 |
-
top_k: Optional[int],
|
| 348 |
-
top_p: Optional[float],
|
| 349 |
-
repetition_penalty: float,
|
| 350 |
-
use_constrained_decoding: bool = True,
|
| 351 |
-
constrained_decoding_debug: bool = False,
|
| 352 |
-
target_duration: Optional[float] = None,
|
| 353 |
-
generation_phase: str = "codes",
|
| 354 |
-
caption: str = "",
|
| 355 |
-
lyrics: str = "",
|
| 356 |
-
cot_text: str = "",
|
| 357 |
-
seeds: Optional[List[int]] = None,
|
| 358 |
-
) -> List[str]:
|
| 359 |
-
"""Batch generation using vllm backend"""
|
| 360 |
-
from nanovllm import SamplingParams
|
| 361 |
-
|
| 362 |
-
batch_size = len(formatted_prompts)
|
| 363 |
-
|
| 364 |
-
# Determine effective temperature for sampler
|
| 365 |
-
effective_sampler_temp = temperature
|
| 366 |
-
|
| 367 |
-
# Use shared constrained processor if enabled
|
| 368 |
-
# Note: vllm batch mode uses same processor for all items
|
| 369 |
-
constrained_processor = None
|
| 370 |
-
if use_constrained_decoding:
|
| 371 |
-
# Reset processor state for new generation
|
| 372 |
-
self.constrained_processor.reset()
|
| 373 |
-
|
| 374 |
-
self.constrained_processor.enabled = use_constrained_decoding
|
| 375 |
-
self.constrained_processor.debug = constrained_decoding_debug
|
| 376 |
-
self.constrained_processor.metadata_temperature = None
|
| 377 |
-
self.constrained_processor.codes_temperature = None
|
| 378 |
-
self.constrained_processor.set_target_duration(target_duration)
|
| 379 |
-
self.constrained_processor.set_user_metadata(None)
|
| 380 |
-
self.constrained_processor.set_stop_at_reasoning(False)
|
| 381 |
-
self.constrained_processor.set_skip_genres(True)
|
| 382 |
-
self.constrained_processor.set_skip_caption(True)
|
| 383 |
-
self.constrained_processor.set_skip_language(True)
|
| 384 |
-
self.constrained_processor.set_generation_phase(generation_phase)
|
| 385 |
-
|
| 386 |
-
constrained_processor = self.constrained_processor
|
| 387 |
-
|
| 388 |
-
# Build sampling params
|
| 389 |
-
sampling_params = SamplingParams(
|
| 390 |
-
max_tokens=self.max_model_len - 64,
|
| 391 |
-
temperature=effective_sampler_temp,
|
| 392 |
-
cfg_scale=cfg_scale,
|
| 393 |
-
top_k=top_k,
|
| 394 |
-
top_p=top_p,
|
| 395 |
-
repetition_penalty=repetition_penalty,
|
| 396 |
-
logits_processor=constrained_processor,
|
| 397 |
-
logits_processor_update_state=constrained_processor.update_state if constrained_processor else None,
|
| 398 |
-
)
|
| 399 |
-
|
| 400 |
-
# Generate with or without CFG
|
| 401 |
-
if cfg_scale > 1.0:
|
| 402 |
-
# Build unconditional prompts
|
| 403 |
-
formatted_unconditional_prompt = self.build_formatted_prompt_with_cot(
|
| 404 |
-
caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=negative_prompt
|
| 405 |
)
|
| 406 |
unconditional_prompts = [formatted_unconditional_prompt] * batch_size
|
| 407 |
|
| 408 |
outputs = self.llm.generate(
|
| 409 |
-
|
| 410 |
sampling_params,
|
| 411 |
unconditional_prompts=unconditional_prompts,
|
| 412 |
)
|
| 413 |
else:
|
| 414 |
-
outputs = self.llm.generate(
|
| 415 |
-
|
| 416 |
-
# Extract text from
|
| 417 |
output_texts = []
|
| 418 |
for output in outputs:
|
| 419 |
if hasattr(output, "outputs") and len(output.outputs) > 0:
|
|
@@ -424,70 +497,11 @@ class LLMHandler:
|
|
| 424 |
output_texts.append(output["text"])
|
| 425 |
else:
|
| 426 |
output_texts.append(str(output))
|
| 427 |
-
|
| 428 |
-
return output_texts
|
| 429 |
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
formatted_prompts: List[str],
|
| 433 |
-
temperature: float,
|
| 434 |
-
cfg_scale: float,
|
| 435 |
-
negative_prompt: str,
|
| 436 |
-
top_k: Optional[int],
|
| 437 |
-
top_p: Optional[float],
|
| 438 |
-
repetition_penalty: float,
|
| 439 |
-
use_constrained_decoding: bool = True,
|
| 440 |
-
constrained_decoding_debug: bool = False,
|
| 441 |
-
target_duration: Optional[float] = None,
|
| 442 |
-
generation_phase: str = "codes",
|
| 443 |
-
caption: str = "",
|
| 444 |
-
lyrics: str = "",
|
| 445 |
-
cot_text: str = "",
|
| 446 |
-
seeds: Optional[List[int]] = None,
|
| 447 |
-
) -> List[str]:
|
| 448 |
-
"""Batch generation using PyTorch backend"""
|
| 449 |
-
import random
|
| 450 |
-
|
| 451 |
-
batch_size = len(formatted_prompts)
|
| 452 |
-
output_texts = []
|
| 453 |
-
|
| 454 |
-
# Generate each item sequentially with different seeds
|
| 455 |
-
# (PyTorch backend doesn't support true batching efficiently)
|
| 456 |
-
for i, formatted_prompt in enumerate(formatted_prompts):
|
| 457 |
-
# Set seed for this item if provided
|
| 458 |
-
if seeds and i < len(seeds):
|
| 459 |
-
torch.manual_seed(seeds[i])
|
| 460 |
-
if torch.cuda.is_available():
|
| 461 |
-
torch.cuda.manual_seed_all(seeds[i])
|
| 462 |
-
|
| 463 |
-
# Generate using single-item method
|
| 464 |
-
output_text = self._run_pt_from_formatted(
|
| 465 |
-
formatted_prompt=formatted_prompt,
|
| 466 |
-
temperature=temperature,
|
| 467 |
-
cfg_scale=cfg_scale,
|
| 468 |
-
negative_prompt=negative_prompt,
|
| 469 |
-
top_k=top_k,
|
| 470 |
-
top_p=top_p,
|
| 471 |
-
repetition_penalty=repetition_penalty,
|
| 472 |
-
use_constrained_decoding=use_constrained_decoding,
|
| 473 |
-
constrained_decoding_debug=constrained_decoding_debug,
|
| 474 |
-
target_duration=target_duration,
|
| 475 |
-
user_metadata=None,
|
| 476 |
-
stop_at_reasoning=False,
|
| 477 |
-
skip_genres=True,
|
| 478 |
-
skip_caption=True,
|
| 479 |
-
skip_language=True,
|
| 480 |
-
generation_phase=generation_phase,
|
| 481 |
-
caption=caption,
|
| 482 |
-
lyrics=lyrics,
|
| 483 |
-
cot_text=cot_text,
|
| 484 |
-
)
|
| 485 |
-
|
| 486 |
-
output_texts.append(output_text)
|
| 487 |
-
|
| 488 |
-
return output_texts
|
| 489 |
|
| 490 |
-
def
|
| 491 |
self,
|
| 492 |
formatted_prompt: str,
|
| 493 |
temperature: float,
|
|
@@ -496,20 +510,20 @@ class LLMHandler:
|
|
| 496 |
top_k: Optional[int],
|
| 497 |
top_p: Optional[float],
|
| 498 |
repetition_penalty: float,
|
| 499 |
-
use_constrained_decoding: bool
|
| 500 |
-
constrained_decoding_debug: bool
|
| 501 |
-
target_duration: Optional[float]
|
| 502 |
-
user_metadata: Optional[Dict[str, Optional[str]]]
|
| 503 |
-
stop_at_reasoning: bool
|
| 504 |
-
skip_genres: bool
|
| 505 |
-
skip_caption: bool
|
| 506 |
-
skip_language: bool
|
| 507 |
-
generation_phase: str
|
| 508 |
-
caption: str
|
| 509 |
-
lyrics: str
|
| 510 |
-
cot_text: str
|
| 511 |
) -> str:
|
| 512 |
-
"""
|
| 513 |
inputs = self.llm_tokenizer(
|
| 514 |
formatted_prompt,
|
| 515 |
return_tensors="pt",
|
|
@@ -517,27 +531,19 @@ class LLMHandler:
|
|
| 517 |
truncation=True,
|
| 518 |
)
|
| 519 |
|
| 520 |
-
#
|
| 521 |
-
constrained_processor =
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
# Set skip_caption and skip_language based on flags
|
| 534 |
-
self.constrained_processor.set_skip_genres(skip_genres)
|
| 535 |
-
self.constrained_processor.set_skip_caption(skip_caption)
|
| 536 |
-
self.constrained_processor.set_skip_language(skip_language)
|
| 537 |
-
# Set generation phase for phase-aware processing
|
| 538 |
-
self.constrained_processor.set_generation_phase(generation_phase)
|
| 539 |
-
|
| 540 |
-
constrained_processor = self.constrained_processor
|
| 541 |
|
| 542 |
with self._load_model_context():
|
| 543 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
@@ -546,25 +552,18 @@ class LLMHandler:
|
|
| 546 |
max_new_tokens = min(max_new_tokens, self.max_model_len - 64)
|
| 547 |
|
| 548 |
# Build logits processor list (only for CFG and repetition penalty)
|
| 549 |
-
logits_processor =
|
| 550 |
-
|
| 551 |
-
# Add repetition penalty if needed
|
| 552 |
-
if repetition_penalty != 1.0:
|
| 553 |
-
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
| 554 |
|
| 555 |
if cfg_scale > 1.0:
|
| 556 |
# Build unconditional prompt based on generation phase
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
formatted_unconditional_prompt = self.build_formatted_prompt(
|
| 566 |
-
caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
|
| 567 |
-
)
|
| 568 |
|
| 569 |
# Tokenize both prompts together to ensure same length (with left padding)
|
| 570 |
# Left padding is important for generation tasks
|
|
@@ -657,7 +656,101 @@ class LLMHandler:
|
|
| 657 |
|
| 658 |
output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
|
| 659 |
return output_text
|
| 660 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
def has_all_metas(self, user_metadata: Optional[Dict[str, Optional[str]]]) -> bool:
|
| 662 |
"""Check if all required metadata are present."""
|
| 663 |
if user_metadata is None:
|
|
@@ -705,10 +798,13 @@ class LLMHandler:
|
|
| 705 |
constrained_decoding_debug: bool = False,
|
| 706 |
target_duration: Optional[float] = None,
|
| 707 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
|
|
|
| 708 |
use_cot_caption: bool = True,
|
| 709 |
use_cot_language: bool = True,
|
| 710 |
-
|
| 711 |
-
|
|
|
|
|
|
|
| 712 |
"""Two-phase LM generation: CoT generation followed by audio codes generation.
|
| 713 |
|
| 714 |
- infer_type='dit': Phase 1 only - generate CoT and return metas (no audio codes)
|
|
@@ -721,30 +817,67 @@ class LLMHandler:
|
|
| 721 |
If specified, constrained decoding will inject these values directly.
|
| 722 |
use_cot_caption: Whether to generate caption in CoT (default True).
|
| 723 |
use_cot_language: Whether to generate language in CoT (default True).
|
| 724 |
-
|
| 725 |
-
|
|
|
|
|
|
|
| 726 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 727 |
infer_type = (infer_type or "").strip().lower()
|
| 728 |
if infer_type not in {"dit", "llm_dit"}:
|
| 729 |
-
|
| 730 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 731 |
metadata = {}
|
| 732 |
audio_codes = ""
|
| 733 |
has_all_metas = self.has_all_metas(user_metadata)
|
| 734 |
-
|
| 735 |
-
# Timing variables
|
| 736 |
phase1_time = 0.0
|
| 737 |
phase2_time = 0.0
|
| 738 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 739 |
# ========== PHASE 1: CoT Generation ==========
|
| 740 |
-
#
|
| 741 |
-
|
| 742 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 743 |
phase1_start = time.time()
|
| 744 |
|
| 745 |
# Build formatted prompt for CoT phase
|
| 746 |
formatted_prompt = self.build_formatted_prompt(caption, lyrics, generation_phase="cot")
|
| 747 |
-
|
| 748 |
logger.info(f"generate_with_stop_condition: formatted_prompt={formatted_prompt}")
|
| 749 |
# Generate CoT (stop at </think>)
|
| 750 |
cot_output_text, status = self.generate_from_formatted_prompt(
|
|
@@ -774,23 +907,63 @@ class LLMHandler:
|
|
| 774 |
phase1_time = time.time() - phase1_start
|
| 775 |
|
| 776 |
if not cot_output_text:
|
| 777 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 778 |
|
| 779 |
# Parse metadata from CoT output
|
| 780 |
metadata, _ = self.parse_lm_output(cot_output_text)
|
| 781 |
-
|
|
|
|
|
|
|
|
|
|
| 782 |
else:
|
| 783 |
# Use user-provided metadata
|
| 784 |
-
|
|
|
|
|
|
|
|
|
|
| 785 |
metadata = {k: v for k, v in user_metadata.items() if v is not None}
|
| 786 |
|
| 787 |
# If infer_type is 'dit', stop here and return only metadata
|
| 788 |
if infer_type == "dit":
|
| 789 |
-
|
| 790 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 791 |
|
| 792 |
# ========== PHASE 2: Audio Codes Generation ==========
|
| 793 |
-
|
|
|
|
|
|
|
|
|
|
| 794 |
phase2_start = time.time()
|
| 795 |
|
| 796 |
# Format metadata as CoT using YAML (matching training format)
|
|
@@ -799,221 +972,163 @@ class LLMHandler:
|
|
| 799 |
# Build formatted prompt with CoT for codes generation phase
|
| 800 |
formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
|
| 801 |
logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}")
|
| 802 |
-
# Generate audio codes
|
| 803 |
-
codes_output_text, status = self.generate_from_formatted_prompt(
|
| 804 |
-
formatted_prompt=formatted_prompt_with_cot,
|
| 805 |
-
cfg={
|
| 806 |
-
"temperature": temperature,
|
| 807 |
-
"cfg_scale": cfg_scale,
|
| 808 |
-
"negative_prompt": negative_prompt,
|
| 809 |
-
"top_k": top_k,
|
| 810 |
-
"top_p": top_p,
|
| 811 |
-
"repetition_penalty": repetition_penalty,
|
| 812 |
-
"target_duration": target_duration,
|
| 813 |
-
"user_metadata": None, # No user metadata injection in Phase 2
|
| 814 |
-
"skip_caption": True, # Skip caption since CoT is already included
|
| 815 |
-
"skip_language": True, # Skip language since CoT is already included
|
| 816 |
-
"generation_phase": "codes",
|
| 817 |
-
# Pass context for building unconditional prompt in codes phase
|
| 818 |
-
"caption": caption,
|
| 819 |
-
"lyrics": lyrics,
|
| 820 |
-
"cot_text": cot_text,
|
| 821 |
-
},
|
| 822 |
-
use_constrained_decoding=use_constrained_decoding,
|
| 823 |
-
constrained_decoding_debug=constrained_decoding_debug,
|
| 824 |
-
stop_at_reasoning=False, # Generate codes until EOS
|
| 825 |
-
)
|
| 826 |
-
|
| 827 |
-
if not codes_output_text:
|
| 828 |
-
return metadata, "", status
|
| 829 |
-
|
| 830 |
-
phase2_time = time.time() - phase2_start
|
| 831 |
-
|
| 832 |
-
# Parse audio codes from output (metadata should be same as Phase 1)
|
| 833 |
-
_, audio_codes = self.parse_lm_output(codes_output_text)
|
| 834 |
-
|
| 835 |
-
codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
|
| 836 |
-
logger.info(f"Phase 2 completed in {phase2_time:.2f}s. Generated {codes_count} audio codes")
|
| 837 |
-
|
| 838 |
-
status_msg = f"✅ Generated successfully (2-phase)\nPhase 1: CoT metadata\nPhase 2: {codes_count} audio codes\nPhase1: {phase1_time:.2f}s, Phase2: {phase2_time:.2f}s"
|
| 839 |
-
return metadata, audio_codes, status_msg
|
| 840 |
-
|
| 841 |
-
def generate_with_stop_condition_batch(
|
| 842 |
-
self,
|
| 843 |
-
caption: str,
|
| 844 |
-
lyrics: str,
|
| 845 |
-
batch_size: int,
|
| 846 |
-
infer_type: str = "llm_dit",
|
| 847 |
-
temperature: float = 0.85,
|
| 848 |
-
cfg_scale: float = 1.0,
|
| 849 |
-
negative_prompt: str = "NO USER INPUT",
|
| 850 |
-
top_k: Optional[int] = None,
|
| 851 |
-
top_p: Optional[float] = None,
|
| 852 |
-
repetition_penalty: float = 1.0,
|
| 853 |
-
use_constrained_decoding: bool = True,
|
| 854 |
-
constrained_decoding_debug: bool = False,
|
| 855 |
-
target_duration: Optional[float] = None,
|
| 856 |
-
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 857 |
-
use_cot_caption: bool = True,
|
| 858 |
-
use_cot_language: bool = True,
|
| 859 |
-
is_format_caption: bool = False,
|
| 860 |
-
seeds: Optional[List[int]] = None,
|
| 861 |
-
) -> Tuple[List[Dict[str, Any]], List[str], str]:
|
| 862 |
-
"""
|
| 863 |
-
Batch version of generate_with_stop_condition.
|
| 864 |
-
|
| 865 |
-
Generates multiple audio codes with same conditions but different seeds (for diversity).
|
| 866 |
-
|
| 867 |
-
Args:
|
| 868 |
-
caption: Same caption for all items
|
| 869 |
-
lyrics: Same lyrics for all items
|
| 870 |
-
batch_size: Number of items to generate
|
| 871 |
-
seeds: Optional list of seeds for each batch item (for reproducibility)
|
| 872 |
-
... (other args same as generate_with_stop_condition)
|
| 873 |
-
|
| 874 |
-
Returns:
|
| 875 |
-
Tuple of (metadata_list, audio_codes_list, status_message)
|
| 876 |
-
- metadata_list: List of metadata dicts (same metadata for all items)
|
| 877 |
-
- audio_codes_list: List of audio code strings (one per item, different due to sampling)
|
| 878 |
-
- status_message: Generation status
|
| 879 |
-
"""
|
| 880 |
-
import random
|
| 881 |
-
import time
|
| 882 |
|
| 883 |
-
|
| 884 |
-
if
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
# Generate seeds if not provided
|
| 888 |
-
if seeds is None:
|
| 889 |
-
seeds = [random.randint(0, 2**32 - 1) for _ in range(batch_size)]
|
| 890 |
-
elif len(seeds) < batch_size:
|
| 891 |
-
# Pad with random seeds if not enough provided
|
| 892 |
-
seeds = list(seeds) + [random.randint(0, 2**32 - 1) for _ in range(batch_size - len(seeds))]
|
| 893 |
-
else:
|
| 894 |
-
seeds = seeds[:batch_size] # Truncate if too many
|
| 895 |
-
|
| 896 |
-
# Timing variables
|
| 897 |
-
phase1_time = 0.0
|
| 898 |
-
phase2_time = 0.0
|
| 899 |
-
|
| 900 |
-
# ========== PHASE 1: CoT Generation (ONCE for all items) ==========
|
| 901 |
-
has_all_metas = self.has_all_metas(user_metadata)
|
| 902 |
-
|
| 903 |
-
if not has_all_metas or not is_format_caption:
|
| 904 |
-
logger.info("Batch Phase 1: Generating CoT metadata (once for all items)...")
|
| 905 |
-
phase1_start = time.time()
|
| 906 |
|
| 907 |
-
#
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 918 |
use_constrained_decoding=use_constrained_decoding,
|
| 919 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 920 |
-
|
| 921 |
-
user_metadata=user_metadata,
|
| 922 |
-
use_cot_caption=use_cot_caption,
|
| 923 |
-
use_cot_language=use_cot_language,
|
| 924 |
-
is_format_caption=is_format_caption,
|
| 925 |
)
|
| 926 |
|
| 927 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 928 |
|
| 929 |
-
|
| 930 |
-
return [], [], status
|
| 931 |
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
|
| 943 |
-
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
|
| 947 |
-
|
| 948 |
-
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
# Replicate prompt for batch (all items have same prompt, differ by seeds)
|
| 955 |
-
formatted_prompts = [formatted_prompt] * batch_size
|
| 956 |
-
|
| 957 |
-
# Call backend-specific batch generation
|
| 958 |
-
try:
|
| 959 |
-
if self.llm_backend == "vllm":
|
| 960 |
-
codes_outputs = self._run_vllm_batch(
|
| 961 |
-
formatted_prompts=formatted_prompts,
|
| 962 |
-
temperature=temperature,
|
| 963 |
-
cfg_scale=cfg_scale,
|
| 964 |
-
negative_prompt=negative_prompt,
|
| 965 |
-
top_k=top_k,
|
| 966 |
-
top_p=top_p,
|
| 967 |
-
repetition_penalty=repetition_penalty,
|
| 968 |
-
use_constrained_decoding=use_constrained_decoding,
|
| 969 |
-
constrained_decoding_debug=constrained_decoding_debug,
|
| 970 |
-
target_duration=target_duration,
|
| 971 |
-
generation_phase="codes",
|
| 972 |
-
caption=caption,
|
| 973 |
-
lyrics=lyrics,
|
| 974 |
-
cot_text=cot_text,
|
| 975 |
-
seeds=seeds,
|
| 976 |
-
)
|
| 977 |
-
else: # pt backend
|
| 978 |
-
codes_outputs = self._run_pt_batch(
|
| 979 |
-
formatted_prompts=formatted_prompts,
|
| 980 |
-
temperature=temperature,
|
| 981 |
-
cfg_scale=cfg_scale,
|
| 982 |
-
negative_prompt=negative_prompt,
|
| 983 |
-
top_k=top_k,
|
| 984 |
-
top_p=top_p,
|
| 985 |
-
repetition_penalty=repetition_penalty,
|
| 986 |
-
use_constrained_decoding=use_constrained_decoding,
|
| 987 |
-
constrained_decoding_debug=constrained_decoding_debug,
|
| 988 |
-
target_duration=target_duration,
|
| 989 |
-
generation_phase="codes",
|
| 990 |
-
caption=caption,
|
| 991 |
-
lyrics=lyrics,
|
| 992 |
-
cot_text=cot_text,
|
| 993 |
-
seeds=seeds,
|
| 994 |
-
)
|
| 995 |
-
except Exception as e:
|
| 996 |
-
error_msg = f"❌ Error in batch codes generation: {str(e)}"
|
| 997 |
-
logger.error(error_msg)
|
| 998 |
-
return [], [], error_msg
|
| 999 |
-
|
| 1000 |
-
# Parse audio codes from each output
|
| 1001 |
-
audio_codes_list = []
|
| 1002 |
-
metadata_list = []
|
| 1003 |
-
for output_text in codes_outputs:
|
| 1004 |
-
_, audio_codes = self.parse_lm_output(output_text)
|
| 1005 |
-
audio_codes_list.append(audio_codes)
|
| 1006 |
-
metadata_list.append(metadata.copy()) # Same metadata for all
|
| 1007 |
-
|
| 1008 |
-
phase2_time = time.time() - phase2_start
|
| 1009 |
-
|
| 1010 |
-
# Log results
|
| 1011 |
-
codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list]
|
| 1012 |
-
logger.info(f"Batch Phase 2 completed in {phase2_time:.2f}s. Generated codes: {codes_counts}")
|
| 1013 |
-
|
| 1014 |
-
status_msg = f"✅ Batch generation completed ({batch_size} items)\nPhase 1: CoT metadata\nPhase 2: {sum(codes_counts)} total codes ({codes_counts})\nPhase1: {phase1_time:.2f}s, Phase2: {phase2_time:.2f}s"
|
| 1015 |
-
return metadata_list, audio_codes_list, status_msg
|
| 1016 |
-
|
| 1017 |
def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str:
|
| 1018 |
"""
|
| 1019 |
Build the chat-formatted prompt for 5Hz LM from caption/lyrics.
|
|
@@ -1035,7 +1150,7 @@ class LLMHandler:
|
|
| 1035 |
if is_negative_prompt:
|
| 1036 |
# Unconditional prompt for CFG
|
| 1037 |
# Check if user provided a meaningful negative prompt (not the default)
|
| 1038 |
-
has_negative_prompt =
|
| 1039 |
|
| 1040 |
if generation_phase == "cot":
|
| 1041 |
# CoT phase unconditional prompt
|
|
@@ -1086,7 +1201,7 @@ class LLMHandler:
|
|
| 1086 |
if is_negative_prompt:
|
| 1087 |
# Unconditional prompt for codes phase
|
| 1088 |
# Check if user provided a meaningful negative prompt
|
| 1089 |
-
has_negative_prompt =
|
| 1090 |
|
| 1091 |
# Use empty CoT for unconditional
|
| 1092 |
cot_for_prompt = "<think>\n</think>"
|
|
@@ -1369,8 +1484,8 @@ class LLMHandler:
|
|
| 1369 |
|
| 1370 |
try:
|
| 1371 |
if self.llm_backend == "vllm":
|
| 1372 |
-
output_text = self.
|
| 1373 |
-
|
| 1374 |
temperature=temperature,
|
| 1375 |
cfg_scale=cfg_scale,
|
| 1376 |
negative_prompt=negative_prompt,
|
|
@@ -1393,8 +1508,8 @@ class LLMHandler:
|
|
| 1393 |
return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
|
| 1394 |
|
| 1395 |
# PyTorch backend
|
| 1396 |
-
output_text = self.
|
| 1397 |
-
|
| 1398 |
temperature=temperature,
|
| 1399 |
cfg_scale=cfg_scale,
|
| 1400 |
negative_prompt=negative_prompt,
|
|
@@ -1459,26 +1574,12 @@ class LLMHandler:
|
|
| 1459 |
eos_token_id = pad_token_id
|
| 1460 |
|
| 1461 |
# Build logits processor for repetition penalty
|
| 1462 |
-
logits_processor =
|
| 1463 |
-
if repetition_penalty != 1.0:
|
| 1464 |
-
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
| 1465 |
|
| 1466 |
with torch.no_grad():
|
| 1467 |
for step in range(max_new_tokens):
|
| 1468 |
# Forward pass
|
| 1469 |
-
|
| 1470 |
-
outputs = model(
|
| 1471 |
-
input_ids=generated_ids,
|
| 1472 |
-
**model_kwargs,
|
| 1473 |
-
use_cache=use_cache,
|
| 1474 |
-
)
|
| 1475 |
-
else:
|
| 1476 |
-
outputs = model(
|
| 1477 |
-
input_ids=generated_ids[:, -1:],
|
| 1478 |
-
past_key_values=past_key_values,
|
| 1479 |
-
**model_kwargs,
|
| 1480 |
-
use_cache=use_cache,
|
| 1481 |
-
)
|
| 1482 |
|
| 1483 |
# Get logits for the last position
|
| 1484 |
next_token_logits = outputs.logits[:, -1, :] # [batch_size, vocab_size]
|
|
@@ -1491,41 +1592,18 @@ class LLMHandler:
|
|
| 1491 |
for processor in logits_processor:
|
| 1492 |
next_token_logits = processor(generated_ids, next_token_logits)
|
| 1493 |
|
| 1494 |
-
# Apply top-k filtering
|
| 1495 |
-
|
| 1496 |
-
|
| 1497 |
-
next_token_logits[indices_to_remove] = float('-inf')
|
| 1498 |
-
|
| 1499 |
-
# Apply top-p filtering
|
| 1500 |
-
if top_p is not None and 0.0 < top_p < 1.0:
|
| 1501 |
-
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
| 1502 |
-
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 1503 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
| 1504 |
-
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 1505 |
-
sorted_indices_to_remove[..., 0] = 0
|
| 1506 |
-
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 1507 |
-
next_token_logits[indices_to_remove] = float('-inf')
|
| 1508 |
|
| 1509 |
# Apply temperature and sample
|
| 1510 |
-
|
| 1511 |
-
next_token_logits = next_token_logits / temperature
|
| 1512 |
-
probs = torch.softmax(next_token_logits, dim=-1)
|
| 1513 |
-
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 1514 |
-
else:
|
| 1515 |
-
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
| 1516 |
|
| 1517 |
# Update constrained processor state
|
| 1518 |
-
|
| 1519 |
-
for b in range(next_tokens.shape[0]):
|
| 1520 |
-
constrained_processor.update_state(next_tokens[b].item())
|
| 1521 |
|
| 1522 |
# Check for EOS token
|
| 1523 |
-
should_stop =
|
| 1524 |
-
if torch.any(next_tokens == eos_token_id):
|
| 1525 |
-
should_stop = True
|
| 1526 |
-
elif pad_token_id is not None and pad_token_id != eos_token_id:
|
| 1527 |
-
if torch.any(next_tokens == pad_token_id):
|
| 1528 |
-
should_stop = True
|
| 1529 |
|
| 1530 |
# Append token to sequence
|
| 1531 |
next_tokens_unsqueezed = next_tokens.unsqueeze(1)
|
|
@@ -1601,28 +1679,12 @@ class LLMHandler:
|
|
| 1601 |
eos_token_id = pad_token_id
|
| 1602 |
|
| 1603 |
# Build logits processor for non-CFG operations (repetition penalty, top_k, top_p)
|
| 1604 |
-
logits_processor =
|
| 1605 |
-
if repetition_penalty != 1.0:
|
| 1606 |
-
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
| 1607 |
|
| 1608 |
with torch.no_grad():
|
| 1609 |
for step in range(max_new_tokens):
|
| 1610 |
# Forward pass for the entire batch (conditional + unconditional)
|
| 1611 |
-
|
| 1612 |
-
# First step: full forward pass
|
| 1613 |
-
outputs = model(
|
| 1614 |
-
input_ids=generated_ids,
|
| 1615 |
-
**model_kwargs,
|
| 1616 |
-
use_cache=use_cache,
|
| 1617 |
-
)
|
| 1618 |
-
else:
|
| 1619 |
-
# Subsequent steps: only forward the last token (utilizing KV cache)
|
| 1620 |
-
outputs = model(
|
| 1621 |
-
input_ids=generated_ids[:, -1:],
|
| 1622 |
-
past_key_values=past_key_values,
|
| 1623 |
-
**model_kwargs,
|
| 1624 |
-
use_cache=use_cache,
|
| 1625 |
-
)
|
| 1626 |
|
| 1627 |
# Get logits for the last position
|
| 1628 |
next_token_logits = outputs.logits[:, -1, :] # [batch_size*2, vocab_size]
|
|
@@ -1645,45 +1707,20 @@ class LLMHandler:
|
|
| 1645 |
for processor in logits_processor:
|
| 1646 |
cfg_logits = processor(current_input_ids, cfg_logits)
|
| 1647 |
|
| 1648 |
-
# Apply top-k filtering
|
| 1649 |
-
|
| 1650 |
-
|
| 1651 |
-
cfg_logits[indices_to_remove] = float('-inf')
|
| 1652 |
-
|
| 1653 |
-
# Apply top-p (nucleus) filtering
|
| 1654 |
-
if top_p is not None and 0.0 < top_p < 1.0:
|
| 1655 |
-
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
|
| 1656 |
-
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 1657 |
-
# Remove tokens with cumulative probability above the threshold
|
| 1658 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
| 1659 |
-
# Shift the indices to the right to keep also the first token above the threshold
|
| 1660 |
-
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 1661 |
-
sorted_indices_to_remove[..., 0] = 0
|
| 1662 |
-
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 1663 |
-
cfg_logits[indices_to_remove] = float('-inf')
|
| 1664 |
|
| 1665 |
# Apply temperature and sample
|
| 1666 |
-
|
| 1667 |
-
cfg_logits = cfg_logits / temperature
|
| 1668 |
-
probs = torch.softmax(cfg_logits, dim=-1)
|
| 1669 |
-
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 1670 |
-
else:
|
| 1671 |
-
next_tokens = torch.argmax(cfg_logits, dim=-1)
|
| 1672 |
|
| 1673 |
# Update constrained processor state AFTER sampling
|
| 1674 |
-
|
| 1675 |
-
for b in range(next_tokens.shape[0]):
|
| 1676 |
-
constrained_processor.update_state(next_tokens[b].item())
|
| 1677 |
|
| 1678 |
# Check for EOS token in conditional sequences BEFORE unsqueezing
|
| 1679 |
# Stop if any conditional sequence generates EOS token
|
| 1680 |
# next_tokens shape: [batch_size] (only conditional tokens)
|
| 1681 |
-
should_stop =
|
| 1682 |
-
if torch.any(next_tokens == eos_token_id):
|
| 1683 |
-
should_stop = True
|
| 1684 |
-
elif pad_token_id is not None and pad_token_id != eos_token_id:
|
| 1685 |
-
if torch.any(next_tokens == pad_token_id):
|
| 1686 |
-
should_stop = True
|
| 1687 |
|
| 1688 |
# Apply the same sampled tokens to both conditional and unconditional sequences
|
| 1689 |
next_tokens_unsqueezed = next_tokens.unsqueeze(1)
|
|
|
|
| 5 |
import os
|
| 6 |
import traceback
|
| 7 |
import time
|
| 8 |
+
import random
|
| 9 |
+
from typing import Optional, Dict, Any, Tuple, List, Union
|
| 10 |
from contextlib import contextmanager
|
| 11 |
|
| 12 |
import yaml
|
|
|
|
| 86 |
except Exception as e:
|
| 87 |
return 0.9, False
|
| 88 |
|
| 89 |
+
def _has_meaningful_negative_prompt(self, negative_prompt: str) -> bool:
|
| 90 |
+
"""Check if negative prompt is meaningful (not default/empty)"""
|
| 91 |
+
return negative_prompt and negative_prompt.strip() and negative_prompt.strip() != "NO USER INPUT"
|
| 92 |
+
|
| 93 |
+
def _build_logits_processor(self, repetition_penalty: float) -> LogitsProcessorList:
|
| 94 |
+
"""Build logits processor list with repetition penalty if needed"""
|
| 95 |
+
logits_processor = LogitsProcessorList()
|
| 96 |
+
if repetition_penalty != 1.0:
|
| 97 |
+
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
| 98 |
+
return logits_processor
|
| 99 |
+
|
| 100 |
+
def _setup_constrained_processor(
|
| 101 |
+
self,
|
| 102 |
+
use_constrained_decoding: bool,
|
| 103 |
+
constrained_decoding_debug: bool,
|
| 104 |
+
target_duration: Optional[float],
|
| 105 |
+
user_metadata: Optional[Dict[str, Optional[str]]],
|
| 106 |
+
stop_at_reasoning: bool,
|
| 107 |
+
skip_genres: bool,
|
| 108 |
+
skip_caption: bool,
|
| 109 |
+
skip_language: bool,
|
| 110 |
+
generation_phase: str,
|
| 111 |
+
is_batch: bool = False,
|
| 112 |
+
metadata_temperature: Optional[float] = None,
|
| 113 |
+
codes_temperature: Optional[float] = None,
|
| 114 |
+
) -> Optional[MetadataConstrainedLogitsProcessor]:
|
| 115 |
+
"""Setup and configure constrained processor for generation"""
|
| 116 |
+
use_phase_temperatures = not is_batch and (metadata_temperature is not None or codes_temperature is not None)
|
| 117 |
+
|
| 118 |
+
if not use_constrained_decoding and not use_phase_temperatures:
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
# Reset processor state for new generation
|
| 122 |
+
self.constrained_processor.reset()
|
| 123 |
+
|
| 124 |
+
# Use shared processor, just update settings
|
| 125 |
+
self.constrained_processor.enabled = use_constrained_decoding
|
| 126 |
+
self.constrained_processor.debug = constrained_decoding_debug
|
| 127 |
+
|
| 128 |
+
# Phase temperatures only supported in single mode
|
| 129 |
+
if use_phase_temperatures:
|
| 130 |
+
self.constrained_processor.metadata_temperature = metadata_temperature
|
| 131 |
+
self.constrained_processor.codes_temperature = codes_temperature
|
| 132 |
+
else:
|
| 133 |
+
self.constrained_processor.metadata_temperature = None
|
| 134 |
+
self.constrained_processor.codes_temperature = None
|
| 135 |
+
|
| 136 |
+
self.constrained_processor.set_target_duration(target_duration)
|
| 137 |
+
|
| 138 |
+
# Batch mode uses default/disabled settings for these options
|
| 139 |
+
if is_batch:
|
| 140 |
+
self.constrained_processor.set_user_metadata(None)
|
| 141 |
+
self.constrained_processor.set_stop_at_reasoning(False)
|
| 142 |
+
self.constrained_processor.set_skip_genres(True)
|
| 143 |
+
self.constrained_processor.set_skip_caption(True)
|
| 144 |
+
self.constrained_processor.set_skip_language(True)
|
| 145 |
+
else:
|
| 146 |
+
# Single mode uses provided settings
|
| 147 |
+
self.constrained_processor.set_user_metadata(user_metadata)
|
| 148 |
+
self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
|
| 149 |
+
self.constrained_processor.set_skip_genres(skip_genres)
|
| 150 |
+
self.constrained_processor.set_skip_caption(skip_caption)
|
| 151 |
+
self.constrained_processor.set_skip_language(skip_language)
|
| 152 |
+
|
| 153 |
+
# Set generation phase for phase-aware processing
|
| 154 |
+
self.constrained_processor.set_generation_phase(generation_phase)
|
| 155 |
+
|
| 156 |
+
return self.constrained_processor
|
| 157 |
+
|
| 158 |
+
def _build_unconditional_prompt(
|
| 159 |
+
self,
|
| 160 |
+
caption: str,
|
| 161 |
+
lyrics: str,
|
| 162 |
+
cot_text: str,
|
| 163 |
+
negative_prompt: str,
|
| 164 |
+
generation_phase: str,
|
| 165 |
+
is_batch: bool = False,
|
| 166 |
+
) -> str:
|
| 167 |
+
"""Build unconditional prompt for CFG based on generation phase and batch mode"""
|
| 168 |
+
if is_batch or generation_phase == "codes":
|
| 169 |
+
# Codes phase or batch mode: use empty CoT in unconditional prompt
|
| 170 |
+
return self.build_formatted_prompt_with_cot(
|
| 171 |
+
caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=negative_prompt
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
# CoT phase (single mode only): unconditional prompt
|
| 175 |
+
# If negative_prompt is provided, use it as caption; otherwise remove caption and keep only lyrics
|
| 176 |
+
return self.build_formatted_prompt(
|
| 177 |
+
caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def _load_pytorch_model(self, model_path: str, device: str) -> Tuple[bool, str]:
|
| 181 |
+
"""Load PyTorch model from path and return (success, status_message)"""
|
| 182 |
+
try:
|
| 183 |
+
self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
|
| 184 |
+
if not self.offload_to_cpu:
|
| 185 |
+
self.llm = self.llm.to(device).to(self.dtype)
|
| 186 |
+
else:
|
| 187 |
+
self.llm = self.llm.to("cpu").to(self.dtype)
|
| 188 |
+
self.llm.eval()
|
| 189 |
+
self.llm_backend = "pt"
|
| 190 |
+
self.llm_initialized = True
|
| 191 |
+
logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}")
|
| 192 |
+
status_msg = f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nBackend: PyTorch\nDevice: {device}"
|
| 193 |
+
return True, status_msg
|
| 194 |
+
except Exception as e:
|
| 195 |
+
return False, f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 196 |
+
|
| 197 |
+
def _apply_top_k_filter(self, logits: torch.Tensor, top_k: Optional[int]) -> torch.Tensor:
|
| 198 |
+
"""Apply top-k filtering to logits"""
|
| 199 |
+
if top_k is not None and top_k > 0:
|
| 200 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 201 |
+
logits[indices_to_remove] = float('-inf')
|
| 202 |
+
return logits
|
| 203 |
+
|
| 204 |
+
def _apply_top_p_filter(self, logits: torch.Tensor, top_p: Optional[float]) -> torch.Tensor:
|
| 205 |
+
"""Apply top-p (nucleus) filtering to logits"""
|
| 206 |
+
if top_p is not None and 0.0 < top_p < 1.0:
|
| 207 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 208 |
+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 209 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 210 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 211 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 212 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 213 |
+
logits[indices_to_remove] = float('-inf')
|
| 214 |
+
return logits
|
| 215 |
+
|
| 216 |
+
def _sample_tokens(self, logits: torch.Tensor, temperature: float) -> torch.Tensor:
|
| 217 |
+
"""Sample tokens from logits with temperature"""
|
| 218 |
+
if temperature > 0:
|
| 219 |
+
logits = logits / temperature
|
| 220 |
+
probs = torch.softmax(logits, dim=-1)
|
| 221 |
+
return torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 222 |
+
else:
|
| 223 |
+
return torch.argmax(logits, dim=-1)
|
| 224 |
+
|
| 225 |
+
def _check_eos_token(self, tokens: torch.Tensor, eos_token_id: int, pad_token_id: Optional[int]) -> bool:
|
| 226 |
+
"""Check if any token in the batch is EOS or pad token"""
|
| 227 |
+
if torch.any(tokens == eos_token_id):
|
| 228 |
+
return True
|
| 229 |
+
if pad_token_id is not None and pad_token_id != eos_token_id:
|
| 230 |
+
if torch.any(tokens == pad_token_id):
|
| 231 |
+
return True
|
| 232 |
+
return False
|
| 233 |
+
|
| 234 |
+
def _update_constrained_processor_state(self, constrained_processor: Optional[MetadataConstrainedLogitsProcessor], tokens: torch.Tensor):
|
| 235 |
+
"""Update constrained processor state with generated tokens"""
|
| 236 |
+
if constrained_processor is not None:
|
| 237 |
+
for b in range(tokens.shape[0]):
|
| 238 |
+
constrained_processor.update_state(tokens[b].item())
|
| 239 |
+
|
| 240 |
+
def _forward_pass(
|
| 241 |
+
self,
|
| 242 |
+
model: Any,
|
| 243 |
+
generated_ids: torch.Tensor,
|
| 244 |
+
model_kwargs: Dict[str, Any],
|
| 245 |
+
past_key_values: Optional[Any],
|
| 246 |
+
use_cache: bool,
|
| 247 |
+
) -> Any:
|
| 248 |
+
"""Perform forward pass with KV cache support"""
|
| 249 |
+
if past_key_values is None:
|
| 250 |
+
outputs = model(
|
| 251 |
+
input_ids=generated_ids,
|
| 252 |
+
**model_kwargs,
|
| 253 |
+
use_cache=use_cache,
|
| 254 |
+
)
|
| 255 |
+
else:
|
| 256 |
+
outputs = model(
|
| 257 |
+
input_ids=generated_ids[:, -1:],
|
| 258 |
+
past_key_values=past_key_values,
|
| 259 |
+
**model_kwargs,
|
| 260 |
+
use_cache=use_cache,
|
| 261 |
+
)
|
| 262 |
+
return outputs
|
| 263 |
+
|
| 264 |
+
def _normalize_batch_input(self, formatted_prompts: Union[str, List[str]]) -> Tuple[List[str], bool]:
|
| 265 |
+
"""Normalize batch input: convert single string to list and return (list, is_batch)"""
|
| 266 |
+
is_batch = isinstance(formatted_prompts, list)
|
| 267 |
+
if is_batch:
|
| 268 |
+
return formatted_prompts, is_batch
|
| 269 |
+
else:
|
| 270 |
+
return [formatted_prompts], is_batch
|
| 271 |
+
|
| 272 |
def initialize(
|
| 273 |
self,
|
| 274 |
checkpoint_dir: str,
|
|
|
|
| 310 |
|
| 311 |
logger.info("loading 5Hz LM tokenizer...")
|
| 312 |
start_time = time.time()
|
| 313 |
+
# TODO: load tokenizer too slow, not found solution yet
|
| 314 |
llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True)
|
| 315 |
logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
|
| 316 |
self.llm_tokenizer = llm_tokenizer
|
|
|
|
| 335 |
# vllm initialization failed, fallback to PyTorch
|
| 336 |
if not self.llm_initialized:
|
| 337 |
logger.warning("vllm initialization failed, falling back to PyTorch backend")
|
| 338 |
+
success, status_msg = self._load_pytorch_model(full_lm_model_path, device)
|
| 339 |
+
if not success:
|
| 340 |
+
return status_msg, False
|
| 341 |
+
status_msg = f"✅ 5Hz LM initialized successfully (PyTorch fallback)\nModel: {full_lm_model_path}\nBackend: PyTorch"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
# If vllm initialization succeeded, self.llm_initialized should already be True
|
| 343 |
else:
|
| 344 |
# Use PyTorch backend (pt)
|
| 345 |
+
success, status_msg = self._load_pytorch_model(full_lm_model_path, device)
|
| 346 |
+
if not success:
|
| 347 |
+
return status_msg, False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
return status_msg, True
|
| 350 |
|
| 351 |
except Exception as e:
|
| 352 |
+
return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
|
|
|
|
| 353 |
|
| 354 |
def _initialize_5hz_lm_vllm(self, model_path: str) -> str:
|
| 355 |
"""Initialize 5Hz LM model using vllm backend"""
|
|
|
|
| 395 |
return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
|
| 396 |
except Exception as e:
|
| 397 |
self.llm_initialized = False
|
| 398 |
+
return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
|
|
|
| 399 |
|
| 400 |
+
def _run_vllm(
|
| 401 |
self,
|
| 402 |
+
formatted_prompts: Union[str, List[str]],
|
| 403 |
temperature: float,
|
| 404 |
cfg_scale: float,
|
| 405 |
negative_prompt: str,
|
|
|
|
| 408 |
repetition_penalty: float,
|
| 409 |
use_constrained_decoding: bool = True,
|
| 410 |
constrained_decoding_debug: bool = False,
|
| 411 |
+
metadata_temperature: Optional[float] = None,
|
| 412 |
codes_temperature: Optional[float] = None,
|
| 413 |
target_duration: Optional[float] = None,
|
| 414 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
|
|
|
| 420 |
caption: str = "",
|
| 421 |
lyrics: str = "",
|
| 422 |
cot_text: str = "",
|
| 423 |
+
seeds: Optional[List[int]] = None,
|
| 424 |
+
) -> Union[str, List[str]]:
|
| 425 |
+
"""
|
| 426 |
+
Unified vllm generation function supporting both single and batch modes.
|
| 427 |
+
Accepts either a single formatted prompt (str) or a list of formatted prompts (List[str]).
|
| 428 |
+
Returns a single string for single mode, or a list of strings for batch mode.
|
| 429 |
+
"""
|
| 430 |
from nanovllm import SamplingParams
|
| 431 |
|
| 432 |
+
# Determine if batch mode
|
| 433 |
+
formatted_prompt_list, is_batch = self._normalize_batch_input(formatted_prompts)
|
| 434 |
+
batch_size = len(formatted_prompt_list)
|
| 435 |
+
|
| 436 |
# Determine effective temperature for sampler
|
| 437 |
+
# Batch mode doesn't support phase temperatures, so use simple temperature
|
| 438 |
+
# Single mode supports phase temperatures
|
| 439 |
+
use_phase_temperatures = not is_batch and (metadata_temperature is not None or codes_temperature is not None)
|
| 440 |
effective_sampler_temp = 1.0 if use_phase_temperatures else temperature
|
| 441 |
|
| 442 |
+
# Setup constrained processor
|
| 443 |
+
constrained_processor = self._setup_constrained_processor(
|
| 444 |
+
use_constrained_decoding=use_constrained_decoding or use_phase_temperatures,
|
| 445 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 446 |
+
target_duration=target_duration,
|
| 447 |
+
user_metadata=user_metadata,
|
| 448 |
+
stop_at_reasoning=stop_at_reasoning,
|
| 449 |
+
skip_genres=skip_genres,
|
| 450 |
+
skip_caption=skip_caption,
|
| 451 |
+
skip_language=skip_language,
|
| 452 |
+
generation_phase=generation_phase,
|
| 453 |
+
is_batch=is_batch,
|
| 454 |
+
metadata_temperature=metadata_temperature,
|
| 455 |
+
codes_temperature=codes_temperature,
|
| 456 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
|
| 458 |
sampling_params = SamplingParams(
|
| 459 |
max_tokens=self.max_model_len - 64,
|
|
|
|
| 468 |
|
| 469 |
if cfg_scale > 1.0:
|
| 470 |
# Build unconditional prompt based on generation phase
|
| 471 |
+
formatted_unconditional_prompt = self._build_unconditional_prompt(
|
| 472 |
+
caption=caption,
|
| 473 |
+
lyrics=lyrics,
|
| 474 |
+
cot_text=cot_text,
|
| 475 |
+
negative_prompt=negative_prompt,
|
| 476 |
+
generation_phase=generation_phase,
|
| 477 |
+
is_batch=is_batch,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
)
|
| 479 |
unconditional_prompts = [formatted_unconditional_prompt] * batch_size
|
| 480 |
|
| 481 |
outputs = self.llm.generate(
|
| 482 |
+
formatted_prompt_list,
|
| 483 |
sampling_params,
|
| 484 |
unconditional_prompts=unconditional_prompts,
|
| 485 |
)
|
| 486 |
else:
|
| 487 |
+
outputs = self.llm.generate(formatted_prompt_list, sampling_params)
|
| 488 |
+
|
| 489 |
+
# Extract text from outputs
|
| 490 |
output_texts = []
|
| 491 |
for output in outputs:
|
| 492 |
if hasattr(output, "outputs") and len(output.outputs) > 0:
|
|
|
|
| 497 |
output_texts.append(output["text"])
|
| 498 |
else:
|
| 499 |
output_texts.append(str(output))
|
|
|
|
|
|
|
| 500 |
|
| 501 |
+
# Return single string for single mode, list for batch mode
|
| 502 |
+
return output_texts[0] if not is_batch else output_texts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
|
| 504 |
+
def _run_pt_single(
|
| 505 |
self,
|
| 506 |
formatted_prompt: str,
|
| 507 |
temperature: float,
|
|
|
|
| 510 |
top_k: Optional[int],
|
| 511 |
top_p: Optional[float],
|
| 512 |
repetition_penalty: float,
|
| 513 |
+
use_constrained_decoding: bool,
|
| 514 |
+
constrained_decoding_debug: bool,
|
| 515 |
+
target_duration: Optional[float],
|
| 516 |
+
user_metadata: Optional[Dict[str, Optional[str]]],
|
| 517 |
+
stop_at_reasoning: bool,
|
| 518 |
+
skip_genres: bool,
|
| 519 |
+
skip_caption: bool,
|
| 520 |
+
skip_language: bool,
|
| 521 |
+
generation_phase: str,
|
| 522 |
+
caption: str,
|
| 523 |
+
lyrics: str,
|
| 524 |
+
cot_text: str,
|
| 525 |
) -> str:
|
| 526 |
+
"""Internal helper function for single-item PyTorch generation."""
|
| 527 |
inputs = self.llm_tokenizer(
|
| 528 |
formatted_prompt,
|
| 529 |
return_tensors="pt",
|
|
|
|
| 531 |
truncation=True,
|
| 532 |
)
|
| 533 |
|
| 534 |
+
# Setup constrained processor
|
| 535 |
+
constrained_processor = self._setup_constrained_processor(
|
| 536 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 537 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 538 |
+
target_duration=target_duration,
|
| 539 |
+
user_metadata=user_metadata,
|
| 540 |
+
stop_at_reasoning=stop_at_reasoning,
|
| 541 |
+
skip_genres=skip_genres,
|
| 542 |
+
skip_caption=skip_caption,
|
| 543 |
+
skip_language=skip_language,
|
| 544 |
+
generation_phase=generation_phase,
|
| 545 |
+
is_batch=False,
|
| 546 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
|
| 548 |
with self._load_model_context():
|
| 549 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
| 552 |
max_new_tokens = min(max_new_tokens, self.max_model_len - 64)
|
| 553 |
|
| 554 |
# Build logits processor list (only for CFG and repetition penalty)
|
| 555 |
+
logits_processor = self._build_logits_processor(repetition_penalty)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
|
| 557 |
if cfg_scale > 1.0:
|
| 558 |
# Build unconditional prompt based on generation phase
|
| 559 |
+
formatted_unconditional_prompt = self._build_unconditional_prompt(
|
| 560 |
+
caption=caption,
|
| 561 |
+
lyrics=lyrics,
|
| 562 |
+
cot_text=cot_text,
|
| 563 |
+
negative_prompt=negative_prompt,
|
| 564 |
+
generation_phase=generation_phase,
|
| 565 |
+
is_batch=False,
|
| 566 |
+
)
|
|
|
|
|
|
|
|
|
|
| 567 |
|
| 568 |
# Tokenize both prompts together to ensure same length (with left padding)
|
| 569 |
# Left padding is important for generation tasks
|
|
|
|
| 656 |
|
| 657 |
output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
|
| 658 |
return output_text
|
| 659 |
+
|
| 660 |
+
def _run_pt(
|
| 661 |
+
self,
|
| 662 |
+
formatted_prompts: Union[str, List[str]],
|
| 663 |
+
temperature: float,
|
| 664 |
+
cfg_scale: float,
|
| 665 |
+
negative_prompt: str,
|
| 666 |
+
top_k: Optional[int],
|
| 667 |
+
top_p: Optional[float],
|
| 668 |
+
repetition_penalty: float,
|
| 669 |
+
use_constrained_decoding: bool = True,
|
| 670 |
+
constrained_decoding_debug: bool = False,
|
| 671 |
+
target_duration: Optional[float] = None,
|
| 672 |
+
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 673 |
+
stop_at_reasoning: bool = False,
|
| 674 |
+
skip_genres: bool = True,
|
| 675 |
+
skip_caption: bool = False,
|
| 676 |
+
skip_language: bool = False,
|
| 677 |
+
generation_phase: str = "cot",
|
| 678 |
+
caption: str = "",
|
| 679 |
+
lyrics: str = "",
|
| 680 |
+
cot_text: str = "",
|
| 681 |
+
seeds: Optional[List[int]] = None,
|
| 682 |
+
) -> Union[str, List[str]]:
|
| 683 |
+
"""
|
| 684 |
+
Unified PyTorch generation function supporting both single and batch modes.
|
| 685 |
+
Accepts either a single formatted prompt (str) or a list of formatted prompts (List[str]).
|
| 686 |
+
Returns a single string for single mode, or a list of strings for batch mode.
|
| 687 |
+
Note: PyTorch backend processes batch items sequentially (doesn't support true batching efficiently).
|
| 688 |
+
"""
|
| 689 |
+
# Determine if batch mode
|
| 690 |
+
formatted_prompt_list, is_batch = self._normalize_batch_input(formatted_prompts)
|
| 691 |
+
|
| 692 |
+
# For batch mode, process each item sequentially with different seeds
|
| 693 |
+
if is_batch:
|
| 694 |
+
output_texts = []
|
| 695 |
+
for i, formatted_prompt in enumerate(formatted_prompt_list):
|
| 696 |
+
# Set seed for this item if provided
|
| 697 |
+
if seeds and i < len(seeds):
|
| 698 |
+
torch.manual_seed(seeds[i])
|
| 699 |
+
if torch.cuda.is_available():
|
| 700 |
+
torch.cuda.manual_seed_all(seeds[i])
|
| 701 |
+
|
| 702 |
+
# Generate using single-item method with batch-mode defaults
|
| 703 |
+
output_text = self._run_pt_single(
|
| 704 |
+
formatted_prompt=formatted_prompt,
|
| 705 |
+
temperature=temperature,
|
| 706 |
+
cfg_scale=cfg_scale,
|
| 707 |
+
negative_prompt=negative_prompt,
|
| 708 |
+
top_k=top_k,
|
| 709 |
+
top_p=top_p,
|
| 710 |
+
repetition_penalty=repetition_penalty,
|
| 711 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 712 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 713 |
+
target_duration=target_duration,
|
| 714 |
+
user_metadata=None,
|
| 715 |
+
stop_at_reasoning=False,
|
| 716 |
+
skip_genres=True,
|
| 717 |
+
skip_caption=True,
|
| 718 |
+
skip_language=True,
|
| 719 |
+
generation_phase=generation_phase,
|
| 720 |
+
caption=caption,
|
| 721 |
+
lyrics=lyrics,
|
| 722 |
+
cot_text=cot_text,
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
output_texts.append(output_text)
|
| 726 |
+
|
| 727 |
+
return output_texts
|
| 728 |
+
|
| 729 |
+
# Single mode: process the formatted prompt
|
| 730 |
+
formatted_prompt = formatted_prompt_list[0]
|
| 731 |
+
|
| 732 |
+
return self._run_pt_single(
|
| 733 |
+
formatted_prompt=formatted_prompt,
|
| 734 |
+
temperature=temperature,
|
| 735 |
+
cfg_scale=cfg_scale,
|
| 736 |
+
negative_prompt=negative_prompt,
|
| 737 |
+
top_k=top_k,
|
| 738 |
+
top_p=top_p,
|
| 739 |
+
repetition_penalty=repetition_penalty,
|
| 740 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 741 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 742 |
+
target_duration=target_duration,
|
| 743 |
+
user_metadata=user_metadata,
|
| 744 |
+
stop_at_reasoning=stop_at_reasoning,
|
| 745 |
+
skip_genres=skip_genres,
|
| 746 |
+
skip_caption=skip_caption,
|
| 747 |
+
skip_language=skip_language,
|
| 748 |
+
generation_phase=generation_phase,
|
| 749 |
+
caption=caption,
|
| 750 |
+
lyrics=lyrics,
|
| 751 |
+
cot_text=cot_text,
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
def has_all_metas(self, user_metadata: Optional[Dict[str, Optional[str]]]) -> bool:
|
| 755 |
"""Check if all required metadata are present."""
|
| 756 |
if user_metadata is None:
|
|
|
|
| 798 |
constrained_decoding_debug: bool = False,
|
| 799 |
target_duration: Optional[float] = None,
|
| 800 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 801 |
+
use_cot_metas: bool = True,
|
| 802 |
use_cot_caption: bool = True,
|
| 803 |
use_cot_language: bool = True,
|
| 804 |
+
batch_size: Optional[int] = None,
|
| 805 |
+
seeds: Optional[List[int]] = None,
|
| 806 |
+
progress=None,
|
| 807 |
+
) -> Dict[str, Any]:
|
| 808 |
"""Two-phase LM generation: CoT generation followed by audio codes generation.
|
| 809 |
|
| 810 |
- infer_type='dit': Phase 1 only - generate CoT and return metas (no audio codes)
|
|
|
|
| 817 |
If specified, constrained decoding will inject these values directly.
|
| 818 |
use_cot_caption: Whether to generate caption in CoT (default True).
|
| 819 |
use_cot_language: Whether to generate language in CoT (default True).
|
| 820 |
+
batch_size: Optional batch size for batch generation. If None or 1, returns single result.
|
| 821 |
+
If > 1, returns batch results (lists).
|
| 822 |
+
seeds: Optional list of seeds for batch generation (for reproducibility).
|
| 823 |
+
Only used when batch_size > 1. TODO: not used yet
|
| 824 |
|
| 825 |
+
Returns:
|
| 826 |
+
Dictionary containing:
|
| 827 |
+
- metadata: Dict or List[Dict] - Generated metadata
|
| 828 |
+
- audio_codes: str or List[str] - Generated audio codes
|
| 829 |
+
- success: bool - Whether generation succeeded
|
| 830 |
+
- error: Optional[str] - Error message if failed
|
| 831 |
+
- extra_outputs: Dict with time_costs and other info
|
| 832 |
+
"""
|
| 833 |
+
if progress is None:
|
| 834 |
+
def progress(*args, **kwargs):
|
| 835 |
+
pass
|
| 836 |
+
|
| 837 |
infer_type = (infer_type or "").strip().lower()
|
| 838 |
if infer_type not in {"dit", "llm_dit"}:
|
| 839 |
+
error_msg = f"invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
|
| 840 |
+
return {
|
| 841 |
+
"metadata": [] if (batch_size and batch_size > 1) else {},
|
| 842 |
+
"audio_codes": [] if (batch_size and batch_size > 1) else "",
|
| 843 |
+
"success": False,
|
| 844 |
+
"error": error_msg,
|
| 845 |
+
"extra_outputs": {"time_costs": {}},
|
| 846 |
+
}
|
| 847 |
+
|
| 848 |
+
# Determine if batch mode
|
| 849 |
+
is_batch = batch_size and batch_size > 1
|
| 850 |
+
actual_batch_size = batch_size if is_batch else 1
|
| 851 |
+
|
| 852 |
+
# Initialize variables
|
| 853 |
metadata = {}
|
| 854 |
audio_codes = ""
|
| 855 |
has_all_metas = self.has_all_metas(user_metadata)
|
|
|
|
|
|
|
| 856 |
phase1_time = 0.0
|
| 857 |
phase2_time = 0.0
|
| 858 |
|
| 859 |
+
# Handle seeds for batch mode
|
| 860 |
+
if is_batch:
|
| 861 |
+
if seeds is None:
|
| 862 |
+
seeds = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
|
| 863 |
+
elif len(seeds) < actual_batch_size:
|
| 864 |
+
seeds = list(seeds) + [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size - len(seeds))]
|
| 865 |
+
else:
|
| 866 |
+
seeds = seeds[:actual_batch_size]
|
| 867 |
+
|
| 868 |
# ========== PHASE 1: CoT Generation ==========
|
| 869 |
+
# Skip CoT if all metadata are user-provided OR caption is already formatted
|
| 870 |
+
progress(0.1, f"Phase 1: Generating CoT metadata (once for all items)...")
|
| 871 |
+
if not has_all_metas and use_cot_metas:
|
| 872 |
+
if is_batch:
|
| 873 |
+
logger.info("Batch Phase 1: Generating CoT metadata (once for all items)...")
|
| 874 |
+
else:
|
| 875 |
+
logger.info("Phase 1: Generating CoT metadata...")
|
| 876 |
phase1_start = time.time()
|
| 877 |
|
| 878 |
# Build formatted prompt for CoT phase
|
| 879 |
formatted_prompt = self.build_formatted_prompt(caption, lyrics, generation_phase="cot")
|
| 880 |
+
|
| 881 |
logger.info(f"generate_with_stop_condition: formatted_prompt={formatted_prompt}")
|
| 882 |
# Generate CoT (stop at </think>)
|
| 883 |
cot_output_text, status = self.generate_from_formatted_prompt(
|
|
|
|
| 907 |
phase1_time = time.time() - phase1_start
|
| 908 |
|
| 909 |
if not cot_output_text:
|
| 910 |
+
return {
|
| 911 |
+
"metadata": [] if is_batch else {},
|
| 912 |
+
"audio_codes": [] if is_batch else "",
|
| 913 |
+
"success": False,
|
| 914 |
+
"error": status,
|
| 915 |
+
"extra_outputs": {"time_costs": {"phase1_time": phase1_time}},
|
| 916 |
+
}
|
| 917 |
|
| 918 |
# Parse metadata from CoT output
|
| 919 |
metadata, _ = self.parse_lm_output(cot_output_text)
|
| 920 |
+
if is_batch:
|
| 921 |
+
logger.info(f"Batch Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
|
| 922 |
+
else:
|
| 923 |
+
logger.info(f"Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
|
| 924 |
else:
|
| 925 |
# Use user-provided metadata
|
| 926 |
+
if is_batch:
|
| 927 |
+
logger.info("Batch Phase 1: Using user-provided metadata (skipping generation)")
|
| 928 |
+
else:
|
| 929 |
+
logger.info("Phase 1: Using user-provided metadata (skipping generation)")
|
| 930 |
metadata = {k: v for k, v in user_metadata.items() if v is not None}
|
| 931 |
|
| 932 |
# If infer_type is 'dit', stop here and return only metadata
|
| 933 |
if infer_type == "dit":
|
| 934 |
+
if is_batch:
|
| 935 |
+
metadata_list = [metadata.copy() for _ in range(actual_batch_size)]
|
| 936 |
+
return {
|
| 937 |
+
"metadata": metadata_list,
|
| 938 |
+
"audio_codes": [""] * actual_batch_size,
|
| 939 |
+
"success": True,
|
| 940 |
+
"error": None,
|
| 941 |
+
"extra_outputs": {
|
| 942 |
+
"time_costs": {
|
| 943 |
+
"phase1_time": phase1_time,
|
| 944 |
+
"total_time": phase1_time,
|
| 945 |
+
}
|
| 946 |
+
},
|
| 947 |
+
}
|
| 948 |
+
else:
|
| 949 |
+
return {
|
| 950 |
+
"metadata": metadata,
|
| 951 |
+
"audio_codes": "",
|
| 952 |
+
"success": True,
|
| 953 |
+
"error": None,
|
| 954 |
+
"extra_outputs": {
|
| 955 |
+
"time_costs": {
|
| 956 |
+
"phase1_time": phase1_time,
|
| 957 |
+
"total_time": phase1_time,
|
| 958 |
+
}
|
| 959 |
+
},
|
| 960 |
+
}
|
| 961 |
|
| 962 |
# ========== PHASE 2: Audio Codes Generation ==========
|
| 963 |
+
if is_batch:
|
| 964 |
+
logger.info(f"Batch Phase 2: Generating audio codes for {actual_batch_size} items...")
|
| 965 |
+
else:
|
| 966 |
+
logger.info("Phase 2: Generating audio codes...")
|
| 967 |
phase2_start = time.time()
|
| 968 |
|
| 969 |
# Format metadata as CoT using YAML (matching training format)
|
|
|
|
| 972 |
# Build formatted prompt with CoT for codes generation phase
|
| 973 |
formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
|
| 974 |
logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 975 |
|
| 976 |
+
progress(0.5, f"Phase 2: Generating audio codes for {actual_batch_size} items...")
|
| 977 |
+
if is_batch:
|
| 978 |
+
# Batch mode: generate codes for all items
|
| 979 |
+
formatted_prompts = [formatted_prompt_with_cot] * actual_batch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 980 |
|
| 981 |
+
# Call backend-specific batch generation
|
| 982 |
+
try:
|
| 983 |
+
if self.llm_backend == "vllm":
|
| 984 |
+
codes_outputs = self._run_vllm(
|
| 985 |
+
formatted_prompts=formatted_prompts,
|
| 986 |
+
temperature=temperature,
|
| 987 |
+
cfg_scale=cfg_scale,
|
| 988 |
+
negative_prompt=negative_prompt,
|
| 989 |
+
top_k=top_k,
|
| 990 |
+
top_p=top_p,
|
| 991 |
+
repetition_penalty=repetition_penalty,
|
| 992 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 993 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 994 |
+
target_duration=target_duration,
|
| 995 |
+
generation_phase="codes",
|
| 996 |
+
caption=caption,
|
| 997 |
+
lyrics=lyrics,
|
| 998 |
+
cot_text=cot_text,
|
| 999 |
+
seeds=seeds,
|
| 1000 |
+
)
|
| 1001 |
+
else: # pt backend
|
| 1002 |
+
codes_outputs = self._run_pt(
|
| 1003 |
+
formatted_prompts=formatted_prompts,
|
| 1004 |
+
temperature=temperature,
|
| 1005 |
+
cfg_scale=cfg_scale,
|
| 1006 |
+
negative_prompt=negative_prompt,
|
| 1007 |
+
top_k=top_k,
|
| 1008 |
+
top_p=top_p,
|
| 1009 |
+
repetition_penalty=repetition_penalty,
|
| 1010 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 1011 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1012 |
+
target_duration=target_duration,
|
| 1013 |
+
generation_phase="codes",
|
| 1014 |
+
caption=caption,
|
| 1015 |
+
lyrics=lyrics,
|
| 1016 |
+
cot_text=cot_text,
|
| 1017 |
+
seeds=seeds,
|
| 1018 |
+
)
|
| 1019 |
+
except Exception as e:
|
| 1020 |
+
error_msg = f"Error in batch codes generation: {str(e)}"
|
| 1021 |
+
logger.error(error_msg)
|
| 1022 |
+
return {
|
| 1023 |
+
"metadata": [],
|
| 1024 |
+
"audio_codes": [],
|
| 1025 |
+
"success": False,
|
| 1026 |
+
"error": error_msg,
|
| 1027 |
+
"extra_outputs": {
|
| 1028 |
+
"time_costs": {
|
| 1029 |
+
"phase1_time": phase1_time,
|
| 1030 |
+
"phase2_time": 0.0,
|
| 1031 |
+
"total_time": phase1_time,
|
| 1032 |
+
}
|
| 1033 |
+
},
|
| 1034 |
+
}
|
| 1035 |
+
|
| 1036 |
+
# Parse audio codes from each output
|
| 1037 |
+
audio_codes_list = []
|
| 1038 |
+
metadata_list = []
|
| 1039 |
+
for output_text in codes_outputs:
|
| 1040 |
+
_, audio_codes_item = self.parse_lm_output(output_text)
|
| 1041 |
+
audio_codes_list.append(audio_codes_item)
|
| 1042 |
+
metadata_list.append(metadata.copy()) # Same metadata for all
|
| 1043 |
+
|
| 1044 |
+
phase2_time = time.time() - phase2_start
|
| 1045 |
+
|
| 1046 |
+
# Log results
|
| 1047 |
+
codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list]
|
| 1048 |
+
logger.info(f"Batch Phase 2 completed in {phase2_time:.2f}s. Generated codes: {codes_counts}")
|
| 1049 |
+
|
| 1050 |
+
total_time = phase1_time + phase2_time
|
| 1051 |
+
return {
|
| 1052 |
+
"metadata": metadata_list,
|
| 1053 |
+
"audio_codes": audio_codes_list,
|
| 1054 |
+
"success": True,
|
| 1055 |
+
"error": None,
|
| 1056 |
+
"extra_outputs": {
|
| 1057 |
+
"time_costs": {
|
| 1058 |
+
"phase1_time": phase1_time,
|
| 1059 |
+
"phase2_time": phase2_time,
|
| 1060 |
+
"total_time": total_time,
|
| 1061 |
+
},
|
| 1062 |
+
"codes_counts": codes_counts,
|
| 1063 |
+
"total_codes": sum(codes_counts),
|
| 1064 |
+
},
|
| 1065 |
+
}
|
| 1066 |
+
else:
|
| 1067 |
+
# Single mode: generate codes for one item
|
| 1068 |
+
codes_output_text, status = self.generate_from_formatted_prompt(
|
| 1069 |
+
formatted_prompt=formatted_prompt_with_cot,
|
| 1070 |
+
cfg={
|
| 1071 |
+
"temperature": temperature,
|
| 1072 |
+
"cfg_scale": cfg_scale,
|
| 1073 |
+
"negative_prompt": negative_prompt,
|
| 1074 |
+
"top_k": top_k,
|
| 1075 |
+
"top_p": top_p,
|
| 1076 |
+
"repetition_penalty": repetition_penalty,
|
| 1077 |
+
"target_duration": target_duration,
|
| 1078 |
+
"user_metadata": None, # No user metadata injection in Phase 2
|
| 1079 |
+
"skip_caption": True, # Skip caption since CoT is already included
|
| 1080 |
+
"skip_language": True, # Skip language since CoT is already included
|
| 1081 |
+
"generation_phase": "codes",
|
| 1082 |
+
# Pass context for building unconditional prompt in codes phase
|
| 1083 |
+
"caption": caption,
|
| 1084 |
+
"lyrics": lyrics,
|
| 1085 |
+
"cot_text": cot_text,
|
| 1086 |
+
},
|
| 1087 |
use_constrained_decoding=use_constrained_decoding,
|
| 1088 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 1089 |
+
stop_at_reasoning=False, # Generate codes until EOS
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1090 |
)
|
| 1091 |
|
| 1092 |
+
if not codes_output_text:
|
| 1093 |
+
total_time = phase1_time + phase2_time
|
| 1094 |
+
return {
|
| 1095 |
+
"metadata": metadata,
|
| 1096 |
+
"audio_codes": "",
|
| 1097 |
+
"success": False,
|
| 1098 |
+
"error": status,
|
| 1099 |
+
"extra_outputs": {
|
| 1100 |
+
"time_costs": {
|
| 1101 |
+
"phase1_time": phase1_time,
|
| 1102 |
+
"phase2_time": phase2_time,
|
| 1103 |
+
"total_time": total_time,
|
| 1104 |
+
}
|
| 1105 |
+
},
|
| 1106 |
+
}
|
| 1107 |
|
| 1108 |
+
phase2_time = time.time() - phase2_start
|
|
|
|
| 1109 |
|
| 1110 |
+
# Parse audio codes from output (metadata should be same as Phase 1)
|
| 1111 |
+
_, audio_codes = self.parse_lm_output(codes_output_text)
|
| 1112 |
+
|
| 1113 |
+
codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
|
| 1114 |
+
logger.info(f"Phase 2 completed in {phase2_time:.2f}s. Generated {codes_count} audio codes")
|
| 1115 |
+
|
| 1116 |
+
total_time = phase1_time + phase2_time
|
| 1117 |
+
return {
|
| 1118 |
+
"metadata": metadata,
|
| 1119 |
+
"audio_codes": audio_codes,
|
| 1120 |
+
"success": True,
|
| 1121 |
+
"error": None,
|
| 1122 |
+
"extra_outputs": {
|
| 1123 |
+
"time_costs": {
|
| 1124 |
+
"phase1_time": phase1_time,
|
| 1125 |
+
"phase2_time": phase2_time,
|
| 1126 |
+
"total_time": total_time,
|
| 1127 |
+
},
|
| 1128 |
+
"codes_count": codes_count,
|
| 1129 |
+
},
|
| 1130 |
+
}
|
| 1131 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1132 |
def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str:
|
| 1133 |
"""
|
| 1134 |
Build the chat-formatted prompt for 5Hz LM from caption/lyrics.
|
|
|
|
| 1150 |
if is_negative_prompt:
|
| 1151 |
# Unconditional prompt for CFG
|
| 1152 |
# Check if user provided a meaningful negative prompt (not the default)
|
| 1153 |
+
has_negative_prompt = self._has_meaningful_negative_prompt(negative_prompt)
|
| 1154 |
|
| 1155 |
if generation_phase == "cot":
|
| 1156 |
# CoT phase unconditional prompt
|
|
|
|
| 1201 |
if is_negative_prompt:
|
| 1202 |
# Unconditional prompt for codes phase
|
| 1203 |
# Check if user provided a meaningful negative prompt
|
| 1204 |
+
has_negative_prompt = self._has_meaningful_negative_prompt(negative_prompt)
|
| 1205 |
|
| 1206 |
# Use empty CoT for unconditional
|
| 1207 |
cot_for_prompt = "<think>\n</think>"
|
|
|
|
| 1484 |
|
| 1485 |
try:
|
| 1486 |
if self.llm_backend == "vllm":
|
| 1487 |
+
output_text = self._run_vllm(
|
| 1488 |
+
formatted_prompts=formatted_prompt,
|
| 1489 |
temperature=temperature,
|
| 1490 |
cfg_scale=cfg_scale,
|
| 1491 |
negative_prompt=negative_prompt,
|
|
|
|
| 1508 |
return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
|
| 1509 |
|
| 1510 |
# PyTorch backend
|
| 1511 |
+
output_text = self._run_pt(
|
| 1512 |
+
formatted_prompts=formatted_prompt,
|
| 1513 |
temperature=temperature,
|
| 1514 |
cfg_scale=cfg_scale,
|
| 1515 |
negative_prompt=negative_prompt,
|
|
|
|
| 1574 |
eos_token_id = pad_token_id
|
| 1575 |
|
| 1576 |
# Build logits processor for repetition penalty
|
| 1577 |
+
logits_processor = self._build_logits_processor(repetition_penalty)
|
|
|
|
|
|
|
| 1578 |
|
| 1579 |
with torch.no_grad():
|
| 1580 |
for step in range(max_new_tokens):
|
| 1581 |
# Forward pass
|
| 1582 |
+
outputs = self._forward_pass(model, generated_ids, model_kwargs, past_key_values, use_cache)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1583 |
|
| 1584 |
# Get logits for the last position
|
| 1585 |
next_token_logits = outputs.logits[:, -1, :] # [batch_size, vocab_size]
|
|
|
|
| 1592 |
for processor in logits_processor:
|
| 1593 |
next_token_logits = processor(generated_ids, next_token_logits)
|
| 1594 |
|
| 1595 |
+
# Apply top-k and top-p filtering
|
| 1596 |
+
next_token_logits = self._apply_top_k_filter(next_token_logits, top_k)
|
| 1597 |
+
next_token_logits = self._apply_top_p_filter(next_token_logits, top_p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1598 |
|
| 1599 |
# Apply temperature and sample
|
| 1600 |
+
next_tokens = self._sample_tokens(next_token_logits, temperature)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1601 |
|
| 1602 |
# Update constrained processor state
|
| 1603 |
+
self._update_constrained_processor_state(constrained_processor, next_tokens)
|
|
|
|
|
|
|
| 1604 |
|
| 1605 |
# Check for EOS token
|
| 1606 |
+
should_stop = self._check_eos_token(next_tokens, eos_token_id, pad_token_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1607 |
|
| 1608 |
# Append token to sequence
|
| 1609 |
next_tokens_unsqueezed = next_tokens.unsqueeze(1)
|
|
|
|
| 1679 |
eos_token_id = pad_token_id
|
| 1680 |
|
| 1681 |
# Build logits processor for non-CFG operations (repetition penalty, top_k, top_p)
|
| 1682 |
+
logits_processor = self._build_logits_processor(repetition_penalty)
|
|
|
|
|
|
|
| 1683 |
|
| 1684 |
with torch.no_grad():
|
| 1685 |
for step in range(max_new_tokens):
|
| 1686 |
# Forward pass for the entire batch (conditional + unconditional)
|
| 1687 |
+
outputs = self._forward_pass(model, generated_ids, model_kwargs, past_key_values, use_cache)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1688 |
|
| 1689 |
# Get logits for the last position
|
| 1690 |
next_token_logits = outputs.logits[:, -1, :] # [batch_size*2, vocab_size]
|
|
|
|
| 1707 |
for processor in logits_processor:
|
| 1708 |
cfg_logits = processor(current_input_ids, cfg_logits)
|
| 1709 |
|
| 1710 |
+
# Apply top-k and top-p filtering
|
| 1711 |
+
cfg_logits = self._apply_top_k_filter(cfg_logits, top_k)
|
| 1712 |
+
cfg_logits = self._apply_top_p_filter(cfg_logits, top_p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1713 |
|
| 1714 |
# Apply temperature and sample
|
| 1715 |
+
next_tokens = self._sample_tokens(cfg_logits, temperature)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1716 |
|
| 1717 |
# Update constrained processor state AFTER sampling
|
| 1718 |
+
self._update_constrained_processor_state(constrained_processor, next_tokens)
|
|
|
|
|
|
|
| 1719 |
|
| 1720 |
# Check for EOS token in conditional sequences BEFORE unsqueezing
|
| 1721 |
# Stop if any conditional sequence generates EOS token
|
| 1722 |
# next_tokens shape: [batch_size] (only conditional tokens)
|
| 1723 |
+
should_stop = self._check_eos_token(next_tokens, eos_token_id, pad_token_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1724 |
|
| 1725 |
# Apply the same sampled tokens to both conditional and unconditional sequences
|
| 1726 |
next_tokens_unsqueezed = next_tokens.unsqueeze(1)
|
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py
CHANGED
|
@@ -68,10 +68,16 @@ class ModelRunner:
|
|
| 68 |
self.model = Qwen3ForCausalLM(hf_config)
|
| 69 |
load_model(self.model, config.model)
|
| 70 |
self.sampler = Sampler()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
self.warmup_model()
|
| 72 |
self.allocate_kv_cache()
|
| 73 |
if not self.enforce_eager:
|
| 74 |
self.capture_cudagraph()
|
|
|
|
| 75 |
torch.set_default_device("cpu")
|
| 76 |
torch.set_default_dtype(default_dtype)
|
| 77 |
|
|
@@ -84,6 +90,39 @@ class ModelRunner:
|
|
| 84 |
self.shm = SharedMemory(name="nanovllm")
|
| 85 |
self.loop()
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
def exit(self):
|
| 88 |
if self.world_size > 1:
|
| 89 |
self.shm.close()
|
|
@@ -203,7 +242,7 @@ class ModelRunner:
|
|
| 203 |
if i != seq.num_blocks - 1:
|
| 204 |
end = start + self.block_size
|
| 205 |
else:
|
| 206 |
-
end = start + seq.last_block_num_tokens
|
| 207 |
slot_mapping.extend(list(range(start, end)))
|
| 208 |
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
| 209 |
block_tables = self.prepare_block_tables(seqs)
|
|
@@ -216,57 +255,58 @@ class ModelRunner:
|
|
| 216 |
return input_ids, positions
|
| 217 |
|
| 218 |
def prepare_decode(self, seqs: list[Sequence]):
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
for seq in seqs:
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
|
|
|
|
|
|
| 232 |
block_tables = self.prepare_block_tables(seqs)
|
| 233 |
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
| 234 |
return input_ids, positions
|
| 235 |
|
| 236 |
def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
|
| 237 |
-
"""
|
| 238 |
if is_cfg_batch:
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
num_cond = len(seqs) // 2
|
| 242 |
-
temperatures = []
|
| 243 |
-
cfg_scales = []
|
| 244 |
-
top_ks = []
|
| 245 |
-
top_ps = []
|
| 246 |
-
repetition_penalties = []
|
| 247 |
-
for seq in seqs[:num_cond]:
|
| 248 |
-
temperatures.append(seq.temperature)
|
| 249 |
-
cfg_scales.append(seq.cfg_scale)
|
| 250 |
-
top_ks.append(seq.top_k if seq.top_k is not None else 0)
|
| 251 |
-
top_ps.append(seq.top_p if seq.top_p is not None else 1.0)
|
| 252 |
-
repetition_penalties.append(seq.repetition_penalty)
|
| 253 |
else:
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
|
| 271 |
|
| 272 |
@torch.inference_mode()
|
|
@@ -293,27 +333,15 @@ class ModelRunner:
|
|
| 293 |
[cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 294 |
where uncond_seqi is the paired unconditional sequence of cond_seqi."""
|
| 295 |
# Check if this is a CFG batch (contains paired conditional and unconditional sequences)
|
| 296 |
-
is_cfg_batch =
|
| 297 |
-
if len(seqs) > 0:
|
| 298 |
-
# CFG batch if first sequence has cfg_scale > 1.0 and paired_seq
|
| 299 |
-
if seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None:
|
| 300 |
-
is_cfg_batch = True
|
| 301 |
-
# Verify batch structure: first half conditional, second half unconditional
|
| 302 |
-
num_cond = len(seqs) // 2
|
| 303 |
-
for i in range(num_cond):
|
| 304 |
-
if seqs[i].is_unconditional or seqs[i + num_cond].is_unconditional == False:
|
| 305 |
-
is_cfg_batch = False
|
| 306 |
-
break
|
| 307 |
-
|
| 308 |
if is_cfg_batch:
|
| 309 |
# CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 310 |
num_cond = len(seqs) // 2
|
| 311 |
cond_seqs = seqs[:num_cond]
|
| 312 |
-
uncond_seqs = seqs[num_cond:]
|
| 313 |
|
| 314 |
# Prepare inputs for both conditional and unconditional (they're already in the batch)
|
| 315 |
-
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
|
| 316 |
-
else self.prepare_decode(seqs))
|
| 317 |
sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
|
| 318 |
if sample_params is not None:
|
| 319 |
temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
|
|
@@ -364,7 +392,7 @@ class ModelRunner:
|
|
| 364 |
logits_cfg[i:i+1] = seq.logits_processor(seq_input_ids, logits_cfg[i:i+1])
|
| 365 |
|
| 366 |
# Prepare input_ids for sampler (for repetition penalty, though we already applied it)
|
| 367 |
-
cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
|
| 368 |
|
| 369 |
# Sample from CFG logits
|
| 370 |
token_ids_cfg = self.sampler(
|
|
@@ -373,7 +401,7 @@ class ModelRunner:
|
|
| 373 |
top_ks=top_ks if top_ks is not None else None,
|
| 374 |
top_ps=top_ps if top_ps is not None else None,
|
| 375 |
repetition_penalties=None, # Already applied above
|
| 376 |
-
input_ids=cond_input_ids,
|
| 377 |
).tolist()
|
| 378 |
|
| 379 |
# Update logits processor state after sampling
|
|
@@ -432,7 +460,7 @@ class ModelRunner:
|
|
| 432 |
logits[i] = processed[0]
|
| 433 |
|
| 434 |
# Prepare input_ids for sampler
|
| 435 |
-
seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
|
| 436 |
|
| 437 |
token_ids = self.sampler(
|
| 438 |
logits,
|
|
@@ -440,7 +468,7 @@ class ModelRunner:
|
|
| 440 |
top_ks=top_ks if top_ks is not None else None,
|
| 441 |
top_ps=top_ps if top_ps is not None else None,
|
| 442 |
repetition_penalties=None, # Already applied above
|
| 443 |
-
input_ids=seq_input_ids,
|
| 444 |
).tolist()
|
| 445 |
|
| 446 |
# Update logits processor state after sampling
|
|
|
|
| 68 |
self.model = Qwen3ForCausalLM(hf_config)
|
| 69 |
load_model(self.model, config.model)
|
| 70 |
self.sampler = Sampler()
|
| 71 |
+
|
| 72 |
+
# Pre-allocate buffers for sampling (optimization: avoid repeated tensor creation)
|
| 73 |
+
# Must be called before warmup_model() since it uses these buffers
|
| 74 |
+
self._allocate_sample_buffers()
|
| 75 |
+
|
| 76 |
self.warmup_model()
|
| 77 |
self.allocate_kv_cache()
|
| 78 |
if not self.enforce_eager:
|
| 79 |
self.capture_cudagraph()
|
| 80 |
+
|
| 81 |
torch.set_default_device("cpu")
|
| 82 |
torch.set_default_dtype(default_dtype)
|
| 83 |
|
|
|
|
| 90 |
self.shm = SharedMemory(name="nanovllm")
|
| 91 |
self.loop()
|
| 92 |
|
| 93 |
+
def _allocate_sample_buffers(self):
|
| 94 |
+
"""Pre-allocate reusable buffers for sampling to avoid repeated tensor creation."""
|
| 95 |
+
max_bs = self.config.max_num_seqs
|
| 96 |
+
max_tokens = self.config.max_num_batched_tokens
|
| 97 |
+
max_num_blocks = (self.config.max_model_len + self.block_size - 1) // self.block_size
|
| 98 |
+
|
| 99 |
+
# Pre-allocate pinned memory buffers on CPU for fast transfer
|
| 100 |
+
# Must explicitly specify device="cpu" since default device may be "cuda"
|
| 101 |
+
self._cpu_temperatures = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
| 102 |
+
self._cpu_cfg_scales = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
| 103 |
+
self._cpu_top_ks = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 104 |
+
self._cpu_top_ps = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
| 105 |
+
self._cpu_repetition_penalties = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
| 106 |
+
|
| 107 |
+
# Pre-allocate decode buffers on CPU with pinned memory
|
| 108 |
+
self._cpu_input_ids = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 109 |
+
self._cpu_positions = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 110 |
+
self._cpu_slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 111 |
+
self._cpu_context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 112 |
+
|
| 113 |
+
# Pre-allocate prefill buffers on CPU with pinned memory (optimization to avoid repeated tensor creation)
|
| 114 |
+
self._cpu_prefill_input_ids = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 115 |
+
self._cpu_prefill_positions = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 116 |
+
self._cpu_prefill_cu_seqlens = torch.zeros(max_bs + 1, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 117 |
+
self._cpu_prefill_slot_mapping = torch.zeros(max_tokens, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 118 |
+
|
| 119 |
+
# Pre-allocate block tables buffer (shared by both decode and prefill)
|
| 120 |
+
self._cpu_block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 121 |
+
|
| 122 |
+
# Pre-allocate buffer for sequence token IDs (used in logits processor and sampler)
|
| 123 |
+
# Max length is max_model_len since sequences can be that long
|
| 124 |
+
self._seq_token_ids_buffer = torch.zeros(max_bs, self.config.max_model_len, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 125 |
+
|
| 126 |
def exit(self):
|
| 127 |
if self.world_size > 1:
|
| 128 |
self.shm.close()
|
|
|
|
| 242 |
if i != seq.num_blocks - 1:
|
| 243 |
end = start + self.block_size
|
| 244 |
else:
|
| 245 |
+
end = start + seq.last_block_num_tokens
|
| 246 |
slot_mapping.extend(list(range(start, end)))
|
| 247 |
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
| 248 |
block_tables = self.prepare_block_tables(seqs)
|
|
|
|
| 255 |
return input_ids, positions
|
| 256 |
|
| 257 |
def prepare_decode(self, seqs: list[Sequence]):
|
| 258 |
+
"""Optimized decode preparation using pre-allocated buffers."""
|
| 259 |
+
bs = len(seqs)
|
| 260 |
+
|
| 261 |
+
# Use pre-allocated CPU buffers
|
| 262 |
+
for i, seq in enumerate(seqs):
|
| 263 |
+
self._cpu_input_ids[i] = seq.last_token
|
| 264 |
+
self._cpu_positions[i] = len(seq) - 1
|
| 265 |
+
self._cpu_context_lens[i] = len(seq)
|
| 266 |
+
self._cpu_slot_mapping[i] = seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1
|
| 267 |
+
|
| 268 |
+
# Transfer to GPU using sliced views
|
| 269 |
+
input_ids = self._cpu_input_ids[:bs].cuda(non_blocking=True)
|
| 270 |
+
positions = self._cpu_positions[:bs].cuda(non_blocking=True)
|
| 271 |
+
slot_mapping = self._cpu_slot_mapping[:bs].cuda(non_blocking=True)
|
| 272 |
+
context_lens = self._cpu_context_lens[:bs].cuda(non_blocking=True)
|
| 273 |
block_tables = self.prepare_block_tables(seqs)
|
| 274 |
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
| 275 |
return input_ids, positions
|
| 276 |
|
| 277 |
def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
|
| 278 |
+
"""Optimized sample preparation using pre-allocated buffers."""
|
| 279 |
if is_cfg_batch:
|
| 280 |
+
num_seqs = len(seqs) // 2
|
| 281 |
+
target_seqs = seqs[:num_seqs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
else:
|
| 283 |
+
num_seqs = len(seqs)
|
| 284 |
+
target_seqs = seqs
|
| 285 |
+
|
| 286 |
+
# Fill pre-allocated CPU buffers
|
| 287 |
+
top_ks_is_zero = True
|
| 288 |
+
top_ps_is_one = True
|
| 289 |
+
repetition_penalties_is_one = True
|
| 290 |
+
for i, seq in enumerate(target_seqs):
|
| 291 |
+
self._cpu_temperatures[i] = seq.temperature
|
| 292 |
+
self._cpu_cfg_scales[i] = seq.cfg_scale
|
| 293 |
+
self._cpu_top_ks[i] = seq.top_k if seq.top_k is not None else 0
|
| 294 |
+
if seq.top_k is not None and seq.top_k > 0:
|
| 295 |
+
top_ks_is_zero = False
|
| 296 |
+
self._cpu_top_ps[i] = seq.top_p if seq.top_p is not None else 1.0
|
| 297 |
+
if seq.top_p is not None and seq.top_p == 1.0:
|
| 298 |
+
top_ps_is_one = False
|
| 299 |
+
self._cpu_repetition_penalties[i] = seq.repetition_penalty if seq.repetition_penalty is not None else 1.0
|
| 300 |
+
if seq.repetition_penalty is not None and seq.repetition_penalty == 1.0:
|
| 301 |
+
repetition_penalties_is_one = False
|
| 302 |
+
|
| 303 |
+
# Transfer to GPU using sliced views (single batched transfer)
|
| 304 |
+
temperatures = self._cpu_temperatures[:num_seqs].cuda(non_blocking=True)
|
| 305 |
+
cfg_scales = self._cpu_cfg_scales[:num_seqs].cuda(non_blocking=True)
|
| 306 |
+
top_ks = self._cpu_top_ks[:num_seqs].cuda(non_blocking=True) if not top_ks_is_zero else None
|
| 307 |
+
top_ps = self._cpu_top_ps[:num_seqs].cuda(non_blocking=True) if not top_ps_is_one else None
|
| 308 |
+
repetition_penalties = self._cpu_repetition_penalties[:num_seqs].cuda(non_blocking=True) if not repetition_penalties_is_one else None
|
| 309 |
+
|
| 310 |
return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
|
| 311 |
|
| 312 |
@torch.inference_mode()
|
|
|
|
| 333 |
[cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 334 |
where uncond_seqi is the paired unconditional sequence of cond_seqi."""
|
| 335 |
# Check if this is a CFG batch (contains paired conditional and unconditional sequences)
|
| 336 |
+
is_cfg_batch = seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
if is_cfg_batch:
|
| 338 |
# CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 339 |
num_cond = len(seqs) // 2
|
| 340 |
cond_seqs = seqs[:num_cond]
|
| 341 |
+
# uncond_seqs = seqs[num_cond:]
|
| 342 |
|
| 343 |
# Prepare inputs for both conditional and unconditional (they're already in the batch)
|
| 344 |
+
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs))
|
|
|
|
| 345 |
sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
|
| 346 |
if sample_params is not None:
|
| 347 |
temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
|
|
|
|
| 392 |
logits_cfg[i:i+1] = seq.logits_processor(seq_input_ids, logits_cfg[i:i+1])
|
| 393 |
|
| 394 |
# Prepare input_ids for sampler (for repetition penalty, though we already applied it)
|
| 395 |
+
# cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
|
| 396 |
|
| 397 |
# Sample from CFG logits
|
| 398 |
token_ids_cfg = self.sampler(
|
|
|
|
| 401 |
top_ks=top_ks if top_ks is not None else None,
|
| 402 |
top_ps=top_ps if top_ps is not None else None,
|
| 403 |
repetition_penalties=None, # Already applied above
|
| 404 |
+
# input_ids=cond_input_ids,
|
| 405 |
).tolist()
|
| 406 |
|
| 407 |
# Update logits processor state after sampling
|
|
|
|
| 460 |
logits[i] = processed[0]
|
| 461 |
|
| 462 |
# Prepare input_ids for sampler
|
| 463 |
+
# seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
|
| 464 |
|
| 465 |
token_ids = self.sampler(
|
| 466 |
logits,
|
|
|
|
| 468 |
top_ks=top_ks if top_ks is not None else None,
|
| 469 |
top_ps=top_ps if top_ps is not None else None,
|
| 470 |
repetition_penalties=None, # Already applied above
|
| 471 |
+
# input_ids=seq_input_ids,
|
| 472 |
).tolist()
|
| 473 |
|
| 474 |
# Update logits processor state after sampling
|
acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py
CHANGED
|
@@ -3,6 +3,83 @@ from torch import nn
|
|
| 3 |
from typing import Optional
|
| 4 |
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
class Sampler(nn.Module):
|
| 7 |
|
| 8 |
def __init__(self):
|
|
@@ -19,56 +96,19 @@ class Sampler(nn.Module):
|
|
| 19 |
input_ids: Optional[torch.Tensor] = None,
|
| 20 |
):
|
| 21 |
"""
|
| 22 |
-
Sample tokens from logits with optional top-k
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
temperatures: [batch_size] temperature values
|
| 27 |
-
top_ks: Optional [batch_size] top-k values (None or 0 means no top-k filtering)
|
| 28 |
-
top_ps: Optional [batch_size] top-p values (None or 1.0 means no top-p filtering)
|
| 29 |
-
repetition_penalties: Optional [batch_size] repetition penalty values (1.0 means no penalty)
|
| 30 |
-
input_ids: Optional [batch_size, seq_len] input token ids for repetition penalty
|
| 31 |
"""
|
| 32 |
-
batch_size, vocab_size = logits.shape
|
| 33 |
-
|
| 34 |
-
# Note: Repetition penalty is applied in ModelRunner before calling sampler
|
| 35 |
-
# This allows us to use the full sequence context
|
| 36 |
-
|
| 37 |
# Apply temperature
|
| 38 |
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
# Get top-k logits, set others to -inf
|
| 46 |
-
top_k_logits, top_k_indices = torch.topk(logits[i], int(top_k), dim=-1)
|
| 47 |
-
filtered_logits = torch.full_like(logits[i], float('-inf'))
|
| 48 |
-
filtered_logits[top_k_indices] = top_k_logits
|
| 49 |
-
logits[i] = filtered_logits
|
| 50 |
-
|
| 51 |
-
# Apply top-p (nucleus) filtering if specified
|
| 52 |
-
if top_ps is not None:
|
| 53 |
-
probs = torch.softmax(logits, dim=-1)
|
| 54 |
-
for i in range(batch_size):
|
| 55 |
-
top_p = top_ps[i].item()
|
| 56 |
-
if 0.0 < top_p < 1.0:
|
| 57 |
-
# Sort probabilities in descending order
|
| 58 |
-
sorted_probs, sorted_indices = torch.sort(probs[i], descending=True)
|
| 59 |
-
# Calculate cumulative probabilities
|
| 60 |
-
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 61 |
-
# Find the cutoff point
|
| 62 |
-
cutoff_idx = (cumsum_probs <= top_p).sum().item()
|
| 63 |
-
if cutoff_idx < len(sorted_indices):
|
| 64 |
-
cutoff_idx += 1 # Include one more token to ensure we have at least one
|
| 65 |
-
# Create mask for tokens to keep
|
| 66 |
-
mask = torch.zeros_like(probs[i])
|
| 67 |
-
mask[sorted_indices[:cutoff_idx]] = 1.0
|
| 68 |
-
# Apply mask: set filtered tokens to -inf
|
| 69 |
-
logits[i] = torch.where(mask > 0, logits[i], torch.tensor(float('-inf'), device=logits.device))
|
| 70 |
-
|
| 71 |
-
# Sample using Gumbel-max trick (equivalent to sampling from softmax)
|
| 72 |
probs = torch.softmax(logits, dim=-1)
|
| 73 |
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
|
| 74 |
-
return sample_tokens
|
|
|
|
| 3 |
from typing import Optional
|
| 4 |
|
| 5 |
|
| 6 |
+
def apply_top_k_top_p(
|
| 7 |
+
logits: torch.Tensor,
|
| 8 |
+
k: Optional[torch.Tensor],
|
| 9 |
+
p: Optional[torch.Tensor],
|
| 10 |
+
) -> torch.Tensor:
|
| 11 |
+
"""Apply top-k and top-p masks to the logits (vLLM style).
|
| 12 |
+
|
| 13 |
+
The logits tensor is updated in-place.
|
| 14 |
+
"""
|
| 15 |
+
if p is None:
|
| 16 |
+
if k is None:
|
| 17 |
+
return logits
|
| 18 |
+
# Avoid sorting vocab for top-k only case
|
| 19 |
+
return apply_top_k_only(logits, k)
|
| 20 |
+
|
| 21 |
+
# Need to sort for top-p
|
| 22 |
+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
| 23 |
+
|
| 24 |
+
if k is not None:
|
| 25 |
+
# Apply top-k first
|
| 26 |
+
vocab_size = logits_sort.size(1)
|
| 27 |
+
# Clamp k to valid range
|
| 28 |
+
k_clamped = k.clamp(1, vocab_size).long()
|
| 29 |
+
top_k_mask_idx = vocab_size - k_clamped # shape: [B]
|
| 30 |
+
# Get the threshold value for each batch
|
| 31 |
+
top_k_thresh = logits_sort.gather(1, top_k_mask_idx.unsqueeze(1))
|
| 32 |
+
top_k_mask = logits_sort < top_k_thresh
|
| 33 |
+
logits_sort.masked_fill_(top_k_mask, float('-inf'))
|
| 34 |
+
|
| 35 |
+
# Apply top-p
|
| 36 |
+
probs_sort = logits_sort.softmax(dim=-1)
|
| 37 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) # reuse buffer
|
| 38 |
+
top_p_mask = probs_sum <= (1.0 - p.unsqueeze(1))
|
| 39 |
+
# Ensure at least one token is kept
|
| 40 |
+
top_p_mask[:, -1] = False
|
| 41 |
+
logits_sort.masked_fill_(top_p_mask, float('-inf'))
|
| 42 |
+
|
| 43 |
+
# Re-sort back to original positions
|
| 44 |
+
logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
|
| 45 |
+
return logits
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def apply_top_k_only(
|
| 49 |
+
logits: torch.Tensor,
|
| 50 |
+
k: torch.Tensor,
|
| 51 |
+
) -> torch.Tensor:
|
| 52 |
+
"""Apply top-k mask without sorting the entire vocab (vLLM style).
|
| 53 |
+
|
| 54 |
+
This is much faster than sorting for top-k only cases.
|
| 55 |
+
The logits tensor is updated in-place.
|
| 56 |
+
"""
|
| 57 |
+
vocab_size = logits.shape[1]
|
| 58 |
+
# Handle cases where k >= vocab_size (no filtering needed)
|
| 59 |
+
no_top_k_mask = (k <= 0) | (k >= vocab_size)
|
| 60 |
+
# Set invalid k to 1 so we can still gather
|
| 61 |
+
k_safe = k.masked_fill(no_top_k_mask, 1).long()
|
| 62 |
+
# NOTE: This int() causes CPU-GPU sync, but torch.topk requires Python int
|
| 63 |
+
max_top_k = int(k_safe.max().clamp(max=vocab_size))
|
| 64 |
+
|
| 65 |
+
# Get top-k values for all batches
|
| 66 |
+
# topk.values has shape [batch_size, max_top_k]
|
| 67 |
+
topk_values = logits.topk(max_top_k, dim=1).values
|
| 68 |
+
|
| 69 |
+
# Convert k to 0-based index: we want the k-th largest value (index k-1)
|
| 70 |
+
# Clamp to valid range for gather
|
| 71 |
+
k_index = (k_safe - 1).clamp(0, max_top_k - 1).unsqueeze(1) # shape: [B, 1]
|
| 72 |
+
# Gather the threshold value (the k-th largest)
|
| 73 |
+
top_k_thresh = topk_values.gather(1, k_index)
|
| 74 |
+
|
| 75 |
+
# For rows with no top-k filtering, set threshold to -inf so nothing gets masked
|
| 76 |
+
top_k_thresh.masked_fill_(no_top_k_mask.unsqueeze(1), float('-inf'))
|
| 77 |
+
|
| 78 |
+
# Mask all values below the threshold
|
| 79 |
+
logits.masked_fill_(logits < top_k_thresh, float('-inf'))
|
| 80 |
+
return logits
|
| 81 |
+
|
| 82 |
+
|
| 83 |
class Sampler(nn.Module):
|
| 84 |
|
| 85 |
def __init__(self):
|
|
|
|
| 96 |
input_ids: Optional[torch.Tensor] = None,
|
| 97 |
):
|
| 98 |
"""
|
| 99 |
+
Sample tokens from logits with optional top-k and top-p filtering.
|
| 100 |
|
| 101 |
+
Condition checking is done OUTSIDE the compiled function to avoid
|
| 102 |
+
graph breaks from .any() calls.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
# Apply temperature
|
| 105 |
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
| 106 |
+
|
| 107 |
+
logits = apply_top_k_top_p(
|
| 108 |
+
logits,
|
| 109 |
+
top_ks,
|
| 110 |
+
top_ps,
|
| 111 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
probs = torch.softmax(logits, dim=-1)
|
| 113 |
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
|
| 114 |
+
return sample_tokens
|
acestep/third_parts/nano-vllm/pyproject.toml
CHANGED
|
@@ -15,8 +15,6 @@ dependencies = [
|
|
| 15 |
"triton-windows>=3.0.0; sys_platform == 'win32'",
|
| 16 |
"triton>=3.0.0; sys_platform != 'win32'",
|
| 17 |
"transformers>=4.51.0",
|
| 18 |
-
"flash-attn @ https://github.com/sdbds/flash-attention-for-windows/releases/download/2.8.3/flash_attn-2.8.3+cu128torch2.8.0cxx11abiFALSEfullbackward-cp311-cp311-win_amd64.whl; sys_platform == 'win32'",
|
| 19 |
-
"flash-attn; sys_platform != 'win32'",
|
| 20 |
"xxhash",
|
| 21 |
]
|
| 22 |
|
|
|
|
| 15 |
"triton-windows>=3.0.0; sys_platform == 'win32'",
|
| 16 |
"triton>=3.0.0; sys_platform != 'win32'",
|
| 17 |
"transformers>=4.51.0",
|
|
|
|
|
|
|
| 18 |
"xxhash",
|
| 19 |
]
|
| 20 |
|
profile_inference.py
ADDED
|
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Enhanced profiling script for ACE-Step inference with deep LLM analysis
|
| 4 |
+
|
| 5 |
+
This script helps diagnose why LLM generation is slow by tracking:
|
| 6 |
+
1. Total tokens generated vs expected throughput (200 tokens/sec baseline)
|
| 7 |
+
2. Per-iteration timing to detect compilation overhead or slow operations
|
| 8 |
+
3. Constrained decoding overhead
|
| 9 |
+
4. CFG overhead (2x forward passes)
|
| 10 |
+
5. Model forward time vs sampling/processing time
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python profile_inference.py # Standard profiling with warmup
|
| 14 |
+
python profile_inference.py --no-warmup # Profile first run (includes compilation)
|
| 15 |
+
python profile_inference.py --llm-debug # Deep LLM performance debugging
|
| 16 |
+
python profile_inference.py --detailed # Add cProfile function-level analysis
|
| 17 |
+
|
| 18 |
+
Inference mode options:
|
| 19 |
+
python profile_inference.py --thinking # Enable CoT for code generation
|
| 20 |
+
python profile_inference.py --use-constrained-decoding # Use FSM constrained decoding
|
| 21 |
+
python profile_inference.py --use-cot-metas # Enable LM to generate metadata via CoT
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import time
|
| 25 |
+
import argparse
|
| 26 |
+
import sys
|
| 27 |
+
import os
|
| 28 |
+
from contextlib import contextmanager
|
| 29 |
+
from collections import defaultdict
|
| 30 |
+
import json
|
| 31 |
+
from typing import Tuple, Dict, Any, List
|
| 32 |
+
from functools import wraps
|
| 33 |
+
|
| 34 |
+
# Add project root to path
|
| 35 |
+
project_root = os.path.abspath(os.path.dirname(__file__))
|
| 36 |
+
if project_root not in sys.path:
|
| 37 |
+
sys.path.insert(0, project_root)
|
| 38 |
+
|
| 39 |
+
import torch
|
| 40 |
+
from acestep.inference import generate_music, GenerationParams, GenerationConfig
|
| 41 |
+
from acestep.handler import AceStepHandler
|
| 42 |
+
from acestep.llm_inference import LLMHandler
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class PreciseTimer:
|
| 46 |
+
"""High-precision timer with CUDA synchronization for accurate GPU timing"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, device="cuda"):
|
| 49 |
+
self.device = device
|
| 50 |
+
self.timings = defaultdict(list)
|
| 51 |
+
self.enabled = True
|
| 52 |
+
|
| 53 |
+
def sync(self):
|
| 54 |
+
"""Synchronize CUDA operations for accurate timing"""
|
| 55 |
+
if self.enabled and self.device.startswith("cuda") and torch.cuda.is_available():
|
| 56 |
+
torch.cuda.synchronize()
|
| 57 |
+
|
| 58 |
+
@contextmanager
|
| 59 |
+
def time(self, name: str):
|
| 60 |
+
"""Time a code section with CUDA synchronization"""
|
| 61 |
+
if not self.enabled:
|
| 62 |
+
yield
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
self.sync()
|
| 66 |
+
start = time.perf_counter()
|
| 67 |
+
try:
|
| 68 |
+
yield
|
| 69 |
+
finally:
|
| 70 |
+
self.sync()
|
| 71 |
+
elapsed = time.perf_counter() - start
|
| 72 |
+
self.timings[name].append(elapsed)
|
| 73 |
+
|
| 74 |
+
def get_total(self, name: str) -> float:
|
| 75 |
+
"""Get total accumulated time for a section"""
|
| 76 |
+
return sum(self.timings.get(name, []))
|
| 77 |
+
|
| 78 |
+
def get_mean(self, name: str) -> float:
|
| 79 |
+
"""Get mean time per call for a section"""
|
| 80 |
+
times = self.timings.get(name, [])
|
| 81 |
+
return sum(times) / len(times) if times else 0.0
|
| 82 |
+
|
| 83 |
+
def get_count(self, name: str) -> int:
|
| 84 |
+
"""Get number of calls for a section"""
|
| 85 |
+
return len(self.timings.get(name, []))
|
| 86 |
+
|
| 87 |
+
def get_all(self, name: str) -> List[float]:
|
| 88 |
+
"""Get all timing samples for a section"""
|
| 89 |
+
return self.timings.get(name, [])
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class LLMDebugger:
|
| 93 |
+
"""Track detailed LLM performance metrics to diagnose slow generation"""
|
| 94 |
+
|
| 95 |
+
def __init__(self):
|
| 96 |
+
self.reset()
|
| 97 |
+
|
| 98 |
+
def reset(self):
|
| 99 |
+
"""Reset all metrics"""
|
| 100 |
+
self.total_tokens = 0
|
| 101 |
+
self.generation_start = None
|
| 102 |
+
self.generation_end = None
|
| 103 |
+
self.output_text = ""
|
| 104 |
+
self.prompt_length = 0
|
| 105 |
+
|
| 106 |
+
def start(self, prompt_length: int = 0):
|
| 107 |
+
"""Mark generation start"""
|
| 108 |
+
self.generation_start = time.perf_counter()
|
| 109 |
+
self.prompt_length = prompt_length
|
| 110 |
+
|
| 111 |
+
def end(self, output_text: str = ""):
|
| 112 |
+
"""Mark generation end and store output"""
|
| 113 |
+
self.generation_end = time.perf_counter()
|
| 114 |
+
self.output_text = output_text
|
| 115 |
+
|
| 116 |
+
def set_token_count(self, count: int):
|
| 117 |
+
"""Set total token count"""
|
| 118 |
+
self.total_tokens = count
|
| 119 |
+
|
| 120 |
+
def get_throughput(self) -> float:
|
| 121 |
+
"""Calculate actual tokens per second"""
|
| 122 |
+
if self.generation_start and self.generation_end and self.total_tokens > 0:
|
| 123 |
+
total_time = self.generation_end - self.generation_start
|
| 124 |
+
if total_time > 0:
|
| 125 |
+
return self.total_tokens / total_time
|
| 126 |
+
return 0.0
|
| 127 |
+
|
| 128 |
+
def print_analysis(self):
|
| 129 |
+
"""Print detailed LLM performance analysis"""
|
| 130 |
+
if not self.generation_start or not self.generation_end:
|
| 131 |
+
return
|
| 132 |
+
|
| 133 |
+
print("\n" + "=" * 100)
|
| 134 |
+
print("🔍 LLM PERFORMANCE DEEP DIVE")
|
| 135 |
+
print("=" * 100)
|
| 136 |
+
|
| 137 |
+
total_time = self.generation_end - self.generation_start
|
| 138 |
+
throughput = self.get_throughput()
|
| 139 |
+
|
| 140 |
+
# Basic metrics table
|
| 141 |
+
print(f"\n{'Metric':<40} {'Value':<20} {'Notes'}")
|
| 142 |
+
print("-" * 100)
|
| 143 |
+
print(f"{'Total Tokens Generated:':<40} {self.total_tokens:<20} (new tokens only)")
|
| 144 |
+
print(f"{'Prompt Length (estimate):':<40} {self.prompt_length:<20} (input tokens)")
|
| 145 |
+
print(f"{'Total Generation Time:':<40} {total_time:<20.3f} seconds")
|
| 146 |
+
print(f"{'Measured Throughput:':<40} {throughput:<20.1f} tokens/sec")
|
| 147 |
+
print(f"{'Expected Throughput:':<40} {'200':<20} tokens/sec (baseline)")
|
| 148 |
+
|
| 149 |
+
# Calculate performance gap
|
| 150 |
+
if throughput > 0:
|
| 151 |
+
slowdown = 200.0 / throughput
|
| 152 |
+
efficiency = (throughput / 200.0) * 100
|
| 153 |
+
print(f"{'Performance vs Baseline:':<40} {efficiency:<20.1f}% of expected")
|
| 154 |
+
print(f"{'Slowdown Factor:':<40} {slowdown:<20.2f}x slower")
|
| 155 |
+
|
| 156 |
+
# Analyze generated output
|
| 157 |
+
if self.output_text:
|
| 158 |
+
print(f"\n{'Output Analysis:':<40}")
|
| 159 |
+
print(f"{' Output length:':<40} {len(self.output_text):<20} characters")
|
| 160 |
+
|
| 161 |
+
# Count audio codes
|
| 162 |
+
import re
|
| 163 |
+
code_pattern = r'<\|audio_code_\d+\|>'
|
| 164 |
+
codes = re.findall(code_pattern, self.output_text)
|
| 165 |
+
if codes:
|
| 166 |
+
print(f"{' Audio codes generated:':<40} {len(codes):<20} codes")
|
| 167 |
+
print(f"{' Expected audio duration:':<40} {f'~{len(codes)/5:.1f}s':<20} (5 codes per second)")
|
| 168 |
+
if total_time > 0:
|
| 169 |
+
print(f"{' Time per audio code:':<40} {f'{total_time/len(codes)*1000:.1f}ms':<20}")
|
| 170 |
+
|
| 171 |
+
# Check for CoT section
|
| 172 |
+
if '<think>' in self.output_text and '</think>' in self.output_text:
|
| 173 |
+
cot_start = self.output_text.find('<think>')
|
| 174 |
+
cot_end = self.output_text.find('</think>') + 8
|
| 175 |
+
cot_section = self.output_text[cot_start:cot_end]
|
| 176 |
+
cot_token_est = len(cot_section) // 4
|
| 177 |
+
print(f"{' CoT section tokens (estimate):':<40} {f'~{cot_token_est}':<20}")
|
| 178 |
+
|
| 179 |
+
# Diagnostic guidance
|
| 180 |
+
print("\n" + "=" * 100)
|
| 181 |
+
print("🔧 DIAGNOSTIC GUIDANCE")
|
| 182 |
+
print("=" * 100)
|
| 183 |
+
|
| 184 |
+
if throughput < 50:
|
| 185 |
+
print("\n⚠️ CRITICAL: Throughput is extremely low (<50 tokens/sec)")
|
| 186 |
+
print("\nThis is ~4x slower than expected. Likely causes:")
|
| 187 |
+
print(" 1. ❗ Constrained decoding FSM overhead")
|
| 188 |
+
print(" → Each token triggers FSM state machine validation")
|
| 189 |
+
print(" → Try: set use_constrained_decoding=False in config")
|
| 190 |
+
print(" 2. ❗ CFG with double forward passes")
|
| 191 |
+
print(" → cfg_scale > 1.0 means running model twice per token")
|
| 192 |
+
print(" → Check: params.lm_cfg_scale value")
|
| 193 |
+
print(" 3. ❗ Running in eager mode without compilation")
|
| 194 |
+
print(" → PyTorch should compile kernels after warmup")
|
| 195 |
+
print(" → Check: torch._dynamo.config settings")
|
| 196 |
+
|
| 197 |
+
elif throughput < 100:
|
| 198 |
+
print("\n⚠️ WARNING: Throughput is low (50-100 tokens/sec)")
|
| 199 |
+
print("\nLikely causes:")
|
| 200 |
+
print(" 1. Constrained decoding overhead (~30-50% slowdown expected)")
|
| 201 |
+
print(" 2. CFG enabled (2x compute per token if cfg_scale > 1.0)")
|
| 202 |
+
print(" 3. Small model or inefficient GPU utilization")
|
| 203 |
+
|
| 204 |
+
elif throughput < 150:
|
| 205 |
+
print("\n⚠️ Throughput is below baseline but acceptable (100-150 tokens/sec)")
|
| 206 |
+
print("\nMinor overhead from:")
|
| 207 |
+
print(" - Constrained decoding: ~20-30% overhead")
|
| 208 |
+
print(" - Profiling instrumentation: ~5-10% overhead")
|
| 209 |
+
|
| 210 |
+
else:
|
| 211 |
+
print(f"\n✓ Throughput is good ({throughput:.1f} tokens/sec)")
|
| 212 |
+
print(" Performance is within acceptable range")
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# Global instances
|
| 216 |
+
timer = None
|
| 217 |
+
llm_debugger = None
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def wrap_method_with_timing(obj, method_name: str, timing_key: str):
|
| 221 |
+
"""Wrap a method with timing instrumentation"""
|
| 222 |
+
original_method = getattr(obj, method_name)
|
| 223 |
+
|
| 224 |
+
@wraps(original_method)
|
| 225 |
+
def timed_wrapper(*args, **kwargs):
|
| 226 |
+
with timer.time(timing_key):
|
| 227 |
+
return original_method(*args, **kwargs)
|
| 228 |
+
|
| 229 |
+
setattr(obj, method_name, timed_wrapper)
|
| 230 |
+
return original_method
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def wrap_llm_with_debug_tracking(llm_handler):
|
| 234 |
+
"""Wrap LLM generation with detailed performance tracking"""
|
| 235 |
+
original_method = llm_handler.generate_with_stop_condition
|
| 236 |
+
|
| 237 |
+
@wraps(original_method)
|
| 238 |
+
def debug_wrapper(*args, **kwargs):
|
| 239 |
+
# Estimate prompt length
|
| 240 |
+
caption = kwargs.get('caption', args[0] if len(args) > 0 else "")
|
| 241 |
+
lyrics = kwargs.get('lyrics', args[1] if len(args) > 1 else "")
|
| 242 |
+
prompt_estimate = len(caption) + len(lyrics)
|
| 243 |
+
prompt_tokens_estimate = prompt_estimate // 4
|
| 244 |
+
|
| 245 |
+
# Start tracking
|
| 246 |
+
llm_debugger.reset()
|
| 247 |
+
llm_debugger.start(prompt_length=prompt_tokens_estimate)
|
| 248 |
+
|
| 249 |
+
# Call original with timing
|
| 250 |
+
with timer.time('llm_inference'):
|
| 251 |
+
result = original_method(*args, **kwargs)
|
| 252 |
+
|
| 253 |
+
# Extract and analyze output
|
| 254 |
+
output_text = ""
|
| 255 |
+
if isinstance(result, tuple) and len(result) >= 2:
|
| 256 |
+
if isinstance(result[1], list):
|
| 257 |
+
# Batch mode
|
| 258 |
+
output_text = "".join(result[1])
|
| 259 |
+
else:
|
| 260 |
+
# Single mode
|
| 261 |
+
cot_output = ""
|
| 262 |
+
if isinstance(result[0], dict):
|
| 263 |
+
for v in result[0].values():
|
| 264 |
+
if isinstance(v, str):
|
| 265 |
+
cot_output += v
|
| 266 |
+
output_text = cot_output + str(result[1])
|
| 267 |
+
|
| 268 |
+
# Count tokens
|
| 269 |
+
import re
|
| 270 |
+
code_pattern = r'<\|audio_code_\d+\|>'
|
| 271 |
+
codes = re.findall(code_pattern, output_text)
|
| 272 |
+
remaining_text = re.sub(code_pattern, '', output_text)
|
| 273 |
+
cot_tokens_estimate = len(remaining_text) // 4
|
| 274 |
+
total_tokens = len(codes) + cot_tokens_estimate
|
| 275 |
+
|
| 276 |
+
llm_debugger.set_token_count(total_tokens)
|
| 277 |
+
llm_debugger.end(output_text)
|
| 278 |
+
|
| 279 |
+
return result
|
| 280 |
+
|
| 281 |
+
llm_handler.generate_with_stop_condition = debug_wrapper
|
| 282 |
+
return original_method
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def instrument_handlers(dit_handler, llm_handler, enable_llm_debug=False):
|
| 286 |
+
"""Add timing instrumentation to handler methods"""
|
| 287 |
+
originals = {}
|
| 288 |
+
|
| 289 |
+
# Instrument LLM
|
| 290 |
+
if llm_handler and llm_handler.llm_initialized:
|
| 291 |
+
if enable_llm_debug:
|
| 292 |
+
originals['llm_generate'] = wrap_llm_with_debug_tracking(llm_handler)
|
| 293 |
+
else:
|
| 294 |
+
originals['llm_generate'] = wrap_method_with_timing(
|
| 295 |
+
llm_handler, 'generate_with_stop_condition', 'llm_inference'
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# Instrument DiT handler
|
| 299 |
+
originals['dit_prepare'] = wrap_method_with_timing(
|
| 300 |
+
dit_handler, 'prepare_batch_data', 'prepare_batch_data'
|
| 301 |
+
)
|
| 302 |
+
originals['dit_generate'] = wrap_method_with_timing(
|
| 303 |
+
dit_handler, 'service_generate', 'dit_inference'
|
| 304 |
+
)
|
| 305 |
+
originals['dit_decode'] = wrap_method_with_timing(
|
| 306 |
+
dit_handler, 'tiled_decode', 'vae_decode'
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
return originals
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def restore_handlers(dit_handler, llm_handler, originals):
|
| 313 |
+
"""Restore original handler methods after profiling"""
|
| 314 |
+
if llm_handler and 'llm_generate' in originals:
|
| 315 |
+
llm_handler.generate_with_stop_condition = originals['llm_generate']
|
| 316 |
+
|
| 317 |
+
dit_handler.prepare_batch_data = originals['dit_prepare']
|
| 318 |
+
dit_handler.service_generate = originals['dit_generate']
|
| 319 |
+
dit_handler.tiled_decode = originals['dit_decode']
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def print_profiling_results(total_time: float, show_llm_debug: bool = False):
|
| 323 |
+
"""Print comprehensive profiling results with performance insights"""
|
| 324 |
+
print("\n" + "=" * 100)
|
| 325 |
+
print("🎯 PROFILING RESULTS")
|
| 326 |
+
print("=" * 100)
|
| 327 |
+
|
| 328 |
+
# Define timing categories
|
| 329 |
+
model_sections = {
|
| 330 |
+
'llm_inference': 'LLM Inference (5Hz Language Model)',
|
| 331 |
+
'dit_inference': 'DiT Inference (Diffusion Transformer)',
|
| 332 |
+
'vae_decode': 'VAE Decode (Audio Decoder)',
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
non_model_sections = {
|
| 336 |
+
'prepare_batch_data': 'Prepare Batch Data (embedding, formatting)',
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
# Calculate totals
|
| 340 |
+
model_time = sum(timer.get_total(k) for k in model_sections.keys())
|
| 341 |
+
non_model_time = sum(timer.get_total(k) for k in non_model_sections.keys())
|
| 342 |
+
other_time = total_time - model_time - non_model_time
|
| 343 |
+
|
| 344 |
+
# Print summary table
|
| 345 |
+
print(f"\n{'CATEGORY':<50} {'TIME (s)':<12} {'%':<8} {'CALLS':<8}")
|
| 346 |
+
print("-" * 100)
|
| 347 |
+
|
| 348 |
+
# Model time breakdown
|
| 349 |
+
print(f"\n{'🤖 MODEL TIME (Total)':<50} {model_time:<12.3f} {100*model_time/total_time:>6.1f}% {'':<8}")
|
| 350 |
+
for key, desc in model_sections.items():
|
| 351 |
+
t = timer.get_total(key)
|
| 352 |
+
c = timer.get_count(key)
|
| 353 |
+
if c > 0:
|
| 354 |
+
mean = timer.get_mean(key)
|
| 355 |
+
pct = 100 * t / total_time
|
| 356 |
+
print(f" {'├─ ' + desc:<48} {t:<12.3f} {pct:>6.1f}% {c:<8} (avg: {mean:.3f}s)")
|
| 357 |
+
|
| 358 |
+
# Non-model time breakdown
|
| 359 |
+
print(f"\n{'⚙️ NON-MODEL TIME (Total)':<50} {non_model_time:<12.3f} {100*non_model_time/total_time:>6.1f}% {'':<8}")
|
| 360 |
+
for key, desc in non_model_sections.items():
|
| 361 |
+
t = timer.get_total(key)
|
| 362 |
+
c = timer.get_count(key)
|
| 363 |
+
if c > 0:
|
| 364 |
+
mean = timer.get_mean(key)
|
| 365 |
+
pct = 100 * t / total_time
|
| 366 |
+
print(f" {'├─ ' + desc:<48} {t:<12.3f} {pct:>6.1f}% {c:<8} (avg: {mean:.3f}s)")
|
| 367 |
+
|
| 368 |
+
# Other time
|
| 369 |
+
if other_time > 0.01:
|
| 370 |
+
pct = 100 * other_time / total_time
|
| 371 |
+
print(f"\n{'📦 OTHER TIME (I/O, overhead, audio save)':<50} {other_time:<12.3f} {pct:>6.1f}% {'':<8}")
|
| 372 |
+
|
| 373 |
+
print(f"\n{'📊 TOTAL TIME':<50} {total_time:<12.3f} {'100.0%':>6} {'':<8}")
|
| 374 |
+
|
| 375 |
+
# Show LLM detailed analysis if enabled
|
| 376 |
+
if show_llm_debug:
|
| 377 |
+
llm_debugger.print_analysis()
|
| 378 |
+
|
| 379 |
+
# Performance insights
|
| 380 |
+
print("\n" + "=" * 100)
|
| 381 |
+
print("💡 PERFORMANCE INSIGHTS")
|
| 382 |
+
print("=" * 100)
|
| 383 |
+
|
| 384 |
+
llm_t = timer.get_total('llm_inference')
|
| 385 |
+
dit_t = timer.get_total('dit_inference')
|
| 386 |
+
vae_t = timer.get_total('vae_decode')
|
| 387 |
+
prep_t = timer.get_total('prepare_batch_data')
|
| 388 |
+
|
| 389 |
+
# Model time insights
|
| 390 |
+
if model_time > 0:
|
| 391 |
+
print(f"\n✓ Model operations: {model_time:.3f}s ({100*model_time/total_time:.1f}% of total)")
|
| 392 |
+
|
| 393 |
+
if llm_t > 0:
|
| 394 |
+
print(f" - LLM: {llm_t:.3f}s ({100*llm_t/model_time:.1f}% of model time)")
|
| 395 |
+
if dit_t > 0:
|
| 396 |
+
print(f" - DiT: {dit_t:.3f}s ({100*dit_t/model_time:.1f}% of model time)")
|
| 397 |
+
if vae_t > 0:
|
| 398 |
+
print(f" - VAE: {vae_t:.3f}s ({100*vae_t/model_time:.1f}% of model time)")
|
| 399 |
+
|
| 400 |
+
# LLM bottleneck analysis
|
| 401 |
+
if llm_t > dit_t and llm_t > 5.0:
|
| 402 |
+
print(f"\n⚠️ LLM IS THE BOTTLENECK: {llm_t:.3f}s ({100*llm_t/total_time:.1f}% of total)")
|
| 403 |
+
print(f"\n Possible causes:")
|
| 404 |
+
print(f" 1. Generating too many tokens → use --llm-debug to verify")
|
| 405 |
+
print(f" 2. Constrained decoding overhead → FSM validation per token")
|
| 406 |
+
print(f" 3. CFG overhead → cfg_scale > 1.0 = 2x forward passes")
|
| 407 |
+
print(f" 4. First-token latency → warmup should help")
|
| 408 |
+
print(f" 5. KV cache inefficiency → should be ~5-10ms/token")
|
| 409 |
+
|
| 410 |
+
# Non-model insights
|
| 411 |
+
if non_model_time / total_time > 0.1:
|
| 412 |
+
print(f"\n⚠️ Non-model operations: {non_model_time:.3f}s ({100*non_model_time/total_time:.1f}%)")
|
| 413 |
+
if prep_t > 0.1:
|
| 414 |
+
print(f" - Batch preparation: {prep_t:.3f}s")
|
| 415 |
+
|
| 416 |
+
# I/O overhead
|
| 417 |
+
if other_time / total_time > 0.2:
|
| 418 |
+
print(f"\n⚠️ Overhead/I/O: {other_time:.3f}s ({100*other_time/total_time:.1f}%)")
|
| 419 |
+
|
| 420 |
+
# Recommendations
|
| 421 |
+
print("\n" + "=" * 100)
|
| 422 |
+
print("🚀 OPTIMIZATION RECOMMENDATIONS")
|
| 423 |
+
print("=" * 100)
|
| 424 |
+
|
| 425 |
+
if llm_t > dit_t * 2:
|
| 426 |
+
print("\n🎯 Priority: Optimize LLM")
|
| 427 |
+
print(" 1. Run: python profile_inference.py --llm-debug")
|
| 428 |
+
print(" → Shows exact token count and throughput")
|
| 429 |
+
print(" 2. Check constrained decoding overhead")
|
| 430 |
+
print(" 3. Check CFG scaling (lm_cfg_scale parameter)")
|
| 431 |
+
print(" 4. Profile nanovllm engine step() timing")
|
| 432 |
+
print(" 5. Compare vllm vs transformers backends")
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def run_profiled_generation(dit_handler, llm_handler, params, config,
|
| 436 |
+
enable_cprofile=False, enable_llm_debug=False):
|
| 437 |
+
"""Execute generation with full profiling instrumentation"""
|
| 438 |
+
# Instrument handlers
|
| 439 |
+
originals = instrument_handlers(dit_handler, llm_handler, enable_llm_debug)
|
| 440 |
+
|
| 441 |
+
try:
|
| 442 |
+
print("\n[Profiling] Starting generation...")
|
| 443 |
+
timer.sync()
|
| 444 |
+
total_start = time.perf_counter()
|
| 445 |
+
|
| 446 |
+
# Optional cProfile
|
| 447 |
+
prof = None
|
| 448 |
+
if enable_cprofile:
|
| 449 |
+
import cProfile
|
| 450 |
+
prof = cProfile.Profile()
|
| 451 |
+
prof.enable()
|
| 452 |
+
|
| 453 |
+
# Run generation
|
| 454 |
+
result = generate_music(dit_handler, llm_handler, params, config, save_dir="./")
|
| 455 |
+
|
| 456 |
+
# Stop timing
|
| 457 |
+
timer.sync()
|
| 458 |
+
total_time = time.perf_counter() - total_start
|
| 459 |
+
|
| 460 |
+
# Save cProfile if enabled
|
| 461 |
+
if enable_cprofile and prof:
|
| 462 |
+
prof.disable()
|
| 463 |
+
|
| 464 |
+
import pstats
|
| 465 |
+
import io
|
| 466 |
+
|
| 467 |
+
output_file = "profile_cprofile_detailed.txt"
|
| 468 |
+
with open(output_file, 'w') as f:
|
| 469 |
+
ps = pstats.Stats(prof, stream=f)
|
| 470 |
+
ps.sort_stats('cumulative')
|
| 471 |
+
ps.print_stats(100)
|
| 472 |
+
|
| 473 |
+
# Print top functions
|
| 474 |
+
print("\n" + "=" * 100)
|
| 475 |
+
print("📊 TOP 20 FUNCTIONS BY CUMULATIVE TIME (cProfile)")
|
| 476 |
+
print("=" * 100)
|
| 477 |
+
s = io.StringIO()
|
| 478 |
+
ps = pstats.Stats(prof, stream=s)
|
| 479 |
+
ps.sort_stats('cumulative')
|
| 480 |
+
ps.print_stats(20)
|
| 481 |
+
print(s.getvalue())
|
| 482 |
+
|
| 483 |
+
print(f"\nFull report: {output_file}")
|
| 484 |
+
|
| 485 |
+
# Print results
|
| 486 |
+
print_profiling_results(total_time, show_llm_debug=enable_llm_debug)
|
| 487 |
+
|
| 488 |
+
return result, total_time
|
| 489 |
+
|
| 490 |
+
finally:
|
| 491 |
+
restore_handlers(dit_handler, llm_handler, originals)
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def load_example_config(example_file: str) -> Tuple[GenerationParams, GenerationConfig]:
|
| 495 |
+
"""Load configuration from example JSON file"""
|
| 496 |
+
try:
|
| 497 |
+
with open(example_file, 'r', encoding='utf-8') as f:
|
| 498 |
+
data = json.load(f)
|
| 499 |
+
|
| 500 |
+
params = GenerationParams(
|
| 501 |
+
caption=data.get('caption', ''),
|
| 502 |
+
lyrics=data.get('lyrics', ''),
|
| 503 |
+
bpm=data.get('bpm'),
|
| 504 |
+
keyscale=data.get('keyscale', ''),
|
| 505 |
+
timesignature=data.get('timesignature', ''),
|
| 506 |
+
vocal_language=data.get('language', 'unknown'),
|
| 507 |
+
duration=data.get('duration'),
|
| 508 |
+
thinking=data.get('think', False),
|
| 509 |
+
inference_steps=data.get('inference_steps', 8),
|
| 510 |
+
seed=data.get('seed', 42),
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
config = GenerationConfig(batch_size=data.get('batch_size', 1), seeds=[42])
|
| 514 |
+
|
| 515 |
+
return params, config
|
| 516 |
+
|
| 517 |
+
except Exception as e:
|
| 518 |
+
print(f" ❌ Failed to load: {e}")
|
| 519 |
+
return None, None
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def main():
|
| 523 |
+
global timer, llm_debugger
|
| 524 |
+
|
| 525 |
+
parser = argparse.ArgumentParser(
|
| 526 |
+
description="Profile ACE-Step inference with LLM debugging"
|
| 527 |
+
)
|
| 528 |
+
parser.add_argument("--checkpoint-dir", type=str, default="./checkpoints")
|
| 529 |
+
parser.add_argument("--config-path", type=str, default="acestep-v15-turbo-rl")
|
| 530 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 531 |
+
parser.add_argument("--lm-model", type=str, default="acestep-5Hz-lm-0.6B-v3")
|
| 532 |
+
parser.add_argument("--lm-backend", type=str, default="vllm")
|
| 533 |
+
parser.add_argument("--no-warmup", action="store_true")
|
| 534 |
+
parser.add_argument("--detailed", action="store_true")
|
| 535 |
+
parser.add_argument("--llm-debug", action="store_true",
|
| 536 |
+
help="Enable deep LLM debugging (token count, throughput)")
|
| 537 |
+
parser.add_argument("--example", type=str, default="example_05.json")
|
| 538 |
+
|
| 539 |
+
# Inference mode parameters
|
| 540 |
+
parser.add_argument("--thinking", action="store_true",
|
| 541 |
+
help="Enable CoT reasoning for LM to generate audio codes")
|
| 542 |
+
parser.add_argument("--use-constrained-decoding", action="store_true",
|
| 543 |
+
help="Use FSM-based constrained decoding for meta generation")
|
| 544 |
+
parser.add_argument("--use-cot-metas", action="store_true",
|
| 545 |
+
help="Enable LLM to generate music metadata via CoT reasoning")
|
| 546 |
+
|
| 547 |
+
args = parser.parse_args()
|
| 548 |
+
|
| 549 |
+
# Initialize
|
| 550 |
+
timer = PreciseTimer(device=args.device)
|
| 551 |
+
llm_debugger = LLMDebugger()
|
| 552 |
+
|
| 553 |
+
print("=" * 100)
|
| 554 |
+
print("🎵 ACE-Step Inference Profiler (LLM Performance Analysis)")
|
| 555 |
+
print("=" * 100)
|
| 556 |
+
print(f"\nConfiguration:")
|
| 557 |
+
print(f" Device: {args.device}")
|
| 558 |
+
print(f" LLM Backend: {args.lm_backend}")
|
| 559 |
+
print(f" LLM Debug: {'Enabled' if args.llm_debug else 'Disabled'}")
|
| 560 |
+
print(f" Warmup: {'Disabled' if args.no_warmup else 'Enabled'}")
|
| 561 |
+
print(f"\nInference Mode:")
|
| 562 |
+
print(f" Thinking (CoT): {'Enabled' if args.thinking else 'Disabled'}")
|
| 563 |
+
print(f" Constrained Decoding: {'Enabled' if args.use_constrained_decoding else 'Disabled'}")
|
| 564 |
+
print(f" Use CoT for Metas: {'Enabled' if args.use_cot_metas else 'Disabled'}")
|
| 565 |
+
|
| 566 |
+
# Initialize models
|
| 567 |
+
print(f"\nInitializing models...")
|
| 568 |
+
|
| 569 |
+
dit_handler = AceStepHandler()
|
| 570 |
+
llm_handler = LLMHandler()
|
| 571 |
+
|
| 572 |
+
print(" 🎹 Initializing DiT...")
|
| 573 |
+
status_dit, success_dit = dit_handler.initialize_service(
|
| 574 |
+
project_root=project_root,
|
| 575 |
+
config_path=args.config_path,
|
| 576 |
+
device=args.device,
|
| 577 |
+
use_flash_attention=True,
|
| 578 |
+
)
|
| 579 |
+
if not success_dit:
|
| 580 |
+
print(f" ❌ Failed: {status_dit}")
|
| 581 |
+
sys.exit(1)
|
| 582 |
+
print(f" ✓ DiT ready")
|
| 583 |
+
|
| 584 |
+
print(" 🧠 Initializing LLM...")
|
| 585 |
+
if args.thinking or args.use_cot_metas:
|
| 586 |
+
status_llm, success_llm = llm_handler.initialize(
|
| 587 |
+
checkpoint_dir=args.checkpoint_dir,
|
| 588 |
+
lm_model_path=args.lm_model,
|
| 589 |
+
backend=args.lm_backend,
|
| 590 |
+
device=args.device,
|
| 591 |
+
)
|
| 592 |
+
if success_llm:
|
| 593 |
+
print(f" ✓ LLM ready ({args.lm_backend})")
|
| 594 |
+
else:
|
| 595 |
+
print(f" ⚠ Failed: {status_llm}")
|
| 596 |
+
else:
|
| 597 |
+
print(f" ✓ LLM not initialized (thinking or use_cot_metas is disabled)")
|
| 598 |
+
|
| 599 |
+
# Load example
|
| 600 |
+
example_file = os.path.join(project_root, "examples", "text2music", args.example)
|
| 601 |
+
if not os.path.exists(example_file):
|
| 602 |
+
print(f"\n❌ Not found: {example_file}")
|
| 603 |
+
sys.exit(1)
|
| 604 |
+
|
| 605 |
+
print(f"\n📄 Loading: {args.example}")
|
| 606 |
+
params, config = load_example_config(example_file)
|
| 607 |
+
|
| 608 |
+
if not params or not config:
|
| 609 |
+
print("❌ Failed to load config")
|
| 610 |
+
sys.exit(1)
|
| 611 |
+
|
| 612 |
+
print(f" Caption: {params.caption[:60]}...")
|
| 613 |
+
print(f" Batch: {config.batch_size}, Steps: {params.inference_steps}, LLM: {params.thinking}")
|
| 614 |
+
|
| 615 |
+
# Warmup
|
| 616 |
+
if not args.no_warmup:
|
| 617 |
+
print("\n" + "=" * 100)
|
| 618 |
+
print("🔥 WARMUP RUN")
|
| 619 |
+
print("=" * 100)
|
| 620 |
+
|
| 621 |
+
warmup_params = GenerationParams(
|
| 622 |
+
caption=params.caption,
|
| 623 |
+
lyrics=params.lyrics,
|
| 624 |
+
bpm=params.bpm,
|
| 625 |
+
keyscale=params.keyscale,
|
| 626 |
+
timesignature=params.timesignature,
|
| 627 |
+
vocal_language=params.vocal_language,
|
| 628 |
+
duration=params.duration,
|
| 629 |
+
thinking=args.thinking,
|
| 630 |
+
use_cot_metas=args.use_cot_metas,
|
| 631 |
+
inference_steps=params.inference_steps,
|
| 632 |
+
seed=params.seed,
|
| 633 |
+
)
|
| 634 |
+
warmup_config = GenerationConfig(batch_size=1, seeds=[42])
|
| 635 |
+
warmup_config.use_constrained_decoding = args.use_constrained_decoding
|
| 636 |
+
|
| 637 |
+
warmup_start = time.perf_counter()
|
| 638 |
+
warmup_result = generate_music(dit_handler, llm_handler, warmup_params, warmup_config, save_dir="./")
|
| 639 |
+
warmup_time = time.perf_counter() - warmup_start
|
| 640 |
+
|
| 641 |
+
print(f"\n✓ Warmup: {warmup_time:.2f}s")
|
| 642 |
+
if not warmup_result.success:
|
| 643 |
+
print(f"⚠️ Warning: {warmup_result.error}")
|
| 644 |
+
|
| 645 |
+
# Reset
|
| 646 |
+
timer = PreciseTimer(device=args.device)
|
| 647 |
+
llm_debugger = LLMDebugger()
|
| 648 |
+
|
| 649 |
+
# Profiling run
|
| 650 |
+
print("\n" + "=" * 100)
|
| 651 |
+
print("⏱️ PROFILING RUN")
|
| 652 |
+
print("=" * 100)
|
| 653 |
+
|
| 654 |
+
# Apply inference mode settings
|
| 655 |
+
config.use_constrained_decoding = args.use_constrained_decoding
|
| 656 |
+
# Override thinking and use_cot_metas parameters if specified via CLI
|
| 657 |
+
if args.thinking:
|
| 658 |
+
params.thinking = True
|
| 659 |
+
if args.use_cot_metas:
|
| 660 |
+
params.use_cot_metas = True
|
| 661 |
+
|
| 662 |
+
result, total_time = run_profiled_generation(
|
| 663 |
+
dit_handler, llm_handler, params, config,
|
| 664 |
+
enable_cprofile=args.detailed,
|
| 665 |
+
enable_llm_debug=args.llm_debug
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
if not result.success:
|
| 669 |
+
print(f"\n❌ Failed: {result.error}")
|
| 670 |
+
sys.exit(1)
|
| 671 |
+
|
| 672 |
+
print(f"\n✅ Success! Generated {len(result.audios)} audio file(s)")
|
| 673 |
+
|
| 674 |
+
# Final tips
|
| 675 |
+
if args.detailed:
|
| 676 |
+
print("\n💡 Check profile_cprofile_detailed.txt for function-level analysis")
|
| 677 |
+
elif not args.llm_debug:
|
| 678 |
+
print("\n💡 Run with --llm-debug to see LLM token count and throughput analysis")
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
if __name__ == "__main__":
|
| 682 |
+
main()
|