ChuxiJ commited on
Commit
a161649
·
1 Parent(s): 4bcd037
acestep/constrained_logits_processor.py CHANGED
@@ -35,6 +35,9 @@ class FSMState(Enum):
35
  DURATION_NAME = auto() # Generating "duration: "
36
  DURATION_VALUE = auto() # Generating numeric value 10-600
37
  NEWLINE_AFTER_DURATION = auto()
 
 
 
38
  KEYSCALE_NAME = auto() # Generating "keyscale: "
39
  KEYSCALE_VALUE = auto() # Generating keyscale pattern
40
  NEWLINE_AFTER_KEYSCALE = auto()
@@ -74,7 +77,8 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
74
  tokenizer: AutoTokenizer,
75
  enabled: bool = True,
76
  debug: bool = False,
77
- **kwargs: Any,
 
78
  ):
79
  """
80
  Initialize the constrained logits processor.
@@ -89,6 +93,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
89
  self.tokenizer = tokenizer
90
  self.enabled = enabled
91
  self.debug = debug
 
92
  self.skip_caption = False # Set to True to skip caption field generation
93
  self.skip_language = False # Set to True to skip language field generation
94
  self.caption: Optional[str] = None # Set via update_caption() before each generation
@@ -103,6 +108,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
103
  "keyscale": None,
104
  "language": None,
105
  "timesignature": None,
 
106
  }
107
 
108
  # Temperature settings for different generation phases (set per-generation)
@@ -143,6 +149,16 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
143
  # Pre-compute token IDs for efficiency
144
  self._precompute_tokens()
145
 
 
 
 
 
 
 
 
 
 
 
146
  self._char_to_tokens: Dict[str, set] = {} # Precomputed char -> token IDs mapping
147
 
148
  # Precompute token mappings once (O(vocab_size), runs once at init)
@@ -186,6 +202,8 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
186
  # Build language prefix tree (similar to keyscale but for language codes)
187
  self.language_prefix_tree = self._build_language_prefix_tree()
188
 
 
 
189
  # Fixed strings for each state
190
  # IMPORTANT: Do NOT include trailing space after colon - tokenizer will handle spacing
191
  # All matching should be done at token level, not string level
@@ -196,6 +214,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
196
  FSMState.BPM_NAME: "bpm:",
197
  FSMState.CAPTION_NAME: "caption:",
198
  FSMState.DURATION_NAME: "duration:",
 
199
  FSMState.KEYSCALE_NAME: "keyscale:",
200
  FSMState.LANGUAGE_NAME: "language:",
201
  FSMState.TIMESIG_NAME: "timesignature:",
@@ -211,17 +230,19 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
211
  even if the field is user-provided (we still need to generate the field name).
212
 
213
  Args:
214
- current_field: Current field name ("bpm", "caption", "duration", "keyscale", "language", "timesignature")
215
 
216
  Returns:
217
  Next FSMState (NAME state of next field), or THINK_END_TAG if no more fields
218
  """
219
  # New field order: bpm -> caption -> duration -> keyscale -> language -> timesignature
220
- field_order = ["bpm", "caption", "duration","keyscale", "language", "timesignature"]
 
221
  field_to_state = {
222
  "bpm": FSMState.BPM_NAME,
223
  "caption": FSMState.CAPTION_NAME,
224
  "duration": FSMState.DURATION_NAME,
 
225
  "keyscale": FSMState.KEYSCALE_NAME,
226
  "language": FSMState.LANGUAGE_NAME,
227
  "timesignature": FSMState.TIMESIG_NAME,
@@ -235,7 +256,10 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
235
  # Find next field in order
236
  for i in range(current_idx + 1, len(field_order)):
237
  field = field_order[i]
238
-
 
 
 
239
  if field == "caption" and self.skip_caption:
240
  continue
241
  if field == "language" and self.skip_language:
@@ -257,7 +281,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
257
  }
258
 
259
  # Build transitions for all fields (even if user-provided, we still need to generate field name)
260
- # Field order: bpm -> caption -> duration -> keyscale -> language -> timesignature
261
 
262
  # BPM field: NAME -> VALUE -> next field (caption or duration)
263
  self.next_state[FSMState.BPM_NAME] = FSMState.BPM_VALUE
@@ -271,6 +295,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
271
  # Duration field: NAME -> VALUE -> next field
272
  self.next_state[FSMState.DURATION_NAME] = FSMState.DURATION_VALUE
273
  self.next_state[FSMState.DURATION_VALUE] = self._get_next_field_state("duration")
 
 
 
 
 
274
 
275
  # Keyscale field: NAME -> VALUE -> next field (language or timesignature)
276
  self.next_state[FSMState.KEYSCALE_NAME] = FSMState.KEYSCALE_VALUE
@@ -284,6 +313,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
284
  # Timesignature field: NAME -> VALUE -> THINK_END_TAG
285
  self.next_state[FSMState.TIMESIG_NAME] = FSMState.TIMESIG_VALUE
286
  self.next_state[FSMState.TIMESIG_VALUE] = FSMState.THINK_END_TAG
 
 
 
 
 
287
 
288
  def set_skip_caption(self, skip: bool):
289
  """Set whether to skip caption generation and rebuild state transitions."""
@@ -366,13 +400,14 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
366
  - "keyscale": Optional[str] - e.g., "G major"
367
  - "language": Optional[str] - e.g., "en"
368
  - "timesignature": Optional[str] - e.g., "4"
 
369
  If None, clears all user-provided metadata.
370
  """
371
  if metadata is None:
372
  metadata = {}
373
 
374
  # Update user-provided metadata
375
- for field in ["bpm", "caption", "duration", "keyscale", "language", "timesignature"]:
376
  if field in metadata:
377
  self.user_provided_metadata[field] = metadata[field]
378
  else:
@@ -437,6 +472,10 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
437
 
438
  # Vocab size
439
  self.vocab_size = len(self.tokenizer)
 
 
 
 
440
 
441
  # EOS token for duration-constrained codes generation
442
  self.eos_token_id = self.tokenizer.eos_token_id
@@ -531,7 +570,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
531
 
532
  if self.debug:
533
  logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens")
534
-
535
  def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
536
  """
537
  Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
@@ -808,6 +847,133 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
808
 
809
  print("=" * 60)
810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
811
  def _precompute_char_token_mapping(self):
812
  """
813
  Precompute mapping from characters to token IDs and token decoded texts.
@@ -859,8 +1025,36 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
859
 
860
  if self.debug:
861
  logger.debug(f"Precomputed char->token mapping for {len(self._char_to_tokens)} unique characters")
862
-
863
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
864
  def _get_trie_node_from_trie(self, trie: Dict, prefix: str) -> Optional[Dict]:
865
  """Get a trie node from a specific trie (helper for caption vs full trie)."""
866
  node = trie
@@ -870,6 +1064,108 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
870
  node = node[char]
871
  return node
872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873
  def reset(self):
874
  """Reset the processor state for a new generation."""
875
  self.state = FSMState.THINK_TAG
@@ -1061,6 +1357,26 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1061
 
1062
  return newline_prob > max_digit_prob
1063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1064
  def _get_allowed_keyscale_tokens(self) -> List[int]:
1065
  """
1066
  Get allowed tokens for keyscale field using the precomputed prefix tree.
@@ -1265,6 +1581,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1265
  "keyscale": "keyscale: ",
1266
  "language": "language: ",
1267
  "timesignature": "timesignature: ",
 
1268
  }
1269
  prefix = field_to_prefix[field_name]
1270
  full_text = f"{prefix}{value}\n"
@@ -1428,9 +1745,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1428
  # Allow free generation (no constraints) so LM can generate field name naturally
1429
  return scores
1430
  else:
1431
- # It's indentation, continue caption
1432
  self.caption_after_newline = False
1433
-
 
 
1434
  # If caption is ending (LM generating next field name), allow free generation
1435
  # and track the field name until we see colon
1436
  if self.caption_ending:
@@ -1505,7 +1824,55 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1505
  mask[0, self.newline_token] = 0
1506
 
1507
  scores = scores + mask
1508
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1509
  elif self.state == FSMState.KEYSCALE_VALUE:
1510
  # Check if field is user-provided and we haven't started injecting yet
