ChuxiJ commited on
Commit
9e64ac5
·
1 Parent(s): c0934b3

support lm duration & fix keyscale bias

Browse files
Files changed (1) hide show
  1. acestep/llm_inference.py +171 -77
acestep/llm_inference.py CHANGED
@@ -7,7 +7,7 @@ import re
7
  import traceback
8
  import time
9
  from enum import Enum, auto
10
- from typing import Optional, Dict, Any, Tuple, List, Callable
11
  from contextlib import contextmanager
12
 
13
  import torch
@@ -102,6 +102,12 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
102
  self.metadata_temperature: Optional[float] = None
103
  self.codes_temperature: Optional[float] = None
104
 
 
 
 
 
 
 
105
  # Current state
106
  self.state = FSMState.THINK_TAG
107
  self.position_in_state = 0 # Position within current state's fixed string
@@ -242,6 +248,54 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
242
  # Comma token for multi-genre support
243
  comma_tokens = self.tokenizer.encode(",", add_special_tokens=False)
244
  self.comma_token = comma_tokens[-1] if comma_tokens else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  def _load_genres_vocab(self):
247
  """
@@ -566,6 +620,25 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
566
  self.state = FSMState.THINK_TAG
567
  self.position_in_state = 0
568
  self.accumulated_value = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
 
570
  def update_caption(self, caption: Optional[str]):
571
  """
@@ -702,71 +775,23 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
702
  return newline_prob > max_other_prob
703
 
704
  def _get_allowed_keyscale_tokens(self) -> List[int]:
705
- """Get allowed tokens for keyscale field based on accumulated value."""
706
- # Don't use strip() - we need to track spaces properly
 
 
 
707
  acc = self.accumulated_value
708
- acc_stripped = acc.strip()
709
-
710
- if not acc_stripped:
711
- # First character: must be a note (A-G)
712
- return list(self.note_tokens.values())
713
-
714
- # Check if we already have a space
715
- has_space = " " in acc
716
-
717
- # Parse what we have
718
- if not has_space:
719
- # No space yet
720
- if len(acc_stripped) == 1 and acc_stripped.upper() in "ABCDEFG":
721
- # After note: can be # ♯ b ♭ or space (for major/minor)
722
- allowed = self.sharp_tokens + self.flat_tokens
723
- if self.space_token:
724
- allowed.append(self.space_token)
725
- return allowed
726
-
727
- if len(acc_stripped) >= 2 and acc_stripped[-1] in "#♯b♭":
728
- # After accidental: must be space
729
- return [self.space_token] if self.space_token else []
730
 
731
- if has_space:
732
- # After space: should be major or minor
733
- after_space = acc.split(" ", 1)[-1].lower()
734
-
735
- # Allow tokens that continue "major" or "minor"
736
- allowed = []
737
- for word in ["major", "minor"]:
738
- if word.startswith(after_space):
739
- remaining = word[len(after_space):]
740
- if remaining:
741
- # Try to encode the next character
742
- tokens = self.tokenizer.encode(remaining[0], add_special_tokens=False)
743
- allowed.extend(tokens)
744
- # Also try encoding the whole remaining part
745
- tokens = self.tokenizer.encode(remaining, add_special_tokens=False)
746
- if tokens:
747
- allowed.append(tokens[0])
748
-
749
- # If after_space is exactly "major" or "minor", allow newline
750
- if after_space in ["major", "minor"]:
751
- if self.newline_token:
752
- allowed.append(self.newline_token)
753
-
754
- # If no tokens found but we have incomplete word, this is an error state
755
- # Force newline if we've tried enough
756
- if not allowed and len(after_space) > 5:
757
- if self.newline_token:
758
- allowed.append(self.newline_token)
759
-
760
- return list(set(allowed))
761
 
 
 
762
  return []
763
 
764
  def _is_keyscale_complete(self) -> bool:
765
- """Check if keyscale value is complete and valid."""
766
- acc = self.accumulated_value.strip().lower()
767
- # Pattern: [A-G][#♯b♭]? (major|minor)
768
- pattern = r'^[a-g][#♯b♭]?\s*(major|minor)$'
769
- return bool(re.match(pattern, acc, re.IGNORECASE))
770
 
