Spaces:
Running
on
A100
Running
on
A100
fix bugs
Browse files- acestep/constrained_logits_processor.py +398 -15
- acestep/gradio_ui.py +362 -17
- acestep/llm_inference.py +60 -1
- acestep/test_time_scaling.py +261 -0
acestep/constrained_logits_processor.py
CHANGED
|
@@ -35,6 +35,9 @@ class FSMState(Enum):
|
|
| 35 |
DURATION_NAME = auto() # Generating "duration: "
|
| 36 |
DURATION_VALUE = auto() # Generating numeric value 10-600
|
| 37 |
NEWLINE_AFTER_DURATION = auto()
|
|
|
|
|
|
|
|
|
|
| 38 |
KEYSCALE_NAME = auto() # Generating "keyscale: "
|
| 39 |
KEYSCALE_VALUE = auto() # Generating keyscale pattern
|
| 40 |
NEWLINE_AFTER_KEYSCALE = auto()
|
|
@@ -74,7 +77,8 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 74 |
tokenizer: AutoTokenizer,
|
| 75 |
enabled: bool = True,
|
| 76 |
debug: bool = False,
|
| 77 |
-
|
|
|
|
| 78 |
):
|
| 79 |
"""
|
| 80 |
Initialize the constrained logits processor.
|
|
@@ -89,6 +93,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 89 |
self.tokenizer = tokenizer
|
| 90 |
self.enabled = enabled
|
| 91 |
self.debug = debug
|
|
|
|
| 92 |
self.skip_caption = False # Set to True to skip caption field generation
|
| 93 |
self.skip_language = False # Set to True to skip language field generation
|
| 94 |
self.caption: Optional[str] = None # Set via update_caption() before each generation
|
|
@@ -103,6 +108,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 103 |
"keyscale": None,
|
| 104 |
"language": None,
|
| 105 |
"timesignature": None,
|
|
|
|
| 106 |
}
|
| 107 |
|
| 108 |
# Temperature settings for different generation phases (set per-generation)
|
|
@@ -143,6 +149,16 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 143 |
# Pre-compute token IDs for efficiency
|
| 144 |
self._precompute_tokens()
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
self._char_to_tokens: Dict[str, set] = {} # Precomputed char -> token IDs mapping
|
| 147 |
|
| 148 |
# Precompute token mappings once (O(vocab_size), runs once at init)
|
|
@@ -186,6 +202,8 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 186 |
# Build language prefix tree (similar to keyscale but for language codes)
|
| 187 |
self.language_prefix_tree = self._build_language_prefix_tree()
|
| 188 |
|
|
|
|
|
|
|
| 189 |
# Fixed strings for each state
|
| 190 |
# IMPORTANT: Do NOT include trailing space after colon - tokenizer will handle spacing
|
| 191 |
# All matching should be done at token level, not string level
|
|
@@ -196,6 +214,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 196 |
FSMState.BPM_NAME: "bpm:",
|
| 197 |
FSMState.CAPTION_NAME: "caption:",
|
| 198 |
FSMState.DURATION_NAME: "duration:",
|
|
|
|
| 199 |
FSMState.KEYSCALE_NAME: "keyscale:",
|
| 200 |
FSMState.LANGUAGE_NAME: "language:",
|
| 201 |
FSMState.TIMESIG_NAME: "timesignature:",
|
|
@@ -211,17 +230,19 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 211 |
even if the field is user-provided (we still need to generate the field name).
|
| 212 |
|
| 213 |
Args:
|
| 214 |
-
current_field: Current field name ("bpm", "caption", "duration", "keyscale", "language", "timesignature")
|
| 215 |
|
| 216 |
Returns:
|
| 217 |
Next FSMState (NAME state of next field), or THINK_END_TAG if no more fields
|
| 218 |
"""
|
| 219 |
# New field order: bpm -> caption -> duration -> keyscale -> language -> timesignature
|
| 220 |
-
|
|
|
|
| 221 |
field_to_state = {
|
| 222 |
"bpm": FSMState.BPM_NAME,
|
| 223 |
"caption": FSMState.CAPTION_NAME,
|
| 224 |
"duration": FSMState.DURATION_NAME,
|
|
|
|
| 225 |
"keyscale": FSMState.KEYSCALE_NAME,
|
| 226 |
"language": FSMState.LANGUAGE_NAME,
|
| 227 |
"timesignature": FSMState.TIMESIG_NAME,
|
|
@@ -235,7 +256,10 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 235 |
# Find next field in order
|
| 236 |
for i in range(current_idx + 1, len(field_order)):
|
| 237 |
field = field_order[i]
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
| 239 |
if field == "caption" and self.skip_caption:
|
| 240 |
continue
|
| 241 |
if field == "language" and self.skip_language:
|
|
@@ -257,7 +281,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 257 |
}
|
| 258 |
|
| 259 |
# Build transitions for all fields (even if user-provided, we still need to generate field name)
|
| 260 |
-
# Field order: bpm -> caption -> duration -> keyscale -> language -> timesignature
|
| 261 |
|
| 262 |
# BPM field: NAME -> VALUE -> next field (caption or duration)
|
| 263 |
self.next_state[FSMState.BPM_NAME] = FSMState.BPM_VALUE
|
|
@@ -271,6 +295,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 271 |
# Duration field: NAME -> VALUE -> next field
|
| 272 |
self.next_state[FSMState.DURATION_NAME] = FSMState.DURATION_VALUE
|
| 273 |
self.next_state[FSMState.DURATION_VALUE] = self._get_next_field_state("duration")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
# Keyscale field: NAME -> VALUE -> next field (language or timesignature)
|
| 276 |
self.next_state[FSMState.KEYSCALE_NAME] = FSMState.KEYSCALE_VALUE
|
|
@@ -284,6 +313,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 284 |
# Timesignature field: NAME -> VALUE -> THINK_END_TAG
|
| 285 |
self.next_state[FSMState.TIMESIG_NAME] = FSMState.TIMESIG_VALUE
|
| 286 |
self.next_state[FSMState.TIMESIG_VALUE] = FSMState.THINK_END_TAG
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
def set_skip_caption(self, skip: bool):
|
| 289 |
"""Set whether to skip caption generation and rebuild state transitions."""
|
|
@@ -366,13 +400,14 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 366 |
- "keyscale": Optional[str] - e.g., "G major"
|
| 367 |
- "language": Optional[str] - e.g., "en"
|
| 368 |
- "timesignature": Optional[str] - e.g., "4"
|
|
|
|
| 369 |
If None, clears all user-provided metadata.
|
| 370 |
"""
|
| 371 |
if metadata is None:
|
| 372 |
metadata = {}
|
| 373 |
|
| 374 |
# Update user-provided metadata
|
| 375 |
-
for field in ["bpm", "caption", "duration", "keyscale", "language", "timesignature"]:
|
| 376 |
if field in metadata:
|
| 377 |
self.user_provided_metadata[field] = metadata[field]
|
| 378 |
else:
|
|
@@ -437,6 +472,10 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 437 |
|
| 438 |
# Vocab size
|
| 439 |
self.vocab_size = len(self.tokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
|
| 441 |
# EOS token for duration-constrained codes generation
|
| 442 |
self.eos_token_id = self.tokenizer.eos_token_id
|
|
@@ -531,7 +570,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 531 |
|
| 532 |
if self.debug:
|
| 533 |
logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens")
|
| 534 |
-
|
| 535 |
def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
|
| 536 |
"""
|
| 537 |
Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
|
|
@@ -808,6 +847,133 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 808 |
|
| 809 |
print("=" * 60)
|
| 810 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 811 |
def _precompute_char_token_mapping(self):
|
| 812 |
"""
|
| 813 |
Precompute mapping from characters to token IDs and token decoded texts.
|
|
@@ -859,8 +1025,36 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 859 |
|
| 860 |
if self.debug:
|
| 861 |
logger.debug(f"Precomputed char->token mapping for {len(self._char_to_tokens)} unique characters")
|
| 862 |
-
|
| 863 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 864 |
def _get_trie_node_from_trie(self, trie: Dict, prefix: str) -> Optional[Dict]:
|
| 865 |
"""Get a trie node from a specific trie (helper for caption vs full trie)."""
|
| 866 |
node = trie
|
|
@@ -870,6 +1064,108 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 870 |
node = node[char]
|
| 871 |
return node
|
| 872 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
def reset(self):
|
| 874 |
"""Reset the processor state for a new generation."""
|
| 875 |
self.state = FSMState.THINK_TAG
|
|
@@ -1061,6 +1357,26 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1061 |
|
| 1062 |
return newline_prob > max_digit_prob
|
| 1063 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1064 |
def _get_allowed_keyscale_tokens(self) -> List[int]:
|
| 1065 |
"""
|
| 1066 |
Get allowed tokens for keyscale field using the precomputed prefix tree.
|
|
@@ -1265,6 +1581,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1265 |
"keyscale": "keyscale: ",
|
| 1266 |
"language": "language: ",
|
| 1267 |
"timesignature": "timesignature: ",
|
|
|
|
| 1268 |
}
|
| 1269 |
prefix = field_to_prefix[field_name]
|
| 1270 |
full_text = f"{prefix}{value}\n"
|
|
@@ -1428,9 +1745,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1428 |
# Allow free generation (no constraints) so LM can generate field name naturally
|
| 1429 |
return scores
|
| 1430 |
else:
|
| 1431 |
-
# It's indentation, continue caption
|
| 1432 |
self.caption_after_newline = False
|
| 1433 |
-
|
|
|
|
|
|
|
| 1434 |
# If caption is ending (LM generating next field name), allow free generation
|
| 1435 |
# and track the field name until we see colon
|
| 1436 |
if self.caption_ending:
|
|
@@ -1505,7 +1824,55 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1505 |
mask[0, self.newline_token] = 0
|
| 1506 |
|
| 1507 |
scores = scores + mask
|
| 1508 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1509 |
elif self.state == FSMState.KEYSCALE_VALUE:
|
| 1510 |
# Check if field is user-provided and we haven't started injecting yet
|
| 1511 |
if self.user_provided_metadata["keyscale"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
|
|
@@ -1561,7 +1928,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1561 |
mask[0, value_tokens[0]] = 0
|
| 1562 |
scores = scores + mask
|
| 1563 |
return scores
|
| 1564 |
-
|
| 1565 |
# If we haven't started generating language yet (empty accumulated_token_ids),
|
| 1566 |
# select the top-1 probability token from all valid first tokens
|
| 1567 |
if not self.accumulated_token_ids:
|
|
@@ -1780,6 +2147,20 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1780 |
if token_str.strip().isdigit():
|
| 1781 |
self.accumulated_value += token_str.strip()
|
| 1782 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1783 |
elif self.state == FSMState.CAPTION_VALUE:
|
| 1784 |
# Track token count for 512 limit
|
| 1785 |
self.caption_token_count += 1
|
|
@@ -1787,8 +2168,9 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1787 |
# Accumulate caption text
|
| 1788 |
self.accumulated_value += token_str
|
| 1789 |
|
| 1790 |
-
# Track if this token
|
| 1791 |
-
|
|
|
|
| 1792 |
# Mark that we need to check next token for field transition
|
| 1793 |
self.caption_after_newline = True
|
| 1794 |
else:
|
|
@@ -1813,6 +2195,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1813 |
# Map field name to VALUE state
|
| 1814 |
field_name_to_value_state = {
|
| 1815 |
"duration": FSMState.DURATION_VALUE,
|
|
|
|
| 1816 |
"keyscale": FSMState.KEYSCALE_VALUE,
|
| 1817 |
"language": FSMState.LANGUAGE_VALUE,
|
| 1818 |
"timesignature": FSMState.TIMESIG_VALUE,
|
|
|
|
| 35 |
DURATION_NAME = auto() # Generating "duration: "
|
| 36 |
DURATION_VALUE = auto() # Generating numeric value 10-600
|
| 37 |
NEWLINE_AFTER_DURATION = auto()
|
| 38 |
+
GENRES_NAME = auto() # Generating "genres: "
|
| 39 |
+
GENRES_VALUE = auto() # Generating any non-empty string
|
| 40 |
+
NEWLINE_AFTER_GENRES = auto()
|
| 41 |
KEYSCALE_NAME = auto() # Generating "keyscale: "
|
| 42 |
KEYSCALE_VALUE = auto() # Generating keyscale pattern
|
| 43 |
NEWLINE_AFTER_KEYSCALE = auto()
|
|
|
|
| 77 |
tokenizer: AutoTokenizer,
|
| 78 |
enabled: bool = True,
|
| 79 |
debug: bool = False,
|
| 80 |
+
genres_vocab_path: Optional[str] = None,
|
| 81 |
+
skip_genres: bool = True,
|
| 82 |
):
|
| 83 |
"""
|
| 84 |
Initialize the constrained logits processor.
|
|
|
|
| 93 |
self.tokenizer = tokenizer
|
| 94 |
self.enabled = enabled
|
| 95 |
self.debug = debug
|
| 96 |
+
self.skip_genres = skip_genres
|
| 97 |
self.skip_caption = False # Set to True to skip caption field generation
|
| 98 |
self.skip_language = False # Set to True to skip language field generation
|
| 99 |
self.caption: Optional[str] = None # Set via update_caption() before each generation
|
|
|
|
| 108 |
"keyscale": None,
|
| 109 |
"language": None,
|
| 110 |
"timesignature": None,
|
| 111 |
+
"genres": None,
|
| 112 |
}
|
| 113 |
|
| 114 |
# Temperature settings for different generation phases (set per-generation)
|
|
|
|
| 149 |
# Pre-compute token IDs for efficiency
|
| 150 |
self._precompute_tokens()
|
| 151 |
|
| 152 |
+
# Genres vocabulary for constrained decoding
|
| 153 |
+
self.genres_vocab_path = genres_vocab_path or os.path.join(
|
| 154 |
+
os.path.dirname(os.path.abspath(__file__)), "genres_vocab.txt"
|
| 155 |
+
)
|
| 156 |
+
self.genres_vocab: List[str] = [] # Full vocab
|
| 157 |
+
self.genres_vocab_mtime: float = 0.0
|
| 158 |
+
self.genres_trie: Dict = {} # Trie for full vocab (fallback)
|
| 159 |
+
self.caption_genres_trie: Dict = {} # Trie for caption-matched genres (priority)
|
| 160 |
+
self.caption_matched_genres: List[str] = [] # Genres matched from caption
|
| 161 |
+
|
| 162 |
self._char_to_tokens: Dict[str, set] = {} # Precomputed char -> token IDs mapping
|
| 163 |
|
| 164 |
# Precompute token mappings once (O(vocab_size), runs once at init)
|
|
|
|
| 202 |
# Build language prefix tree (similar to keyscale but for language codes)
|
| 203 |
self.language_prefix_tree = self._build_language_prefix_tree()
|
| 204 |
|
| 205 |
+
self._load_genres_vocab()
|
| 206 |
+
|
| 207 |
# Fixed strings for each state
|
| 208 |
# IMPORTANT: Do NOT include trailing space after colon - tokenizer will handle spacing
|
| 209 |
# All matching should be done at token level, not string level
|
|
|
|
| 214 |
FSMState.BPM_NAME: "bpm:",
|
| 215 |
FSMState.CAPTION_NAME: "caption:",
|
| 216 |
FSMState.DURATION_NAME: "duration:",
|
| 217 |
+
FSMState.GENRES_NAME: "genres:",
|
| 218 |
FSMState.KEYSCALE_NAME: "keyscale:",
|
| 219 |
FSMState.LANGUAGE_NAME: "language:",
|
| 220 |
FSMState.TIMESIG_NAME: "timesignature:",
|
|
|
|
| 230 |
even if the field is user-provided (we still need to generate the field name).
|
| 231 |
|
| 232 |
Args:
|
| 233 |
+
current_field: Current field name ("bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature")
|
| 234 |
|
| 235 |
Returns:
|
| 236 |
Next FSMState (NAME state of next field), or THINK_END_TAG if no more fields
|
| 237 |
"""
|
| 238 |
# New field order: bpm -> caption -> duration -> keyscale -> language -> timesignature
|
| 239 |
+
# genres is optional and can be skipped
|
| 240 |
+
field_order = ["bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature"]
|
| 241 |
field_to_state = {
|
| 242 |
"bpm": FSMState.BPM_NAME,
|
| 243 |
"caption": FSMState.CAPTION_NAME,
|
| 244 |
"duration": FSMState.DURATION_NAME,
|
| 245 |
+
"genres": FSMState.GENRES_NAME,
|
| 246 |
"keyscale": FSMState.KEYSCALE_NAME,
|
| 247 |
"language": FSMState.LANGUAGE_NAME,
|
| 248 |
"timesignature": FSMState.TIMESIG_NAME,
|
|
|
|
| 256 |
# Find next field in order
|
| 257 |
for i in range(current_idx + 1, len(field_order)):
|
| 258 |
field = field_order[i]
|
| 259 |
+
|
| 260 |
+
# Skip fields based on flags
|
| 261 |
+
if field == "genres" and self.skip_genres:
|
| 262 |
+
continue
|
| 263 |
if field == "caption" and self.skip_caption:
|
| 264 |
continue
|
| 265 |
if field == "language" and self.skip_language:
|
|
|
|
| 281 |
}
|
| 282 |
|
| 283 |
# Build transitions for all fields (even if user-provided, we still need to generate field name)
|
| 284 |
+
# Field order: bpm -> caption -> duration -> genres -> keyscale -> language -> timesignature
|
| 285 |
|
| 286 |
# BPM field: NAME -> VALUE -> next field (caption or duration)
|
| 287 |
self.next_state[FSMState.BPM_NAME] = FSMState.BPM_VALUE
|
|
|
|
| 295 |
# Duration field: NAME -> VALUE -> next field
|
| 296 |
self.next_state[FSMState.DURATION_NAME] = FSMState.DURATION_VALUE
|
| 297 |
self.next_state[FSMState.DURATION_VALUE] = self._get_next_field_state("duration")
|
| 298 |
+
|
| 299 |
+
# Genres field (only if not skipped): NAME -> VALUE -> next field
|
| 300 |
+
if not self.skip_genres:
|
| 301 |
+
self.next_state[FSMState.GENRES_NAME] = FSMState.GENRES_VALUE
|
| 302 |
+
self.next_state[FSMState.GENRES_VALUE] = self._get_next_field_state("genres")
|
| 303 |
|
| 304 |
# Keyscale field: NAME -> VALUE -> next field (language or timesignature)
|
| 305 |
self.next_state[FSMState.KEYSCALE_NAME] = FSMState.KEYSCALE_VALUE
|
|
|
|
| 313 |
# Timesignature field: NAME -> VALUE -> THINK_END_TAG
|
| 314 |
self.next_state[FSMState.TIMESIG_NAME] = FSMState.TIMESIG_VALUE
|
| 315 |
self.next_state[FSMState.TIMESIG_VALUE] = FSMState.THINK_END_TAG
|
| 316 |
+
|
| 317 |
+
def set_skip_genres(self, skip: bool):
|
| 318 |
+
"""Set whether to skip genres generation and rebuild state transitions."""
|
| 319 |
+
self.skip_genres = skip
|
| 320 |
+
self._build_state_transitions()
|
| 321 |
|
| 322 |
def set_skip_caption(self, skip: bool):
|
| 323 |
"""Set whether to skip caption generation and rebuild state transitions."""
|
|
|
|
| 400 |
- "keyscale": Optional[str] - e.g., "G major"
|
| 401 |
- "language": Optional[str] - e.g., "en"
|
| 402 |
- "timesignature": Optional[str] - e.g., "4"
|
| 403 |
+
- "genres": Optional[str] - e.g., "Pop Rock"
|
| 404 |
If None, clears all user-provided metadata.
|
| 405 |
"""
|
| 406 |
if metadata is None:
|
| 407 |
metadata = {}
|
| 408 |
|
| 409 |
# Update user-provided metadata
|
| 410 |
+
for field in ["bpm", "caption", "duration", "keyscale", "language", "timesignature", "genres"]:
|
| 411 |
if field in metadata:
|
| 412 |
self.user_provided_metadata[field] = metadata[field]
|
| 413 |
else:
|
|
|
|
| 472 |
|
| 473 |
# Vocab size
|
| 474 |
self.vocab_size = len(self.tokenizer)
|
| 475 |
+
|
| 476 |
+
# Comma token for multi-genre support
|
| 477 |
+
comma_tokens = self.tokenizer.encode(",", add_special_tokens=False)
|
| 478 |
+
self.comma_token = comma_tokens[-1] if comma_tokens else None
|
| 479 |
|
| 480 |
# EOS token for duration-constrained codes generation
|
| 481 |
self.eos_token_id = self.tokenizer.eos_token_id
|
|
|
|
| 570 |
|
| 571 |
if self.debug:
|
| 572 |
logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens")
|
| 573 |
+
|
| 574 |
def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
|
| 575 |
"""
|
| 576 |
Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
|
|
|
|
| 847 |
|
| 848 |
print("=" * 60)
|
| 849 |
|
| 850 |
+
|
| 851 |
+
def _load_genres_vocab(self):
|
| 852 |
+
"""
|
| 853 |
+
Load genres vocabulary from file. Supports hot reload by checking file mtime.
|
| 854 |
+
File format: one genre per line, lines starting with # are comments.
|
| 855 |
+
"""
|
| 856 |
+
if not os.path.exists(self.genres_vocab_path):
|
| 857 |
+
if self.debug:
|
| 858 |
+
logger.debug(f"Genres vocab file not found: {self.genres_vocab_path}")
|
| 859 |
+
return
|
| 860 |
+
|
| 861 |
+
try:
|
| 862 |
+
mtime = os.path.getmtime(self.genres_vocab_path)
|
| 863 |
+
if mtime <= self.genres_vocab_mtime:
|
| 864 |
+
return # File hasn't changed
|
| 865 |
+
|
| 866 |
+
with open(self.genres_vocab_path, 'r', encoding='utf-8') as f:
|
| 867 |
+
genres = []
|
| 868 |
+
for line in f:
|
| 869 |
+
line = line.strip()
|
| 870 |
+
if line and not line.startswith('#'):
|
| 871 |
+
genres.append(line.lower())
|
| 872 |
+
|
| 873 |
+
self.genres_vocab = genres
|
| 874 |
+
self.genres_vocab_mtime = mtime
|
| 875 |
+
self._build_genres_trie()
|
| 876 |
+
|
| 877 |
+
if self.debug:
|
| 878 |
+
logger.debug(f"Loaded {len(self.genres_vocab)} genres from {self.genres_vocab_path}")
|
| 879 |
+
except Exception as e:
|
| 880 |
+
logger.warning(f"Failed to load genres vocab: {e}")
|
| 881 |
+
|
| 882 |
+
def _build_genres_trie(self):
|
| 883 |
+
"""
|
| 884 |
+
Build a trie (prefix tree) from genres vocabulary for efficient prefix matching.
|
| 885 |
+
Each node is a dict with:
|
| 886 |
+
- '_end': True if this node represents a complete genre
|
| 887 |
+
- other keys: next characters in the trie
|
| 888 |
+
"""
|
| 889 |
+
self.genres_trie = {}
|
| 890 |
+
|
| 891 |
+
for genre in self.genres_vocab:
|
| 892 |
+
node = self.genres_trie
|
| 893 |
+
for char in genre:
|
| 894 |
+
if char not in node:
|
| 895 |
+
node[char] = {}
|
| 896 |
+
node = node[char]
|
| 897 |
+
node['_end'] = True # Mark end of a complete genre
|
| 898 |
+
|
| 899 |
+
if self.debug:
|
| 900 |
+
logger.debug(f"Built genres trie with {len(self.genres_vocab)} entries")
|
| 901 |
+
|
| 902 |
+
def _extract_caption_genres(self, caption: str):
|
| 903 |
+
"""
|
| 904 |
+
Extract genres from the user's caption that match entries in the vocabulary.
|
| 905 |
+
This creates a smaller trie for faster and more relevant genre generation.
|
| 906 |
+
|
| 907 |
+
Strategy (optimized - O(words * max_genre_len) instead of O(vocab_size)):
|
| 908 |
+
1. Extract words/phrases from caption
|
| 909 |
+
2. For each word, use trie to find all vocab entries that START with this word
|
| 910 |
+
3. Build a separate trie from matched genres
|
| 911 |
+
"""
|
| 912 |
+
if not caption or not self.genres_vocab:
|
| 913 |
+
return
|
| 914 |
+
|
| 915 |
+
caption_lower = caption.lower()
|
| 916 |
+
matched_genres = set()
|
| 917 |
+
|
| 918 |
+
# Extract words from caption (split by common delimiters)
|
| 919 |
+
import re
|
| 920 |
+
words = re.split(r'[,\s\-_/\\|]+', caption_lower)
|
| 921 |
+
words = [w.strip() for w in words if w.strip() and len(w.strip()) >= 2]
|
| 922 |
+
|
| 923 |
+
# For each word, find genres in trie that start with this word
|
| 924 |
+
for word in words:
|
| 925 |
+
# Find all genres starting with this word using trie traversal
|
| 926 |
+
node = self._get_genres_trie_node(word)
|
| 927 |
+
if node is not None:
|
| 928 |
+
# Collect all complete genres under this node
|
| 929 |
+
self._collect_complete_genres(node, word, matched_genres)
|
| 930 |
+
|
| 931 |
+
# Also check if any word appears as a substring in short genres (< 20 chars)
|
| 932 |
+
# This is a quick check for common single-word genres
|
| 933 |
+
genres_set = set(self.genres_vocab)
|
| 934 |
+
for word in words:
|
| 935 |
+
if word in genres_set:
|
| 936 |
+
matched_genres.add(word)
|
| 937 |
+
|
| 938 |
+
if not matched_genres:
|
| 939 |
+
if self.debug:
|
| 940 |
+
logger.debug(f"No genres matched in caption, using full vocab")
|
| 941 |
+
return
|
| 942 |
+
|
| 943 |
+
# Build a trie from matched genres
|
| 944 |
+
self.caption_matched_genres = list(matched_genres)
|
| 945 |
+
self.caption_genres_trie = {}
|
| 946 |
+
|
| 947 |
+
for genre in matched_genres:
|
| 948 |
+
node = self.caption_genres_trie
|
| 949 |
+
for char in genre:
|
| 950 |
+
if char not in node:
|
| 951 |
+
node[char] = {}
|
| 952 |
+
node = node[char]
|
| 953 |
+
node['_end'] = True
|
| 954 |
+
|
| 955 |
+
if self.debug:
|
| 956 |
+
logger.debug(f"Matched {len(matched_genres)} genres from caption: {list(matched_genres)[:5]}...")
|
| 957 |
+
|
| 958 |
+
def _collect_complete_genres(self, node: Dict, prefix: str, result: set, max_depth: int = 50):
|
| 959 |
+
"""
|
| 960 |
+
Recursively collect all complete genres under a trie node.
|
| 961 |
+
Limited depth to avoid too many matches.
|
| 962 |
+
"""
|
| 963 |
+
if max_depth <= 0:
|
| 964 |
+
return
|
| 965 |
+
|
| 966 |
+
if node.get('_end', False):
|
| 967 |
+
result.add(prefix)
|
| 968 |
+
|
| 969 |
+
# Limit total collected genres to avoid slowdown
|
| 970 |
+
if len(result) >= 100:
|
| 971 |
+
return
|
| 972 |
+
|
| 973 |
+
for char, child_node in node.items():
|
| 974 |
+
if char not in ('_end', '_tokens'):
|
| 975 |
+
self._collect_complete_genres(child_node, prefix + char, result, max_depth - 1)
|
| 976 |
+
|
| 977 |
def _precompute_char_token_mapping(self):
|
| 978 |
"""
|
| 979 |
Precompute mapping from characters to token IDs and token decoded texts.
|
|
|
|
| 1025 |
|
| 1026 |
if self.debug:
|
| 1027 |
logger.debug(f"Precomputed char->token mapping for {len(self._char_to_tokens)} unique characters")
|
| 1028 |
+
|
| 1029 |
+
def _try_reload_genres_vocab(self):
|
| 1030 |
+
"""Check if genres vocab file has been updated and reload if necessary."""
|
| 1031 |
+
if not os.path.exists(self.genres_vocab_path):
|
| 1032 |
+
return
|
| 1033 |
+
|
| 1034 |
+
try:
|
| 1035 |
+
mtime = os.path.getmtime(self.genres_vocab_path)
|
| 1036 |
+
if mtime > self.genres_vocab_mtime:
|
| 1037 |
+
self._load_genres_vocab()
|
| 1038 |
+
except Exception:
|
| 1039 |
+
pass # Ignore errors during hot reload check
|
| 1040 |
+
|
| 1041 |
+
def _get_genres_trie_node(self, prefix: str) -> Optional[Dict]:
|
| 1042 |
+
"""
|
| 1043 |
+
Get the trie node for a given prefix.
|
| 1044 |
+
Returns None if the prefix is not valid (no genres start with this prefix).
|
| 1045 |
+
"""
|
| 1046 |
+
node = self.genres_trie
|
| 1047 |
+
for char in prefix.lower():
|
| 1048 |
+
if char not in node:
|
| 1049 |
+
return None
|
| 1050 |
+
node = node[char]
|
| 1051 |
+
return node
|
| 1052 |
+
|
| 1053 |
+
def _is_complete_genre(self, text: str) -> bool:
|
| 1054 |
+
"""Check if the given text is a complete genre in the vocabulary."""
|
| 1055 |
+
node = self._get_genres_trie_node(text.strip())
|
| 1056 |
+
return node is not None and node.get('_end', False)
|
| 1057 |
+
|
| 1058 |
def _get_trie_node_from_trie(self, trie: Dict, prefix: str) -> Optional[Dict]:
|
| 1059 |
"""Get a trie node from a specific trie (helper for caption vs full trie)."""
|
| 1060 |
node = trie
|
|
|
|
| 1064 |
node = node[char]
|
| 1065 |
return node
|
| 1066 |
|
| 1067 |
+
def _get_allowed_genres_tokens(self) -> List[int]:
|
| 1068 |
+
"""
|
| 1069 |
+
Get allowed tokens for genres field based on trie matching.
|
| 1070 |
+
|
| 1071 |
+
The entire genres string (including commas) must match a complete entry in the vocab.
|
| 1072 |
+
For example, if vocab contains "pop, rock, jazz", the generated string must exactly
|
| 1073 |
+
match that entry - we don't treat commas as separators for individual genres.
|
| 1074 |
+
|
| 1075 |
+
Strategy:
|
| 1076 |
+
1. If caption-matched genres exist, use that smaller trie first (faster + more relevant)
|
| 1077 |
+
2. If no caption matches or prefix not in caption trie, fallback to full vocab trie
|
| 1078 |
+
3. Get valid next characters from current trie node
|
| 1079 |
+
4. For each candidate token, verify the full decoded text forms a valid trie prefix
|
| 1080 |
+
"""
|
| 1081 |
+
if not self.genres_vocab:
|
| 1082 |
+
# No vocab loaded, allow all except newline if empty
|
| 1083 |
+
return []
|
| 1084 |
+
|
| 1085 |
+
# Use the full accumulated value (don't split by comma - treat as single entry)
|
| 1086 |
+
accumulated = self.accumulated_value.lower()
|
| 1087 |
+
current_genre_prefix = accumulated.strip()
|
| 1088 |
+
|
| 1089 |
+
# Determine which trie to use: caption-matched (priority) or full vocab (fallback)
|
| 1090 |
+
use_caption_trie = False
|
| 1091 |
+
current_node = None
|
| 1092 |
+
|
| 1093 |
+
# Try caption-matched trie first if available
|
| 1094 |
+
if self.caption_genres_trie:
|
| 1095 |
+
if current_genre_prefix == "":
|
| 1096 |
+
current_node = self.caption_genres_trie
|
| 1097 |
+
use_caption_trie = True
|
| 1098 |
+
else:
|
| 1099 |
+
current_node = self._get_trie_node_from_trie(self.caption_genres_trie, current_genre_prefix)
|
| 1100 |
+
if current_node is not None:
|
| 1101 |
+
use_caption_trie = True
|
| 1102 |
+
|
| 1103 |
+
# Fallback to full vocab trie
|
| 1104 |
+
if current_node is None:
|
| 1105 |
+
if current_genre_prefix == "":
|
| 1106 |
+
current_node = self.genres_trie
|
| 1107 |
+
else:
|
| 1108 |
+
current_node = self._get_genres_trie_node(current_genre_prefix)
|
| 1109 |
+
|
| 1110 |
+
if current_node is None:
|
| 1111 |
+
# Invalid prefix, force newline to end
|
| 1112 |
+
if self.newline_token:
|
| 1113 |
+
return [self.newline_token]
|
| 1114 |
+
return []
|
| 1115 |
+
|
| 1116 |
+
# Get valid next characters from trie node
|
| 1117 |
+
valid_next_chars = set(k for k in current_node.keys() if k not in ('_end', '_tokens'))
|
| 1118 |
+
|
| 1119 |
+
# If current value is a complete genre, allow newline to end
|
| 1120 |
+
is_complete = current_node.get('_end', False)
|
| 1121 |
+
|
| 1122 |
+
if not valid_next_chars:
|
| 1123 |
+
# No more characters to match, only allow newline if complete
|
| 1124 |
+
allowed = set()
|
| 1125 |
+
if is_complete and self.newline_token:
|
| 1126 |
+
allowed.add(self.newline_token)
|
| 1127 |
+
return list(allowed)
|
| 1128 |
+
|
| 1129 |
+
# Collect candidate tokens based on first character
|
| 1130 |
+
candidate_tokens = set()
|
| 1131 |
+
for char in valid_next_chars:
|
| 1132 |
+
if char in self._char_to_tokens:
|
| 1133 |
+
candidate_tokens.update(self._char_to_tokens[char])
|
| 1134 |
+
|
| 1135 |
+
# Select the appropriate trie for validation
|
| 1136 |
+
active_trie = self.caption_genres_trie if use_caption_trie else self.genres_trie
|
| 1137 |
+
|
| 1138 |
+
# Validate each candidate token: check if prefix + decoded_token is a valid trie prefix
|
| 1139 |
+
allowed = set()
|
| 1140 |
+
for token_id in candidate_tokens:
|
| 1141 |
+
# Use precomputed decoded text (already normalized)
|
| 1142 |
+
decoded_normalized = self._token_to_text.get(token_id, "")
|
| 1143 |
+
|
| 1144 |
+
if not decoded_normalized or not decoded_normalized.strip():
|
| 1145 |
+
# Token decodes to empty or only whitespace - allow if space/comma is a valid next char
|
| 1146 |
+
if ' ' in valid_next_chars or ',' in valid_next_chars:
|
| 1147 |
+
allowed.add(token_id)
|
| 1148 |
+
continue
|
| 1149 |
+
|
| 1150 |
+
# Build new prefix by appending decoded token
|
| 1151 |
+
# Handle space-prefixed tokens (e.g., " rock" from "pop rock")
|
| 1152 |
+
if decoded_normalized.startswith(' ') or decoded_normalized.startswith(','):
|
| 1153 |
+
# Token has leading space/comma - append directly
|
| 1154 |
+
new_prefix = current_genre_prefix + decoded_normalized
|
| 1155 |
+
else:
|
| 1156 |
+
new_prefix = current_genre_prefix + decoded_normalized
|
| 1157 |
+
|
| 1158 |
+
# Check if new_prefix is a valid prefix in the active trie
|
| 1159 |
+
new_node = self._get_trie_node_from_trie(active_trie, new_prefix)
|
| 1160 |
+
if new_node is not None:
|
| 1161 |
+
allowed.add(token_id)
|
| 1162 |
+
|
| 1163 |
+
# If current value is a complete genre, also allow newline
|
| 1164 |
+
if is_complete and self.newline_token:
|
| 1165 |
+
allowed.add(self.newline_token)
|
| 1166 |
+
|
| 1167 |
+
return list(allowed)
|
| 1168 |
+
|
| 1169 |
def reset(self):
|
| 1170 |
"""Reset the processor state for a new generation."""
|
| 1171 |
self.state = FSMState.THINK_TAG
|
|
|
|
| 1357 |
|
| 1358 |
return newline_prob > max_digit_prob
|
| 1359 |
|
| 1360 |
+
|
| 1361 |
+
def _should_end_text_field(self, logits: torch.Tensor) -> bool:
|
| 1362 |
+
"""
|
| 1363 |
+
Determine if we should end a text field (genres).
|
| 1364 |
+
Returns True if P(newline) > P(any other token) AND we have some content.
|
| 1365 |
+
"""
|
| 1366 |
+
if not self.accumulated_value.strip():
|
| 1367 |
+
return False # Need at least some content
|
| 1368 |
+
|
| 1369 |
+
probs = torch.softmax(logits, dim=-1)
|
| 1370 |
+
newline_prob = probs[0, self.newline_token].item() if self.newline_token else 0
|
| 1371 |
+
|
| 1372 |
+
# Get max probability among non-newline tokens
|
| 1373 |
+
masked_probs = probs.clone()
|
| 1374 |
+
if self.newline_token:
|
| 1375 |
+
masked_probs[0, self.newline_token] = 0
|
| 1376 |
+
max_other_prob = masked_probs[0].max().item()
|
| 1377 |
+
|
| 1378 |
+
return newline_prob > max_other_prob
|
| 1379 |
+
|
| 1380 |
def _get_allowed_keyscale_tokens(self) -> List[int]:
|
| 1381 |
"""
|
| 1382 |
Get allowed tokens for keyscale field using the precomputed prefix tree.
|
|
|
|
| 1581 |
"keyscale": "keyscale: ",
|
| 1582 |
"language": "language: ",
|
| 1583 |
"timesignature": "timesignature: ",
|
| 1584 |
+
"genres": "genres: ",
|
| 1585 |
}
|
| 1586 |
prefix = field_to_prefix[field_name]
|
| 1587 |
full_text = f"{prefix}{value}\n"
|
|
|
|
| 1745 |
# Allow free generation (no constraints) so LM can generate field name naturally
|
| 1746 |
return scores
|
| 1747 |
else:
|
| 1748 |
+
# It's indentation, continue caption (don't transition!)
|
| 1749 |
self.caption_after_newline = False
|
| 1750 |
+
# Continue normal caption generation
|
| 1751 |
+
# Fall through to caption constraints below
|
| 1752 |
+
|
| 1753 |
# If caption is ending (LM generating next field name), allow free generation
|
| 1754 |
# and track the field name until we see colon
|
| 1755 |
if self.caption_ending:
|
|
|
|
| 1824 |
mask[0, self.newline_token] = 0
|
| 1825 |
|
| 1826 |
scores = scores + mask
|
| 1827 |
+
|
| 1828 |
+
elif self.state == FSMState.GENRES_VALUE:
|
| 1829 |
+
# Check if field is user-provided and we haven't started injecting yet
|
| 1830 |
+
if self.user_provided_metadata["genres"] is not None and not self.user_field_token_queue and not self.accumulated_value:
|
| 1831 |
+
# Initialize token queue with field value tokens (value + newline)
|
| 1832 |
+
value = self.user_provided_metadata["genres"]
|
| 1833 |
+
value_text = f" {value}\n"
|
| 1834 |
+
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
|
| 1835 |
+
if value_tokens:
|
| 1836 |
+
self.user_field_token_queue = value_tokens
|
| 1837 |
+
self.current_user_field = "genres"
|
| 1838 |
+
# Inject first token
|
| 1839 |
+
mask[0, value_tokens[0]] = 0
|
| 1840 |
+
scores = scores + mask
|
| 1841 |
+
return scores
|
| 1842 |
+
|
| 1843 |
+
# Try to hot-reload genres vocab if file has changed
|
| 1844 |
+
self._try_reload_genres_vocab()
|
| 1845 |
+
|
| 1846 |
+
# Get allowed tokens based on genres vocabulary
|
| 1847 |
+
allowed = self._get_allowed_genres_tokens()
|
| 1848 |
+
|
| 1849 |
+
if allowed:
|
| 1850 |
+
# Use vocabulary-constrained decoding
|
| 1851 |
+
for t in allowed:
|
| 1852 |
+
mask[0, t] = 0
|
| 1853 |
+
scores = scores + mask
|
| 1854 |
+
elif self.genres_vocab:
|
| 1855 |
+
# Vocab is loaded but no valid continuation found
|
| 1856 |
+
# Force newline to end the field
|
| 1857 |
+
if self.newline_token:
|
| 1858 |
+
mask[0, self.newline_token] = 0
|
| 1859 |
+
if self.debug:
|
| 1860 |
+
logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
|
| 1861 |
+
scores = scores + mask
|
| 1862 |
+
else:
|
| 1863 |
+
# Fallback: no vocab loaded, use probability-based ending
|
| 1864 |
+
if self._should_end_text_field(scores):
|
| 1865 |
+
if self.newline_token:
|
| 1866 |
+
mask[0, self.newline_token] = 0
|
| 1867 |
+
self._transition_to_next_state()
|
| 1868 |
+
scores = scores + mask
|
| 1869 |
+
else:
|
| 1870 |
+
# Allow any token except newline if we don't have content yet
|
| 1871 |
+
if not self.accumulated_value.strip():
|
| 1872 |
+
if self.newline_token:
|
| 1873 |
+
scores[0, self.newline_token] = float('-inf')
|
| 1874 |
+
# Otherwise, don't constrain (fallback behavior)
|
| 1875 |
+
|
| 1876 |
elif self.state == FSMState.KEYSCALE_VALUE:
|
| 1877 |
# Check if field is user-provided and we haven't started injecting yet
|
| 1878 |
if self.user_provided_metadata["keyscale"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
|
|
|
|
| 1928 |
mask[0, value_tokens[0]] = 0
|
| 1929 |
scores = scores + mask
|
| 1930 |
return scores
|
| 1931 |
+
|
| 1932 |
# If we haven't started generating language yet (empty accumulated_token_ids),
|
| 1933 |
# select the top-1 probability token from all valid first tokens
|
| 1934 |
if not self.accumulated_token_ids:
|
|
|
|
| 2147 |
if token_str.strip().isdigit():
|
| 2148 |
self.accumulated_value += token_str.strip()
|
| 2149 |
|
| 2150 |
+
elif self.state == FSMState.GENRES_VALUE:
|
| 2151 |
+
if generated_token_id == self.newline_token:
|
| 2152 |
+
# Newline ends the field
|
| 2153 |
+
self._transition_to_next_state()
|
| 2154 |
+
# IMPORTANT: After state transition, if new state is a fixed_strings state,
|
| 2155 |
+
# we should NOT update position_in_state with the newline token length,
|
| 2156 |
+
# because that token belongs to the old state, not the new state.
|
| 2157 |
+
# Return early to avoid the fixed_strings update logic below.
|
| 2158 |
+
if self.state in self.fixed_strings:
|
| 2159 |
+
return
|
| 2160 |
+
else:
|
| 2161 |
+
# Genres still uses string-based trie, so keep accumulated_value
|
| 2162 |
+
self.accumulated_value += token_str
|
| 2163 |
+
|
| 2164 |
elif self.state == FSMState.CAPTION_VALUE:
|
| 2165 |
# Track token count for 512 limit
|
| 2166 |
self.caption_token_count += 1
|
|
|
|
| 2168 |
# Accumulate caption text
|
| 2169 |
self.accumulated_value += token_str
|
| 2170 |
|
| 2171 |
+
# Track if this token contains a newline (for transition detection)
|
| 2172 |
+
# Token may be '\n' alone or combined with other chars like '.\n'
|
| 2173 |
+
if '\n' in token_str:
|
| 2174 |
# Mark that we need to check next token for field transition
|
| 2175 |
self.caption_after_newline = True
|
| 2176 |
else:
|
|
|
|
| 2195 |
# Map field name to VALUE state
|
| 2196 |
field_name_to_value_state = {
|
| 2197 |
"duration": FSMState.DURATION_VALUE,
|
| 2198 |
+
"genres": FSMState.GENRES_VALUE,
|
| 2199 |
"keyscale": FSMState.KEYSCALE_VALUE,
|
| 2200 |
"language": FSMState.LANGUAGE_VALUE,
|
| 2201 |
"timesignature": FSMState.TIMESIG_VALUE,
|
acestep/gradio_ui.py
CHANGED
|
@@ -607,6 +607,12 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 607 |
info="Generate language in CoT (chain-of-thought)",
|
| 608 |
scale=1,
|
| 609 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
|
| 611 |
with gr.Row():
|
| 612 |
audio_cover_strength = gr.Slider(
|
|
@@ -618,11 +624,21 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 618 |
info="Control how many denoising steps use LM-generated codes",
|
| 619 |
scale=1,
|
| 620 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
output_alignment_preference = gr.Checkbox(
|
| 622 |
label="Output Attention Focus Score (disabled)",
|
| 623 |
value=False,
|
| 624 |
info="Output attention focus score analysis",
|
| 625 |
interactive=False,
|
|
|
|
| 626 |
scale=1,
|
| 627 |
)
|
| 628 |
|
|
@@ -632,10 +648,14 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 632 |
think_checkbox = gr.Checkbox(
|
| 633 |
label="Think",
|
| 634 |
value=True,
|
| 635 |
-
info="Enable llm generate hints",
|
| 636 |
scale=1,
|
| 637 |
)
|
| 638 |
generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg", interactive=generate_btn_interactive, scale=10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
|
| 640 |
return {
|
| 641 |
"service_config_accordion": service_config_accordion,
|
|
@@ -695,6 +715,9 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 695 |
"output_alignment_preference": output_alignment_preference,
|
| 696 |
"think_checkbox": think_checkbox,
|
| 697 |
"generate_btn": generate_btn,
|
|
|
|
|
|
|
|
|
|
| 698 |
}
|
| 699 |
|
| 700 |
|
|
@@ -720,7 +743,7 @@ def create_results_section(dit_handler) -> dict:
|
|
| 720 |
)
|
| 721 |
with gr.Row(equal_height=True):
|
| 722 |
send_to_src_btn_1 = gr.Button(
|
| 723 |
-
"Send To Src Audio",
|
| 724 |
variant="secondary",
|
| 725 |
size="sm",
|
| 726 |
scale=1
|
|
@@ -731,6 +754,17 @@ def create_results_section(dit_handler) -> dict:
|
|
| 731 |
size="sm",
|
| 732 |
scale=1
|
| 733 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
with gr.Column():
|
| 735 |
generated_audio_2 = gr.Audio(
|
| 736 |
label="🎵 Generated Music (Sample 2)",
|
|
@@ -739,7 +773,7 @@ def create_results_section(dit_handler) -> dict:
|
|
| 739 |
)
|
| 740 |
with gr.Row(equal_height=True):
|
| 741 |
send_to_src_btn_2 = gr.Button(
|
| 742 |
-
"Send To Src Audio",
|
| 743 |
variant="secondary",
|
| 744 |
size="sm",
|
| 745 |
scale=1
|
|
@@ -750,6 +784,17 @@ def create_results_section(dit_handler) -> dict:
|
|
| 750 |
size="sm",
|
| 751 |
scale=1
|
| 752 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
|
| 754 |
with gr.Accordion("📁 Batch Results & Generation Details", open=False):
|
| 755 |
generated_audio_batch = gr.File(
|
|
@@ -780,6 +825,10 @@ def create_results_section(dit_handler) -> dict:
|
|
| 780 |
"send_to_src_btn_2": send_to_src_btn_2,
|
| 781 |
"save_btn_1": save_btn_1,
|
| 782 |
"save_btn_2": save_btn_2,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 783 |
"generated_audio_batch": generated_audio_batch,
|
| 784 |
"generation_info": generation_info,
|
| 785 |
"align_score_1": align_score_1,
|
|
@@ -1042,11 +1091,12 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1042 |
gr.Warning(f"Error loading example: {str(e)}")
|
| 1043 |
return "", "", True, None, None, "", "", ""
|
| 1044 |
|
| 1045 |
-
def sample_example_smart(task_type: str):
|
| 1046 |
"""Smart sample function that uses LM if initialized, otherwise falls back to examples
|
| 1047 |
|
| 1048 |
Args:
|
| 1049 |
task_type: The task type (e.g., "text2music")
|
|
|
|
| 1050 |
|
| 1051 |
Returns:
|
| 1052 |
Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
|
|
@@ -1060,6 +1110,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1060 |
audio_codes="NO USER INPUT",
|
| 1061 |
use_constrained_decoding=True,
|
| 1062 |
temperature=0.85,
|
|
|
|
| 1063 |
)
|
| 1064 |
|
| 1065 |
if metadata:
|
|
@@ -1094,7 +1145,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1094 |
if timesignature_value in [None, "N/A"]:
|
| 1095 |
timesignature_value = ''
|
| 1096 |
|
| 1097 |
-
gr.Info("🤖 Generated example using LM
|
| 1098 |
return caption_value, lyrics_value, think_value, bpm_value, duration_value, keyscale_value, language_value, timesignature_value
|
| 1099 |
else:
|
| 1100 |
gr.Warning("Failed to generate example using LM, falling back to examples directory")
|
|
@@ -1285,6 +1336,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1285 |
use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
|
| 1286 |
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 1287 |
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
|
|
|
|
| 1288 |
progress=gr.Progress(track_tqdm=True)
|
| 1289 |
):
|
| 1290 |
# If think is enabled (llm_dit mode) and use_cot_metas is True, generate audio codes using LM first
|
|
@@ -1342,6 +1394,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1342 |
use_cot_caption=use_cot_caption,
|
| 1343 |
use_cot_language=use_cot_language,
|
| 1344 |
is_format_caption=is_format_caption,
|
|
|
|
| 1345 |
)
|
| 1346 |
|
| 1347 |
# Store LM-generated metadata and audio codes for display
|
|
@@ -1471,7 +1524,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1471 |
generation_section["use_cot_metas"],
|
| 1472 |
generation_section["use_cot_caption"],
|
| 1473 |
generation_section["use_cot_language"],
|
| 1474 |
-
results_section["is_format_caption_state"]
|
|
|
|
| 1475 |
],
|
| 1476 |
outputs=[
|
| 1477 |
results_section["generated_audio_1"],
|
|
@@ -1720,15 +1774,18 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1720 |
|
| 1721 |
# Sample button - smart sample (uses LM if initialized, otherwise examples)
|
| 1722 |
# Need to add is_format_caption return value to sample_example_smart
|
| 1723 |
-
def sample_example_smart_with_flag(task_type: str):
|
| 1724 |
"""Wrapper for sample_example_smart that adds is_format_caption flag"""
|
| 1725 |
-
result = sample_example_smart(task_type)
|
| 1726 |
# Add True at the end to set is_format_caption
|
| 1727 |
return result + (True,)
|
| 1728 |
|
| 1729 |
generation_section["sample_btn"].click(
|
| 1730 |
fn=sample_example_smart_with_flag,
|
| 1731 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
| 1732 |
outputs=[
|
| 1733 |
generation_section["captions"],
|
| 1734 |
generation_section["lyrics"],
|
|
@@ -1743,13 +1800,14 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1743 |
)
|
| 1744 |
|
| 1745 |
# Transcribe audio codes to metadata (or generate example if empty)
|
| 1746 |
-
def transcribe_audio_codes(audio_code_string):
|
| 1747 |
"""
|
| 1748 |
Transcribe audio codes to metadata using LLM understanding.
|
| 1749 |
If audio_code_string is empty, generate a sample example instead.
|
| 1750 |
|
| 1751 |
Args:
|
| 1752 |
audio_code_string: String containing audio codes (or empty for example generation)
|
|
|
|
| 1753 |
|
| 1754 |
Returns:
|
| 1755 |
Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature)
|
|
@@ -1763,7 +1821,11 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1763 |
audio_code_string = "NO USER INPUT"
|
| 1764 |
|
| 1765 |
# Call LLM understanding
|
| 1766 |
-
metadata, status = llm_handler.understand_audio_from_codes(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1767 |
|
| 1768 |
# Extract fields for UI update
|
| 1769 |
caption = metadata.get('caption', '')
|
|
@@ -1818,7 +1880,10 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1818 |
|
| 1819 |
generation_section["transcribe_btn"].click(
|
| 1820 |
fn=transcribe_audio_codes,
|
| 1821 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
| 1822 |
outputs=[
|
| 1823 |
results_section["status_output"], # Show status
|
| 1824 |
generation_section["captions"], # Update caption field
|
|
@@ -1899,9 +1964,9 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1899 |
outputs=[generation_section["audio_uploads_accordion"]]
|
| 1900 |
)
|
| 1901 |
|
| 1902 |
-
# Save metadata handlers
|
| 1903 |
results_section["save_btn_1"].click(
|
| 1904 |
-
fn=
|
| 1905 |
inputs=[
|
| 1906 |
generation_section["task_type"],
|
| 1907 |
generation_section["captions"],
|
|
@@ -1936,11 +2001,77 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1936 |
generation_section["complete_track_classes"],
|
| 1937 |
results_section["lm_metadata_state"],
|
| 1938 |
],
|
| 1939 |
-
outputs=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1940 |
)
|
| 1941 |
|
| 1942 |
results_section["save_btn_2"].click(
|
| 1943 |
-
fn=
|
| 1944 |
inputs=[
|
| 1945 |
generation_section["task_type"],
|
| 1946 |
generation_section["captions"],
|
|
@@ -1975,7 +2106,73 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1975 |
generation_section["complete_track_classes"],
|
| 1976 |
results_section["lm_metadata_state"],
|
| 1977 |
],
|
| 1978 |
-
outputs=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1979 |
)
|
| 1980 |
|
| 1981 |
# Load metadata handler - triggered when file is uploaded via UploadButton
|
|
@@ -2017,4 +2214,152 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 2017 |
results_section["is_format_caption_state"]
|
| 2018 |
]
|
| 2019 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2020 |
|
|
|
|
| 607 |
info="Generate language in CoT (chain-of-thought)",
|
| 608 |
scale=1,
|
| 609 |
)
|
| 610 |
+
constrained_decoding_debug = gr.Checkbox(
|
| 611 |
+
label="Constrained Decoding Debug",
|
| 612 |
+
value=False,
|
| 613 |
+
info="Enable debug logging for constrained decoding (check to see detailed logs)",
|
| 614 |
+
scale=1,
|
| 615 |
+
)
|
| 616 |
|
| 617 |
with gr.Row():
|
| 618 |
audio_cover_strength = gr.Slider(
|
|
|
|
| 624 |
info="Control how many denoising steps use LM-generated codes",
|
| 625 |
scale=1,
|
| 626 |
)
|
| 627 |
+
score_scale = gr.Slider(
|
| 628 |
+
minimum=1.0,
|
| 629 |
+
maximum=200.0,
|
| 630 |
+
value=10.0,
|
| 631 |
+
step=1.0,
|
| 632 |
+
label="Quality Score Sensitivity",
|
| 633 |
+
info="Lower = more sensitive to quality differences (default: 10.0)",
|
| 634 |
+
scale=1,
|
| 635 |
+
)
|
| 636 |
output_alignment_preference = gr.Checkbox(
|
| 637 |
label="Output Attention Focus Score (disabled)",
|
| 638 |
value=False,
|
| 639 |
info="Output attention focus score analysis",
|
| 640 |
interactive=False,
|
| 641 |
+
visible=False,
|
| 642 |
scale=1,
|
| 643 |
)
|
| 644 |
|
|
|
|
| 648 |
think_checkbox = gr.Checkbox(
|
| 649 |
label="Think",
|
| 650 |
value=True,
|
|
|
|
| 651 |
scale=1,
|
| 652 |
)
|
| 653 |
generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg", interactive=generate_btn_interactive, scale=10)
|
| 654 |
+
instrumental_checkbox = gr.Checkbox(
|
| 655 |
+
label="Instrumental",
|
| 656 |
+
value=False,
|
| 657 |
+
scale=1,
|
| 658 |
+
)
|
| 659 |
|
| 660 |
return {
|
| 661 |
"service_config_accordion": service_config_accordion,
|
|
|
|
| 715 |
"output_alignment_preference": output_alignment_preference,
|
| 716 |
"think_checkbox": think_checkbox,
|
| 717 |
"generate_btn": generate_btn,
|
| 718 |
+
"instrumental_checkbox": instrumental_checkbox,
|
| 719 |
+
"constrained_decoding_debug": constrained_decoding_debug,
|
| 720 |
+
"score_scale": score_scale,
|
| 721 |
}
|
| 722 |
|
| 723 |
|
|
|
|
| 743 |
)
|
| 744 |
with gr.Row(equal_height=True):
|
| 745 |
send_to_src_btn_1 = gr.Button(
|
| 746 |
+
"🔗 Send To Src Audio",
|
| 747 |
variant="secondary",
|
| 748 |
size="sm",
|
| 749 |
scale=1
|
|
|
|
| 754 |
size="sm",
|
| 755 |
scale=1
|
| 756 |
)
|
| 757 |
+
score_btn_1 = gr.Button(
|
| 758 |
+
"📊 Score",
|
| 759 |
+
variant="secondary",
|
| 760 |
+
size="sm",
|
| 761 |
+
scale=1
|
| 762 |
+
)
|
| 763 |
+
score_display_1 = gr.Textbox(
|
| 764 |
+
label="Quality Score (Sample 1)",
|
| 765 |
+
interactive=False,
|
| 766 |
+
placeholder="Click 'Score' to calculate perplexity-based quality score"
|
| 767 |
+
)
|
| 768 |
with gr.Column():
|
| 769 |
generated_audio_2 = gr.Audio(
|
| 770 |
label="🎵 Generated Music (Sample 2)",
|
|
|
|
| 773 |
)
|
| 774 |
with gr.Row(equal_height=True):
|
| 775 |
send_to_src_btn_2 = gr.Button(
|
| 776 |
+
"🔗 Send To Src Audio",
|
| 777 |
variant="secondary",
|
| 778 |
size="sm",
|
| 779 |
scale=1
|
|
|
|
| 784 |
size="sm",
|
| 785 |
scale=1
|
| 786 |
)
|
| 787 |
+
score_btn_2 = gr.Button(
|
| 788 |
+
"📊 Score",
|
| 789 |
+
variant="secondary",
|
| 790 |
+
size="sm",
|
| 791 |
+
scale=1
|
| 792 |
+
)
|
| 793 |
+
score_display_2 = gr.Textbox(
|
| 794 |
+
label="Quality Score (Sample 2)",
|
| 795 |
+
interactive=False,
|
| 796 |
+
placeholder="Click 'Score' to calculate perplexity-based quality score"
|
| 797 |
+
)
|
| 798 |
|
| 799 |
with gr.Accordion("📁 Batch Results & Generation Details", open=False):
|
| 800 |
generated_audio_batch = gr.File(
|
|
|
|
| 825 |
"send_to_src_btn_2": send_to_src_btn_2,
|
| 826 |
"save_btn_1": save_btn_1,
|
| 827 |
"save_btn_2": save_btn_2,
|
| 828 |
+
"score_btn_1": score_btn_1,
|
| 829 |
+
"score_btn_2": score_btn_2,
|
| 830 |
+
"score_display_1": score_display_1,
|
| 831 |
+
"score_display_2": score_display_2,
|
| 832 |
"generated_audio_batch": generated_audio_batch,
|
| 833 |
"generation_info": generation_info,
|
| 834 |
"align_score_1": align_score_1,
|
|
|
|
| 1091 |
gr.Warning(f"Error loading example: {str(e)}")
|
| 1092 |
return "", "", True, None, None, "", "", ""
|
| 1093 |
|
| 1094 |
+
def sample_example_smart(task_type: str, constrained_decoding_debug: bool = False):
|
| 1095 |
"""Smart sample function that uses LM if initialized, otherwise falls back to examples
|
| 1096 |
|
| 1097 |
Args:
|
| 1098 |
task_type: The task type (e.g., "text2music")
|
| 1099 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 1100 |
|
| 1101 |
Returns:
|
| 1102 |
Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
|
|
|
|
| 1110 |
audio_codes="NO USER INPUT",
|
| 1111 |
use_constrained_decoding=True,
|
| 1112 |
temperature=0.85,
|
| 1113 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1114 |
)
|
| 1115 |
|
| 1116 |
if metadata:
|
|
|
|
| 1145 |
if timesignature_value in [None, "N/A"]:
|
| 1146 |
timesignature_value = ''
|
| 1147 |
|
| 1148 |
+
gr.Info("🤖 Generated example using LM")
|
| 1149 |
return caption_value, lyrics_value, think_value, bpm_value, duration_value, keyscale_value, language_value, timesignature_value
|
| 1150 |
else:
|
| 1151 |
gr.Warning("Failed to generate example using LM, falling back to examples directory")
|
|
|
|
| 1336 |
use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
|
| 1337 |
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 1338 |
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
|
| 1339 |
+
constrained_decoding_debug,
|
| 1340 |
progress=gr.Progress(track_tqdm=True)
|
| 1341 |
):
|
| 1342 |
# If think is enabled (llm_dit mode) and use_cot_metas is True, generate audio codes using LM first
|
|
|
|
| 1394 |
use_cot_caption=use_cot_caption,
|
| 1395 |
use_cot_language=use_cot_language,
|
| 1396 |
is_format_caption=is_format_caption,
|
| 1397 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1398 |
)
|
| 1399 |
|
| 1400 |
# Store LM-generated metadata and audio codes for display
|
|
|
|
| 1524 |
generation_section["use_cot_metas"],
|
| 1525 |
generation_section["use_cot_caption"],
|
| 1526 |
generation_section["use_cot_language"],
|
| 1527 |
+
results_section["is_format_caption_state"],
|
| 1528 |
+
generation_section["constrained_decoding_debug"]
|
| 1529 |
],
|
| 1530 |
outputs=[
|
| 1531 |
results_section["generated_audio_1"],
|
|
|
|
| 1774 |
|
| 1775 |
# Sample button - smart sample (uses LM if initialized, otherwise examples)
|
| 1776 |
# Need to add is_format_caption return value to sample_example_smart
|
| 1777 |
+
def sample_example_smart_with_flag(task_type: str, constrained_decoding_debug: bool):
|
| 1778 |
"""Wrapper for sample_example_smart that adds is_format_caption flag"""
|
| 1779 |
+
result = sample_example_smart(task_type, constrained_decoding_debug)
|
| 1780 |
# Add True at the end to set is_format_caption
|
| 1781 |
return result + (True,)
|
| 1782 |
|
| 1783 |
generation_section["sample_btn"].click(
|
| 1784 |
fn=sample_example_smart_with_flag,
|
| 1785 |
+
inputs=[
|
| 1786 |
+
generation_section["task_type"],
|
| 1787 |
+
generation_section["constrained_decoding_debug"]
|
| 1788 |
+
],
|
| 1789 |
outputs=[
|
| 1790 |
generation_section["captions"],
|
| 1791 |
generation_section["lyrics"],
|
|
|
|
| 1800 |
)
|
| 1801 |
|
| 1802 |
# Transcribe audio codes to metadata (or generate example if empty)
|
| 1803 |
+
def transcribe_audio_codes(audio_code_string, constrained_decoding_debug):
|
| 1804 |
"""
|
| 1805 |
Transcribe audio codes to metadata using LLM understanding.
|
| 1806 |
If audio_code_string is empty, generate a sample example instead.
|
| 1807 |
|
| 1808 |
Args:
|
| 1809 |
audio_code_string: String containing audio codes (or empty for example generation)
|
| 1810 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 1811 |
|
| 1812 |
Returns:
|
| 1813 |
Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature)
|
|
|
|
| 1821 |
audio_code_string = "NO USER INPUT"
|
| 1822 |
|
| 1823 |
# Call LLM understanding
|
| 1824 |
+
metadata, status = llm_handler.understand_audio_from_codes(
|
| 1825 |
+
audio_codes=audio_code_string,
|
| 1826 |
+
use_constrained_decoding=True,
|
| 1827 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1828 |
+
)
|
| 1829 |
|
| 1830 |
# Extract fields for UI update
|
| 1831 |
caption = metadata.get('caption', '')
|
|
|
|
| 1880 |
|
| 1881 |
generation_section["transcribe_btn"].click(
|
| 1882 |
fn=transcribe_audio_codes,
|
| 1883 |
+
inputs=[
|
| 1884 |
+
generation_section["text2music_audio_code_string"],
|
| 1885 |
+
generation_section["constrained_decoding_debug"]
|
| 1886 |
+
],
|
| 1887 |
outputs=[
|
| 1888 |
results_section["status_output"], # Show status
|
| 1889 |
generation_section["captions"], # Update caption field
|
|
|
|
| 1964 |
outputs=[generation_section["audio_uploads_accordion"]]
|
| 1965 |
)
|
| 1966 |
|
| 1967 |
+
# Save metadata handlers - use JavaScript to trigger automatic download
|
| 1968 |
results_section["save_btn_1"].click(
|
| 1969 |
+
fn=None,
|
| 1970 |
inputs=[
|
| 1971 |
generation_section["task_type"],
|
| 1972 |
generation_section["captions"],
|
|
|
|
| 2001 |
generation_section["complete_track_classes"],
|
| 2002 |
results_section["lm_metadata_state"],
|
| 2003 |
],
|
| 2004 |
+
outputs=None,
|
| 2005 |
+
js="""
|
| 2006 |
+
(task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature, audio_duration,
|
| 2007 |
+
batch_size_input, inference_steps, guidance_scale, seed, random_seed_checkbox,
|
| 2008 |
+
use_adg, cfg_interval_start, cfg_interval_end, audio_format,
|
| 2009 |
+
lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 2010 |
+
use_cot_caption, use_cot_language, audio_cover_strength,
|
| 2011 |
+
think_checkbox, text2music_audio_code_string, repainting_start, repainting_end,
|
| 2012 |
+
track_name, complete_track_classes, lm_metadata) => {
|
| 2013 |
+
// Create metadata object
|
| 2014 |
+
const metadata = {
|
| 2015 |
+
saved_at: new Date().toISOString(),
|
| 2016 |
+
task_type: task_type,
|
| 2017 |
+
caption: captions || "",
|
| 2018 |
+
lyrics: lyrics || "",
|
| 2019 |
+
vocal_language: vocal_language,
|
| 2020 |
+
bpm: bpm,
|
| 2021 |
+
keyscale: key_scale || "",
|
| 2022 |
+
timesignature: time_signature || "",
|
| 2023 |
+
duration: audio_duration,
|
| 2024 |
+
batch_size: batch_size_input,
|
| 2025 |
+
inference_steps: inference_steps,
|
| 2026 |
+
guidance_scale: guidance_scale,
|
| 2027 |
+
seed: seed,
|
| 2028 |
+
random_seed: random_seed_checkbox,
|
| 2029 |
+
use_adg: use_adg,
|
| 2030 |
+
cfg_interval_start: cfg_interval_start,
|
| 2031 |
+
cfg_interval_end: cfg_interval_end,
|
| 2032 |
+
audio_format: audio_format,
|
| 2033 |
+
lm_temperature: lm_temperature,
|
| 2034 |
+
lm_cfg_scale: lm_cfg_scale,
|
| 2035 |
+
lm_top_k: lm_top_k,
|
| 2036 |
+
lm_top_p: lm_top_p,
|
| 2037 |
+
lm_negative_prompt: lm_negative_prompt,
|
| 2038 |
+
use_cot_caption: use_cot_caption,
|
| 2039 |
+
use_cot_language: use_cot_language,
|
| 2040 |
+
audio_cover_strength: audio_cover_strength,
|
| 2041 |
+
think: think_checkbox,
|
| 2042 |
+
audio_codes: text2music_audio_code_string || "",
|
| 2043 |
+
repainting_start: repainting_start,
|
| 2044 |
+
repainting_end: repainting_end,
|
| 2045 |
+
track_name: track_name,
|
| 2046 |
+
complete_track_classes: complete_track_classes || []
|
| 2047 |
+
};
|
| 2048 |
+
|
| 2049 |
+
if (lm_metadata) {
|
| 2050 |
+
metadata.lm_generated_metadata = lm_metadata;
|
| 2051 |
+
}
|
| 2052 |
+
|
| 2053 |
+
// Create JSON string
|
| 2054 |
+
const jsonStr = JSON.stringify(metadata, null, 2);
|
| 2055 |
+
|
| 2056 |
+
// Create blob and download
|
| 2057 |
+
const blob = new Blob([jsonStr], { type: 'application/json' });
|
| 2058 |
+
const url = URL.createObjectURL(blob);
|
| 2059 |
+
const a = document.createElement('a');
|
| 2060 |
+
a.href = url;
|
| 2061 |
+
const timestamp = new Date().toISOString().replace(/[-:]/g, '').replace('T', '_').split('.')[0];
|
| 2062 |
+
a.download = `generation_params_${timestamp}.json`;
|
| 2063 |
+
document.body.appendChild(a);
|
| 2064 |
+
a.click();
|
| 2065 |
+
document.body.removeChild(a);
|
| 2066 |
+
URL.revokeObjectURL(url);
|
| 2067 |
+
|
| 2068 |
+
return Array(32).fill(null);
|
| 2069 |
+
}
|
| 2070 |
+
"""
|
| 2071 |
)
|
| 2072 |
|
| 2073 |
results_section["save_btn_2"].click(
|
| 2074 |
+
fn=None,
|
| 2075 |
inputs=[
|
| 2076 |
generation_section["task_type"],
|
| 2077 |
generation_section["captions"],
|
|
|
|
| 2106 |
generation_section["complete_track_classes"],
|
| 2107 |
results_section["lm_metadata_state"],
|
| 2108 |
],
|
| 2109 |
+
outputs=None,
|
| 2110 |
+
js="""
|
| 2111 |
+
(task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature, audio_duration,
|
| 2112 |
+
batch_size_input, inference_steps, guidance_scale, seed, random_seed_checkbox,
|
| 2113 |
+
use_adg, cfg_interval_start, cfg_interval_end, audio_format,
|
| 2114 |
+
lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 2115 |
+
use_cot_caption, use_cot_language, audio_cover_strength,
|
| 2116 |
+
think_checkbox, text2music_audio_code_string, repainting_start, repainting_end,
|
| 2117 |
+
track_name, complete_track_classes, lm_metadata) => {
|
| 2118 |
+
// Create metadata object
|
| 2119 |
+
const metadata = {
|
| 2120 |
+
saved_at: new Date().toISOString(),
|
| 2121 |
+
task_type: task_type,
|
| 2122 |
+
caption: captions || "",
|
| 2123 |
+
lyrics: lyrics || "",
|
| 2124 |
+
vocal_language: vocal_language,
|
| 2125 |
+
bpm: bpm,
|
| 2126 |
+
keyscale: key_scale || "",
|
| 2127 |
+
timesignature: time_signature || "",
|
| 2128 |
+
duration: audio_duration,
|
| 2129 |
+
batch_size: batch_size_input,
|
| 2130 |
+
inference_steps: inference_steps,
|
| 2131 |
+
guidance_scale: guidance_scale,
|
| 2132 |
+
seed: seed,
|
| 2133 |
+
random_seed: random_seed_checkbox,
|
| 2134 |
+
use_adg: use_adg,
|
| 2135 |
+
cfg_interval_start: cfg_interval_start,
|
| 2136 |
+
cfg_interval_end: cfg_interval_end,
|
| 2137 |
+
audio_format: audio_format,
|
| 2138 |
+
lm_temperature: lm_temperature,
|
| 2139 |
+
lm_cfg_scale: lm_cfg_scale,
|
| 2140 |
+
lm_top_k: lm_top_k,
|
| 2141 |
+
lm_top_p: lm_top_p,
|
| 2142 |
+
lm_negative_prompt: lm_negative_prompt,
|
| 2143 |
+
use_cot_caption: use_cot_caption,
|
| 2144 |
+
use_cot_language: use_cot_language,
|
| 2145 |
+
audio_cover_strength: audio_cover_strength,
|
| 2146 |
+
think: think_checkbox,
|
| 2147 |
+
audio_codes: text2music_audio_code_string || "",
|
| 2148 |
+
repainting_start: repainting_start,
|
| 2149 |
+
repainting_end: repainting_end,
|
| 2150 |
+
track_name: track_name,
|
| 2151 |
+
complete_track_classes: complete_track_classes || []
|
| 2152 |
+
};
|
| 2153 |
+
|
| 2154 |
+
if (lm_metadata) {
|
| 2155 |
+
metadata.lm_generated_metadata = lm_metadata;
|
| 2156 |
+
}
|
| 2157 |
+
|
| 2158 |
+
// Create JSON string
|
| 2159 |
+
const jsonStr = JSON.stringify(metadata, null, 2);
|
| 2160 |
+
|
| 2161 |
+
// Create blob and download
|
| 2162 |
+
const blob = new Blob([jsonStr], { type: 'application/json' });
|
| 2163 |
+
const url = URL.createObjectURL(blob);
|
| 2164 |
+
const a = document.createElement('a');
|
| 2165 |
+
a.href = url;
|
| 2166 |
+
const timestamp = new Date().toISOString().replace(/[-:]/g, '').replace('T', '_').split('.')[0];
|
| 2167 |
+
a.download = `generation_params_${timestamp}.json`;
|
| 2168 |
+
document.body.appendChild(a);
|
| 2169 |
+
a.click();
|
| 2170 |
+
document.body.removeChild(a);
|
| 2171 |
+
URL.revokeObjectURL(url);
|
| 2172 |
+
|
| 2173 |
+
return Array(32).fill(null);
|
| 2174 |
+
}
|
| 2175 |
+
"""
|
| 2176 |
)
|
| 2177 |
|
| 2178 |
# Load metadata handler - triggered when file is uploaded via UploadButton
|
|
|
|
| 2214 |
results_section["is_format_caption_state"]
|
| 2215 |
]
|
| 2216 |
)
|
| 2217 |
+
|
| 2218 |
+
# Instrumental checkbox handler - auto-fill [Instrumental] when checked
|
| 2219 |
+
def handle_instrumental_checkbox(instrumental_checked, current_lyrics):
|
| 2220 |
+
"""
|
| 2221 |
+
Handle instrumental checkbox changes.
|
| 2222 |
+
When checked: if no lyrics, fill with [Instrumental]
|
| 2223 |
+
When unchecked: if lyrics is [Instrumental], clear it
|
| 2224 |
+
"""
|
| 2225 |
+
if instrumental_checked:
|
| 2226 |
+
# If checked and no lyrics, fill with [Instrumental]
|
| 2227 |
+
if not current_lyrics or not current_lyrics.strip():
|
| 2228 |
+
return "[Instrumental]"
|
| 2229 |
+
else:
|
| 2230 |
+
# Has lyrics, don't change
|
| 2231 |
+
return current_lyrics
|
| 2232 |
+
else:
|
| 2233 |
+
# If unchecked and lyrics is exactly [Instrumental], clear it
|
| 2234 |
+
if current_lyrics and current_lyrics.strip() == "[Instrumental]":
|
| 2235 |
+
return ""
|
| 2236 |
+
else:
|
| 2237 |
+
# Has other lyrics, don't change
|
| 2238 |
+
return current_lyrics
|
| 2239 |
+
|
| 2240 |
+
generation_section["instrumental_checkbox"].change(
|
| 2241 |
+
fn=handle_instrumental_checkbox,
|
| 2242 |
+
inputs=[generation_section["instrumental_checkbox"], generation_section["lyrics"]],
|
| 2243 |
+
outputs=[generation_section["lyrics"]]
|
| 2244 |
+
)
|
| 2245 |
+
|
| 2246 |
+
# Score calculation handlers
|
| 2247 |
+
def calculate_score_handler(audio_codes_str, caption, lyrics, lm_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale):
|
| 2248 |
+
"""
|
| 2249 |
+
Calculate perplexity-based quality score for generated audio.
|
| 2250 |
+
|
| 2251 |
+
Args:
|
| 2252 |
+
audio_codes_str: Generated audio codes string
|
| 2253 |
+
caption: Caption text used for generation
|
| 2254 |
+
lyrics: Lyrics text used for generation
|
| 2255 |
+
lm_metadata: LM-generated metadata dictionary (from CoT generation)
|
| 2256 |
+
bpm: BPM value
|
| 2257 |
+
key_scale: Key scale value
|
| 2258 |
+
time_signature: Time signature value
|
| 2259 |
+
audio_duration: Audio duration value
|
| 2260 |
+
vocal_language: Vocal language value
|
| 2261 |
+
score_scale: Sensitivity scale parameter (lower = more sensitive)
|
| 2262 |
+
|
| 2263 |
+
Returns:
|
| 2264 |
+
Score display string
|
| 2265 |
+
"""
|
| 2266 |
+
from acestep.test_time_scaling import calculate_perplexity, perplexity_to_score
|
| 2267 |
+
|
| 2268 |
+
if not llm_handler.llm_initialized:
|
| 2269 |
+
return "❌ LLM not initialized. Please initialize 5Hz LM first."
|
| 2270 |
+
|
| 2271 |
+
if not audio_codes_str or not audio_codes_str.strip():
|
| 2272 |
+
return "❌ No audio codes available. Please generate music first."
|
| 2273 |
+
|
| 2274 |
+
try:
|
| 2275 |
+
# Build metadata dictionary from both LM metadata and user inputs
|
| 2276 |
+
metadata = {}
|
| 2277 |
+
|
| 2278 |
+
# Priority 1: Use LM-generated metadata if available
|
| 2279 |
+
if lm_metadata and isinstance(lm_metadata, dict):
|
| 2280 |
+
metadata.update(lm_metadata)
|
| 2281 |
+
|
| 2282 |
+
# Priority 2: Add user-provided metadata (if not already in LM metadata)
|
| 2283 |
+
if bpm is not None and 'bpm' not in metadata:
|
| 2284 |
+
try:
|
| 2285 |
+
metadata['bpm'] = int(bpm)
|
| 2286 |
+
except:
|
| 2287 |
+
pass
|
| 2288 |
+
|
| 2289 |
+
if caption and 'caption' not in metadata:
|
| 2290 |
+
metadata['caption'] = caption
|
| 2291 |
+
|
| 2292 |
+
if audio_duration is not None and audio_duration > 0 and 'duration' not in metadata:
|
| 2293 |
+
try:
|
| 2294 |
+
metadata['duration'] = int(audio_duration)
|
| 2295 |
+
except:
|
| 2296 |
+
pass
|
| 2297 |
+
|
| 2298 |
+
if key_scale and key_scale.strip() and 'keyscale' not in metadata:
|
| 2299 |
+
metadata['keyscale'] = key_scale.strip()
|
| 2300 |
+
|
| 2301 |
+
if vocal_language and vocal_language.strip() and 'language' not in metadata:
|
| 2302 |
+
metadata['language'] = vocal_language.strip()
|
| 2303 |
+
|
| 2304 |
+
if time_signature and time_signature.strip() and 'timesignature' not in metadata:
|
| 2305 |
+
metadata['timesignature'] = time_signature.strip()
|
| 2306 |
+
|
| 2307 |
+
# Calculate perplexity
|
| 2308 |
+
perplexity, status = calculate_perplexity(
|
| 2309 |
+
llm_handler=llm_handler,
|
| 2310 |
+
audio_codes=audio_codes_str,
|
| 2311 |
+
caption=caption or "",
|
| 2312 |
+
lyrics=lyrics or "",
|
| 2313 |
+
metadata=metadata if metadata else None,
|
| 2314 |
+
temperature=1.0
|
| 2315 |
+
)
|
| 2316 |
+
|
| 2317 |
+
# Convert perplexity to normalized score [0, 1] (higher is better)
|
| 2318 |
+
normalized_score = perplexity_to_score(perplexity, scale=score_scale)
|
| 2319 |
+
|
| 2320 |
+
# Format display string
|
| 2321 |
+
if perplexity == float('inf'):
|
| 2322 |
+
return f"❌ Scoring failed: {status}"
|
| 2323 |
+
else:
|
| 2324 |
+
return f"✅ Quality Score: {normalized_score:.4f} (range: 0-1, higher is better)\nPerplexity: {perplexity:.4f}\nSensitivity: {score_scale}\n{status}"
|
| 2325 |
+
|
| 2326 |
+
except Exception as e:
|
| 2327 |
+
import traceback
|
| 2328 |
+
error_msg = f"❌ Error calculating score: {str(e)}\n{traceback.format_exc()}"
|
| 2329 |
+
return error_msg
|
| 2330 |
+
|
| 2331 |
+
# Connect score buttons to handlers
|
| 2332 |
+
results_section["score_btn_1"].click(
|
| 2333 |
+
fn=calculate_score_handler,
|
| 2334 |
+
inputs=[
|
| 2335 |
+
generation_section["text2music_audio_code_string"],
|
| 2336 |
+
generation_section["captions"],
|
| 2337 |
+
generation_section["lyrics"],
|
| 2338 |
+
results_section["lm_metadata_state"],
|
| 2339 |
+
generation_section["bpm"],
|
| 2340 |
+
generation_section["key_scale"],
|
| 2341 |
+
generation_section["time_signature"],
|
| 2342 |
+
generation_section["audio_duration"],
|
| 2343 |
+
generation_section["vocal_language"],
|
| 2344 |
+
generation_section["score_scale"]
|
| 2345 |
+
],
|
| 2346 |
+
outputs=[results_section["score_display_1"]]
|
| 2347 |
+
)
|
| 2348 |
+
|
| 2349 |
+
results_section["score_btn_2"].click(
|
| 2350 |
+
fn=calculate_score_handler,
|
| 2351 |
+
inputs=[
|
| 2352 |
+
generation_section["text2music_audio_code_string"],
|
| 2353 |
+
generation_section["captions"],
|
| 2354 |
+
generation_section["lyrics"],
|
| 2355 |
+
results_section["lm_metadata_state"],
|
| 2356 |
+
generation_section["bpm"],
|
| 2357 |
+
generation_section["key_scale"],
|
| 2358 |
+
generation_section["time_signature"],
|
| 2359 |
+
generation_section["audio_duration"],
|
| 2360 |
+
generation_section["vocal_language"],
|
| 2361 |
+
generation_section["score_scale"]
|
| 2362 |
+
],
|
| 2363 |
+
outputs=[results_section["score_display_2"]]
|
| 2364 |
+
)
|
| 2365 |
|
acestep/llm_inference.py
CHANGED
|
@@ -39,6 +39,9 @@ class LLMHandler:
|
|
| 39 |
|
| 40 |
# Shared constrained decoding processor (initialized once when LLM is loaded)
|
| 41 |
self.constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def get_available_5hz_lm_models(self) -> List[str]:
|
| 44 |
"""Scan and return all model directory names starting with 'acestep-5Hz-lm-'"""
|
|
@@ -246,6 +249,7 @@ class LLMHandler:
|
|
| 246 |
target_duration: Optional[float] = None,
|
| 247 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 248 |
stop_at_reasoning: bool = False,
|
|
|
|
| 249 |
skip_caption: bool = False,
|
| 250 |
skip_language: bool = False,
|
| 251 |
generation_phase: str = "cot",
|
|
@@ -276,6 +280,7 @@ class LLMHandler:
|
|
| 276 |
self.constrained_processor.set_user_metadata(user_metadata)
|
| 277 |
self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
|
| 278 |
# Set skip_caption and skip_language based on flags
|
|
|
|
| 279 |
self.constrained_processor.set_skip_caption(skip_caption)
|
| 280 |
self.constrained_processor.set_skip_language(skip_language)
|
| 281 |
# Set generation phase for phase-aware processing
|
|
@@ -347,6 +352,7 @@ class LLMHandler:
|
|
| 347 |
target_duration: Optional[float] = None,
|
| 348 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 349 |
stop_at_reasoning: bool = False,
|
|
|
|
| 350 |
skip_caption: bool = False,
|
| 351 |
skip_language: bool = False,
|
| 352 |
generation_phase: str = "cot",
|
|
@@ -376,6 +382,7 @@ class LLMHandler:
|
|
| 376 |
self.constrained_processor.set_user_metadata(user_metadata)
|
| 377 |
self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
|
| 378 |
# Set skip_caption and skip_language based on flags
|
|
|
|
| 379 |
self.constrained_processor.set_skip_caption(skip_caption)
|
| 380 |
self.constrained_processor.set_skip_language(skip_language)
|
| 381 |
# Set generation phase for phase-aware processing
|
|
@@ -597,6 +604,7 @@ class LLMHandler:
|
|
| 597 |
"user_metadata": user_metadata,
|
| 598 |
"skip_caption": not use_cot_caption,
|
| 599 |
"skip_language": not use_cot_language,
|
|
|
|
| 600 |
"generation_phase": "cot",
|
| 601 |
# Pass context for building unconditional prompt in CoT phase
|
| 602 |
"caption": caption,
|
|
@@ -863,7 +871,6 @@ class LLMHandler:
|
|
| 863 |
- bpm: int or str
|
| 864 |
- caption: str
|
| 865 |
- duration: int or str
|
| 866 |
-
- genres: str
|
| 867 |
- keyscale: str
|
| 868 |
- language: str
|
| 869 |
- timesignature: str
|
|
@@ -901,6 +908,7 @@ class LLMHandler:
|
|
| 901 |
"user_metadata": None, # No user metadata injection
|
| 902 |
"skip_caption": False, # Generate caption
|
| 903 |
"skip_language": False, # Generate language
|
|
|
|
| 904 |
"generation_phase": "understand", # Understanding phase: generate CoT metadata, then free-form lyrics
|
| 905 |
# Context for building unconditional prompt
|
| 906 |
"caption": "",
|
|
@@ -1015,6 +1023,7 @@ class LLMHandler:
|
|
| 1015 |
user_metadata = cfg.get("user_metadata") # User-provided metadata fields
|
| 1016 |
skip_caption = cfg.get("skip_caption", False) # Skip caption generation in CoT
|
| 1017 |
skip_language = cfg.get("skip_language", False) # Skip language generation in CoT
|
|
|
|
| 1018 |
generation_phase = cfg.get("generation_phase", "cot") # "cot" or "codes"
|
| 1019 |
# Additional context for codes phase unconditional prompt building
|
| 1020 |
caption = cfg.get("caption", "")
|
|
@@ -1036,6 +1045,7 @@ class LLMHandler:
|
|
| 1036 |
target_duration=target_duration,
|
| 1037 |
user_metadata=user_metadata,
|
| 1038 |
stop_at_reasoning=stop_at_reasoning,
|
|
|
|
| 1039 |
skip_caption=skip_caption,
|
| 1040 |
skip_language=skip_language,
|
| 1041 |
generation_phase=generation_phase,
|
|
@@ -1059,6 +1069,7 @@ class LLMHandler:
|
|
| 1059 |
target_duration=target_duration,
|
| 1060 |
user_metadata=user_metadata,
|
| 1061 |
stop_at_reasoning=stop_at_reasoning,
|
|
|
|
| 1062 |
skip_caption=skip_caption,
|
| 1063 |
skip_language=skip_language,
|
| 1064 |
generation_phase=generation_phase,
|
|
@@ -1521,3 +1532,51 @@ class LLMHandler:
|
|
| 1521 |
torch.cuda.empty_cache()
|
| 1522 |
offload_time = time.time() - start_time
|
| 1523 |
logger.info(f"Offloaded LLM to CPU in {offload_time:.4f}s")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
# Shared constrained decoding processor (initialized once when LLM is loaded)
|
| 41 |
self.constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None
|
| 42 |
+
|
| 43 |
+
# Shared HuggingFace model for perplexity calculation (when using vllm backend)
|
| 44 |
+
self._hf_model_for_scoring = None
|
| 45 |
|
| 46 |
def get_available_5hz_lm_models(self) -> List[str]:
|
| 47 |
"""Scan and return all model directory names starting with 'acestep-5Hz-lm-'"""
|
|
|
|
| 249 |
target_duration: Optional[float] = None,
|
| 250 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 251 |
stop_at_reasoning: bool = False,
|
| 252 |
+
skip_genres: bool = True,
|
| 253 |
skip_caption: bool = False,
|
| 254 |
skip_language: bool = False,
|
| 255 |
generation_phase: str = "cot",
|
|
|
|
| 280 |
self.constrained_processor.set_user_metadata(user_metadata)
|
| 281 |
self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
|
| 282 |
# Set skip_caption and skip_language based on flags
|
| 283 |
+
self.constrained_processor.set_skip_genres(skip_genres)
|
| 284 |
self.constrained_processor.set_skip_caption(skip_caption)
|
| 285 |
self.constrained_processor.set_skip_language(skip_language)
|
| 286 |
# Set generation phase for phase-aware processing
|
|
|
|
| 352 |
target_duration: Optional[float] = None,
|
| 353 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 354 |
stop_at_reasoning: bool = False,
|
| 355 |
+
skip_genres: bool = True,
|
| 356 |
skip_caption: bool = False,
|
| 357 |
skip_language: bool = False,
|
| 358 |
generation_phase: str = "cot",
|
|
|
|
| 382 |
self.constrained_processor.set_user_metadata(user_metadata)
|
| 383 |
self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
|
| 384 |
# Set skip_caption and skip_language based on flags
|
| 385 |
+
self.constrained_processor.set_skip_genres(skip_genres)
|
| 386 |
self.constrained_processor.set_skip_caption(skip_caption)
|
| 387 |
self.constrained_processor.set_skip_language(skip_language)
|
| 388 |
# Set generation phase for phase-aware processing
|
|
|
|
| 604 |
"user_metadata": user_metadata,
|
| 605 |
"skip_caption": not use_cot_caption,
|
| 606 |
"skip_language": not use_cot_language,
|
| 607 |
+
"skip_genres": True, # Generate genres
|
| 608 |
"generation_phase": "cot",
|
| 609 |
# Pass context for building unconditional prompt in CoT phase
|
| 610 |
"caption": caption,
|
|
|
|
| 871 |
- bpm: int or str
|
| 872 |
- caption: str
|
| 873 |
- duration: int or str
|
|
|
|
| 874 |
- keyscale: str
|
| 875 |
- language: str
|
| 876 |
- timesignature: str
|
|
|
|
| 908 |
"user_metadata": None, # No user metadata injection
|
| 909 |
"skip_caption": False, # Generate caption
|
| 910 |
"skip_language": False, # Generate language
|
| 911 |
+
"skip_genres": False, # Generate genres
|
| 912 |
"generation_phase": "understand", # Understanding phase: generate CoT metadata, then free-form lyrics
|
| 913 |
# Context for building unconditional prompt
|
| 914 |
"caption": "",
|
|
|
|
| 1023 |
user_metadata = cfg.get("user_metadata") # User-provided metadata fields
|
| 1024 |
skip_caption = cfg.get("skip_caption", False) # Skip caption generation in CoT
|
| 1025 |
skip_language = cfg.get("skip_language", False) # Skip language generation in CoT
|
| 1026 |
+
skip_genres = cfg.get("skip_genres", False) # Skip genres generation in CoT
|
| 1027 |
generation_phase = cfg.get("generation_phase", "cot") # "cot" or "codes"
|
| 1028 |
# Additional context for codes phase unconditional prompt building
|
| 1029 |
caption = cfg.get("caption", "")
|
|
|
|
| 1045 |
target_duration=target_duration,
|
| 1046 |
user_metadata=user_metadata,
|
| 1047 |
stop_at_reasoning=stop_at_reasoning,
|
| 1048 |
+
skip_genres=skip_genres,
|
| 1049 |
skip_caption=skip_caption,
|
| 1050 |
skip_language=skip_language,
|
| 1051 |
generation_phase=generation_phase,
|
|
|
|
| 1069 |
target_duration=target_duration,
|
| 1070 |
user_metadata=user_metadata,
|
| 1071 |
stop_at_reasoning=stop_at_reasoning,
|
| 1072 |
+
skip_genres=skip_genres,
|
| 1073 |
skip_caption=skip_caption,
|
| 1074 |
skip_language=skip_language,
|
| 1075 |
generation_phase=generation_phase,
|
|
|
|
| 1532 |
torch.cuda.empty_cache()
|
| 1533 |
offload_time = time.time() - start_time
|
| 1534 |
logger.info(f"Offloaded LLM to CPU in {offload_time:.4f}s")
|
| 1535 |
+
|
| 1536 |
+
def get_hf_model_for_scoring(self):
|
| 1537 |
+
"""
|
| 1538 |
+
Get HuggingFace model for perplexity scoring.
|
| 1539 |
+
|
| 1540 |
+
For vllm backend, loads HuggingFace model from disk (weights are cached by transformers).
|
| 1541 |
+
For pt backend, returns the existing model.
|
| 1542 |
+
|
| 1543 |
+
Returns:
|
| 1544 |
+
HuggingFace model instance
|
| 1545 |
+
"""
|
| 1546 |
+
if self.llm_backend == "pt":
|
| 1547 |
+
# For PyTorch backend, directly return the model
|
| 1548 |
+
return self.llm
|
| 1549 |
+
|
| 1550 |
+
elif self.llm_backend == "vllm":
|
| 1551 |
+
# For vllm backend, load HuggingFace model from disk
|
| 1552 |
+
# Note: transformers caches model weights, so this doesn't duplicate disk I/O
|
| 1553 |
+
if self._hf_model_for_scoring is None:
|
| 1554 |
+
logger.info("Loading HuggingFace model for scoring (from checkpoint)")
|
| 1555 |
+
|
| 1556 |
+
# Get model path from vllm config
|
| 1557 |
+
model_runner = self.llm.model_runner
|
| 1558 |
+
model_path = model_runner.config.model
|
| 1559 |
+
|
| 1560 |
+
# Load HuggingFace model from the same checkpoint
|
| 1561 |
+
# This will load the original unfused weights
|
| 1562 |
+
import time
|
| 1563 |
+
start_time = time.time()
|
| 1564 |
+
self._hf_model_for_scoring = AutoModelForCausalLM.from_pretrained(
|
| 1565 |
+
model_path,
|
| 1566 |
+
trust_remote_code=True,
|
| 1567 |
+
torch_dtype=self.dtype
|
| 1568 |
+
)
|
| 1569 |
+
load_time = time.time() - start_time
|
| 1570 |
+
logger.info(f"HuggingFace model loaded in {load_time:.2f}s")
|
| 1571 |
+
|
| 1572 |
+
# Move to same device as vllm model
|
| 1573 |
+
device = next(model_runner.model.parameters()).device
|
| 1574 |
+
self._hf_model_for_scoring = self._hf_model_for_scoring.to(device)
|
| 1575 |
+
self._hf_model_for_scoring.eval()
|
| 1576 |
+
|
| 1577 |
+
logger.info(f"HuggingFace model for scoring ready on {device}")
|
| 1578 |
+
|
| 1579 |
+
return self._hf_model_for_scoring
|
| 1580 |
+
|
| 1581 |
+
else:
|
| 1582 |
+
raise ValueError(f"Unknown backend: {self.llm_backend}")
|
acestep/test_time_scaling.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test-Time Scaling Module
|
| 3 |
+
Implements perplexity-based scoring for generated audio codes
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from typing import Tuple, Optional, Dict, Any
|
| 8 |
+
from loguru import logger
|
| 9 |
+
import yaml
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def perplexity_to_score(perplexity: float, scale: float = 100.0) -> float:
|
| 13 |
+
"""
|
| 14 |
+
Convert perplexity to a normalized score in [0, 1] range.
|
| 15 |
+
|
| 16 |
+
Lower perplexity = higher score (better quality)
|
| 17 |
+
Uses exponential decay: score = exp(-perplexity / scale)
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
perplexity: Perplexity value (typically 1 to 1000+)
|
| 21 |
+
scale: Scale parameter to control score distribution (default 100.0)
|
| 22 |
+
- Smaller scale: more sensitive to perplexity changes
|
| 23 |
+
- Larger scale: less sensitive to perplexity changes
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Score in [0, 1] range, where 1 is perfect and 0 is worst
|
| 27 |
+
|
| 28 |
+
Examples:
|
| 29 |
+
perplexity=1 → score≈0.99 (excellent)
|
| 30 |
+
perplexity=50 → score≈0.61 (good if scale=100)
|
| 31 |
+
perplexity=100 → score≈0.37 (medium if scale=100)
|
| 32 |
+
perplexity=500 → score≈0.01 (poor if scale=100)
|
| 33 |
+
"""
|
| 34 |
+
import math
|
| 35 |
+
return math.exp(-perplexity / scale)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def calculate_perplexity(
|
| 39 |
+
llm_handler,
|
| 40 |
+
audio_codes: str,
|
| 41 |
+
caption: str = "",
|
| 42 |
+
lyrics: str = "",
|
| 43 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 44 |
+
temperature: float = 1.0,
|
| 45 |
+
) -> Tuple[float, str]:
|
| 46 |
+
"""
|
| 47 |
+
Calculate perplexity of generated audio codes conditioned on caption/lyrics/metadata.
|
| 48 |
+
|
| 49 |
+
This reverses the generation task: given audio codes as input, measure how well
|
| 50 |
+
the model can predict the CoT metadata and lyrics that should generate those codes.
|
| 51 |
+
|
| 52 |
+
Lower perplexity = model is less surprised = better quality generation
|
| 53 |
+
Score = -perplexity (higher is better)
|
| 54 |
+
|
| 55 |
+
The understanding task format is:
|
| 56 |
+
Input: <|audio_code_123|><|audio_code_456|>...
|
| 57 |
+
Output: <think>\nmetadata_yaml\n</think>\n\n# Lyric\nlyrics_text
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
llm_handler: LLM handler instance with initialized model
|
| 61 |
+
audio_codes: Generated audio code string (e.g., "<|audio_code_123|><|audio_code_456|>...")
|
| 62 |
+
caption: Caption text used for generation
|
| 63 |
+
lyrics: Lyrics text used for generation
|
| 64 |
+
metadata: Dictionary with CoT metadata fields (bpm, duration, keyscale, language, timesignature, etc.)
|
| 65 |
+
temperature: Temperature for probability scaling (default 1.0)
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Tuple of (perplexity_value, status_message)
|
| 69 |
+
|
| 70 |
+
Example:
|
| 71 |
+
metadata = {'bpm': 120, 'duration': 30, 'keyscale': 'C major', 'language': 'en', 'timesignature': '4'}
|
| 72 |
+
perplexity, status = calculate_perplexity(
|
| 73 |
+
llm_handler,
|
| 74 |
+
audio_codes="<|audio_code_123|>...",
|
| 75 |
+
caption="calm piano",
|
| 76 |
+
lyrics="verse 1...",
|
| 77 |
+
metadata=metadata
|
| 78 |
+
)
|
| 79 |
+
score = -perplexity # Higher score = better quality
|
| 80 |
+
"""
|
| 81 |
+
if not llm_handler.llm_initialized:
|
| 82 |
+
return float('inf'), "❌ LLM not initialized"
|
| 83 |
+
|
| 84 |
+
if not audio_codes or not audio_codes.strip():
|
| 85 |
+
return float('inf'), "❌ No audio codes provided"
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
# Build the understanding prompt: codes as input
|
| 89 |
+
# The model should generate: <think>metadata</think>\n# Lyric\n...
|
| 90 |
+
formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(
|
| 91 |
+
audio_codes=audio_codes,
|
| 92 |
+
is_negative_prompt=False
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
logger.info(f"Calculating perplexity for {len(audio_codes)} character audio codes")
|
| 96 |
+
|
| 97 |
+
# Build the expected output (target sequence) following understanding task format
|
| 98 |
+
# Format: <think>\nmetadata_yaml\n</think>\n\n# Lyric\nlyrics_text
|
| 99 |
+
target_parts = []
|
| 100 |
+
|
| 101 |
+
# Build CoT section with metadata
|
| 102 |
+
if metadata and isinstance(metadata, dict):
|
| 103 |
+
# Filter out None values and format as YAML (sorted keys)
|
| 104 |
+
cot_items = {}
|
| 105 |
+
for key in ['bpm', 'caption', 'duration', 'genres', 'keyscale', 'language', 'timesignature']:
|
| 106 |
+
if key in metadata and metadata[key] is not None:
|
| 107 |
+
cot_items[key] = metadata[key]
|
| 108 |
+
|
| 109 |
+
if cot_items:
|
| 110 |
+
cot_yaml = yaml.dump(cot_items, allow_unicode=True, sort_keys=True).strip()
|
| 111 |
+
target_parts.append(f"<think>\n{cot_yaml}\n</think>\n")
|
| 112 |
+
|
| 113 |
+
# Add Lyric section (note: understanding task uses "# Lyric" not "# Caption")
|
| 114 |
+
if lyrics:
|
| 115 |
+
target_parts.append(f"\n# Lyric\n{lyrics}\n")
|
| 116 |
+
|
| 117 |
+
target_text = "".join(target_parts)
|
| 118 |
+
|
| 119 |
+
if not target_text.strip():
|
| 120 |
+
return float('inf'), "❌ No target text to evaluate (lyrics or metadata required)"
|
| 121 |
+
|
| 122 |
+
logger.debug(f"Target text (first 200 chars): {target_text[:200]}...")
|
| 123 |
+
|
| 124 |
+
# Calculate perplexity using appropriate backend
|
| 125 |
+
if llm_handler.llm_backend == "vllm":
|
| 126 |
+
perplexity = _calculate_perplexity_vllm(
|
| 127 |
+
llm_handler,
|
| 128 |
+
formatted_prompt,
|
| 129 |
+
target_text,
|
| 130 |
+
temperature
|
| 131 |
+
)
|
| 132 |
+
else: # pt backend
|
| 133 |
+
perplexity = _calculate_perplexity_pt(
|
| 134 |
+
llm_handler,
|
| 135 |
+
formatted_prompt,
|
| 136 |
+
target_text,
|
| 137 |
+
temperature
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
status_msg = f"✅ Perplexity calculated: {perplexity:.4f}"
|
| 141 |
+
logger.info(status_msg)
|
| 142 |
+
return perplexity, status_msg
|
| 143 |
+
|
| 144 |
+
except Exception as e:
|
| 145 |
+
error_msg = f"❌ Error calculating perplexity: {str(e)}"
|
| 146 |
+
logger.error(error_msg)
|
| 147 |
+
import traceback
|
| 148 |
+
logger.error(traceback.format_exc())
|
| 149 |
+
return float('inf'), error_msg
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _calculate_perplexity_pt(
|
| 153 |
+
llm_handler,
|
| 154 |
+
formatted_prompt: str,
|
| 155 |
+
target_text: str,
|
| 156 |
+
temperature: float
|
| 157 |
+
) -> float:
|
| 158 |
+
"""
|
| 159 |
+
Calculate perplexity using PyTorch backend.
|
| 160 |
+
|
| 161 |
+
For vllm backend, this uses a shared-weight HuggingFace model.
|
| 162 |
+
For pt backend, this uses the original model.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
llm_handler: LLM handler with pt or vllm backend
|
| 166 |
+
formatted_prompt: Formatted input prompt (audio codes)
|
| 167 |
+
target_text: Expected output text (CoT metadata + lyrics)
|
| 168 |
+
temperature: Temperature for probability scaling
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
Perplexity value
|
| 172 |
+
"""
|
| 173 |
+
# Get model for scoring (handles both pt and vllm backends)
|
| 174 |
+
model = llm_handler.get_hf_model_for_scoring()
|
| 175 |
+
tokenizer = llm_handler.llm_tokenizer
|
| 176 |
+
device = llm_handler.device if llm_handler.llm_backend == "pt" else next(model.parameters()).device
|
| 177 |
+
|
| 178 |
+
# Tokenize prompt and target separately
|
| 179 |
+
prompt_tokens = tokenizer(
|
| 180 |
+
formatted_prompt,
|
| 181 |
+
return_tensors="pt",
|
| 182 |
+
padding=False,
|
| 183 |
+
truncation=True,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
target_tokens = tokenizer(
|
| 187 |
+
target_text,
|
| 188 |
+
return_tensors="pt",
|
| 189 |
+
padding=False,
|
| 190 |
+
truncation=True,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# Concatenate prompt + target for full sequence
|
| 194 |
+
full_input_ids = torch.cat([
|
| 195 |
+
prompt_tokens['input_ids'],
|
| 196 |
+
target_tokens['input_ids']
|
| 197 |
+
], dim=1).to(device)
|
| 198 |
+
|
| 199 |
+
# Create attention mask
|
| 200 |
+
attention_mask = torch.ones_like(full_input_ids)
|
| 201 |
+
|
| 202 |
+
# Forward pass to get logits
|
| 203 |
+
with torch.no_grad():
|
| 204 |
+
with llm_handler._load_model_context():
|
| 205 |
+
outputs = model(
|
| 206 |
+
input_ids=full_input_ids,
|
| 207 |
+
attention_mask=attention_mask
|
| 208 |
+
)
|
| 209 |
+
logits = outputs.logits # [batch_size, seq_len, vocab_size]
|
| 210 |
+
|
| 211 |
+
# Get the logits for predicting target tokens
|
| 212 |
+
# Shift logits and labels: logits[i] predicts token[i+1]
|
| 213 |
+
prompt_len = prompt_tokens['input_ids'].shape[1]
|
| 214 |
+
target_len = target_tokens['input_ids'].shape[1]
|
| 215 |
+
|
| 216 |
+
# Extract logits for positions that predict target tokens
|
| 217 |
+
# logits at positions [prompt_len-1 : prompt_len+target_len-1] predict target tokens
|
| 218 |
+
pred_logits = logits[0, prompt_len-1:prompt_len+target_len-1, :] # [target_len, vocab_size]
|
| 219 |
+
target_ids = target_tokens['input_ids'][0] # [target_len]
|
| 220 |
+
|
| 221 |
+
# Apply temperature scaling
|
| 222 |
+
if temperature != 1.0:
|
| 223 |
+
pred_logits = pred_logits / temperature
|
| 224 |
+
|
| 225 |
+
# Calculate cross-entropy loss for each position
|
| 226 |
+
log_probs = F.log_softmax(pred_logits, dim=-1) # [target_len, vocab_size]
|
| 227 |
+
|
| 228 |
+
# Gather log probabilities of target tokens
|
| 229 |
+
target_log_probs = log_probs[torch.arange(target_len), target_ids] # [target_len]
|
| 230 |
+
|
| 231 |
+
# Calculate perplexity: exp(-mean(log_probs))
|
| 232 |
+
mean_neg_log_prob = -target_log_probs.mean()
|
| 233 |
+
perplexity = torch.exp(mean_neg_log_prob).item()
|
| 234 |
+
|
| 235 |
+
return perplexity
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def _calculate_perplexity_vllm(
|
| 239 |
+
llm_handler,
|
| 240 |
+
formatted_prompt: str,
|
| 241 |
+
target_text: str,
|
| 242 |
+
temperature: float
|
| 243 |
+
) -> float:
|
| 244 |
+
"""
|
| 245 |
+
Calculate perplexity using vllm backend.
|
| 246 |
+
|
| 247 |
+
Uses shared-weight HuggingFace model for perplexity calculation.
|
| 248 |
+
This avoids the complexity of nanovllm's context management.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
llm_handler: LLM handler with vllm backend
|
| 252 |
+
formatted_prompt: Formatted input prompt (audio codes)
|
| 253 |
+
target_text: Expected output text (CoT metadata + lyrics)
|
| 254 |
+
temperature: Temperature for probability scaling
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
Perplexity value
|
| 258 |
+
"""
|
| 259 |
+
logger.debug("Using vllm backend with shared-weight HuggingFace model for perplexity")
|
| 260 |
+
# Delegate to pt backend implementation which now handles both backends
|
| 261 |
+
return _calculate_perplexity_pt(llm_handler, formatted_prompt, target_text, temperature)
|