1511
  if self.user_provided_metadata["keyscale"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
@@ -1561,7 +1928,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1561
  mask[0, value_tokens[0]] = 0
1562
  scores = scores + mask
1563
  return scores
1564
-
1565
  # If we haven't started generating language yet (empty accumulated_token_ids),
1566
  # select the top-1 probability token from all valid first tokens
1567
  if not self.accumulated_token_ids:
@@ -1780,6 +2147,20 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1780
  if token_str.strip().isdigit():
1781
  self.accumulated_value += token_str.strip()
1782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1783
  elif self.state == FSMState.CAPTION_VALUE:
1784
  # Track token count for 512 limit
1785
  self.caption_token_count += 1
@@ -1787,8 +2168,9 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1787
  # Accumulate caption text
1788
  self.accumulated_value += token_str
1789
 
1790
- # Track if this token is a newline (for transition detection)
1791
- if generated_token_id == self.newline_token:
 
1792
  # Mark that we need to check next token for field transition
1793
  self.caption_after_newline = True
1794
  else:
@@ -1813,6 +2195,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1813
  # Map field name to VALUE state
1814
  field_name_to_value_state = {
1815
  "duration": FSMState.DURATION_VALUE,
 
1816
  "keyscale": FSMState.KEYSCALE_VALUE,
1817
  "language": FSMState.LANGUAGE_VALUE,
1818
  "timesignature": FSMState.TIMESIG_VALUE,
 
35
  DURATION_NAME = auto() # Generating "duration: "
36
  DURATION_VALUE = auto() # Generating numeric value 10-600
37
  NEWLINE_AFTER_DURATION = auto()
38
+ GENRES_NAME = auto() # Generating "genres: "
39
+ GENRES_VALUE = auto() # Generating any non-empty string
40
+ NEWLINE_AFTER_GENRES = auto()
41
  KEYSCALE_NAME = auto() # Generating "keyscale: "
42
  KEYSCALE_VALUE = auto() # Generating keyscale pattern
43
  NEWLINE_AFTER_KEYSCALE = auto()
 
77
  tokenizer: AutoTokenizer,
78
  enabled: bool = True,
79
  debug: bool = False,
80
+ genres_vocab_path: Optional[str] = None,
81
+ skip_genres: bool = True,
82
  ):
83
  """
84
  Initialize the constrained logits processor.
 
93
  self.tokenizer = tokenizer
94
  self.enabled = enabled
95
  self.debug = debug
96
+ self.skip_genres = skip_genres
97
  self.skip_caption = False # Set to True to skip caption field generation
98
  self.skip_language = False # Set to True to skip language field generation
99
  self.caption: Optional[str] = None # Set via update_caption() before each generation
 
108
  "keyscale": None,
109
  "language": None,
110
  "timesignature": None,
111
+ "genres": None,
112
  }
113
 
114
  # Temperature settings for different generation phases (set per-generation)
 
149
  # Pre-compute token IDs for efficiency
150
  self._precompute_tokens()
151
 
152
+ # Genres vocabulary for constrained decoding
153
+ self.genres_vocab_path = genres_vocab_path or os.path.join(
154
+ os.path.dirname(os.path.abspath(__file__)), "genres_vocab.txt"
155
+ )
156
+ self.genres_vocab: List[str] = [] # Full vocab
157
+ self.genres_vocab_mtime: float = 0.0
158
+ self.genres_trie: Dict = {} # Trie for full vocab (fallback)
159
+ self.caption_genres_trie: Dict = {} # Trie for caption-matched genres (priority)
160
+ self.caption_matched_genres: List[str] = [] # Genres matched from caption
161
+
162
  self._char_to_tokens: Dict[str, set] = {} # Precomputed char -> token IDs mapping
163
 
164
  # Precompute token mappings once (O(vocab_size), runs once at init)
 
202
  # Build language prefix tree (similar to keyscale but for language codes)
203
  self.language_prefix_tree = self._build_language_prefix_tree()
204
 
205
+ self._load_genres_vocab()
206
+
207
  # Fixed strings for each state
208
  # IMPORTANT: Do NOT include trailing space after colon - tokenizer will handle spacing
209
  # All matching should be done at token level, not string level
 
214
  FSMState.BPM_NAME: "bpm:",
215
  FSMState.CAPTION_NAME: "caption:",
216
  FSMState.DURATION_NAME: "duration:",
217
+ FSMState.GENRES_NAME: "genres:",
218
  FSMState.KEYSCALE_NAME: "keyscale:",
219
  FSMState.LANGUAGE_NAME: "language:",
220
  FSMState.TIMESIG_NAME: "timesignature:",
 
230
  even if the field is user-provided (we still need to generate the field name).
231
 
232
  Args:
233
+ current_field: Current field name ("bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature")
234
 
235
  Returns:
236
  Next FSMState (NAME state of next field), or THINK_END_TAG if no more fields
237
  """
238
  # New field order: bpm -> caption -> duration -> keyscale -> language -> timesignature
239
+ # genres is optional and can be skipped
240
+ field_order = ["bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature"]
241
  field_to_state = {
242
  "bpm": FSMState.BPM_NAME,
243
  "caption": FSMState.CAPTION_NAME,
244
  "duration": FSMState.DURATION_NAME,
245
+ "genres": FSMState.GENRES_NAME,
246
  "keyscale": FSMState.KEYSCALE_NAME,
247
  "language": FSMState.LANGUAGE_NAME,
248
  "timesignature": FSMState.TIMESIG_NAME,
 
256
  # Find next field in order
257
  for i in range(current_idx + 1, len(field_order)):
258
  field = field_order[i]
259
+
260
+ # Skip fields based on flags
261
+ if field == "genres" and self.skip_genres:
262
+ continue
263
  if field == "caption" and self.skip_caption:
264
  continue
265
  if field == "language" and self.skip_language:
 
281
  }
282
 
283
  # Build transitions for all fields (even if user-provided, we still need to generate field name)
284
+ # Field order: bpm -> caption -> duration -> genres -> keyscale -> language -> timesignature
285
 
286
  # BPM field: NAME -> VALUE -> next field (caption or duration)
287
  self.next_state[FSMState.BPM_NAME] = FSMState.BPM_VALUE
 
295
  # Duration field: NAME -> VALUE -> next field
296
  self.next_state[FSMState.DURATION_NAME] = FSMState.DURATION_VALUE
297
  self.next_state[FSMState.DURATION_VALUE] = self._get_next_field_state("duration")
298
+
299
+ # Genres field (only if not skipped): NAME -> VALUE -> next field
300
+ if not self.skip_genres:
301
+ self.next_state[FSMState.GENRES_NAME] = FSMState.GENRES_VALUE
302
+ self.next_state[FSMState.GENRES_VALUE] = self._get_next_field_state("genres")
303
 
304
  # Keyscale field: NAME -> VALUE -> next field (language or timesignature)
305
  self.next_state[FSMState.KEYSCALE_NAME] = FSMState.KEYSCALE_VALUE
 
313
  # Timesignature field: NAME -> VALUE -> THINK_END_TAG
314
  self.next_state[FSMState.TIMESIG_NAME] = FSMState.TIMESIG_VALUE
315
  self.next_state[FSMState.TIMESIG_VALUE] = FSMState.THINK_END_TAG
316
+
317
+ def set_skip_genres(self, skip: bool):
318
+ """Set whether to skip genres generation and rebuild state transitions."""
319
+ self.skip_genres = skip
320
+ self._build_state_transitions()
321
 
322
  def set_skip_caption(self, skip: bool):
323
  """Set whether to skip caption generation and rebuild state transitions."""
 
400
  - "keyscale": Optional[str] - e.g., "G major"
401
  - "language": Optional[str] - e.g., "en"
402
  - "timesignature": Optional[str] - e.g., "4"
403
+ - "genres": Optional[str] - e.g., "Pop Rock"
404
  If None, clears all user-provided metadata.
405
  """
406
  if metadata is None:
407
  metadata = {}
408
 
409
  # Update user-provided metadata
410
+ for field in ["bpm", "caption", "duration", "keyscale", "language", "timesignature", "genres"]:
411
  if field in metadata:
412
  self.user_provided_metadata[field] = metadata[field]
413
  else:
 
472
 
473
  # Vocab size
474
  self.vocab_size = len(self.tokenizer)
475
+
476
+ # Comma token for multi-genre support
477
+ comma_tokens = self.tokenizer.encode(",", add_special_tokens=False)
478
+ self.comma_token = comma_tokens[-1] if comma_tokens else None
479
 
480
  # EOS token for duration-constrained codes generation
481
  self.eos_token_id = self.tokenizer.eos_token_id
 
570
 
571
  if self.debug:
572
  logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens")
573
+
574
  def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
575
  """
576
  Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
 
847
 
848
  print("=" * 60)
849
 
850
+
851
+ def _load_genres_vocab(self):
852
+ """
853
+ Load genres vocabulary from file. Supports hot reload by checking file mtime.
854
+ File format: one genre per line, lines starting with # are comments.
855
+ """
856
+ if not os.path.exists(self.genres_vocab_path):
857
+ if self.debug:
858
+ logger.debug(f"Genres vocab file not found: {self.genres_vocab_path}")
859
+ return
860
+
861
+ try:
862
+ mtime = os.path.getmtime(self.genres_vocab_path)
863
+ if mtime <= self.genres_vocab_mtime:
864
+ return # File hasn't changed
865
+
866
+ with open(self.genres_vocab_path, 'r', encoding='utf-8') as f:
867
+ genres = []
868
+ for line in f:
869
+ line = line.strip()
870
+ if line and not line.startswith('#'):
871
+ genres.append(line.lower())
872
+
873
+ self.genres_vocab = genres
874
+ self.genres_vocab_mtime = mtime
875
+ self._build_genres_trie()
876
+
877
+ if self.debug:
878
+ logger.debug(f"Loaded {len(self.genres_vocab)} genres from {self.genres_vocab_path}")
879
+ except Exception as e:
880
+ logger.warning(f"Failed to load genres vocab: {e}")
881
+
882
+ def _build_genres_trie(self):
883
+ """
884
+ Build a trie (prefix tree) from genres vocabulary for efficient prefix matching.
885
+ Each node is a dict with:
886
+ - '_end': True if this node represents a complete genre
887
+ - other keys: next characters in the trie
888
+ """
889
+ self.genres_trie = {}
890
+
891
+ for genre in self.genres_vocab:
892
+ node = self.genres_trie
893
+ for char in genre:
894
+ if char not in node:
895
+ node[char] = {}
896
+ node = node[char]
897
+ node['_end'] = True # Mark end of a complete genre
898
+
899
+ if self.debug:
900
+ logger.debug(f"Built genres trie with {len(self.genres_vocab)} entries")
901
+
902
+ def _extract_caption_genres(self, caption: str):
903
+ """
904
+ Extract genres from the user's caption that match entries in the vocabulary.
905
+ This creates a smaller trie for faster and more relevant genre generation.
906
+
907
+ Strategy (optimized - O(words * max_genre_len) instead of O(vocab_size)):
908
+ 1. Extract words/phrases from caption
909
+ 2. For each word, use trie to find all vocab entries that START with this word
910
+ 3. Build a separate trie from matched genres
911
+ """
912
+ if not caption or not self.genres_vocab:
913
+ return
914
+
915
+ caption_lower = caption.lower()
916
+ matched_genres = set()
917
+
918
+ # Extract words from caption (split by common delimiters)
919
+ import re
920
+ words = re.split(r'[,\s\-_/\\|]+', caption_lower)
921
+ words = [w.strip() for w in words if w.strip() and len(w.strip()) >= 2]
922
+
923
+ # For each word, find genres in trie that start with this word
924
+ for word in words:
925
+ # Find all genres starting with this word using trie traversal
926
+ node = self._get_genres_trie_node(word)
927
+ if node is not None:
928
+ # Collect all complete genres under this node
929
+ self._collect_complete_genres(node, word, matched_genres)
930
+
931
+ # Also check if any word appears as a substring in short genres (< 20 chars)
932
+ # This is a quick check for common single-word genres
933
+ genres_set = set(self.genres_vocab)
934
+ for word in words:
935
+ if word in genres_set:
936
+ matched_genres.add(word)
937
+
938
+ if not matched_genres:
939
+ if self.debug:
940
+ logger.debug(f"No genres matched in caption, using full vocab")
941
+ return
942
+
943
+ # Build a trie from matched genres
944
+ self.caption_matched_genres = list(matched_genres)
945
+ self.caption_genres_trie = {}
946
+
947
+ for genre in matched_genres:
948
+ node = self.caption_genres_trie
949
+ for char in genre:
950
+ if char not in node:
951
+ node[char] = {}
952
+ node = node[char]
953
+ node['_end'] = True
954
+
955
+ if self.debug:
956
+ logger.debug(f"Matched {len(matched_genres)} genres from caption: {list(matched_genres)[:5]}...")
957
+
958
+ def _collect_complete_genres(self, node: Dict, prefix: str, result: set, max_depth: int = 50):
959
+ """
960
+ Recursively collect all complete genres under a trie node.
961
+ Limited depth to avoid too many matches.
962
+ """
963
+ if max_depth <= 0:
964
+ return
965
+
966
+ if node.get('_end', False):
967
+ result.add(prefix)
968
+
969
+ # Limit total collected genres to avoid slowdown
970
+ if len(result) >= 100:
971
+ return
972
+
973
+ for char, child_node in node.items():
974
+ if char not in ('_end', '_tokens'):
975
+ self._collect_complete_genres(child_node, prefix + char, result, max_depth - 1)
976
+
977
  def _precompute_char_token_mapping(self):
978
  """
979
  Precompute mapping from characters to token IDs and token decoded texts.
 
1025
 
1026
  if self.debug:
1027
  logger.debug(f"Precomputed char->token mapping for {len(self._char_to_tokens)} unique characters")
1028
+
1029
+ def _try_reload_genres_vocab(self):
1030
+ """Check if genres vocab file has been updated and reload if necessary."""
1031
+ if not os.path.exists(self.genres_vocab_path):
1032
+ return
1033
+
1034
+ try:
1035
+ mtime = os.path.getmtime(self.genres_vocab_path)
1036
+ if mtime > self.genres_vocab_mtime:
1037
+ self._load_genres_vocab()
1038
+ except Exception:
1039
+ pass # Ignore errors during hot reload check
1040
+
1041
+ def _get_genres_trie_node(self, prefix: str) -> Optional[Dict]:
1042
+ """
1043
+ Get the trie node for a given prefix.
1044
+ Returns None if the prefix is not valid (no genres start with this prefix).
1045
+ """
1046
+ node = self.genres_trie
1047
+ for char in prefix.lower():
1048
+ if char not in node:
1049
+ return None
1050
+ node = node[char]
1051
+ return node
1052
+
1053
+ def _is_complete_genre(self, text: str) -> bool:
1054
+ """Check if the given text is a complete genre in the vocabulary."""
1055
+ node = self._get_genres_trie_node(text.strip())
1056
+ return node is not None and node.get('_end', False)
1057
+
1058
  def _get_trie_node_from_trie(self, trie: Dict, prefix: str) -> Optional[Dict]:
1059
  """Get a trie node from a specific trie (helper for caption vs full trie)."""
1060
  node = trie
 
1064
  node = node[char]
1065
  return node
1066
 
1067
+ def _get_allowed_genres_tokens(self) -> List[int]:
1068
+ """
1069
+ Get allowed tokens for genres field based on trie matching.
1070
+
1071
+ The entire genres string (including commas) must match a complete entry in the vocab.
1072
+ For example, if vocab contains "pop, rock, jazz", the generated string must exactly
1073
+ match that entry - we don't treat commas as separators for individual genres.
1074
+
1075
+ Strategy:
1076
+ 1. If caption-matched genres exist, use that smaller trie first (faster + more relevant)
1077
+ 2. If no caption matches or prefix not in caption trie, fallback to full vocab trie
1078
+ 3. Get valid next characters from current trie node
1079
+ 4. For each candidate token, verify the full decoded text forms a valid trie prefix
1080
+ """
1081
+ if not self.genres_vocab:
1082
+ # No vocab loaded, allow all except newline if empty
1083
+ return []
1084
+
1085
+ # Use the full accumulated value (don't split by comma - treat as single entry)
1086
+ accumulated = self.accumulated_value.lower()
1087
+ current_genre_prefix = accumulated.strip()
1088
+
1089
+ # Determine which trie to use: caption-matched (priority) or full vocab (fallback)
1090
+ use_caption_trie = False
1091
+ current_node = None
1092
+
1093
+ # Try caption-matched trie first if available
1094
+ if self.caption_genres_trie:
1095
+ if current_genre_prefix == "":
1096
+ current_node = self.caption_genres_trie
1097
+ use_caption_trie = True
1098
+ else:
1099
+ current_node = self._get_trie_node_from_trie(self.caption_genres_trie, current_genre_prefix)
1100
+ if current_node is not None:
1101
+ use_caption_trie = True
1102
+
1103
+ # Fallback to full vocab trie
1104
+ if current_node is None:
1105
+ if current_genre_prefix == "":
1106
+ current_node = self.genres_trie
1107
+ else:
1108
+ current_node = self._get_genres_trie_node(current_genre_prefix)
1109
+
1110
+ if current_node is None:
1111
+ # Invalid prefix, force newline to end
1112
+ if self.newline_token:
1113
+ return [self.newline_token]
1114
+ return []
1115
+
1116
+ # Get valid next characters from trie node
1117
+ valid_next_chars = set(k for k in current_node.keys() if k not in ('_end', '_tokens'))
1118
+
1119
+ # If current value is a complete genre, allow newline to end
1120
+ is_complete = current_node.get('_end', False)
1121
+
1122
+ if not valid_next_chars:
1123
+ # No more characters to match, only allow newline if complete
1124
+ allowed = set()
1125
+ if is_complete and self.newline_token:
1126
+ allowed.add(self.newline_token)
1127
+ return list(allowed)
1128
+
1129
+ # Collect candidate tokens based on first character
1130
+ candidate_tokens = set()
1131
+ for char in valid_next_chars:
1132
+ if char in self._char_to_tokens:
1133
+ candidate_tokens.update(self._char_to_tokens[char])
1134
+
1135
+ # Select the appropriate trie for validation
1136
+ active_trie = self.caption_genres_trie if use_caption_trie else self.genres_trie
1137
+
1138
+ # Validate each candidate token: check if prefix + decoded_token is a valid trie prefix
1139
+ allowed = set()
1140
+ for token_id in candidate_tokens:
1141
+ # Use precomputed decoded text (already normalized)
1142
+ decoded_normalized = self._token_to_text.get(token_id, "")
1143
+
1144
+ if not decoded_normalized or not decoded_normalized.strip():
1145
+ # Token decodes to empty or only whitespace - allow if space/comma is a valid next char
1146
+ if ' ' in valid_next_chars or ',' in valid_next_chars:
1147
+ allowed.add(token_id)
1148
+ continue
1149
+
1150
+ # Build new prefix by appending decoded token
1151
+ # Handle space-prefixed tokens (e.g., " rock" from "pop rock")
1152
+ if decoded_normalized.startswith(' ') or decoded_normalized.startswith(','):
1153
+ # Token has leading space/comma - append directly
1154
+ new_prefix = current_genre_prefix + decoded_normalized
1155
+ else:
1156
+ new_prefix = current_genre_prefix + decoded_normalized
1157
+
1158
+ # Check if new_prefix is a valid prefix in the active trie
1159
+ new_node = self._get_trie_node_from_trie(active_trie, new_prefix)
1160
+ if new_node is not None:
1161
+ allowed.add(token_id)
1162
+
1163
+ # If current value is a complete genre, also allow newline
1164
+ if is_complete and self.newline_token:
1165
+ allowed.add(self.newline_token)
1166
+
1167
+ return list(allowed)
1168
+
1169
  def reset(self):
1170
  """Reset the processor state for a new generation."""
1171
  self.state = FSMState.THINK_TAG
 
1357
 
1358
  return newline_prob > max_digit_prob
1359
 
1360
+
1361
+ def _should_end_text_field(self, logits: torch.Tensor) -> bool:
1362
+ """
1363
+ Determine if we should end a text field (genres).
1364
+ Returns True if P(newline) > P(any other token) AND we have some content.
1365
+ """
1366
+ if not self.accumulated_value.strip():
1367
+ return False # Need at least some content
1368
+
1369
+ probs = torch.softmax(logits, dim=-1)
1370
+ newline_prob = probs[0, self.newline_token].item() if self.newline_token else 0
1371
+
1372
+ # Get max probability among non-newline tokens
1373
+ masked_probs = probs.clone()
1374
+ if self.newline_token:
1375
+ masked_probs[0, self.newline_token] = 0
1376
+ max_other_prob = masked_probs[0].max().item()
1377
+
1378
+ return newline_prob > max_other_prob
1379
+
1380
  def _get_allowed_keyscale_tokens(self) -> List[int]:
1381
  """
1382
  Get allowed tokens for keyscale field using the precomputed prefix tree.
 
1581
  "keyscale": "keyscale: ",
1582
  "language": "language: ",
1583
  "timesignature": "timesignature: ",
1584
+ "genres": "genres: ",
1585
  }