771
  def _get_allowed_timesig_tokens(self) -> List[int]:
772
  """Get allowed tokens for timesignature field."""
@@ -797,8 +822,24 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
797
  if not self.enabled:
798
  return self._apply_temperature_scaling(scores)
799
 
800
- if self.state == FSMState.COMPLETED or self.state == FSMState.CODES_GENERATION:
801
- # No constraints in codes generation phase, but still apply temperature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
802
  return self._apply_temperature_scaling(scores)
803
 
804
  batch_size = scores.shape[0]
@@ -890,21 +931,40 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
890
  scores = scores + mask
891
 
892
  elif self.state == FSMState.DURATION_VALUE:
893
- min_val, max_val = self.field_specs["duration"]["min"], self.field_specs["duration"]["max"]
894
-
895
- if self._should_end_numeric_field(scores, min_val, max_val):
896
- if self.newline_token:
897
- mask[0, self.newline_token] = 0
898
- self._transition_to_next_state()
 
 
 
 
 
 
 
 
 
 
 
899
  else:
900
- allowed = self._get_allowed_digit_tokens(min_val, max_val)
901
- for t in allowed:
902
- mask[0, t] = 0
903
- current = int(self.accumulated_value) if self.accumulated_value else 0
904
- if min_val <= current <= max_val and self.newline_token:
905
- mask[0, self.newline_token] = 0
906
-
907
- scores = scores + mask
 
 
 
 
 
 
 
 
908
 
909
  elif self.state == FSMState.GENRES_VALUE:
910
  # Try to hot-reload genres vocab if file has changed
@@ -997,7 +1057,14 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
997
  if not self.enabled:
998
  return
999
 
1000
- if self.state == FSMState.COMPLETED or self.state == FSMState.CODES_GENERATION:
 
 
 
 
 
 
 
1001
  return
1002
 
1003
  token_str = self.tokenizer.decode([generated_token_id])
@@ -1258,6 +1325,7 @@ class LLMHandler:
1258
  constrained_decoding_debug: bool = False,
1259
  metadata_temperature: Optional[float] = 0.85,
1260
  codes_temperature: Optional[float] = None,
 
1261
  ) -> Tuple[Dict[str, Any], str, str]:
1262
  """Generate metadata and audio codes using 5Hz LM with vllm backend
1263
 
@@ -1276,6 +1344,8 @@ class LLMHandler:
1276
  If None, uses base temperature
1277
  codes_temperature: Temperature for audio codes generation (higher = more diverse)
1278
  If None, uses base temperature
 
 
1279
  """
1280
  try:
1281
  from nanovllm import SamplingParams
@@ -1298,6 +1368,7 @@ class LLMHandler:
1298
  self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
1299
  self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
1300
  self.constrained_processor.update_caption(caption)
 
1301
 
1302
  constrained_processor = self.constrained_processor
1303
  update_state_fn = constrained_processor.update_state
@@ -1357,6 +1428,7 @@ class LLMHandler:
1357
  constrained_decoding_debug: bool = False,
1358
  metadata_temperature: Optional[float] = 0.85,
1359
  codes_temperature: Optional[float] = None,
 
1360
  ) -> str:
1361
  """Shared vllm path: accept prebuilt formatted prompt and return text."""
1362
  from nanovllm import SamplingParams
@@ -1374,6 +1446,7 @@ class LLMHandler:
1374
  self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
1375
  self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
1376
  self.constrained_processor.update_caption(formatted_prompt) # Use formatted prompt for genre extraction
 
1377
 
1378
  constrained_processor = self.constrained_processor
1379
 
@@ -1427,6 +1500,7 @@ class LLMHandler:
1427
  constrained_decoding_debug: bool = False,
1428
  metadata_temperature: Optional[float] = 0.85,
1429
  codes_temperature: Optional[float] = None,
 
1430
  ) -> Tuple[Dict[str, Any], str, str]:
1431
  """Generate metadata and audio codes using 5Hz LM with PyTorch backend
1432
 
@@ -1445,6 +1519,8 @@ class LLMHandler:
1445
  If None, uses base temperature
1446
  codes_temperature: Temperature for audio codes generation (higher = more diverse)
1447
  If None, uses base temperature
 
 
1448
  """
