Spaces:
Running
on
A100
Running
on
A100
support lm duration & fix keyscale bias
Browse files- acestep/llm_inference.py +171 -77
acestep/llm_inference.py
CHANGED
|
@@ -7,7 +7,7 @@ import re
|
|
| 7 |
import traceback
|
| 8 |
import time
|
| 9 |
from enum import Enum, auto
|
| 10 |
-
from typing import Optional, Dict, Any, Tuple, List, Callable
|
| 11 |
from contextlib import contextmanager
|
| 12 |
|
| 13 |
import torch
|
|
@@ -102,6 +102,12 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 102 |
self.metadata_temperature: Optional[float] = None
|
| 103 |
self.codes_temperature: Optional[float] = None
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
# Current state
|
| 106 |
self.state = FSMState.THINK_TAG
|
| 107 |
self.position_in_state = 0 # Position within current state's fixed string
|
|
@@ -242,6 +248,54 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 242 |
# Comma token for multi-genre support
|
| 243 |
comma_tokens = self.tokenizer.encode(",", add_special_tokens=False)
|
| 244 |
self.comma_token = comma_tokens[-1] if comma_tokens else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
def _load_genres_vocab(self):
|
| 247 |
"""
|
|
@@ -566,6 +620,25 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 566 |
self.state = FSMState.THINK_TAG
|
| 567 |
self.position_in_state = 0
|
| 568 |
self.accumulated_value = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
|
| 570 |
def update_caption(self, caption: Optional[str]):
|
| 571 |
"""
|
|
@@ -702,71 +775,23 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 702 |
return newline_prob > max_other_prob
|
| 703 |
|
| 704 |
def _get_allowed_keyscale_tokens(self) -> List[int]:
|
| 705 |
-
"""
|
| 706 |
-
|
|
|
|
|
|
|
|
|
|
| 707 |
acc = self.accumulated_value
|
| 708 |
-
acc_stripped = acc.strip()
|
| 709 |
-
|
| 710 |
-
if not acc_stripped:
|
| 711 |
-
# First character: must be a note (A-G)
|
| 712 |
-
return list(self.note_tokens.values())
|
| 713 |
-
|
| 714 |
-
# Check if we already have a space
|
| 715 |
-
has_space = " " in acc
|
| 716 |
-
|
| 717 |
-
# Parse what we have
|
| 718 |
-
if not has_space:
|
| 719 |
-
# No space yet
|
| 720 |
-
if len(acc_stripped) == 1 and acc_stripped.upper() in "ABCDEFG":
|
| 721 |
-
# After note: can be # ♯ b ♭ or space (for major/minor)
|
| 722 |
-
allowed = self.sharp_tokens + self.flat_tokens
|
| 723 |
-
if self.space_token:
|
| 724 |
-
allowed.append(self.space_token)
|
| 725 |
-
return allowed
|
| 726 |
-
|
| 727 |
-
if len(acc_stripped) >= 2 and acc_stripped[-1] in "#♯b♭":
|
| 728 |
-
# After accidental: must be space
|
| 729 |
-
return [self.space_token] if self.space_token else []
|
| 730 |
|
| 731 |
-
if
|
| 732 |
-
|
| 733 |
-
after_space = acc.split(" ", 1)[-1].lower()
|
| 734 |
-
|
| 735 |
-
# Allow tokens that continue "major" or "minor"
|
| 736 |
-
allowed = []
|
| 737 |
-
for word in ["major", "minor"]:
|
| 738 |
-
if word.startswith(after_space):
|
| 739 |
-
remaining = word[len(after_space):]
|
| 740 |
-
if remaining:
|
| 741 |
-
# Try to encode the next character
|
| 742 |
-
tokens = self.tokenizer.encode(remaining[0], add_special_tokens=False)
|
| 743 |
-
allowed.extend(tokens)
|
| 744 |
-
# Also try encoding the whole remaining part
|
| 745 |
-
tokens = self.tokenizer.encode(remaining, add_special_tokens=False)
|
| 746 |
-
if tokens:
|
| 747 |
-
allowed.append(tokens[0])
|
| 748 |
-
|
| 749 |
-
# If after_space is exactly "major" or "minor", allow newline
|
| 750 |
-
if after_space in ["major", "minor"]:
|
| 751 |
-
if self.newline_token:
|
| 752 |
-
allowed.append(self.newline_token)
|
| 753 |
-
|
| 754 |
-
# If no tokens found but we have incomplete word, this is an error state
|
| 755 |
-
# Force newline if we've tried enough
|
| 756 |
-
if not allowed and len(after_space) > 5:
|
| 757 |
-
if self.newline_token:
|
| 758 |
-
allowed.append(self.newline_token)
|
| 759 |
-
|
| 760 |
-
return list(set(allowed))
|
| 761 |
|
|
|
|
|
|
|
| 762 |
return []
|
| 763 |
|
| 764 |
def _is_keyscale_complete(self) -> bool:
|
| 765 |
-
"""Check if keyscale value is complete and valid."""
|
| 766 |
-
|
| 767 |
-
# Pattern: [A-G][#♯b♭]? (major|minor)
|
| 768 |
-
pattern = r'^[a-g][#♯b♭]?\s*(major|minor)$'
|
| 769 |
-
return bool(re.match(pattern, acc, re.IGNORECASE))
|
| 770 |
|
| 771 |
def _get_allowed_timesig_tokens(self) -> List[int]:
|
| 772 |
"""Get allowed tokens for timesignature field."""
|
|
@@ -797,8 +822,24 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 797 |
if not self.enabled:
|
| 798 |
return self._apply_temperature_scaling(scores)
|
| 799 |
|
| 800 |
-
if self.state == FSMState.COMPLETED
|
| 801 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 802 |
return self._apply_temperature_scaling(scores)
|
| 803 |
|
| 804 |
batch_size = scores.shape[0]
|
|
@@ -890,21 +931,40 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 890 |
scores = scores + mask
|
| 891 |
|
| 892 |
elif self.state == FSMState.DURATION_VALUE:
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 899 |
else:
|
| 900 |
-
|
| 901 |
-
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 908 |
|
| 909 |
elif self.state == FSMState.GENRES_VALUE:
|
| 910 |
# Try to hot-reload genres vocab if file has changed
|
|
@@ -997,7 +1057,14 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 997 |
if not self.enabled:
|
| 998 |
return
|
| 999 |
|
| 1000 |
-
if self.state == FSMState.COMPLETED
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1001 |
return
|
| 1002 |
|
| 1003 |
token_str = self.tokenizer.decode([generated_token_id])
|
|
@@ -1258,6 +1325,7 @@ class LLMHandler:
|
|
| 1258 |
constrained_decoding_debug: bool = False,
|
| 1259 |
metadata_temperature: Optional[float] = 0.85,
|
| 1260 |
codes_temperature: Optional[float] = None,
|
|
|
|
| 1261 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 1262 |
"""Generate metadata and audio codes using 5Hz LM with vllm backend
|
| 1263 |
|
|
@@ -1276,6 +1344,8 @@ class LLMHandler:
|
|
| 1276 |
If None, uses base temperature
|
| 1277 |
codes_temperature: Temperature for audio codes generation (higher = more diverse)
|
| 1278 |
If None, uses base temperature
|
|
|
|
|
|
|
| 1279 |
"""
|
| 1280 |
try:
|
| 1281 |
from nanovllm import SamplingParams
|
|
@@ -1298,6 +1368,7 @@ class LLMHandler:
|
|
| 1298 |
self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
|
| 1299 |
self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
|
| 1300 |
self.constrained_processor.update_caption(caption)
|
|
|
|
| 1301 |
|
| 1302 |
constrained_processor = self.constrained_processor
|
| 1303 |
update_state_fn = constrained_processor.update_state
|
|
@@ -1357,6 +1428,7 @@ class LLMHandler:
|
|
| 1357 |
constrained_decoding_debug: bool = False,
|
| 1358 |
metadata_temperature: Optional[float] = 0.85,
|
| 1359 |
codes_temperature: Optional[float] = None,
|
|
|
|
| 1360 |
) -> str:
|
| 1361 |
"""Shared vllm path: accept prebuilt formatted prompt and return text."""
|
| 1362 |
from nanovllm import SamplingParams
|
|
@@ -1374,6 +1446,7 @@ class LLMHandler:
|
|
| 1374 |
self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
|
| 1375 |
self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
|
| 1376 |
self.constrained_processor.update_caption(formatted_prompt) # Use formatted prompt for genre extraction
|
|
|
|
| 1377 |
|
| 1378 |
constrained_processor = self.constrained_processor
|
| 1379 |
|
|
@@ -1427,6 +1500,7 @@ class LLMHandler:
|
|
| 1427 |
constrained_decoding_debug: bool = False,
|
| 1428 |
metadata_temperature: Optional[float] = 0.85,
|
| 1429 |
codes_temperature: Optional[float] = None,
|
|
|
|
| 1430 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 1431 |
"""Generate metadata and audio codes using 5Hz LM with PyTorch backend
|
| 1432 |
|
|
@@ -1445,6 +1519,8 @@ class LLMHandler:
|
|
| 1445 |
If None, uses base temperature
|
| 1446 |
codes_temperature: Temperature for audio codes generation (higher = more diverse)
|
| 1447 |
If None, uses base temperature
|
|
|
|
|
|
|
| 1448 |
"""
|
| 1449 |
try:
|
| 1450 |
formatted_prompt = self.build_formatted_prompt(caption, lyrics)
|
|
@@ -1496,6 +1572,7 @@ class LLMHandler:
|
|
| 1496 |
self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
|
| 1497 |
self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
|
| 1498 |
self.constrained_processor.update_caption(caption)
|
|
|
|
| 1499 |
|
| 1500 |
constrained_processor = self.constrained_processor
|
| 1501 |
|
|
@@ -1623,6 +1700,7 @@ class LLMHandler:
|
|
| 1623 |
repetition_penalty: float,
|
| 1624 |
use_constrained_decoding: bool = True,
|
| 1625 |
constrained_decoding_debug: bool = False,
|
|
|
|
| 1626 |
) -> str:
|
| 1627 |
"""Shared PyTorch path: accept prebuilt formatted prompt and return text."""
|
| 1628 |
inputs = self.llm_tokenizer(
|
|
@@ -1639,6 +1717,7 @@ class LLMHandler:
|
|
| 1639 |
self.constrained_processor.enabled = use_constrained_decoding
|
| 1640 |
self.constrained_processor.debug = constrained_decoding_debug
|
| 1641 |
self.constrained_processor.update_caption(formatted_prompt) # Use formatted prompt for genre extraction
|
|
|
|
| 1642 |
|
| 1643 |
constrained_processor = self.constrained_processor
|
| 1644 |
|
|
@@ -1764,6 +1843,7 @@ class LLMHandler:
|
|
| 1764 |
constrained_decoding_debug: bool = False,
|
| 1765 |
metadata_temperature: Optional[float] = 0.85,
|
| 1766 |
codes_temperature: Optional[float] = None,
|
|
|
|
| 1767 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 1768 |
"""Generate metadata and audio codes using 5Hz LM
|
| 1769 |
|
|
@@ -1782,6 +1862,8 @@ class LLMHandler:
|
|
| 1782 |
Recommended: 0.3-0.5 for accurate metadata
|
| 1783 |
codes_temperature: Temperature for audio codes generation (higher = more diverse)
|
| 1784 |
Recommended: 0.7-1.0 for diverse codes
|
|
|
|
|
|
|
| 1785 |
"""
|
| 1786 |
# Check if 5Hz LM is initialized
|
| 1787 |
if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
|
|
@@ -1811,6 +1893,7 @@ class LLMHandler:
|
|
| 1811 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 1812 |
metadata_temperature=metadata_temperature,
|
| 1813 |
codes_temperature=codes_temperature,
|
|
|
|
| 1814 |
)
|
| 1815 |
else:
|
| 1816 |
return self.generate_with_5hz_lm_pt(
|
|
@@ -1826,6 +1909,7 @@ class LLMHandler:
|
|
| 1826 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 1827 |
metadata_temperature=metadata_temperature,
|
| 1828 |
codes_temperature=codes_temperature,
|
|
|
|
| 1829 |
)
|
| 1830 |
|
| 1831 |
def generate_with_stop_condition(
|
|
@@ -1843,11 +1927,16 @@ class LLMHandler:
|
|
| 1843 |
constrained_decoding_debug: bool = False,
|
| 1844 |
metadata_temperature: Optional[float] = 0.85,
|
| 1845 |
codes_temperature: Optional[float] = None,
|
|
|
|
| 1846 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 1847 |
"""Feishu-compatible LM generation.
|
| 1848 |
|
| 1849 |
- infer_type='dit': stop at </think> and return metas only (no audio codes)
|
| 1850 |
- infer_type='llm_dit': normal generation (metas + audio codes)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1851 |
"""
|
| 1852 |
infer_type = (infer_type or "").strip().lower()
|
| 1853 |
if infer_type not in {"dit", "llm_dit"}:
|
|
@@ -1867,6 +1956,7 @@ class LLMHandler:
|
|
| 1867 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 1868 |
metadata_temperature=metadata_temperature,
|
| 1869 |
codes_temperature=codes_temperature,
|
|
|
|
| 1870 |
)
|
| 1871 |
|
| 1872 |
# dit: generate and truncate at reasoning end tag
|
|
@@ -1934,6 +2024,7 @@ class LLMHandler:
|
|
| 1934 |
- cfg_scale (float)
|
| 1935 |
- negative_prompt (str) used when cfg_scale > 1
|
| 1936 |
- top_k (int), top_p (float), repetition_penalty (float)
|
|
|
|
| 1937 |
use_constrained_decoding: Whether to use FSM-based constrained decoding
|
| 1938 |
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 1939 |
|
|
@@ -1956,6 +2047,7 @@ class LLMHandler:
|
|
| 1956 |
top_k = cfg.get("top_k")
|
| 1957 |
top_p = cfg.get("top_p")
|
| 1958 |
repetition_penalty = cfg.get("repetition_penalty", 1.0)
|
|
|
|
| 1959 |
|
| 1960 |
try:
|
| 1961 |
if self.llm_backend == "vllm":
|
|
@@ -1969,6 +2061,7 @@ class LLMHandler:
|
|
| 1969 |
repetition_penalty=repetition_penalty,
|
| 1970 |
use_constrained_decoding=use_constrained_decoding,
|
| 1971 |
constrained_decoding_debug=constrained_decoding_debug,
|
|
|
|
| 1972 |
)
|
| 1973 |
return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
|
| 1974 |
|
|
@@ -1983,6 +2076,7 @@ class LLMHandler:
|
|
| 1983 |
repetition_penalty=repetition_penalty,
|
| 1984 |
use_constrained_decoding=use_constrained_decoding,
|
| 1985 |
constrained_decoding_debug=constrained_decoding_debug,
|
|
|
|
| 1986 |
)
|
| 1987 |
return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}"
|
| 1988 |
|
|
|
|
| 7 |
import traceback
|
| 8 |
import time
|
| 9 |
from enum import Enum, auto
|
| 10 |
+
from typing import Optional, Dict, Any, Tuple, List, Callable, Set
|
| 11 |
from contextlib import contextmanager
|
| 12 |
|
| 13 |
import torch
|
|
|
|
| 102 |
self.metadata_temperature: Optional[float] = None
|
| 103 |
self.codes_temperature: Optional[float] = None
|
| 104 |
|
| 105 |
+
# Duration constraint for codes generation
|
| 106 |
+
# 5 codes = 1 second, so target_codes = target_duration * 5
|
| 107 |
+
self.target_duration: Optional[float] = None # User-specified duration in seconds
|
| 108 |
+
self.target_codes: Optional[int] = None # Computed target codes count
|
| 109 |
+
self.codes_count: int = 0 # Counter for generated codes
|
| 110 |
+
|
| 111 |
# Current state
|
| 112 |
self.state = FSMState.THINK_TAG
|
| 113 |
self.position_in_state = 0 # Position within current state's fixed string
|
|
|
|
| 248 |
# Comma token for multi-genre support
|
| 249 |
comma_tokens = self.tokenizer.encode(",", add_special_tokens=False)
|
| 250 |
self.comma_token = comma_tokens[-1] if comma_tokens else None
|
| 251 |
+
|
| 252 |
+
# EOS token for duration-constrained codes generation
|
| 253 |
+
self.eos_token_id = self.tokenizer.eos_token_id
|
| 254 |
+
|
| 255 |
+
# Build valid keyscales set and prefix tree for constrained decoding
|
| 256 |
+
# 7 notes × 5 accidentals (none, #, b, ♯, ♭) × 2 modes = 70 valid combinations
|
| 257 |
+
notes = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
|
| 258 |
+
accidentals = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
|
| 259 |
+
modes = ['major', 'minor']
|
| 260 |
+
|
| 261 |
+
self.valid_keyscales = set()
|
| 262 |
+
for note in notes:
|
| 263 |
+
for acc in accidentals:
|
| 264 |
+
for mode in modes:
|
| 265 |
+
self.valid_keyscales.add(f"{note}{acc} {mode}")
|
| 266 |
+
|
| 267 |
+
# Build prefix tree for keyscale constrained decoding
|
| 268 |
+
self.keyscale_prefix_tree = self._build_keyscale_prefix_tree()
|
| 269 |
+
|
| 270 |
+
def _build_keyscale_prefix_tree(self) -> Dict[str, Set[int]]:
|
| 271 |
+
"""
|
| 272 |
+
Build keyscale prefix to allowed tokens mapping.
|
| 273 |
+
For each prefix of each valid keyscale, we store the set of tokens
|
| 274 |
+
that can continue to form a valid keyscale.
|
| 275 |
+
"""
|
| 276 |
+
prefix_to_tokens: Dict[str, Set[int]] = {}
|
| 277 |
+
|
| 278 |
+
for keyscale in self.valid_keyscales:
|
| 279 |
+
for i in range(len(keyscale)):
|
| 280 |
+
prefix = keyscale[:i]
|
| 281 |
+
next_char = keyscale[i]
|
| 282 |
+
# Encode the next character
|
| 283 |
+
tokens = self.tokenizer.encode(next_char, add_special_tokens=False)
|
| 284 |
+
if prefix not in prefix_to_tokens:
|
| 285 |
+
prefix_to_tokens[prefix] = set()
|
| 286 |
+
prefix_to_tokens[prefix].update(tokens)
|
| 287 |
+
|
| 288 |
+
# For complete keyscales, allow newline token
|
| 289 |
+
for keyscale in self.valid_keyscales:
|
| 290 |
+
if keyscale not in prefix_to_tokens:
|
| 291 |
+
prefix_to_tokens[keyscale] = set()
|
| 292 |
+
if self.newline_token:
|
| 293 |
+
prefix_to_tokens[keyscale].add(self.newline_token)
|
| 294 |
+
|
| 295 |
+
if self.debug:
|
| 296 |
+
logger.debug(f"Built keyscale prefix tree with {len(prefix_to_tokens)} prefixes for {len(self.valid_keyscales)} valid keyscales")
|
| 297 |
+
|
| 298 |
+
return prefix_to_tokens
|
| 299 |
|
| 300 |
def _load_genres_vocab(self):
|
| 301 |
"""
|
|
|
|
| 620 |
self.state = FSMState.THINK_TAG
|
| 621 |
self.position_in_state = 0
|
| 622 |
self.accumulated_value = ""
|
| 623 |
+
self.codes_count = 0 # Reset codes counter
|
| 624 |
+
|
| 625 |
+
def set_target_duration(self, duration: Optional[float]):
|
| 626 |
+
"""
|
| 627 |
+
Set the target duration for codes generation.
|
| 628 |
+
|
| 629 |
+
Args:
|
| 630 |
+
duration: Target duration in seconds. If None, no duration constraint is applied.
|
| 631 |
+
5 codes = 1 second, so target_codes = duration * 5.
|
| 632 |
+
"""
|
| 633 |
+
self.target_duration = duration
|
| 634 |
+
if duration is not None and duration > 0:
|
| 635 |
+
self.target_codes = int(duration * 5)
|
| 636 |
+
if self.debug:
|
| 637 |
+
logger.debug(f"Set target duration: {duration}s -> {self.target_codes} codes")
|
| 638 |
+
else:
|
| 639 |
+
self.target_codes = None
|
| 640 |
+
if self.debug:
|
| 641 |
+
logger.debug("Target duration cleared, no duration constraint")
|
| 642 |
|
| 643 |
def update_caption(self, caption: Optional[str]):
|
| 644 |
"""
|
|
|
|
| 775 |
return newline_prob > max_other_prob
|
| 776 |
|
| 777 |
def _get_allowed_keyscale_tokens(self) -> List[int]:
|
| 778 |
+
"""
|
| 779 |
+
Get allowed tokens for keyscale field using prefix tree.
|
| 780 |
+
Only allows tokens that can lead to valid keyscales like:
|
| 781 |
+
- "A major", "A minor", "A# major", "Ab minor", etc.
|
| 782 |
+
"""
|
| 783 |
acc = self.accumulated_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 784 |
|
| 785 |
+
if acc in self.keyscale_prefix_tree:
|
| 786 |
+
return list(self.keyscale_prefix_tree[acc])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 787 |
|
| 788 |
+
# No valid continuation found - return empty list
|
| 789 |
+
# The caller will handle this by forcing newline to end the field
|
| 790 |
return []
|
| 791 |
|
| 792 |
def _is_keyscale_complete(self) -> bool:
|
| 793 |
+
"""Check if keyscale value is complete and valid by checking against valid_keyscales set."""
|
| 794 |
+
return self.accumulated_value in self.valid_keyscales
|
|
|
|
|
|
|
|
|
|
| 795 |
|
| 796 |
def _get_allowed_timesig_tokens(self) -> List[int]:
|
| 797 |
"""Get allowed tokens for timesignature field."""
|
|
|
|
| 822 |
if not self.enabled:
|
| 823 |
return self._apply_temperature_scaling(scores)
|
| 824 |
|
| 825 |
+
if self.state == FSMState.COMPLETED:
|
| 826 |
+
return self._apply_temperature_scaling(scores)
|
| 827 |
+
|
| 828 |
+
if self.state == FSMState.CODES_GENERATION:
|
| 829 |
+
# Apply duration constraint in codes generation phase
|
| 830 |
+
if self.target_codes is not None and self.eos_token_id is not None:
|
| 831 |
+
if self.codes_count < self.target_codes:
|
| 832 |
+
# Block EOS token until target codes count is reached
|
| 833 |
+
scores[:, self.eos_token_id] = float('-inf')
|
| 834 |
+
if self.debug:
|
| 835 |
+
logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS")
|
| 836 |
+
else:
|
| 837 |
+
# Force EOS token when target codes count is reached
|
| 838 |
+
mask = torch.full_like(scores, float('-inf'))
|
| 839 |
+
mask[:, self.eos_token_id] = 0
|
| 840 |
+
scores = scores + mask
|
| 841 |
+
if self.debug:
|
| 842 |
+
logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS")
|
| 843 |
return self._apply_temperature_scaling(scores)
|
| 844 |
|
| 845 |
batch_size = scores.shape[0]
|
|
|
|
| 931 |
scores = scores + mask
|
| 932 |
|
| 933 |
elif self.state == FSMState.DURATION_VALUE:
|
| 934 |
+
# If target_duration is set, force generate that exact value
|
| 935 |
+
if self.target_duration is not None:
|
| 936 |
+
target_str = str(int(self.target_duration))
|
| 937 |
+
current_pos = len(self.accumulated_value)
|
| 938 |
+
|
| 939 |
+
if current_pos < len(target_str):
|
| 940 |
+
# Force the next digit
|
| 941 |
+
next_digit = int(target_str[current_pos])
|
| 942 |
+
if next_digit in self.digit_tokens:
|
| 943 |
+
mask[0, self.digit_tokens[next_digit]] = 0
|
| 944 |
+
else:
|
| 945 |
+
# All digits generated, force newline
|
| 946 |
+
if self.newline_token:
|
| 947 |
+
mask[0, self.newline_token] = 0
|
| 948 |
+
self._transition_to_next_state()
|
| 949 |
+
|
| 950 |
+
scores = scores + mask
|
| 951 |
else:
|
| 952 |
+
# Normal duration generation with range constraint
|
| 953 |
+
min_val, max_val = self.field_specs["duration"]["min"], self.field_specs["duration"]["max"]
|
| 954 |
+
|
| 955 |
+
if self._should_end_numeric_field(scores, min_val, max_val):
|
| 956 |
+
if self.newline_token:
|
| 957 |
+
mask[0, self.newline_token] = 0
|
| 958 |
+
self._transition_to_next_state()
|
| 959 |
+
else:
|
| 960 |
+
allowed = self._get_allowed_digit_tokens(min_val, max_val)
|
| 961 |
+
for t in allowed:
|
| 962 |
+
mask[0, t] = 0
|
| 963 |
+
current = int(self.accumulated_value) if self.accumulated_value else 0
|
| 964 |
+
if min_val <= current <= max_val and self.newline_token:
|
| 965 |
+
mask[0, self.newline_token] = 0
|
| 966 |
+
|
| 967 |
+
scores = scores + mask
|
| 968 |
|
| 969 |
elif self.state == FSMState.GENRES_VALUE:
|
| 970 |
# Try to hot-reload genres vocab if file has changed
|
|
|
|
| 1057 |
if not self.enabled:
|
| 1058 |
return
|
| 1059 |
|
| 1060 |
+
if self.state == FSMState.COMPLETED:
|
| 1061 |
+
return
|
| 1062 |
+
|
| 1063 |
+
if self.state == FSMState.CODES_GENERATION:
|
| 1064 |
+
# Count generated codes for duration constraint
|
| 1065 |
+
self.codes_count += 1
|
| 1066 |
+
if self.debug and self.target_codes is not None:
|
| 1067 |
+
logger.debug(f"Codes count: {self.codes_count}/{self.target_codes}")
|
| 1068 |
return
|
| 1069 |
|
| 1070 |
token_str = self.tokenizer.decode([generated_token_id])
|
|
|
|
| 1325 |
constrained_decoding_debug: bool = False,
|
| 1326 |
metadata_temperature: Optional[float] = 0.85,
|
| 1327 |
codes_temperature: Optional[float] = None,
|
| 1328 |
+
target_duration: Optional[float] = None,
|
| 1329 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 1330 |
"""Generate metadata and audio codes using 5Hz LM with vllm backend
|
| 1331 |
|
|
|
|
| 1344 |
If None, uses base temperature
|
| 1345 |
codes_temperature: Temperature for audio codes generation (higher = more diverse)
|
| 1346 |
If None, uses base temperature
|
| 1347 |
+
target_duration: Target duration in seconds for codes generation constraint.
|
| 1348 |
+
5 codes = 1 second. If specified, blocks EOS until target reached.
|
| 1349 |
"""
|
| 1350 |
try:
|
| 1351 |
from nanovllm import SamplingParams
|
|
|
|
| 1368 |
self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
|
| 1369 |
self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
|
| 1370 |
self.constrained_processor.update_caption(caption)
|
| 1371 |
+
self.constrained_processor.set_target_duration(target_duration)
|
| 1372 |
|
| 1373 |
constrained_processor = self.constrained_processor
|
| 1374 |
update_state_fn = constrained_processor.update_state
|
|
|
|
| 1428 |
constrained_decoding_debug: bool = False,
|
| 1429 |
metadata_temperature: Optional[float] = 0.85,
|
| 1430 |
codes_temperature: Optional[float] = None,
|
| 1431 |
+
target_duration: Optional[float] = None,
|
| 1432 |
) -> str:
|
| 1433 |
"""Shared vllm path: accept prebuilt formatted prompt and return text."""
|
| 1434 |
from nanovllm import SamplingParams
|
|
|
|
| 1446 |
self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
|
| 1447 |
self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
|
| 1448 |
self.constrained_processor.update_caption(formatted_prompt) # Use formatted prompt for genre extraction
|
| 1449 |
+
self.constrained_processor.set_target_duration(target_duration)
|
| 1450 |
|
| 1451 |
constrained_processor = self.constrained_processor
|
| 1452 |
|
|
|
|
| 1500 |
constrained_decoding_debug: bool = False,
|
| 1501 |
metadata_temperature: Optional[float] = 0.85,
|
| 1502 |
codes_temperature: Optional[float] = None,
|
| 1503 |
+
target_duration: Optional[float] = None,
|
| 1504 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 1505 |
"""Generate metadata and audio codes using 5Hz LM with PyTorch backend
|
| 1506 |
|
|
|
|
| 1519 |
If None, uses base temperature
|
| 1520 |
codes_temperature: Temperature for audio codes generation (higher = more diverse)
|
| 1521 |
If None, uses base temperature
|
| 1522 |
+
target_duration: Target duration in seconds for codes generation constraint.
|
| 1523 |
+
5 codes = 1 second. If specified, blocks EOS until target reached.
|
| 1524 |
"""
|
| 1525 |
try:
|
| 1526 |
formatted_prompt = self.build_formatted_prompt(caption, lyrics)
|
|
|
|
| 1572 |
self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
|
| 1573 |
self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
|
| 1574 |
self.constrained_processor.update_caption(caption)
|
| 1575 |
+
self.constrained_processor.set_target_duration(target_duration)
|
| 1576 |
|
| 1577 |
constrained_processor = self.constrained_processor
|
| 1578 |
|
|
|
|
| 1700 |
repetition_penalty: float,
|
| 1701 |
use_constrained_decoding: bool = True,
|
| 1702 |
constrained_decoding_debug: bool = False,
|
| 1703 |
+
target_duration: Optional[float] = None,
|
| 1704 |
) -> str:
|
| 1705 |
"""Shared PyTorch path: accept prebuilt formatted prompt and return text."""
|
| 1706 |
inputs = self.llm_tokenizer(
|
|
|
|
| 1717 |
self.constrained_processor.enabled = use_constrained_decoding
|
| 1718 |
self.constrained_processor.debug = constrained_decoding_debug
|
| 1719 |
self.constrained_processor.update_caption(formatted_prompt) # Use formatted prompt for genre extraction
|
| 1720 |
+
self.constrained_processor.set_target_duration(target_duration)
|
| 1721 |
|
| 1722 |
constrained_processor = self.constrained_processor
|
| 1723 |
|
|
|
|
| 1843 |
constrained_decoding_debug: bool = False,
|
| 1844 |
metadata_temperature: Optional[float] = 0.85,
|
| 1845 |
codes_temperature: Optional[float] = None,
|
| 1846 |
+
target_duration: Optional[float] = None,
|
| 1847 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 1848 |
"""Generate metadata and audio codes using 5Hz LM
|
| 1849 |
|
|
|
|
| 1862 |
Recommended: 0.3-0.5 for accurate metadata
|
| 1863 |
codes_temperature: Temperature for audio codes generation (higher = more diverse)
|
| 1864 |
Recommended: 0.7-1.0 for diverse codes
|
| 1865 |
+
target_duration: Target duration in seconds for codes generation constraint.
|
| 1866 |
+
5 codes = 1 second. If specified, blocks EOS until target reached.
|
| 1867 |
"""
|
| 1868 |
# Check if 5Hz LM is initialized
|
| 1869 |
if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
|
|
|
|
| 1893 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 1894 |
metadata_temperature=metadata_temperature,
|
| 1895 |
codes_temperature=codes_temperature,
|
| 1896 |
+
target_duration=target_duration,
|
| 1897 |
)
|
| 1898 |
else:
|
| 1899 |
return self.generate_with_5hz_lm_pt(
|
|
|
|
| 1909 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 1910 |
metadata_temperature=metadata_temperature,
|
| 1911 |
codes_temperature=codes_temperature,
|
| 1912 |
+
target_duration=target_duration,
|
| 1913 |
)
|
| 1914 |
|
| 1915 |
def generate_with_stop_condition(
|
|
|
|
| 1927 |
constrained_decoding_debug: bool = False,
|
| 1928 |
metadata_temperature: Optional[float] = 0.85,
|
| 1929 |
codes_temperature: Optional[float] = None,
|
| 1930 |
+
target_duration: Optional[float] = None,
|
| 1931 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 1932 |
"""Feishu-compatible LM generation.
|
| 1933 |
|
| 1934 |
- infer_type='dit': stop at </think> and return metas only (no audio codes)
|
| 1935 |
- infer_type='llm_dit': normal generation (metas + audio codes)
|
| 1936 |
+
|
| 1937 |
+
Args:
|
| 1938 |
+
target_duration: Target duration in seconds for codes generation constraint.
|
| 1939 |
+
5 codes = 1 second. If specified, blocks EOS until target reached.
|
| 1940 |
"""
|
| 1941 |
infer_type = (infer_type or "").strip().lower()
|
| 1942 |
if infer_type not in {"dit", "llm_dit"}:
|
|
|
|
| 1956 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 1957 |
metadata_temperature=metadata_temperature,
|
| 1958 |
codes_temperature=codes_temperature,
|
| 1959 |
+
target_duration=target_duration,
|
| 1960 |
)
|
| 1961 |
|
| 1962 |
# dit: generate and truncate at reasoning end tag
|
|
|
|
| 2024 |
- cfg_scale (float)
|
| 2025 |
- negative_prompt (str) used when cfg_scale > 1
|
| 2026 |
- top_k (int), top_p (float), repetition_penalty (float)
|
| 2027 |
+
- target_duration (float): Target duration in seconds for codes generation
|
| 2028 |
use_constrained_decoding: Whether to use FSM-based constrained decoding
|
| 2029 |
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 2030 |
|
|
|
|
| 2047 |
top_k = cfg.get("top_k")
|
| 2048 |
top_p = cfg.get("top_p")
|
| 2049 |
repetition_penalty = cfg.get("repetition_penalty", 1.0)
|
| 2050 |
+
target_duration = cfg.get("target_duration")
|
| 2051 |
|
| 2052 |
try:
|
| 2053 |
if self.llm_backend == "vllm":
|
|
|
|
| 2061 |
repetition_penalty=repetition_penalty,
|
| 2062 |
use_constrained_decoding=use_constrained_decoding,
|
| 2063 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 2064 |
+
target_duration=target_duration,
|
| 2065 |
)
|
| 2066 |
return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
|
| 2067 |
|
|
|
|
| 2076 |
repetition_penalty=repetition_penalty,
|
| 2077 |
use_constrained_decoding=use_constrained_decoding,
|
| 2078 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 2079 |
+
target_duration=target_duration,
|
| 2080 |
)
|
| 2081 |
return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}"
|
| 2082 |
|