1586
  prefix = field_to_prefix[field_name]
1587
  full_text = f"{prefix}{value}\n"
 
1745
  # Allow free generation (no constraints) so LM can generate field name naturally
1746
  return scores
1747
  else:
1748
+ # It's indentation, continue caption (don't transition!)
1749
  self.caption_after_newline = False
1750
+ # Continue normal caption generation
1751
+ # Fall through to caption constraints below
1752
+
1753
  # If caption is ending (LM generating next field name), allow free generation
1754
  # and track the field name until we see colon
1755
  if self.caption_ending:
 
1824
  mask[0, self.newline_token] = 0
1825
 
1826
  scores = scores + mask
1827
+
1828
+ elif self.state == FSMState.GENRES_VALUE:
1829
+ # Check if field is user-provided and we haven't started injecting yet
1830
+ if self.user_provided_metadata["genres"] is not None and not self.user_field_token_queue and not self.accumulated_value:
1831
+ # Initialize token queue with field value tokens (value + newline)
1832
+ value = self.user_provided_metadata["genres"]
1833
+ value_text = f" {value}\n"
1834
+ value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
1835
+ if value_tokens:
1836
+ self.user_field_token_queue = value_tokens
1837
+ self.current_user_field = "genres"
1838
+ # Inject first token
1839
+ mask[0, value_tokens[0]] = 0
1840
+ scores = scores + mask
1841
+ return scores
1842
+
1843
+ # Try to hot-reload genres vocab if file has changed
1844
+ self._try_reload_genres_vocab()
1845
+
1846
+ # Get allowed tokens based on genres vocabulary
1847
+ allowed = self._get_allowed_genres_tokens()
1848
+
1849
+ if allowed:
1850
+ # Use vocabulary-constrained decoding
1851
+ for t in allowed:
1852
+ mask[0, t] = 0
1853
+ scores = scores + mask
1854
+ elif self.genres_vocab:
1855
+ # Vocab is loaded but no valid continuation found
1856
+ # Force newline to end the field
1857
+ if self.newline_token:
1858
+ mask[0, self.newline_token] = 0
1859
+ if self.debug:
1860
+ logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
1861
+ scores = scores + mask
1862
+ else:
1863
+ # Fallback: no vocab loaded, use probability-based ending
1864
+ if self._should_end_text_field(scores):
1865
+ if self.newline_token:
1866
+ mask[0, self.newline_token] = 0
1867
+ self._transition_to_next_state()
1868
+ scores = scores + mask
1869
+ else:
1870
+ # Allow any token except newline if we don't have content yet
1871
+ if not self.accumulated_value.strip():
1872
+ if self.newline_token:
1873
+ scores[0, self.newline_token] = float('-inf')
1874
+ # Otherwise, don't constrain (fallback behavior)
1875
+
1876
  elif self.state == FSMState.KEYSCALE_VALUE:
1877
  # Check if field is user-provided and we haven't started injecting yet
1878
  if self.user_provided_metadata["keyscale"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
 
1928
  mask[0, value_tokens[0]] = 0
1929
  scores = scores + mask
1930
  return scores
1931
+
1932
  # If we haven't started generating language yet (empty accumulated_token_ids),
1933
  # select the top-1 probability token from all valid first tokens
1934
  if not self.accumulated_token_ids:
 
2147
  if token_str.strip().isdigit():
2148
  self.accumulated_value += token_str.strip()
2149
 
2150
+ elif self.state == FSMState.GENRES_VALUE:
2151
+ if generated_token_id == self.newline_token:
2152
+ # Newline ends the field
2153
+ self._transition_to_next_state()
2154
+ # IMPORTANT: After state transition, if new state is a fixed_strings state,
2155
+ # we should NOT update position_in_state with the newline token length,
2156
+ # because that token belongs to the old state, not the new state.
2157
+ # Return early to avoid the fixed_strings update logic below.
2158
+ if self.state in self.fixed_strings:
2159
+ return
2160
+ else:
2161
+ # Genres still uses string-based trie, so keep accumulated_value
2162
+ self.accumulated_value += token_str
2163
+
2164
  elif self.state == FSMState.CAPTION_VALUE:
2165
  # Track token count for 512 limit
2166
  self.caption_token_count += 1
 
2168
  # Accumulate caption text
2169
  self.accumulated_value += token_str
2170
 
2171
+ # Track if this token contains a newline (for transition detection)
2172
+ # Token may be '\n' alone or combined with other chars like '.\n'
2173
+ if '\n' in token_str:
2174
  # Mark that we need to check next token for field transition
2175
  self.caption_after_newline = True
2176
  else:
 
2195
  # Map field name to VALUE state
2196
  field_name_to_value_state = {
2197
  "duration": FSMState.DURATION_VALUE,
2198
+ "genres": FSMState.GENRES_VALUE,
2199
  "keyscale": FSMState.KEYSCALE_VALUE,
2200
  "language": FSMState.LANGUAGE_VALUE,
2201
  "timesignature": FSMState.TIMESIG_VALUE,
acestep/gradio_ui.py CHANGED
@@ -607,6 +607,12 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
607
  info="Generate language in CoT (chain-of-thought)",
608
  scale=1,
609
  )
 
 
 
 
 
 
610
 
611
  with gr.Row():
612
  audio_cover_strength = gr.Slider(
@@ -618,11 +624,21 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
618
  info="Control how many denoising steps use LM-generated codes",
619
  scale=1,
620
  )
 
 
 
 
 
 
 
 
 
621
  output_alignment_preference = gr.Checkbox(
622
  label="Output Attention Focus Score (disabled)",
623
  value=False,
624
  info="Output attention focus score analysis",
625
  interactive=False,
 
626
  scale=1,
627
  )
628
 
@@ -632,10 +648,14 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
632
  think_checkbox = gr.Checkbox(
633
  label="Think",
634
  value=True,
635
- info="Enable llm generate hints",
636
  scale=1,
637
  )
638
  generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg", interactive=generate_btn_interactive, scale=10)
 
 
 
 
 
639
 
640
  return {
641
  "service_config_accordion": service_config_accordion,
@@ -695,6 +715,9 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
695
  "output_alignment_preference": output_alignment_preference,
696
  "think_checkbox": think_checkbox,
697
  "generate_btn": generate_btn,
 
 
 
698
  }
699
 
700
 
@@ -720,7 +743,7 @@ def create_results_section(dit_handler) -> dict:
720
  )
721
  with gr.Row(equal_height=True):
722
  send_to_src_btn_1 = gr.Button(
723
- "Send To Src Audio",
724
  variant="secondary",
725
  size="sm",
726
  scale=1
@@ -731,6 +754,17 @@ def create_results_section(dit_handler) -> dict:
731
  size="sm",
732
  scale=1
733
  )
 
 
 
 
 
 
 
 
 
 
 
734
  with gr.Column():
735
  generated_audio_2 = gr.Audio(
736
  label="🎵 Generated Music (Sample 2)",
@@ -739,7 +773,7 @@ def create_results_section(dit_handler) -> dict:
739
  )
740
  with gr.Row(equal_height=True):
741
  send_to_src_btn_2 = gr.Button(
742
- "Send To Src Audio",
743
  variant="secondary",
744
  size="sm",
745
  scale=1
@@ -750,6 +784,17 @@ def create_results_section(dit_handler) -> dict:
750
  size="sm",
751
  scale=1
752
  )
 
 
 
 
 
 
 
 
 
 
 
753
 
754
  with gr.Accordion("📁 Batch Results & Generation Details", open=False):
755
  generated_audio_batch = gr.File(
@@ -780,6 +825,10 @@ def create_results_section(dit_handler) -> dict:
780
  "send_to_src_btn_2": send_to_src_btn_2,
781
  "save_btn_1": save_btn_1,
782
  "save_btn_2": save_btn_2,
 
 
 
 
783
  "generated_audio_batch": generated_audio_batch,
784
  "generation_info": generation_info,
785
  "align_score_1": align_score_1,
@@ -1042,11 +1091,12 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1042
  gr.Warning(f"Error loading example: {str(e)}")
1043
  return "", "", True, None, None, "", "", ""
1044
 
1045
- def sample_example_smart(task_type: str):
1046
  """Smart sample function that uses LM if initialized, otherwise falls back to examples
1047
 
1048
  Args:
1049
  task_type: The task type (e.g., "text2music")
 
1050
 
1051
  Returns:
1052
  Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
@@ -1060,6 +1110,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1060
  audio_codes="NO USER INPUT",
1061
  use_constrained_decoding=True,
1062
  temperature=0.85,
 
1063
  )
1064
 
1065
  if metadata:
@@ -1094,7 +1145,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1094
  if timesignature_value in [None, "N/A"]:
1095
  timesignature_value = ''
1096
 
1097
- gr.Info("🤖 Generated example using LM (Language Model)")
1098
  return caption_value, lyrics_value, think_value, bpm_value, duration_value, keyscale_value, language_value, timesignature_value
1099
  else:
1100
  gr.Warning("Failed to generate example using LM, falling back to examples directory")
@@ -1285,6 +1336,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1285
  use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
1286
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
1287
  use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
 
1288
  progress=gr.Progress(track_tqdm=True)
1289
  ):
1290
  # If think is enabled (llm_dit mode) and use_cot_metas is True, generate audio codes using LM first
@@ -1342,6 +1394,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1342
  use_cot_caption=use_cot_caption,
1343
  use_cot_language=use_cot_language,
1344
  is_format_caption=is_format_caption,
 
1345
  )
1346
 
1347
  # Store LM-generated metadata and audio codes for display
@@ -1471,7 +1524,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1471
  generation_section["use_cot_metas"],
1472
  generation_section["use_cot_caption"],
1473
  generation_section["use_cot_language"],
1474
- results_section["is_format_caption_state"]
 
1475
  ],
1476
  outputs=[
1477
  results_section["generated_audio_1"],
@@ -1720,15 +1774,18 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1720
 
1721
  # Sample button - smart sample (uses LM if initialized, otherwise examples)
1722
  # Need to add is_format_caption return value to sample_example_smart
1723
- def sample_example_smart_with_flag(task_type: str):
1724
  """Wrapper for sample_example_smart that adds is_format_caption flag"""
1725
- result = sample_example_smart(task_type)
1726
  # Add True at the end to set is_format_caption
1727
  return result + (True,)
1728
 
1729
  generation_section["sample_btn"].click(
1730
  fn=sample_example_smart_with_flag,
1731
- inputs=[generation_section["task_type"]],
 
 
 
1732
  outputs=[
1733
  generation_section["captions"],
1734
  generation_section["lyrics"],
@@ -1743,13 +1800,14 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1743
  )
1744
 
1745
  # Transcribe audio codes to metadata (or generate example if empty)
1746
- def transcribe_audio_codes(audio_code_string):
1747
  """
1748
  Transcribe audio codes to metadata using LLM understanding.
1749
  If audio_code_string is empty, generate a sample example instead.
1750
 
1751
  Args:
1752
  audio_code_string: String containing audio codes (or empty for example generation)
 
1753
 
1754
  Returns:
1755
  Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature)
@@ -1763,7 +1821,11 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1763
  audio_code_string = "NO USER INPUT"
1764
 
1765
  # Call LLM understanding
1766
- metadata, status = llm_handler.understand_audio_from_codes(audio_codes=audio_code_string, use_constrained_decoding=True)
 
 
 
 
1767
 
1768
  # Extract fields for UI update
1769
  caption = metadata.get('caption', '')
@@ -1818,7 +1880,10 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1818
 
1819
  generation_section["transcribe_btn"].click(
1820
  fn=transcribe_audio_codes,
1821
- inputs=[generation_section["text2music_audio_code_string"]],
 
 
 
1822
  outputs=[
1823
  results_section["status_output"], # Show status
1824
  generation_section["captions"], # Update caption field
@@ -1899,9 +1964,9 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1899
  outputs=[generation_section["audio_uploads_accordion"]]
1900
  )
1901
 
1902
- # Save metadata handlers
1903
  results_section["save_btn_1"].click(
1904
- fn=save_metadata,
1905
  inputs=[
1906
  generation_section["task_type"],
1907
  generation_section["captions"],
@@ -1936,11 +2001,77 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1936
  generation_section["complete_track_classes"],
1937
  results_section["lm_metadata_state"],
1938
  ],
1939
- outputs=[]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1940
  )
1941
 
1942
  results_section["save_btn_2"].click(
1943
- fn=save_metadata,
1944
  inputs=[
1945
  generation_section["task_type"],
1946
  generation_section["captions"],
@@ -1975,7 +2106,73 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1975
  generation_section["complete_track_classes"],
1976
  results_section["lm_metadata_state"],
1977
  ],
1978
- outputs=[]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1979
  )
1980
 
1981
  # Load metadata handler - triggered when file is uploaded via UploadButton
@@ -2017,4 +2214,152 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
2017
  results_section["is_format_caption_state"]
2018
  ]