1449
  try:
1450
  formatted_prompt = self.build_formatted_prompt(caption, lyrics)
@@ -1496,6 +1572,7 @@ class LLMHandler:
1496
  self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
1497
  self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
1498
  self.constrained_processor.update_caption(caption)
 
1499
 
1500
  constrained_processor = self.constrained_processor
1501
 
@@ -1623,6 +1700,7 @@ class LLMHandler:
1623
  repetition_penalty: float,
1624
  use_constrained_decoding: bool = True,
1625
  constrained_decoding_debug: bool = False,
 
1626
  ) -> str:
1627
  """Shared PyTorch path: accept prebuilt formatted prompt and return text."""
1628
  inputs = self.llm_tokenizer(
@@ -1639,6 +1717,7 @@ class LLMHandler:
1639
  self.constrained_processor.enabled = use_constrained_decoding
1640
  self.constrained_processor.debug = constrained_decoding_debug
1641
  self.constrained_processor.update_caption(formatted_prompt) # Use formatted prompt for genre extraction
 
1642
 
1643
  constrained_processor = self.constrained_processor
1644
 
@@ -1764,6 +1843,7 @@ class LLMHandler:
1764
  constrained_decoding_debug: bool = False,
1765
  metadata_temperature: Optional[float] = 0.85,
1766
  codes_temperature: Optional[float] = None,
 
1767
  ) -> Tuple[Dict[str, Any], str, str]:
1768
  """Generate metadata and audio codes using 5Hz LM
1769
 
@@ -1782,6 +1862,8 @@ class LLMHandler:
1782
  Recommended: 0.3-0.5 for accurate metadata
1783
  codes_temperature: Temperature for audio codes generation (higher = more diverse)
1784
  Recommended: 0.7-1.0 for diverse codes
 
 
1785
  """
1786
  # Check if 5Hz LM is initialized
1787
  if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
@@ -1811,6 +1893,7 @@ class LLMHandler:
1811
  constrained_decoding_debug=constrained_decoding_debug,
1812
  metadata_temperature=metadata_temperature,
1813
  codes_temperature=codes_temperature,
 
1814
  )
1815
  else:
1816
  return self.generate_with_5hz_lm_pt(
@@ -1826,6 +1909,7 @@ class LLMHandler:
1826
  constrained_decoding_debug=constrained_decoding_debug,
1827
  metadata_temperature=metadata_temperature,
1828
  codes_temperature=codes_temperature,
 
1829
  )
1830
 
1831
  def generate_with_stop_condition(
@@ -1843,11 +1927,16 @@ class LLMHandler:
1843
  constrained_decoding_debug: bool = False,
1844
  metadata_temperature: Optional[float] = 0.85,
1845
  codes_temperature: Optional[float] = None,
 
1846
  ) -> Tuple[Dict[str, Any], str, str]:
1847
  """Feishu-compatible LM generation.
1848
 
1849
  - infer_type='dit': stop at </think> and return metas only (no audio codes)
1850
  - infer_type='llm_dit': normal generation (metas + audio codes)
 
 
 
 
1851
  """
1852
  infer_type = (infer_type or "").strip().lower()
1853
  if infer_type not in {"dit", "llm_dit"}:
@@ -1867,6 +1956,7 @@ class LLMHandler:
1867
  constrained_decoding_debug=constrained_decoding_debug,
1868
  metadata_temperature=metadata_temperature,
1869
  codes_temperature=codes_temperature,
 
1870
  )
1871
 
1872
  # dit: generate and truncate at reasoning end tag
@@ -1934,6 +2024,7 @@ class LLMHandler:
1934
  - cfg_scale (float)
1935
  - negative_prompt (str) used when cfg_scale > 1
1936
  - top_k (int), top_p (float), repetition_penalty (float)
 
1937
  use_constrained_decoding: Whether to use FSM-based constrained decoding
1938
  constrained_decoding_debug: Whether to enable debug logging for constrained decoding
1939
 
@@ -1956,6 +2047,7 @@ class LLMHandler:
1956
  top_k = cfg.get("top_k")
1957
  top_p = cfg.get("top_p")
1958
  repetition_penalty = cfg.get("repetition_penalty", 1.0)
 
1959
 
1960
  try:
1961
  if self.llm_backend == "vllm":
@@ -1969,6 +2061,7 @@ class LLMHandler:
1969
  repetition_penalty=repetition_penalty,
1970
  use_constrained_decoding=use_constrained_decoding,
1971
  constrained_decoding_debug=constrained_decoding_debug,
 
1972
  )
1973
  return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
1974
 
@@ -1983,6 +2076,7 @@ class LLMHandler:
1983
  repetition_penalty=repetition_penalty,
1984
  use_constrained_decoding=use_constrained_decoding,
1985
  constrained_decoding_debug=constrained_decoding_debug,
 
1986
  )
1987
  return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}"
1988
 
 
7
  import traceback
8
  import time
9
  from enum import Enum, auto
10
+ from typing import Optional, Dict, Any, Tuple, List, Callable, Set
11
  from contextlib import contextmanager
12
 
13
  import torch
 
102
  self.metadata_temperature: Optional[float] = None
103
  self.codes_temperature: Optional[float] = None
104
 
105
+ # Duration constraint for codes generation
106
+ # 5 codes = 1 second, so target_codes = target_duration * 5
107
+ self.target_duration: Optional[float] = None # User-specified duration in seconds
108
+ self.target_codes: Optional[int] = None # Computed target codes count
109
+ self.codes_count: int = 0 # Counter for generated codes
110
+
111
  # Current state
112
  self.state = FSMState.THINK_TAG
113
  self.position_in_state = 0 # Position within current state's fixed string
 
248
  # Comma token for multi-genre support
249
  comma_tokens = self.tokenizer.encode(",", add_special_tokens=False)
250
  self.comma_token = comma_tokens[-1] if comma_tokens else None
251
+
252
+ # EOS token for duration-constrained codes generation
253
+ self.eos_token_id = self.tokenizer.eos_token_id
254
+
255
+ # Build valid keyscales set and prefix tree for constrained decoding
256
+ # 7 notes × 5 accidentals (none, #, b, ♯, ♭) × 2 modes = 70 valid combinations
257
+ notes = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
258
+ accidentals = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
259
+ modes = ['major', 'minor']
260
+
261
+ self.valid_keyscales = set()
262
+ for note in notes:
263
+ for acc in accidentals:
264
+ for mode in modes:
265
+ self.valid_keyscales.add(f"{note}{acc} {mode}")
266
+
267
+ # Build prefix tree for keyscale constrained decoding
268
+ self.keyscale_prefix_tree = self._build_keyscale_prefix_tree()
269
+
270
+ def _build_keyscale_prefix_tree(self) -> Dict[str, Set[int]]:
271
+ """
272
+ Build keyscale prefix to allowed tokens mapping.
273
+ For each prefix of each valid keyscale, we store the set of tokens
274
+ that can continue to form a valid keyscale.
275
+ """
276
+ prefix_to_tokens: Dict[str, Set[int]] = {}
277
+
278
+ for keyscale in self.valid_keyscales:
279
+ for i in range(len(keyscale)):
280
+ prefix = keyscale[:i]
281
+ next_char = keyscale[i]
282
+ # Encode the next character
283
+ tokens = self.tokenizer.encode(next_char, add_special_tokens=False)
284
+ if prefix not in prefix_to_tokens:
285
+ prefix_to_tokens[prefix] = set()
286
+ prefix_to_tokens[prefix].update(tokens)
287
+
288
+ # For complete keyscales, allow newline token
289
+ for keyscale in self.valid_keyscales:
290
+ if keyscale not in prefix_to_tokens:
291
+ prefix_to_tokens[keyscale] = set()
292
+ if self.newline_token:
293
+ prefix_to_tokens[keyscale].add(self.newline_token)
294
+
295
+ if self.debug:
296
+ logger.debug(f"Built keyscale prefix tree with {len(prefix_to_tokens)} prefixes for {len(self.valid_keyscales)} valid keyscales")
297
+
298
+ return prefix_to_tokens
299
 
300
  def _load_genres_vocab(self):