2019
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2020
 
 
607
  info="Generate language in CoT (chain-of-thought)",
608
  scale=1,
609
  )
610
+ constrained_decoding_debug = gr.Checkbox(
611
+ label="Constrained Decoding Debug",
612
+ value=False,
613
+ info="Enable debug logging for constrained decoding (check to see detailed logs)",
614
+ scale=1,
615
+ )
616
 
617
  with gr.Row():
618
  audio_cover_strength = gr.Slider(
 
624
  info="Control how many denoising steps use LM-generated codes",
625
  scale=1,
626
  )
627
+ score_scale = gr.Slider(
628
+ minimum=1.0,
629
+ maximum=200.0,
630
+ value=10.0,
631
+ step=1.0,
632
+ label="Quality Score Sensitivity",
633
+ info="Lower = more sensitive to quality differences (default: 10.0)",
634
+ scale=1,
635
+ )
636
  output_alignment_preference = gr.Checkbox(
637
  label="Output Attention Focus Score (disabled)",
638
  value=False,
639
  info="Output attention focus score analysis",
640
  interactive=False,
641
+ visible=False,
642
  scale=1,
643
  )
644
 
 
648
  think_checkbox = gr.Checkbox(
649
  label="Think",
650
  value=True,
 
651
  scale=1,
652
  )
653
  generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg", interactive=generate_btn_interactive, scale=10)
654
+ instrumental_checkbox = gr.Checkbox(
655
+ label="Instrumental",
656
+ value=False,
657
+ scale=1,
658
+ )
659
 
660
  return {
661
  "service_config_accordion": service_config_accordion,
 
715
  "output_alignment_preference": output_alignment_preference,
716
  "think_checkbox": think_checkbox,
717
  "generate_btn": generate_btn,
718
+ "instrumental_checkbox": instrumental_checkbox,
719
+ "constrained_decoding_debug": constrained_decoding_debug,
720
+ "score_scale": score_scale,
721
  }
722
 
723
 
 
743
  )
744
  with gr.Row(equal_height=True):
745
  send_to_src_btn_1 = gr.Button(
746
+ "🔗 Send To Src Audio",
747
  variant="secondary",
748
  size="sm",
749
  scale=1
 
754
  size="sm",
755
  scale=1
756
  )
757
+ score_btn_1 = gr.Button(
758
+ "📊 Score",
759
+ variant="secondary",
760
+ size="sm",
761
+ scale=1
762
+ )
763
+ score_display_1 = gr.Textbox(
764
+ label="Quality Score (Sample 1)",
765
+ interactive=False,
766
+ placeholder="Click 'Score' to calculate perplexity-based quality score"
767
+ )
768
  with gr.Column():
769
  generated_audio_2 = gr.Audio(
770
  label="🎵 Generated Music (Sample 2)",
 
773
  )
774
  with gr.Row(equal_height=True):
775
  send_to_src_btn_2 = gr.Button(
776
+ "🔗 Send To Src Audio",
777
  variant="secondary",
778
  size="sm",
779
  scale=1
 
784
  size="sm",
785
  scale=1
786
  )
787
+ score_btn_2 = gr.Button(
788
+ "📊 Score",
789
+ variant="secondary",
790
+ size="sm",
791
+ scale=1
792
+ )
793
+ score_display_2 = gr.Textbox(
794
+ label="Quality Score (Sample 2)",
795
+ interactive=False,
796
+ placeholder="Click 'Score' to calculate perplexity-based quality score"
797
+ )
798
 
799
  with gr.Accordion("📁 Batch Results & Generation Details", open=False):
800
  generated_audio_batch = gr.File(
 
825
  "send_to_src_btn_2": send_to_src_btn_2,
826
  "save_btn_1": save_btn_1,
827
  "save_btn_2": save_btn_2,
828
+ "score_btn_1": score_btn_1,
829
+ "score_btn_2": score_btn_2,
830
+ "score_display_1": score_display_1,
831
+ "score_display_2": score_display_2,
832
  "generated_audio_batch": generated_audio_batch,
833
  "generation_info": generation_info,
834
  "align_score_1": align_score_1,
 
1091
  gr.Warning(f"Error loading example: {str(e)}")
1092
  return "", "", True, None, None, "", "", ""
1093
 
1094
+ def sample_example_smart(task_type: str, constrained_decoding_debug: bool = False):
1095
  """Smart sample function that uses LM if initialized, otherwise falls back to examples
1096
 
1097
  Args:
1098
  task_type: The task type (e.g., "text2music")
1099
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
1100
 
1101
  Returns:
1102
  Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
 
1110
  audio_codes="NO USER INPUT",
1111
  use_constrained_decoding=True,
1112
  temperature=0.85,
1113
+ constrained_decoding_debug=constrained_decoding_debug,
1114
  )
1115
 
1116
  if metadata:
 
1145
  if timesignature_value in [None, "N/A"]:
1146
  timesignature_value = ''
1147
 
1148
+ gr.Info("🤖 Generated example using LM")
1149
  return caption_value, lyrics_value, think_value, bpm_value, duration_value, keyscale_value, language_value, timesignature_value
1150
  else:
1151
  gr.Warning("Failed to generate example using LM, falling back to examples directory")
 
1336
  use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
1337
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
1338
  use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
1339
+ constrained_decoding_debug,
1340
  progress=gr.Progress(track_tqdm=True)
1341
  ):
1342
  # If think is enabled (llm_dit mode) and use_cot_metas is True, generate audio codes using LM first
 
1394
  use_cot_caption=use_cot_caption,
1395
  use_cot_language=use_cot_language,
1396
  is_format_caption=is_format_caption,
1397
+ constrained_decoding_debug=constrained_decoding_debug,
1398
  )
1399
 
1400
  # Store LM-generated metadata and audio codes for display
 
1524
  generation_section["use_cot_metas"],
1525
  generation_section["use_cot_caption"],
1526
  generation_section["use_cot_language"],
1527
+ results_section["is_format_caption_state"],
1528
+ generation_section["constrained_decoding_debug"]
1529
  ],
1530
  outputs=[
1531
  results_section["generated_audio_1"],
 
1774
 
1775
  # Sample button - smart sample (uses LM if initialized, otherwise examples)
1776
  # Need to add is_format_caption return value to sample_example_smart
1777
+ def sample_example_smart_with_flag(task_type: str, constrained_decoding_debug: bool):
1778
  """Wrapper for sample_example_smart that adds is_format_caption flag"""
1779
+ result = sample_example_smart(task_type, constrained_decoding_debug)
1780
  # Add True at the end to set is_format_caption
1781
  return result + (True,)
1782
 
1783
  generation_section["sample_btn"].click(
1784
  fn=sample_example_smart_with_flag,
1785
+ inputs=[
1786
+ generation_section["task_type"],
1787
+ generation_section["constrained_decoding_debug"]
1788
+ ],
1789
  outputs=[
1790
  generation_section["captions"],
1791
  generation_section["lyrics"],
 
1800
  )
1801
 
1802
  # Transcribe audio codes to metadata (or generate example if empty)
1803
+ def transcribe_audio_codes(audio_code_string, constrained_decoding_debug):
1804
  """
1805
  Transcribe audio codes to metadata using LLM understanding.
1806
  If audio_code_string is empty, generate a sample example instead.
1807
 
1808
  Args:
1809
  audio_code_string: String containing audio codes (or empty for example generation)
1810
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
1811
 
1812
  Returns:
1813
  Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature)
 
1821
  audio_code_string = "NO USER INPUT"
1822
 
1823
  # Call LLM understanding
1824
+ metadata, status = llm_handler.understand_audio_from_codes(
1825
+ audio_codes=audio_code_string,
1826
+ use_constrained_decoding=True,
1827
+ constrained_decoding_debug=constrained_decoding_debug,
1828
+ )
1829
 
1830
  # Extract fields for UI update
1831
  caption = metadata.get('caption', '')
 
1880
 