301
  """
 
620
  self.state = FSMState.THINK_TAG
621
  self.position_in_state = 0
622
  self.accumulated_value = ""
623
+ self.codes_count = 0 # Reset codes counter
624
+
625
+ def set_target_duration(self, duration: Optional[float]):
626
+ """
627
+ Set the target duration for codes generation.
628
+
629
+ Args:
630
+ duration: Target duration in seconds. If None, no duration constraint is applied.
631
+ 5 codes = 1 second, so target_codes = duration * 5.
632
+ """
633
+ self.target_duration = duration
634
+ if duration is not None and duration > 0:
635
+ self.target_codes = int(duration * 5)
636
+ if self.debug:
637
+ logger.debug(f"Set target duration: {duration}s -> {self.target_codes} codes")
638
+ else:
639
+ self.target_codes = None
640
+ if self.debug:
641
+ logger.debug("Target duration cleared, no duration constraint")
642
 
643
  def update_caption(self, caption: Optional[str]):
644
  """
 
775
  return newline_prob > max_other_prob
776
 
777
  def _get_allowed_keyscale_tokens(self) -> List[int]:
778
+ """
779
+ Get allowed tokens for keyscale field using prefix tree.
780
+ Only allows tokens that can lead to valid keyscales like:
781
+ - "A major", "A minor", "A# major", "Ab minor", etc.
782
+ """
783
  acc = self.accumulated_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
784
 
785
+ if acc in self.keyscale_prefix_tree:
786
+ return list(self.keyscale_prefix_tree[acc])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
787
 
788
+ # No valid continuation found - return empty list
789
+ # The caller will handle this by forcing newline to end the field
790
  return []
791
 
792
  def _is_keyscale_complete(self) -> bool:
793
+ """Check if keyscale value is complete and valid by checking against valid_keyscales set."""
794
+ return self.accumulated_value in self.valid_keyscales
 
 
 
795
 
796
  def _get_allowed_timesig_tokens(self) -> List[int]:
797
  """Get allowed tokens for timesignature field."""
 
822
  if not self.enabled:
823
  return self._apply_temperature_scaling(scores)
824
 
825
+ if self.state == FSMState.COMPLETED:
826
+ return self._apply_temperature_scaling(scores)
827
+
828
+ if self.state == FSMState.CODES_GENERATION:
829
+ # Apply duration constraint in codes generation phase
830
+ if self.target_codes is not None and self.eos_token_id is not None:
831
+ if self.codes_count < self.target_codes:
832
+ # Block EOS token until target codes count is reached
833
+ scores[:, self.eos_token_id] = float('-inf')
834
+ if self.debug:
835
+ logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS")
836
+ else:
837
+ # Force EOS token when target codes count is reached
838
+ mask = torch.full_like(scores, float('-inf'))
839
+ mask[:, self.eos_token_id] = 0
840
+ scores = scores + mask
841
+ if self.debug:
842
+ logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS")
843
  return self._apply_temperature_scaling(scores)
844
 
845
  batch_size = scores.shape[0]
 
931
  scores = scores + mask
932
 
933
  elif self.state == FSMState.DURATION_VALUE:
934
+ # If target_duration is set, force generate that exact value
935
+ if self.target_duration is not None:
936
+ target_str = str(int(self.target_duration))
937
+ current_pos = len(self.accumulated_value)
938
+
939
+ if current_pos < len(target_str):
940
+ # Force the next digit
941
+ next_digit = int(target_str[current_pos])
942
+ if next_digit in self.digit_tokens:
943
+ mask[0, self.digit_tokens[next_digit]] = 0
944
+ else:
945
+ # All digits generated, force newline
946
+ if self.newline_token:
947
+ mask[0, self.newline_token] = 0
948
+ self._transition_to_next_state()
949
+
950
+ scores = scores + mask
951
  else:
952
+ # Normal duration generation with range constraint
953
+ min_val, max_val = self.field_specs["duration"]["min"], self.field_specs["duration"]["max"]
954
+
955
+ if self._should_end_numeric_field(scores, min_val, max_val):
956
+ if self.newline_token:
957
+ mask[0, self.newline_token] = 0
958
+ self._transition_to_next_state()
959
+ else:
960
+ allowed = self._get_allowed_digit_tokens(min_val, max_val)
961
+ for t in allowed:
962
+ mask[0, t] = 0
963
+ current = int(self.accumulated_value) if self.accumulated_value else 0
964
+ if min_val <= current <= max_val and self.newline_token:
965
+ mask[0, self.newline_token] = 0
966
+
967
+ scores = scores + mask
968
 
969
  elif self.state == FSMState.GENRES_VALUE:
970
  # Try to hot-reload genres vocab if file has changed
 
1057
  if not self.enabled:
1058
  return
1059
 
1060
+ if self.state == FSMState.COMPLETED:
1061
+ return
1062
+
1063
+ if self.state == FSMState.CODES_GENERATION:
1064
+ # Count generated codes for duration constraint
1065
+ self.codes_count += 1
1066
+ if self.debug and self.target_codes is not None:
1067
+ logger.debug(f"Codes count: {self.codes_count}/{self.target_codes}")
1068
  return
1069
 
1070
  token_str = self.tokenizer.decode([generated_token_id])
 
1325
  constrained_decoding_debug: bool = False,
1326
  metadata_temperature: Optional[float] = 0.85,
1327
  codes_temperature: Optional[float] = None,
1328
+ target_duration: Optional[float] = None,
1329
  ) -> Tuple[Dict[str, Any], str, str]:
1330
  """Generate metadata and audio codes using 5Hz LM with vllm backend
1331
 
 
1344
  If None, uses base temperature
1345
  codes_temperature: Temperature for audio codes generation (higher = more diverse)
1346
  If None, uses base temperature
1347
+ target_duration: Target duration in seconds for codes generation constraint.
1348
+ 5 codes = 1 second. If specified, blocks EOS until target reached.
1349
  """
1350
  try:
1351
  from nanovllm import SamplingParams
 
1368
  self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
1369
  self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
1370
  self.constrained_processor.update_caption(caption)
1371
+ self.constrained_processor.set_target_duration(target_duration)
1372
 
1373
  constrained_processor = self.constrained_processor
1374
  update_state_fn = constrained_processor.update_state
 
1428
  constrained_decoding_debug: bool = False,
1429
  metadata_temperature: Optional[float] = 0.85,
1430
  codes_temperature: Optional[float] = None,
1431
+ target_duration: Optional[float] = None,
1432
  ) -> str:
1433
  """Shared vllm path: accept prebuilt formatted prompt and return text."""
1434
  from nanovllm import SamplingParams
 
1446
  self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
1447
  self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
1448
  self.constrained_processor.update_caption(formatted_prompt) # Use formatted prompt for genre extraction
1449
+ self.constrained_processor.set_target_duration(target_duration)
1450
 
1451
  constrained_processor = self.constrained_processor
1452
 
 
1500
  constrained_decoding_debug: bool = False,
1501
  metadata_temperature: Optional[float] = 0.85,
1502
  codes_temperature: Optional[float] = None,
1503
+ target_duration: Optional[float] = None,
1504
  ) -> Tuple[Dict[str, Any], str, str]:
1505
  """Generate metadata and audio codes using 5Hz LM with PyTorch backend
1506
 
 
1519
  If None, uses base temperature
1520
  codes_temperature: Temperature for audio codes generation (higher = more diverse)
1521
  If None, uses base temperature
1522
+ target_duration: Target duration in seconds for codes generation constraint.
1523
+ 5 codes = 1 second. If specified, blocks EOS until target reached.
1524
  """
1525
  try:
1526
  formatted_prompt = self.build_formatted_prompt(caption, lyrics)
 
1572
  self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
1573
  self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
1574
  self.constrained_processor.update_caption(caption)
1575
+ self.constrained_processor.set_target_duration(target_duration)
1576
 
1577
  constrained_processor = self.constrained_processor
1578
 
 
1700
  repetition_penalty: float,
1701
  use_constrained_decoding: bool = True,
1702
  constrained_decoding_debug: bool = False,
1703
+ target_duration: Optional[float] = None,
1704
  ) -> str:
1705
  """Shared PyTorch path: accept prebuilt formatted prompt and return text."""