1881
  generation_section["transcribe_btn"].click(
1882
  fn=transcribe_audio_codes,
1883
+ inputs=[
1884
+ generation_section["text2music_audio_code_string"],
1885
+ generation_section["constrained_decoding_debug"]
1886
+ ],
1887
  outputs=[
1888
  results_section["status_output"], # Show status
1889
  generation_section["captions"], # Update caption field
 
1964
  outputs=[generation_section["audio_uploads_accordion"]]
1965
  )
1966
 
1967
+ # Save metadata handlers - use JavaScript to trigger automatic download
1968
  results_section["save_btn_1"].click(
1969
+ fn=None,
1970
  inputs=[
1971
  generation_section["task_type"],
1972
  generation_section["captions"],
 
2001
  generation_section["complete_track_classes"],
2002
  results_section["lm_metadata_state"],
2003
  ],
2004
+ outputs=None,
2005
+ js="""
2006
+ (task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature, audio_duration,
2007
+ batch_size_input, inference_steps, guidance_scale, seed, random_seed_checkbox,
2008
+ use_adg, cfg_interval_start, cfg_interval_end, audio_format,
2009
+ lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
2010
+ use_cot_caption, use_cot_language, audio_cover_strength,
2011
+ think_checkbox, text2music_audio_code_string, repainting_start, repainting_end,
2012
+ track_name, complete_track_classes, lm_metadata) => {
2013
+ // Create metadata object
2014
+ const metadata = {
2015
+ saved_at: new Date().toISOString(),
2016
+ task_type: task_type,
2017
+ caption: captions || "",
2018
+ lyrics: lyrics || "",
2019
+ vocal_language: vocal_language,
2020
+ bpm: bpm,
2021
+ keyscale: key_scale || "",
2022
+ timesignature: time_signature || "",
2023
+ duration: audio_duration,
2024
+ batch_size: batch_size_input,
2025
+ inference_steps: inference_steps,
2026
+ guidance_scale: guidance_scale,
2027
+ seed: seed,
2028
+ random_seed: random_seed_checkbox,
2029
+ use_adg: use_adg,
2030
+ cfg_interval_start: cfg_interval_start,
2031
+ cfg_interval_end: cfg_interval_end,
2032
+ audio_format: audio_format,
2033
+ lm_temperature: lm_temperature,
2034
+ lm_cfg_scale: lm_cfg_scale,
2035
+ lm_top_k: lm_top_k,
2036
+ lm_top_p: lm_top_p,
2037
+ lm_negative_prompt: lm_negative_prompt,
2038
+ use_cot_caption: use_cot_caption,
2039
+ use_cot_language: use_cot_language,
2040
+ audio_cover_strength: audio_cover_strength,
2041
+ think: think_checkbox,
2042
+ audio_codes: text2music_audio_code_string || "",
2043
+ repainting_start: repainting_start,
2044
+ repainting_end: repainting_end,
2045
+ track_name: track_name,
2046
+ complete_track_classes: complete_track_classes || []
2047
+ };
2048
+
2049
+ if (lm_metadata) {
2050
+ metadata.lm_generated_metadata = lm_metadata;
2051
+ }
2052
+
2053
+ // Create JSON string
2054
+ const jsonStr = JSON.stringify(metadata, null, 2);
2055
+
2056
+ // Create blob and download
2057
+ const blob = new Blob([jsonStr], { type: 'application/json' });
2058
+ const url = URL.createObjectURL(blob);
2059
+ const a = document.createElement('a');
2060
+ a.href = url;
2061
+ const timestamp = new Date().toISOString().replace(/[-:]/g, '').replace('T', '_').split('.')[0];
2062
+ a.download = `generation_params_${timestamp}.json`;
2063
+ document.body.appendChild(a);
2064
+ a.click();
2065
+ document.body.removeChild(a);
2066
+ URL.revokeObjectURL(url);
2067
+
2068
+ return Array(32).fill(null);
2069
+ }
2070
+ """
2071
  )
2072
 
2073
  results_section["save_btn_2"].click(
2074
+ fn=None,
2075
  inputs=[
2076
  generation_section["task_type"],
2077
  generation_section["captions"],
 
2106
  generation_section["complete_track_classes"],
2107
  results_section["lm_metadata_state"],
2108
  ],
2109
+ outputs=None,
2110
+ js="""
2111
+ (task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature, audio_duration,
2112
+ batch_size_input, inference_steps, guidance_scale, seed, random_seed_checkbox,
2113
+ use_adg, cfg_interval_start, cfg_interval_end, audio_format,
2114
+ lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
2115
+ use_cot_caption, use_cot_language, audio_cover_strength,
2116
+ think_checkbox, text2music_audio_code_string, repainting_start, repainting_end,
2117
+ track_name, complete_track_classes, lm_metadata) => {
2118
+ // Create metadata object
2119
+ const metadata = {
2120
+ saved_at: new Date().toISOString(),
2121
+ task_type: task_type,
2122
+ caption: captions || "",
2123
+ lyrics: lyrics || "",
2124
+ vocal_language: vocal_language,
2125
+ bpm: bpm,
2126
+ keyscale: key_scale || "",
2127
+ timesignature: time_signature || "",
2128
+ duration: audio_duration,
2129
+ batch_size: batch_size_input,
2130
+ inference_steps: inference_steps,
2131
+ guidance_scale: guidance_scale,
2132
+ seed: seed,
2133
+ random_seed: random_seed_checkbox,
2134
+ use_adg: use_adg,
2135
+ cfg_interval_start: cfg_interval_start,
2136
+ cfg_interval_end: cfg_interval_end,
2137
+ audio_format: audio_format,
2138
+ lm_temperature: lm_temperature,
2139
+ lm_cfg_scale: lm_cfg_scale,
2140
+ lm_top_k: lm_top_k,
2141
+ lm_top_p: lm_top_p,
2142
+ lm_negative_prompt: lm_negative_prompt,
2143
+ use_cot_caption: use_cot_caption,
2144
+ use_cot_language: use_cot_language,
2145
+ audio_cover_strength: audio_cover_strength,
2146
+ think: think_checkbox,
2147
+ audio_codes: text2music_audio_code_string || "",
2148
+ repainting_start: repainting_start,
2149
+ repainting_end: repainting_end,
2150
+ track_name: track_name,
2151
+ complete_track_classes: complete_track_classes || []
2152
+ };
2153
+
2154
+ if (lm_metadata) {
2155
+ metadata.lm_generated_metadata = lm_metadata;
2156
+ }
2157
+
2158
+ // Create JSON string
2159
+ const jsonStr = JSON.stringify(metadata, null, 2);
2160
+
2161
+ // Create blob and download
2162
+ const blob = new Blob([jsonStr], { type: 'application/json' });
2163
+ const url = URL.createObjectURL(blob);
2164
+ const a = document.createElement('a');
2165
+ a.href = url;
2166
+ const timestamp = new Date().toISOString().replace(/[-:]/g, '').replace('T', '_').split('.')[0];
2167
+ a.download = `generation_params_${timestamp}.json`;
2168
+ document.body.appendChild(a);
2169
+ a.click();
2170
+ document.body.removeChild(a);
2171
+ URL.revokeObjectURL(url);
2172
+
2173
+ return Array(32).fill(null);
2174
+ }
2175
+ """
2176
  )
2177
 
2178
  # Load metadata handler - triggered when file is uploaded via UploadButton
 
2214
  results_section["is_format_caption_state"]
2215
  ]
2216
  )
2217
+
2218
+ # Instrumental checkbox handler - auto-fill [Instrumental] when checked
2219
+ def handle_instrumental_checkbox(instrumental_checked, current_lyrics):
2220
+ """
2221
+ Handle instrumental checkbox changes.
2222
+ When checked: if no lyrics, fill with [Instrumental]
2223
+ When unchecked: if lyrics is [Instrumental], clear it
2224
+ """
2225
+ if instrumental_checked:
2226
+ # If checked and no lyrics, fill with [Instrumental]
2227
+ if not current_lyrics or not current_lyrics.strip():
2228
+ return "[Instrumental]"
2229
+ else:
2230
+ # Has lyrics, don't change
2231
+ return current_lyrics
2232
+ else:
2233
+ # If unchecked and lyrics is exactly [Instrumental], clear it
2234
+ if current_lyrics and current_lyrics.strip() == "[Instrumental]":
2235
+ return ""
2236
+ else:
2237
+ # Has other lyrics, don't change
2238
+ return current_lyrics
2239
+
2240
+ generation_section["instrumental_checkbox"].change(
2241
+ fn=handle_instrumental_checkbox,
2242
+ inputs=[generation_section["instrumental_checkbox"], generation_section["lyrics"]],
2243
+ outputs=[generation_section["lyrics"]]
2244
+ )
2245
+
2246
+ # Score calculation handlers
2247
+ def calculate_score_handler(audio_codes_str, caption, lyrics, lm_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale):
2248
+ """
2249
+ Calculate perplexity-based quality score for generated audio.
2250
+
2251
+ Args:
2252
+ audio_codes_str: Generated audio codes string
2253
+ caption: Caption text used for generation
2254
+ lyrics: Lyrics text used for generation
2255
+ lm_metadata: LM-generated metadata dictionary (from CoT generation)
2256
+ bpm: BPM value
2257
+ key_scale: Key scale value
2258
+ time_signature: Time signature value
2259
+ audio_duration: Audio duration value
2260
+ vocal_language: Vocal language value
2261
+ score_scale: Sensitivity scale parameter (lower = more sensitive)
2262
+
2263
+ Returns:
2264
+ Score display string
2265
+ """
2266
+ from acestep.test_time_scaling import calculate_perplexity, perplexity_to_score
2267
+
2268
+ if not llm_handler.llm_initialized:
2269
+ return "❌ LLM not initialized. Please initialize 5Hz LM first."
2270
+
2271
+ if not audio_codes_str or not audio_codes_str.strip():
2272
+ return "❌ No audio codes available. Please generate music first."
2273
+
2274
+ try:
2275
+ # Build metadata dictionary from both LM metadata and user inputs
2276
+ metadata = {}
2277
+
2278
+ # Priority 1: Use LM-generated metadata if available
2279
+ if lm_metadata and isinstance(lm_metadata, dict):
2280
+ metadata.update(lm_metadata)
2281
+
2282
+ # Priority 2: Add user-provided metadata (if not already in LM metadata)
2283
+ if bpm is not None and 'bpm' not in metadata:
2284
+ try:
2285
+ metadata['bpm'] = int(bpm)
2286
+ except:
2287
+ pass
2288
+
2289
+ if caption and 'caption' not in metadata:
2290
+ metadata['caption'] = caption
2291
+
2292
+ if audio_duration is not None and audio_duration > 0 and 'duration' not in metadata:
2293
+ try:
2294
+ metadata['duration'] = int(audio_duration)
2295
+ except:
2296
+ pass
2297
+
2298
+ if key_scale and key_scale.strip() and 'keyscale' not in metadata:
2299
+ metadata['keyscale'] = key_scale.strip()
2300
+
2301
+ if vocal_language and vocal_language.strip() and 'language' not in metadata:
2302
+ metadata['language'] = vocal_language.strip()
2303
+
2304
+ if time_signature and time_signature.strip() and 'timesignature' not in metadata:
2305
+ metadata['timesignature'] = time_signature.strip()
2306
+
2307
+ # Calculate perplexity
2308
+ perplexity, status = calculate_perplexity(
2309
+ llm_handler=llm_handler,
2310
+ audio_codes=audio_codes_str,
2311
+ caption=caption or "",
2312
+ lyrics=lyrics or "",
2313
+ metadata=metadata if metadata else None,
2314
+ temperature=1.0
2315
+ )
2316
+
2317
+ # Convert perplexity to normalized score [0, 1] (higher is better)
2318
+ normalized_score = perplexity_to_score(perplexity, scale=score_scale)
2319
+
2320
+ # Format display string
2321
+ if perplexity == float('inf'):
2322
+ return f"❌ Scoring failed: {status}"
2323
+ else:
2324
+ return f"✅ Quality Score: {normalized_score:.4f} (range: 0-1, higher is better)\nPerplexity: {perplexity:.4f}\nSensitivity: {score_scale}\n{status}"
2325
+
2326
+ except Exception as e:
2327
+ import traceback
2328
+ error_msg = f"❌ Error calculating score: {str(e)}\n{traceback.format_exc()}"
2329
+ return error_msg
2330
+
2331
+ # Connect score buttons to handlers
2332
+ results_section["score_btn_1"].click(
2333
+ fn=calculate_score_handler,
2334
+ inputs=[
2335
+ generation_section["text2music_audio_code_string"],
2336
+ generation_section["captions"],
2337
+ generation_section["lyrics"],
2338
+ results_section["lm_metadata_state"],
2339
+ generation_section["bpm"],
2340
+ generation_section["key_scale"],
2341
+ generation_section["time_signature"],
2342
+ generation_section["audio_duration"],
2343
+ generation_section["vocal_language"],
2344
+ generation_section["score_scale"]
2345
+ ],
2346
+ outputs=[results_section["score_display_1"]]
2347
+ )
2348
+
2349
+ results_section["score_btn_2"].click(
2350
+ fn=calculate_score_handler,
2351
+ inputs=[
2352
+ generation_section["text2music_audio_code_string"],
2353
+ generation_section["captions"],
2354
+ generation_section["lyrics"],
2355
+ results_section["lm_metadata_state"],
2356
+ generation_section["bpm"],
2357
+ generation_section["key_scale"],
2358
+ generation_section["time_signature"],
2359
+ generation_section["audio_duration"],
2360
+ generation_section["vocal_language"],
2361
+ generation_section["score_scale"]
2362
+ ],
2363
+ outputs=[results_section["score_display_2"]]
2364
+ )
2365
 