1706
  inputs = self.llm_tokenizer(
 
1717
  self.constrained_processor.enabled = use_constrained_decoding
1718
  self.constrained_processor.debug = constrained_decoding_debug
1719
  self.constrained_processor.update_caption(formatted_prompt) # Use formatted prompt for genre extraction
1720
+ self.constrained_processor.set_target_duration(target_duration)
1721
 
1722
  constrained_processor = self.constrained_processor
1723
 
 
1843
  constrained_decoding_debug: bool = False,
1844
  metadata_temperature: Optional[float] = 0.85,
1845
  codes_temperature: Optional[float] = None,
1846
+ target_duration: Optional[float] = None,
1847
  ) -> Tuple[Dict[str, Any], str, str]:
1848
  """Generate metadata and audio codes using 5Hz LM
1849
 
 
1862
  Recommended: 0.3-0.5 for accurate metadata
1863
  codes_temperature: Temperature for audio codes generation (higher = more diverse)
1864
  Recommended: 0.7-1.0 for diverse codes
1865
+ target_duration: Target duration in seconds for codes generation constraint.
1866
+ 5 codes = 1 second. If specified, blocks EOS until target reached.
1867
  """
1868
  # Check if 5Hz LM is initialized
1869
  if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
 
1893
  constrained_decoding_debug=constrained_decoding_debug,
1894
  metadata_temperature=metadata_temperature,
1895
  codes_temperature=codes_temperature,
1896
+ target_duration=target_duration,
1897
  )
1898
  else:
1899
  return self.generate_with_5hz_lm_pt(
 
1909
  constrained_decoding_debug=constrained_decoding_debug,
1910
  metadata_temperature=metadata_temperature,
1911
  codes_temperature=codes_temperature,
1912
+ target_duration=target_duration,
1913
  )
1914
 
1915
  def generate_with_stop_condition(
 
1927
  constrained_decoding_debug: bool = False,
1928
  metadata_temperature: Optional[float] = 0.85,
1929
  codes_temperature: Optional[float] = None,
1930
+ target_duration: Optional[float] = None,
1931
  ) -> Tuple[Dict[str, Any], str, str]:
1932
  """Feishu-compatible LM generation.
1933
 
1934
  - infer_type='dit': stop at </think> and return metas only (no audio codes)
1935
  - infer_type='llm_dit': normal generation (metas + audio codes)
1936
+
1937
+ Args:
1938
+ target_duration: Target duration in seconds for codes generation constraint.
1939
+ 5 codes = 1 second. If specified, blocks EOS until target reached.
1940
  """
1941
  infer_type = (infer_type or "").strip().lower()
1942
  if infer_type not in {"dit", "llm_dit"}:
 
1956
  constrained_decoding_debug=constrained_decoding_debug,
1957
  metadata_temperature=metadata_temperature,
1958
  codes_temperature=codes_temperature,
1959
+ target_duration=target_duration,
1960
  )
1961
 
1962
  # dit: generate and truncate at reasoning end tag
 
2024
  - cfg_scale (float)
2025
  - negative_prompt (str) used when cfg_scale > 1
2026
  - top_k (int), top_p (float), repetition_penalty (float)
2027
+ - target_duration (float): Target duration in seconds for codes generation
2028
  use_constrained_decoding: Whether to use FSM-based constrained decoding
2029
  constrained_decoding_debug: Whether to enable debug logging for constrained decoding
2030
 
 
2047
  top_k = cfg.get("top_k")
2048
  top_p = cfg.get("top_p")
2049
  repetition_penalty = cfg.get("repetition_penalty", 1.0)
2050
+ target_duration = cfg.get("target_duration")
2051
 
2052
  try:
2053
  if self.llm_backend == "vllm":
 
2061
  repetition_penalty=repetition_penalty,
2062
  use_constrained_decoding=use_constrained_decoding,
2063
  constrained_decoding_debug=constrained_decoding_debug,
2064
+ target_duration=target_duration,
2065
  )
2066
  return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
2067
 
 
2076
  repetition_penalty=repetition_penalty,
2077
  use_constrained_decoding=use_constrained_decoding,
2078
  constrained_decoding_debug=constrained_decoding_debug,
2079
+ target_duration=target_duration,
2080
  )
2081
  return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}"
2082