acestep/llm_inference.py CHANGED
@@ -39,6 +39,9 @@ class LLMHandler:
39
 
40
  # Shared constrained decoding processor (initialized once when LLM is loaded)
41
  self.constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None
 
 
 
42
 
43
  def get_available_5hz_lm_models(self) -> List[str]:
44
  """Scan and return all model directory names starting with 'acestep-5Hz-lm-'"""
@@ -246,6 +249,7 @@ class LLMHandler:
246
  target_duration: Optional[float] = None,
247
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
248
  stop_at_reasoning: bool = False,
 
249
  skip_caption: bool = False,
250
  skip_language: bool = False,
251
  generation_phase: str = "cot",
@@ -276,6 +280,7 @@ class LLMHandler:
276
  self.constrained_processor.set_user_metadata(user_metadata)
277
  self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
278
  # Set skip_caption and skip_language based on flags
 
279
  self.constrained_processor.set_skip_caption(skip_caption)
280
  self.constrained_processor.set_skip_language(skip_language)
281
  # Set generation phase for phase-aware processing
@@ -347,6 +352,7 @@ class LLMHandler:
347
  target_duration: Optional[float] = None,
348
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
349
  stop_at_reasoning: bool = False,
 
350
  skip_caption: bool = False,
351
  skip_language: bool = False,
352
  generation_phase: str = "cot",
@@ -376,6 +382,7 @@ class LLMHandler:
376
  self.constrained_processor.set_user_metadata(user_metadata)
377
  self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
378
  # Set skip_caption and skip_language based on flags
 
379
  self.constrained_processor.set_skip_caption(skip_caption)
380
  self.constrained_processor.set_skip_language(skip_language)
381
  # Set generation phase for phase-aware processing
@@ -597,6 +604,7 @@ class LLMHandler:
597
  "user_metadata": user_metadata,
598
  "skip_caption": not use_cot_caption,
599
  "skip_language": not use_cot_language,
 
600
  "generation_phase": "cot",
601
  # Pass context for building unconditional prompt in CoT phase
602
  "caption": caption,
@@ -863,7 +871,6 @@ class LLMHandler:
863
  - bpm: int or str
864
  - caption: str
865
  - duration: int or str
866
- - genres: str
867
  - keyscale: str
868
  - language: str
869
  - timesignature: str
@@ -901,6 +908,7 @@ class LLMHandler:
901
  "user_metadata": None, # No user metadata injection
902
  "skip_caption": False, # Generate caption
903
  "skip_language": False, # Generate language
 
904
  "generation_phase": "understand", # Understanding phase: generate CoT metadata, then free-form lyrics
905
  # Context for building unconditional prompt
906
  "caption": "",
@@ -1015,6 +1023,7 @@ class LLMHandler:
1015
  user_metadata = cfg.get("user_metadata") # User-provided metadata fields
1016
  skip_caption = cfg.get("skip_caption", False) # Skip caption generation in CoT
1017
  skip_language = cfg.get("skip_language", False) # Skip language generation in CoT
 
1018
  generation_phase = cfg.get("generation_phase", "cot") # "cot" or "codes"
1019
  # Additional context for codes phase unconditional prompt building
1020
  caption = cfg.get("caption", "")
@@ -1036,6 +1045,7 @@ class LLMHandler:
1036
  target_duration=target_duration,
1037
  user_metadata=user_metadata,
1038
  stop_at_reasoning=stop_at_reasoning,
 
1039
  skip_caption=skip_caption,
1040
  skip_language=skip_language,
1041
  generation_phase=generation_phase,
@@ -1059,6 +1069,7 @@ class LLMHandler:
1059
  target_duration=target_duration,
1060
  user_metadata=user_metadata,
1061
  stop_at_reasoning=stop_at_reasoning,
 
1062
  skip_caption=skip_caption,
1063
  skip_language=skip_language,
1064
  generation_phase=generation_phase,
@@ -1521,3 +1532,51 @@ class LLMHandler:
1521
  torch.cuda.empty_cache()
1522
  offload_time = time.time() - start_time
1523
  logger.info(f"Offloaded LLM to CPU in {offload_time:.4f}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # Shared constrained decoding processor (initialized once when LLM is loaded)
41
  self.constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None
42
+
43
+ # Shared HuggingFace model for perplexity calculation (when using vllm backend)
44
+ self._hf_model_for_scoring = None
45
 
46
  def get_available_5hz_lm_models(self) -> List[str]:
47
  """Scan and return all model directory names starting with 'acestep-5Hz-lm-'"""
 
249
  target_duration: Optional[float] = None,
250
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
251
  stop_at_reasoning: bool = False,
252
+ skip_genres: bool = True,
253
  skip_caption: bool = False,
254
  skip_language: bool = False,
255
  generation_phase: str = "cot",
 
280
  self.constrained_processor.set_user_metadata(user_metadata)
281
  self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
282
  # Set skip_caption and skip_language based on flags
283
+ self.constrained_processor.set_skip_genres(skip_genres)
284
  self.constrained_processor.set_skip_caption(skip_caption)
285
  self.constrained_processor.set_skip_language(skip_language)
286
  # Set generation phase for phase-aware processing
 
352
  target_duration: Optional[float] = None,
353
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
354
  stop_at_reasoning: bool = False,
355
+ skip_genres: bool = True,
356
  skip_caption: bool = False,
357
  skip_language: bool = False,
358
  generation_phase: str = "cot",
 
382
  self.constrained_processor.set_user_metadata(user_metadata)
383
  self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
384
  # Set skip_caption and skip_language based on flags
385
+ self.constrained_processor.set_skip_genres(skip_genres)
386
  self.constrained_processor.set_skip_caption(skip_caption)
387
  self.constrained_processor.set_skip_language(skip_language)
388
  # Set generation phase for phase-aware processing
 
604
  "user_metadata": user_metadata,
605
  "skip_caption": not use_cot_caption,
606
  "skip_language": not use_cot_language,
607
+ "skip_genres": True, # Generate genres
608
  "generation_phase": "cot",
609
  # Pass context for building unconditional prompt in CoT phase
610
  "caption": caption,
 
871
  - bpm: int or str
872
  - caption: str
873
  - duration: int or str
 
874
  - keyscale: str
875
  - language: str
876
  - timesignature: str
 
908
  "user_metadata": None, # No user metadata injection
909
  "skip_caption": False, # Generate caption
910
  "skip_language": False, # Generate language
911
+ "skip_genres": False, # Generate genres
912
  "generation_phase": "understand", # Understanding phase: generate CoT metadata, then free-form lyrics
913
  # Context for building unconditional prompt
914
  "caption": "",
 
1023
  user_metadata = cfg.get("user_metadata") # User-provided metadata fields
1024
  skip_caption = cfg.get("skip_caption", False) # Skip caption generation in CoT
1025
  skip_language = cfg.get("skip_language", False) # Skip language generation in CoT
1026
+ skip_genres = cfg.get("skip_genres", False) # Skip genres generation in CoT
1027
  generation_phase = cfg.get("generation_phase", "cot") # "cot" or "codes"
1028
  # Additional context for codes phase unconditional prompt building
1029
  caption = cfg.get("caption", "")
 
1045
  target_duration=target_duration,
1046
  user_metadata=user_metadata,
1047
  stop_at_reasoning=stop_at_reasoning,
1048
+ skip_genres=skip_genres,
1049
  skip_caption=skip_caption,
1050
  skip_language=skip_language,
1051
  generation_phase=generation_phase,
 
1069
  target_duration=target_duration,
1070
  user_metadata=user_metadata,
1071
  stop_at_reasoning=stop_at_reasoning,
1072
+ skip_genres=skip_genres,
1073
  skip_caption=skip_caption,
1074
  skip_language=skip_language,
1075
  generation_phase=generation_phase,
 
1532
  torch.cuda.empty_cache()
1533
  offload_time = time.time() - start_time
1534
  logger.info(f"Offloaded LLM to CPU in {offload_time:.4f}s")
1535
+
1536
+ def get_hf_model_for_scoring(self):
1537
+ """
1538
+ Get HuggingFace model for perplexity scoring.
1539
+
1540
+ For vllm backend, loads HuggingFace model from disk (weights are cached by transformers).
1541
+ For pt backend, returns the existing model.
1542
+
1543
+ Returns:
1544
+ HuggingFace model instance
1545
+ """
1546
+ if self.llm_backend == "pt":
1547
+ # For PyTorch backend, directly return the model
1548
+ return self.llm
1549
+
1550
+ elif self.llm_backend == "vllm":
1551
+ # For vllm backend, load HuggingFace model from disk
1552
+ # Note: transformers caches model weights, so this doesn't duplicate disk I/O
1553
+ if self._hf_model_for_scoring is None:
1554
+ logger.info("Loading HuggingFace model for scoring (from checkpoint)")
1555
+
1556
+ # Get model path from vllm config
1557
+ model_runner = self.llm.model_runner
1558
+ model_path = model_runner.config.model
1559
+
1560
+ # Load HuggingFace model from the same checkpoint
1561
+ # This will load the original unfused weights
1562
+ import time
1563
+ start_time = time.time()
1564
+ self._hf_model_for_scoring = AutoModelForCausalLM.from_pretrained(
1565
+ model_path,
1566
+ trust_remote_code=True,
1567
+ torch_dtype=self.dtype
1568
+ )
1569
+ load_time = time.time() - start_time
1570
+ logger.info(f"HuggingFace model loaded in {load_time:.2f}s")
1571
+
1572
+ # Move to same device as vllm model
1573
+ device = next(model_runner.model.parameters()).device
1574
+ self._hf_model_for_scoring = self._hf_model_for_scoring.to(device)
1575
+ self._hf_model_for_scoring.eval()
1576
+
1577
+ logger.info(f"HuggingFace model for scoring ready on {device}")
1578
+
1579
+ return self._hf_model_for_scoring
1580
+
1581
+ else:
1582
+ raise ValueError(f"Unknown backend: {self.llm_backend}")
acestep/test_time_scaling.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test-Time Scaling Module
3
+ Implements perplexity-based scoring for generated audio codes
4
+ """
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from typing import Tuple, Optional, Dict, Any
8
+ from loguru import logger
9
+ import yaml
10
+
11
+
12
+ def perplexity_to_score(perplexity: float, scale: float = 100.0) -> float:
13
+ """
14
+ Convert perplexity to a normalized score in [0, 1] range.
15
+
16
+ Lower perplexity = higher score (better quality)
17
+ Uses exponential decay: score = exp(-perplexity / scale)
18
+
19
+ Args:
20
+ perplexity: Perplexity value (typically 1 to 1000+)
21
+ scale: Scale parameter to control score distribution (default 100.0)
22
+ - Smaller scale: more sensitive to perplexity changes
23
+ - Larger scale: less sensitive to perplexity changes
24
+
25
+ Returns:
26
+ Score in [0, 1] range, where 1 is perfect and 0 is worst
27
+
28
+ Examples:
29
+ perplexity=1 → score≈0.99 (excellent)
30
+ perplexity=50 → score≈0.61 (good if scale=100)
31
+ perplexity=100 → score≈0.37 (medium if scale=100)
32
+ perplexity=500 → score≈0.01 (poor if scale=100)
33
+ """
34
+ import math
35
+ return math.exp(-perplexity / scale)
36
+
37
+
38
+ def calculate_perplexity(
39
+ llm_handler,
40
+ audio_codes: str,
41
+ caption: str = "",
42
+ lyrics: str = "",
43
+ metadata: Optional[Dict[str, Any]] = None,
44
+ temperature: float = 1.0,
45
+ ) -> Tuple[float, str]:
46
+ """
47
+ Calculate perplexity of generated audio codes conditioned on caption/lyrics/metadata.
48
+
49
+ This reverses the generation task: given audio codes as input, measure how well
50
+ the model can predict the CoT metadata and lyrics that should generate those codes.
51
+
52
+ Lower perplexity = model is less surprised = better quality generation
53
+ Score = -perplexity (higher is better)
54
+
55
+ The understanding task format is:
56
+ Input: <|audio_code_123|><|audio_code_456|>...
57
+ Output: <think>\nmetadata_yaml\n</think>\n\n# Lyric\nlyrics_text
58
+
59
+ Args:
60
+ llm_handler: LLM handler instance with initialized model
61
+ audio_codes: Generated audio code string (e.g., "<|audio_code_123|><|audio_code_456|>...")
62
+ caption: Caption text used for generation
63
+ lyrics: Lyrics text used for generation
64
+ metadata: Dictionary with CoT metadata fields (bpm, duration, keyscale, language, timesignature, etc.)
65
+ temperature: Temperature for probability scaling (default 1.0)
66
+
67
+ Returns:
68
+ Tuple of (perplexity_value, status_message)
69
+
70
+ Example:
71
+ metadata = {'bpm': 120, 'duration': 30, 'keyscale': 'C major', 'language': 'en', 'timesignature': '4'}
72
+ perplexity, status = calculate_perplexity(
73
+ llm_handler,
74
+ audio_codes="<|audio_code_123|>...",
75
+ caption="calm piano",
76
+ lyrics="verse 1...",
77
+ metadata=metadata
78
+ )
79
+ score = -perplexity # Higher score = better quality
80
+ """
81
+ if not llm_handler.llm_initialized:
82
+ return float('inf'), "❌ LLM not initialized"
83
+
84
+ if not audio_codes or not audio_codes.strip():
85
+ return float('inf'), "❌ No audio codes provided"
86
+
87
+ try:
88
+ # Build the understanding prompt: codes as input
89
+ # The model should generate: <think>metadata</think>\n# Lyric\n...
90
+ formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(
91
+ audio_codes=audio_codes,
92
+ is_negative_prompt=False
93
+ )
94
+
95
+ logger.info(f"Calculating perplexity for {len(audio_codes)} character audio codes")
96
+
97
+ # Build the expected output (target sequence) following understanding task format
98
+ # Format: <think>\nmetadata_yaml\n</think>\n\n# Lyric\nlyrics_text
99
+ target_parts = []
100
+
101
+ # Build CoT section with metadata
102
+ if metadata and isinstance(metadata, dict):
103
+ # Filter out None values and format as YAML (sorted keys)
104
+ cot_items = {}
105
+ for key in ['bpm', 'caption', 'duration', 'genres', 'keyscale', 'language', 'timesignature']:
106
+ if key in metadata and metadata[key] is not None:
107
+ cot_items[key] = metadata[key]
108
+
109
+ if cot_items:
110
+ cot_yaml = yaml.dump(cot_items, allow_unicode=True, sort_keys=True).strip()
111
+ target_parts.append(f"<think>\n{cot_yaml}\n</think>\n")
112
+
113
+ # Add Lyric section (note: understanding task uses "# Lyric" not "# Caption")
114
+ if lyrics:
115
+ target_parts.append(f"\n# Lyric\n{lyrics}\n")
116
+
117
+ target_text = "".join(target_parts)
118
+
119
+ if not target_text.strip():
120
+ return float('inf'), "❌ No target text to evaluate (lyrics or metadata required)"
121
+
122
+ logger.debug(f"Target text (first 200 chars): {target_text[:200]}...")
123
+
124
+ # Calculate perplexity using appropriate backend
125
+ if llm_handler.llm_backend == "vllm":
126
+ perplexity = _calculate_perplexity_vllm(
127
+ llm_handler,
128
+ formatted_prompt,
129
+ target_text,
130
+ temperature
131
+ )
132
+ else: # pt backend
133
+ perplexity = _calculate_perplexity_pt(
134
+ llm_handler,
135
+ formatted_prompt,
136
+ target_text,
137
+ temperature
138
+ )
139
+
140
+ status_msg = f"✅ Perplexity calculated: {perplexity:.4f}"
141
+ logger.info(status_msg)
142
+ return perplexity, status_msg
143
+
144
+ except Exception as e:
145
+ error_msg = f"❌ Error calculating perplexity: {str(e)}"
146
+ logger.error(error_msg)
147
+ import traceback
148
+ logger.error(traceback.format_exc())
149
+ return float('inf'), error_msg
150
+
151
+
152
+ def _calculate_perplexity_pt(
153
+ llm_handler,
154
+ formatted_prompt: str,
155
+ target_text: str,
156
+ temperature: float
157
+ ) -> float:
158
+ """
159
+ Calculate perplexity using PyTorch backend.
160
+
161
+ For vllm backend, this uses a shared-weight HuggingFace model.
162
+ For pt backend, this uses the original model.
163
+
164
+ Args:
165
+ llm_handler: LLM handler with pt or vllm backend
166
+ formatted_prompt: Formatted input prompt (audio codes)
167
+ target_text: Expected output text (CoT metadata + lyrics)
168
+ temperature: Temperature for probability scaling
169
+
170
+ Returns:
171
+ Perplexity value
172
+ """
173
+ # Get model for scoring (handles both pt and vllm backends)
174
+ model = llm_handler.get_hf_model_for_scoring()
175
+ tokenizer = llm_handler.llm_tokenizer
176
+ device = llm_handler.device if llm_handler.llm_backend == "pt" else next(model.parameters()).device
177
+
178
+ # Tokenize prompt and target separately
179
+ prompt_tokens = tokenizer(
180
+ formatted_prompt,
181
+ return_tensors="pt",
182
+ padding=False,
183
+ truncation=True,
184
+ )
185
+
186
+ target_tokens = tokenizer(
187
+ target_text,
188
+ return_tensors="pt",
189
+ padding=False,
190
+ truncation=True,
191
+ )
192
+
193
+ # Concatenate prompt + target for full sequence
194
+ full_input_ids = torch.cat([
195
+ prompt_tokens['input_ids'],
196
+ target_tokens['input_ids']
197
+ ], dim=1).to(device)
198
+
199
+ # Create attention mask
200
+ attention_mask = torch.ones_like(full_input_ids)
201
+
202
+ # Forward pass to get logits
203
+ with torch.no_grad():
204
+ with llm_handler._load_model_context():
205
+ outputs = model(
206
+ input_ids=full_input_ids,
207
+ attention_mask=attention_mask
208
+ )
209
+ logits = outputs.logits # [batch_size, seq_len, vocab_size]
210
+
211
+ # Get the logits for predicting target tokens
212
+ # Shift logits and labels: logits[i] predicts token[i+1]
213
+ prompt_len = prompt_tokens['input_ids'].shape[1]
214
+ target_len = target_tokens['input_ids'].shape[1]
215
+
216
+ # Extract logits for positions that predict target tokens
217
+ # logits at positions [prompt_len-1 : prompt_len+target_len-1] predict target tokens
218
+ pred_logits = logits[0, prompt_len-1:prompt_len+target_len-1, :] # [target_len, vocab_size]
219
+ target_ids = target_tokens['input_ids'][0] # [target_len]
220
+
221
+ # Apply temperature scaling
222
+ if temperature != 1.0:
223
+ pred_logits = pred_logits / temperature
224
+
225
+ # Calculate cross-entropy loss for each position
226
+ log_probs = F.log_softmax(pred_logits, dim=-1) # [target_len, vocab_size]
227
+
228
+ # Gather log probabilities of target tokens
229
+ target_log_probs = log_probs[torch.arange(target_len), target_ids] # [target_len]
230
+
231
+ # Calculate perplexity: exp(-mean(log_probs))
232
+ mean_neg_log_prob = -target_log_probs.mean()
233
+ perplexity = torch.exp(mean_neg_log_prob).item()
234
+
235
+ return perplexity
236
+
237
+
238
+ def _calculate_perplexity_vllm(
239
+ llm_handler,
240
+ formatted_prompt: str,
241
+ target_text: str,
242
+ temperature: float
243
+ ) -> float:
244
+ """
245
+ Calculate perplexity using vllm backend.
246
+
247
+ Uses shared-weight HuggingFace model for perplexity calculation.
248
+ This avoids the complexity of nanovllm's context management.
249
+
250
+ Args:
251
+ llm_handler: LLM handler with vllm backend
252
+ formatted_prompt: Formatted input prompt (audio codes)
253
+ target_text: Expected output text (CoT metadata + lyrics)
254
+ temperature: Temperature for probability scaling
255
+
256
+ Returns:
257
+ Perplexity value
258
+ """
259
+ logger.debug("Using vllm backend with shared-weight HuggingFace model for perplexity")
260
+ # Delegate to pt backend implementation which now handles both backends
261
+ return _calculate_perplexity_pt(llm_handler, formatted_prompt, target_text, temperature)