ChuxiJ commited on
Commit
c0934b3
·
1 Parent(s): ae5026d

add MetadataConstrainedLogitsProcessor

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. API.md +41 -0
  3. acestep/genres_vocab.txt +0 -0
  4. acestep/llm_inference.py +647 -57
.gitignore CHANGED
@@ -218,3 +218,5 @@ checkpoints.7z
218
  README_old.md
219
  discord_bot/
220
  feishu_bot/
 
 
 
218
  README_old.md
219
  discord_bot/
220
  feishu_bot/
221
+ tmp*
222
+ torchinductor_root/
API.md CHANGED
@@ -61,6 +61,22 @@ Suitable for passing only text parameters, or referencing audio file paths that
61
  | `seed` | int | `-1` | Specify seed (when use_random_seed=false) |
62
  | `batch_size` | int | null | Batch generation count |
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  **Edit/Reference Audio Parameters** (requires absolute path on server):
65
 
66
  | Parameter Name | Type | Default | Description |
@@ -108,6 +124,31 @@ curl -X POST http://localhost:8001/v1/music/generate \
108
  }'
109
  ```
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  > Note: If you use `curl -d` but **forget** to add `-H 'Content-Type: application/json'`, curl will default to sending `application/x-www-form-urlencoded`, and older server versions will return 415.
112
 
113
  **Form Method (no file upload, application/x-www-form-urlencoded)**:
 
61
  | `seed` | int | `-1` | Specify seed (when use_random_seed=false) |
62
  | `batch_size` | int | null | Batch generation count |
63
 
64
+ **5Hz LM Parameters (Optional, server-side codes generation)**:
65
+
66
+ If you want the server to generate `audio_code_string` using the 5Hz LM (equivalent to Gradio's **Generate LM Hints** button), set `use_5hz_lm=true`.
67
+
68
+ | Parameter Name | Type | Default | Description |
69
+ | :--- | :--- | :--- | :--- |
70
+ | `use_5hz_lm` | bool | `false` | Enable server-side 5Hz LM code generation |
71
+ | `lm_model_path` | string | null | 5Hz LM checkpoint dir name (e.g. `acestep-5Hz-lm-0.6B`) |
72
+ | `lm_backend` | string | `"vllm"` | `vllm` or `pt` |
73
+ | `lm_temperature` | float | `0.6` | Sampling temperature |
74
+ | `lm_cfg_scale` | float | `1.0` | CFG scale (>1 enables CFG) |
75
+ | `lm_negative_prompt` | string | `"NO USER INPUT"` | Negative prompt used by CFG |
76
+ | `lm_top_k` | int | null | Top-k (0/null disables) |
77
+ | `lm_top_p` | float | null | Top-p (>=1/null disables) |
78
+ | `lm_repetition_penalty` | float | `1.0` | Repetition penalty |
79
+
80
  **Edit/Reference Audio Parameters** (requires absolute path on server):
81
 
82
  | Parameter Name | Type | Default | Description |
 
124
  }'
125
  ```
126
 
127
+ **JSON Method (server-side 5Hz LM)**:
128
+
129
+ ```bash
130
+ curl -X POST http://localhost:8001/v1/music/generate \
131
+ -H 'Content-Type: application/json' \
132
+ -d '{
133
+ "caption": "upbeat pop song",
134
+ "lyrics": "Hello world",
135
+ "use_5hz_lm": true,
136
+ "lm_temperature": 0.6,
137
+ "lm_cfg_scale": 1.0,
138
+ "lm_top_k": 0,
139
+ "lm_top_p": 1.0,
140
+ "lm_repetition_penalty": 1.0
141
+ }'
142
+ ```
143
+
144
+ When `use_5hz_lm=true` and the server generates LM codes, the job `result` will also include the following optional fields:
145
+
146
+ - `bpm`
147
+ - `duration`
148
+ - `genres`
149
+ - `keyscale`
150
+ - `timesignature`
151
+
152
  > Note: If you use `curl -d` but **forget** to add `-H 'Content-Type: application/json'`, curl will default to sending `application/x-www-form-urlencoded`, and older server versions will return 415.
153
 
154
  **Form Method (no file upload, application/x-www-form-urlencoded)**:
acestep/genres_vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
acestep/llm_inference.py CHANGED
@@ -72,18 +72,35 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
72
  tokenizer: AutoTokenizer,
73
  enabled: bool = True,
74
  debug: bool = False,
 
 
75
  ):
76
  """
77
  Initialize the constrained logits processor.
78
 
 
 
 
 
79
  Args:
80
  tokenizer: The tokenizer to use for encoding/decoding
81
  enabled: Whether to enable constrained decoding
82
  debug: Whether to print debug information
 
 
 
83
  """
84
  self.tokenizer = tokenizer
85
  self.enabled = enabled
86
  self.debug = debug
 
 
 
 
 
 
 
 
87
 
88
  # Current state
89
  self.state = FSMState.THINK_TAG
@@ -93,6 +110,23 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
93
  # Pre-compute token IDs for efficiency
94
  self._precompute_tokens()
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  # Field definitions
97
  self.field_specs = {
98
  "bpm": {"min": 30, "max": 300},
@@ -117,7 +151,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
117
  FSMState.THINK_END_TAG: "</think>",
118
  }
119
 
120
- # State transitions
 
 
 
 
121
  self.next_state = {
122
  FSMState.THINK_TAG: FSMState.NEWLINE_AFTER_THINK,
123
  FSMState.NEWLINE_AFTER_THINK: FSMState.BPM_NAME,
@@ -126,10 +164,6 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
126
  FSMState.NEWLINE_AFTER_BPM: FSMState.DURATION_NAME,
127
  FSMState.DURATION_NAME: FSMState.DURATION_VALUE,
128
  FSMState.DURATION_VALUE: FSMState.NEWLINE_AFTER_DURATION,
129
- FSMState.NEWLINE_AFTER_DURATION: FSMState.GENRES_NAME,
130
- FSMState.GENRES_NAME: FSMState.GENRES_VALUE,
131
- FSMState.GENRES_VALUE: FSMState.NEWLINE_AFTER_GENRES,
132
- FSMState.NEWLINE_AFTER_GENRES: FSMState.KEYSCALE_NAME,
133
  FSMState.KEYSCALE_NAME: FSMState.KEYSCALE_VALUE,
134
  FSMState.KEYSCALE_VALUE: FSMState.NEWLINE_AFTER_KEYSCALE,
135
  FSMState.NEWLINE_AFTER_KEYSCALE: FSMState.TIMESIG_NAME,
@@ -139,6 +173,21 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
139
  FSMState.THINK_END_TAG: FSMState.CODES_GENERATION,
140
  FSMState.CODES_GENERATION: FSMState.COMPLETED,
141
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  def _precompute_tokens(self):
144
  """Pre-compute commonly used token IDs for efficiency."""
@@ -189,6 +238,328 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
189
 
190
  # Vocab size
191
  self.vocab_size = len(self.tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  def reset(self):
194
  """Reset the processor state for a new generation."""
@@ -196,6 +567,27 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
196
  self.position_in_state = 0
197
  self.accumulated_value = ""
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  def _get_allowed_tokens_for_fixed_string(self, fixed_str: str) -> List[int]:
200
  """
201
  Get the token IDs that can continue the fixed string from current position.
@@ -400,13 +792,14 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
400
  scores: [batch_size, vocab_size] logits for next token
401
 
402
  Returns:
403
- Modified scores with invalid tokens masked to -inf
404
  """
405
  if not self.enabled:
406
- return scores
407
 
408
  if self.state == FSMState.COMPLETED or self.state == FSMState.CODES_GENERATION:
409
- return scores # No constraints in codes generation phase
 
410
 
411
  batch_size = scores.shape[0]
412
 
@@ -414,7 +807,39 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
414
  for b in range(batch_size):
415
  scores[b] = self._process_single_sequence(input_ids[b], scores[b:b+1])
416
 
417
- return scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
  def _process_single_sequence(
420
  self,
@@ -482,17 +907,38 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
482
  scores = scores + mask
483
 
484
  elif self.state == FSMState.GENRES_VALUE:
485
- if self._should_end_text_field(scores):
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  if self.newline_token:
487
  mask[0, self.newline_token] = 0
488
- self._transition_to_next_state()
 
489
  scores = scores + mask
490
  else:
491
- # Allow any token except newline if we don't have content yet
492
- if not self.accumulated_value.strip():
493
  if self.newline_token:
494
- scores[0, self.newline_token] = float('-inf')
495
- # Otherwise, don't constrain (allow any token including newline)
 
 
 
 
 
 
 
496
 
497
  elif self.state == FSMState.KEYSCALE_VALUE:
498
  if self._is_keyscale_complete():
@@ -591,6 +1037,8 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
591
 
592
  class LLMHandler:
593
  """5Hz LM Handler for audio code generation"""
 
 
594
 
595
  def __init__(self):
596
  """Initialize LLMHandler with default values"""
@@ -602,6 +1050,9 @@ class LLMHandler:
602
  self.device = "cpu"
603
  self.dtype = torch.float32
604
  self.offload_to_cpu = False
 
 
 
605
 
606
  def get_available_5hz_lm_models(self) -> List[str]:
607
  """Scan and return all model directory names starting with 'acestep-5Hz-lm-'"""
@@ -690,6 +1141,16 @@ class LLMHandler:
690
  logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
691
  self.llm_tokenizer = llm_tokenizer
692
 
 
 
 
 
 
 
 
 
 
 
693
  # Initialize based on user-selected backend
694
  if backend == "vllm":
695
  # Try to initialize with vllm
@@ -795,13 +1256,15 @@ class LLMHandler:
795
  repetition_penalty: float = 1.0,
796
  use_constrained_decoding: bool = True,
797
  constrained_decoding_debug: bool = False,
 
 
798
  ) -> Tuple[Dict[str, Any], str, str]:
799
  """Generate metadata and audio codes using 5Hz LM with vllm backend
800
 
801
  Args:
802
  caption: Text caption for music generation
803
  lyrics: Lyrics for music generation
804
- temperature: Sampling temperature
805
  cfg_scale: CFG scale (>1.0 enables CFG)
806
  negative_prompt: Negative prompt for CFG
807
  top_k: Top-k sampling parameter
@@ -809,6 +1272,10 @@ class LLMHandler:
809
  repetition_penalty: Repetition penalty
810
  use_constrained_decoding: Whether to use FSM-based constrained decoding
811
  constrained_decoding_debug: Whether to print debug info for constrained decoding
 
 
 
 
812
  """
813
  try:
814
  from nanovllm import SamplingParams
@@ -816,20 +1283,28 @@ class LLMHandler:
816
  formatted_prompt = self.build_formatted_prompt(caption, lyrics)
817
  logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
818
 
819
- # Create constrained decoding processor if enabled
 
 
 
 
 
820
  constrained_processor = None
821
  update_state_fn = None
822
- if use_constrained_decoding:
823
- constrained_processor = MetadataConstrainedLogitsProcessor(
824
- tokenizer=self.llm_tokenizer,
825
- enabled=True,
826
- debug=constrained_decoding_debug,
827
- )
 
 
 
828
  update_state_fn = constrained_processor.update_state
829
 
830
  sampling_params = SamplingParams(
831
  max_tokens=self.max_model_len-64,
832
- temperature=temperature,
833
  cfg_scale=cfg_scale,
834
  top_k=top_k,
835
  top_p=top_p,
@@ -880,21 +1355,31 @@ class LLMHandler:
880
  repetition_penalty: float,
881
  use_constrained_decoding: bool = True,
882
  constrained_decoding_debug: bool = False,
 
 
883
  ) -> str:
884
  """Shared vllm path: accept prebuilt formatted prompt and return text."""
885
  from nanovllm import SamplingParams
886
 
887
- # Create constrained processor if enabled
 
 
 
 
888
  constrained_processor = None
889
- if use_constrained_decoding:
890
- constrained_processor = MetadataConstrainedLogitsProcessor(
891
- tokenizer=self.llm_tokenizer,
892
- debug=constrained_decoding_debug,
893
- )
 
 
 
 
894
 
895
  sampling_params = SamplingParams(
896
  max_tokens=self.max_model_len - 64,
897
- temperature=temperature,
898
  cfg_scale=cfg_scale,
899
  top_k=top_k,
900
  top_p=top_p,
@@ -940,13 +1425,15 @@ class LLMHandler:
940
  repetition_penalty: float = 1.0,
941
  use_constrained_decoding: bool = True,
942
  constrained_decoding_debug: bool = False,
 
 
943
  ) -> Tuple[Dict[str, Any], str, str]:
944
  """Generate metadata and audio codes using 5Hz LM with PyTorch backend
945
 
946
  Args:
947
  caption: Text caption for music generation
948
  lyrics: Lyrics for music generation
949
- temperature: Sampling temperature
950
  cfg_scale: CFG scale (>1.0 enables CFG)
951
  negative_prompt: Negative prompt for CFG
952
  top_k: Top-k sampling parameter
@@ -954,6 +1441,10 @@ class LLMHandler:
954
  repetition_penalty: Repetition penalty
955
  use_constrained_decoding: Whether to use FSM-based constrained decoding
956
  constrained_decoding_debug: Whether to print debug info for constrained decoding
 
 
 
 
957
  """
958
  try:
959
  formatted_prompt = self.build_formatted_prompt(caption, lyrics)
@@ -992,14 +1483,21 @@ class LLMHandler:
992
 
993
  streamer = TqdmTokenStreamer(total=max_new_tokens)
994
 
995
- # Create constrained decoding processor if enabled
 
 
 
 
996
  constrained_processor = None
997
- if use_constrained_decoding:
998
- constrained_processor = MetadataConstrainedLogitsProcessor(
999
- tokenizer=self.llm_tokenizer,
1000
- enabled=True,
1001
- debug=constrained_decoding_debug,
1002
- )
 
 
 
1003
 
1004
  # Build logits processor list (only for CFG and repetition penalty)
1005
  logits_processor = LogitsProcessorList()
@@ -1036,7 +1534,7 @@ class LLMHandler:
1036
  batch_input_ids=batch_input_ids,
1037
  batch_attention_mask=batch_attention_mask,
1038
  max_new_tokens=max_new_tokens,
1039
- temperature=temperature,
1040
  cfg_scale=cfg_scale,
1041
  top_k=top_k,
1042
  top_p=top_p,
@@ -1048,8 +1546,8 @@ class LLMHandler:
1048
 
1049
  # Extract only the conditional output (first in batch)
1050
  outputs = outputs[0:1] # Keep only conditional output
1051
- elif use_constrained_decoding:
1052
- # Use custom generation loop for constrained decoding (non-CFG)
1053
  input_ids = inputs['input_ids']
1054
  attention_mask = inputs.get('attention_mask', None)
1055
 
@@ -1057,7 +1555,7 @@ class LLMHandler:
1057
  input_ids=input_ids,
1058
  attention_mask=attention_mask,
1059
  max_new_tokens=max_new_tokens,
1060
- temperature=temperature,
1061
  top_k=top_k,
1062
  top_p=top_p,
1063
  repetition_penalty=repetition_penalty,
@@ -1071,8 +1569,8 @@ class LLMHandler:
1071
  outputs = self.llm.generate(
1072
  **inputs,
1073
  max_new_tokens=max_new_tokens,
1074
- temperature=temperature if temperature > 0 else 1.0,
1075
- do_sample=True if temperature > 0 else False,
1076
  top_k=top_k if top_k is not None and top_k > 0 else None,
1077
  top_p=top_p if top_p is not None and 0.0 < top_p < 1.0 else None,
1078
  logits_processor=logits_processor if len(logits_processor) > 0 else None,
@@ -1134,13 +1632,15 @@ class LLMHandler:
1134
  truncation=True,
1135
  )
1136
 
1137
- # Create constrained processor if enabled
1138
  constrained_processor = None
1139
  if use_constrained_decoding:
1140
- constrained_processor = MetadataConstrainedLogitsProcessor(
1141
- tokenizer=self.llm_tokenizer,
1142
- debug=constrained_decoding_debug,
1143
- )
 
 
1144
 
1145
  with self._load_model_context():
1146
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
@@ -1262,13 +1762,15 @@ class LLMHandler:
1262
  repetition_penalty: float = 1.0,
1263
  use_constrained_decoding: bool = True,
1264
  constrained_decoding_debug: bool = False,
 
 
1265
  ) -> Tuple[Dict[str, Any], str, str]:
1266
  """Generate metadata and audio codes using 5Hz LM
1267
 
1268
  Args:
1269
  caption: Text caption for music generation
1270
  lyrics: Lyrics for music generation
1271
- temperature: Sampling temperature
1272
  cfg_scale: CFG scale (>1.0 enables CFG)
1273
  negative_prompt: Negative prompt for CFG
1274
  top_k: Top-k sampling parameter
@@ -1276,6 +1778,10 @@ class LLMHandler:
1276
  repetition_penalty: Repetition penalty
1277
  use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
1278
  constrained_decoding_debug: Whether to print debug info for constrained decoding
 
 
 
 
1279
  """
1280
  # Check if 5Hz LM is initialized
1281
  if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
@@ -1293,17 +1799,101 @@ class LLMHandler:
1293
 
1294
  if self.llm_backend == "vllm":
1295
  return self.generate_with_5hz_lm_vllm(
1296
- caption, lyrics, temperature, cfg_scale, negative_prompt,
1297
- top_k, top_p, repetition_penalty,
1298
- use_constrained_decoding, constrained_decoding_debug
 
 
 
 
 
 
 
 
 
1299
  )
1300
  else:
1301
  return self.generate_with_5hz_lm_pt(
1302
- caption, lyrics, temperature, cfg_scale, negative_prompt,
1303
- top_k, top_p, repetition_penalty,
1304
- use_constrained_decoding, constrained_decoding_debug
 
 
 
 
 
 
 
 
 
1305
  )
1306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1307
  def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False) -> str:
1308
  """
1309
  Build the chat-formatted prompt for 5Hz LM from caption/lyrics.
 
72
  tokenizer: AutoTokenizer,
73
  enabled: bool = True,
74
  debug: bool = False,
75
+ genres_vocab_path: Optional[str] = None,
76
+ skip_genres: bool = True,
77
  ):
78
  """
79
  Initialize the constrained logits processor.
80
 
81
+ This processor should be initialized once when loading the LLM and reused
82
+ for all generations. Use update_caption() before each generation to update
83
+ the caption-based genre filtering.
84
+
85
  Args:
86
  tokenizer: The tokenizer to use for encoding/decoding
87
  enabled: Whether to enable constrained decoding
88
  debug: Whether to print debug information
89
+ genres_vocab_path: Path to genres vocabulary file (one genre per line)
90
+ If None, defaults to "acestep/genres_vocab.txt"
91
+ skip_genres: Whether to skip genres generation in metadata (default True)
92
  """
93
  self.tokenizer = tokenizer
94
  self.enabled = enabled
95
  self.debug = debug
96
+ self.skip_genres = skip_genres
97
+ self.caption: Optional[str] = None # Set via update_caption() before each generation
98
+
99
+ # Temperature settings for different generation phases (set per-generation)
100
+ # If set, the processor will apply temperature scaling (divide logits by temperature)
101
+ # Note: Set base sampler temperature to 1.0 when using processor-based temperature
102
+ self.metadata_temperature: Optional[float] = None
103
+ self.codes_temperature: Optional[float] = None
104
 
105
  # Current state
106
  self.state = FSMState.THINK_TAG
 
110
  # Pre-compute token IDs for efficiency
111
  self._precompute_tokens()
112
 
113
+ # Genres vocabulary for constrained decoding
114
+ self.genres_vocab_path = genres_vocab_path or os.path.join(
115
+ os.path.dirname(os.path.abspath(__file__)), "genres_vocab.txt"
116
+ )
117
+ self.genres_vocab: List[str] = [] # Full vocab
118
+ self.genres_vocab_mtime: float = 0.0
119
+ self.genres_trie: Dict = {} # Trie for full vocab (fallback)
120
+ self.caption_genres_trie: Dict = {} # Trie for caption-matched genres (priority)
121
+ self.caption_matched_genres: List[str] = [] # Genres matched from caption
122
+ self._char_to_tokens: Dict[str, set] = {} # Precomputed char -> token IDs mapping
123
+
124
+ # Precompute token mappings once (O(vocab_size), runs once at init)
125
+ self._precompute_char_token_mapping()
126
+ self._load_genres_vocab()
127
+
128
+ # Note: Caption-based genre filtering is initialized via update_caption() before each generation
129
+
130
  # Field definitions
131
  self.field_specs = {
132
  "bpm": {"min": 30, "max": 300},
 
151
  FSMState.THINK_END_TAG: "</think>",
152
  }
153
 
154
+ # State transitions - build dynamically based on skip_genres
155
+ self._build_state_transitions()
156
+
157
+ def _build_state_transitions(self):
158
+ """Build state transition map based on skip_genres setting."""
159
  self.next_state = {
160
  FSMState.THINK_TAG: FSMState.NEWLINE_AFTER_THINK,
161
  FSMState.NEWLINE_AFTER_THINK: FSMState.BPM_NAME,
 
164
  FSMState.NEWLINE_AFTER_BPM: FSMState.DURATION_NAME,
165
  FSMState.DURATION_NAME: FSMState.DURATION_VALUE,
166
  FSMState.DURATION_VALUE: FSMState.NEWLINE_AFTER_DURATION,
 
 
 
 
167
  FSMState.KEYSCALE_NAME: FSMState.KEYSCALE_VALUE,
168
  FSMState.KEYSCALE_VALUE: FSMState.NEWLINE_AFTER_KEYSCALE,
169
  FSMState.NEWLINE_AFTER_KEYSCALE: FSMState.TIMESIG_NAME,
 
173
  FSMState.THINK_END_TAG: FSMState.CODES_GENERATION,
174
  FSMState.CODES_GENERATION: FSMState.COMPLETED,
175
  }
176
+
177
+ if self.skip_genres:
178
+ # Skip genres: NEWLINE_AFTER_DURATION -> KEYSCALE_NAME directly
179
+ self.next_state[FSMState.NEWLINE_AFTER_DURATION] = FSMState.KEYSCALE_NAME
180
+ else:
181
+ # Include genres in the flow
182
+ self.next_state[FSMState.NEWLINE_AFTER_DURATION] = FSMState.GENRES_NAME
183
+ self.next_state[FSMState.GENRES_NAME] = FSMState.GENRES_VALUE
184
+ self.next_state[FSMState.GENRES_VALUE] = FSMState.NEWLINE_AFTER_GENRES
185
+ self.next_state[FSMState.NEWLINE_AFTER_GENRES] = FSMState.KEYSCALE_NAME
186
+
187
+ def set_skip_genres(self, skip: bool):
188
+ """Set whether to skip genres generation and rebuild state transitions."""
189
+ self.skip_genres = skip
190
+ self._build_state_transitions()
191
 
192
  def _precompute_tokens(self):
193
  """Pre-compute commonly used token IDs for efficiency."""
 
238
 
239
  # Vocab size
240
  self.vocab_size = len(self.tokenizer)
241
+
242
+ # Comma token for multi-genre support
243
+ comma_tokens = self.tokenizer.encode(",", add_special_tokens=False)
244
+ self.comma_token = comma_tokens[-1] if comma_tokens else None
245
+
246
+ def _load_genres_vocab(self):
247
+ """
248
+ Load genres vocabulary from file. Supports hot reload by checking file mtime.
249
+ File format: one genre per line, lines starting with # are comments.
250
+ """
251
+ if not os.path.exists(self.genres_vocab_path):
252
+ if self.debug:
253
+ logger.debug(f"Genres vocab file not found: {self.genres_vocab_path}")
254
+ return
255
+
256
+ try:
257
+ mtime = os.path.getmtime(self.genres_vocab_path)
258
+ if mtime <= self.genres_vocab_mtime:
259
+ return # File hasn't changed
260
+
261
+ with open(self.genres_vocab_path, 'r', encoding='utf-8') as f:
262
+ genres = []
263
+ for line in f:
264
+ line = line.strip()
265
+ if line and not line.startswith('#'):
266
+ genres.append(line.lower())
267
+
268
+ self.genres_vocab = genres
269
+ self.genres_vocab_mtime = mtime
270
+ self._build_genres_trie()
271
+
272
+ if self.debug:
273
+ logger.debug(f"Loaded {len(self.genres_vocab)} genres from {self.genres_vocab_path}")
274
+ except Exception as e:
275
+ logger.warning(f"Failed to load genres vocab: {e}")
276
+
277
+ def _build_genres_trie(self):
278
+ """
279
+ Build a trie (prefix tree) from genres vocabulary for efficient prefix matching.
280
+ Each node is a dict with:
281
+ - '_end': True if this node represents a complete genre
282
+ - other keys: next characters in the trie
283
+ """
284
+ self.genres_trie = {}
285
+
286
+ for genre in self.genres_vocab:
287
+ node = self.genres_trie
288
+ for char in genre:
289
+ if char not in node:
290
+ node[char] = {}
291
+ node = node[char]
292
+ node['_end'] = True # Mark end of a complete genre
293
+
294
+ if self.debug:
295
+ logger.debug(f"Built genres trie with {len(self.genres_vocab)} entries")
296
+
297
+ def _extract_caption_genres(self, caption: str):
298
+ """
299
+ Extract genres from the user's caption that match entries in the vocabulary.
300
+ This creates a smaller trie for faster and more relevant genre generation.
301
+
302
+ Strategy (optimized - O(words * max_genre_len) instead of O(vocab_size)):
303
+ 1. Extract words/phrases from caption
304
+ 2. For each word, use trie to find all vocab entries that START with this word
305
+ 3. Build a separate trie from matched genres
306
+ """
307
+ if not caption or not self.genres_vocab:
308
+ return
309
+
310
+ caption_lower = caption.lower()
311
+ matched_genres = set()
312
+
313
+ # Extract words from caption (split by common delimiters)
314
+ import re
315
+ words = re.split(r'[,\s\-_/\\|]+', caption_lower)
316
+ words = [w.strip() for w in words if w.strip() and len(w.strip()) >= 2]
317
+
318
+ # For each word, find genres in trie that start with this word
319
+ for word in words:
320
+ # Find all genres starting with this word using trie traversal
321
+ node = self._get_genres_trie_node(word)
322
+ if node is not None:
323
+ # Collect all complete genres under this node
324
+ self._collect_complete_genres(node, word, matched_genres)
325
+
326
+ # Also check if any word appears as a substring in short genres (< 20 chars)
327
+ # This is a quick check for common single-word genres
328
+ genres_set = set(self.genres_vocab)
329
+ for word in words:
330
+ if word in genres_set:
331
+ matched_genres.add(word)
332
+
333
+ if not matched_genres:
334
+ if self.debug:
335
+ logger.debug(f"No genres matched in caption, using full vocab")
336
+ return
337
+
338
+ # Build a trie from matched genres
339
+ self.caption_matched_genres = list(matched_genres)
340
+ self.caption_genres_trie = {}
341
+
342
+ for genre in matched_genres:
343
+ node = self.caption_genres_trie
344
+ for char in genre:
345
+ if char not in node:
346
+ node[char] = {}
347
+ node = node[char]
348
+ node['_end'] = True
349
+
350
+ if self.debug:
351
+ logger.debug(f"Matched {len(matched_genres)} genres from caption: {list(matched_genres)[:5]}...")
352
+
353
+ def _collect_complete_genres(self, node: Dict, prefix: str, result: set, max_depth: int = 50):
354
+ """
355
+ Recursively collect all complete genres under a trie node.
356
+ Limited depth to avoid too many matches.
357
+ """
358
+ if max_depth <= 0:
359
+ return
360
+
361
+ if node.get('_end', False):
362
+ result.add(prefix)
363
+
364
+ # Limit total collected genres to avoid slowdown
365
+ if len(result) >= 100:
366
+ return
367
+
368
+ for char, child_node in node.items():
369
+ if char not in ('_end', '_tokens'):
370
+ self._collect_complete_genres(child_node, prefix + char, result, max_depth - 1)
371
+
372
+ def _precompute_char_token_mapping(self):
373
+ """
374
+ Precompute mapping from characters to token IDs and token decoded texts.
375
+ This allows O(1) lookup instead of calling tokenizer.encode()/decode() at runtime.
376
+
377
+ Time complexity: O(vocab_size) - runs once during initialization
378
+
379
+ Note: Many subword tokenizers (like Qwen) add space prefixes to tokens.
380
+ We need to handle both the raw first char and the first non-space char.
381
+ """
382
+ self._char_to_tokens: Dict[str, set] = {}
383
+ self._token_to_text: Dict[int, str] = {} # Precomputed decoded text for each token
384
+
385
+ # For each token in vocabulary, get its decoded text
386
+ for token_id in range(self.vocab_size):
387
+ try:
388
+ text = self.tokenizer.decode([token_id])
389
+
390
+ if not text:
391
+ continue
392
+
393
+ # Store the decoded text (normalized to lowercase)
394
+ # Keep leading spaces for proper concatenation (e.g., " rock" in "pop rock")
395
+ # Only rstrip trailing whitespace, unless it's a pure whitespace token
396
+ text_lower = text.lower()
397
+ if text_lower.strip(): # Has non-whitespace content
398
+ normalized_text = text_lower.rstrip()
399
+ else: # Pure whitespace token
400
+ normalized_text = " " # Normalize to single space
401
+ self._token_to_text[token_id] = normalized_text
402
+
403
+ # Map first character (including space) to this token
404
+ first_char = text[0].lower()
405
+ if first_char not in self._char_to_tokens:
406
+ self._char_to_tokens[first_char] = set()
407
+ self._char_to_tokens[first_char].add(token_id)
408
+
409
+ # Also map first non-space character to this token
410
+ # This handles tokenizers that add space prefixes (e.g., " pop" -> maps to 'p')
411
+ stripped_text = text.lstrip()
412
+ if stripped_text and stripped_text != text:
413
+ first_nonspace_char = stripped_text[0].lower()
414
+ if first_nonspace_char not in self._char_to_tokens:
415
+ self._char_to_tokens[first_nonspace_char] = set()
416
+ self._char_to_tokens[first_nonspace_char].add(token_id)
417
+
418
+ except Exception:
419
+ continue
420
+
421
+ if self.debug:
422
+ logger.debug(f"Precomputed char->token mapping for {len(self._char_to_tokens)} unique characters")
423
+
424
+ def _try_reload_genres_vocab(self):
425
+ """Check if genres vocab file has been updated and reload if necessary."""
426
+ if not os.path.exists(self.genres_vocab_path):
427
+ return
428
+
429
+ try:
430
+ mtime = os.path.getmtime(self.genres_vocab_path)
431
+ if mtime > self.genres_vocab_mtime:
432
+ self._load_genres_vocab()
433
+ except Exception:
434
+ pass # Ignore errors during hot reload check
435
+
436
+ def _get_genres_trie_node(self, prefix: str) -> Optional[Dict]:
437
+ """
438
+ Get the trie node for a given prefix.
439
+ Returns None if the prefix is not valid (no genres start with this prefix).
440
+ """
441
+ node = self.genres_trie
442
+ for char in prefix.lower():
443
+ if char not in node:
444
+ return None
445
+ node = node[char]
446
+ return node
447
+
448
+ def _is_complete_genre(self, text: str) -> bool:
449
+ """Check if the given text is a complete genre in the vocabulary."""
450
+ node = self._get_genres_trie_node(text.strip())
451
+ return node is not None and node.get('_end', False)
452
+
453
+ def _get_trie_node_from_trie(self, trie: Dict, prefix: str) -> Optional[Dict]:
454
+ """Get a trie node from a specific trie (helper for caption vs full trie)."""
455
+ node = trie
456
+ for char in prefix.lower():
457
+ if char not in node:
458
+ return None
459
+ node = node[char]
460
+ return node
461
+
462
+ def _get_allowed_genres_tokens(self) -> List[int]:
463
+ """
464
+ Get allowed tokens for genres field based on trie matching.
465
+
466
+ The entire genres string (including commas) must match a complete entry in the vocab.
467
+ For example, if vocab contains "pop, rock, jazz", the generated string must exactly
468
+ match that entry - we don't treat commas as separators for individual genres.
469
+
470
+ Strategy:
471
+ 1. If caption-matched genres exist, use that smaller trie first (faster + more relevant)
472
+ 2. If no caption matches or prefix not in caption trie, fallback to full vocab trie
473
+ 3. Get valid next characters from current trie node
474
+ 4. For each candidate token, verify the full decoded text forms a valid trie prefix
475
+ """
476
+ if not self.genres_vocab:
477
+ # No vocab loaded, allow all except newline if empty
478
+ return []
479
+
480
+ # Use the full accumulated value (don't split by comma - treat as single entry)
481
+ accumulated = self.accumulated_value.lower()
482
+ current_genre_prefix = accumulated.strip()
483
+
484
+ # Determine which trie to use: caption-matched (priority) or full vocab (fallback)
485
+ use_caption_trie = False
486
+ current_node = None
487
+
488
+ # Try caption-matched trie first if available
489
+ if self.caption_genres_trie:
490
+ if current_genre_prefix == "":
491
+ current_node = self.caption_genres_trie
492
+ use_caption_trie = True
493
+ else:
494
+ current_node = self._get_trie_node_from_trie(self.caption_genres_trie, current_genre_prefix)
495
+ if current_node is not None:
496
+ use_caption_trie = True
497
+
498
+ # Fallback to full vocab trie
499
+ if current_node is None:
500
+ if current_genre_prefix == "":
501
+ current_node = self.genres_trie
502
+ else:
503
+ current_node = self._get_genres_trie_node(current_genre_prefix)
504
+
505
+ if current_node is None:
506
+ # Invalid prefix, force newline to end
507
+ if self.newline_token:
508
+ return [self.newline_token]
509
+ return []
510
+
511
+ # Get valid next characters from trie node
512
+ valid_next_chars = set(k for k in current_node.keys() if k not in ('_end', '_tokens'))
513
+
514
+ # If current value is a complete genre, allow newline to end
515
+ is_complete = current_node.get('_end', False)
516
+
517
+ if not valid_next_chars:
518
+ # No more characters to match, only allow newline if complete
519
+ allowed = set()
520
+ if is_complete and self.newline_token:
521
+ allowed.add(self.newline_token)
522
+ return list(allowed)
523
+
524
+ # Collect candidate tokens based on first character
525
+ candidate_tokens = set()
526
+ for char in valid_next_chars:
527
+ if char in self._char_to_tokens:
528
+ candidate_tokens.update(self._char_to_tokens[char])
529
+
530
+ # Select the appropriate trie for validation
531
+ active_trie = self.caption_genres_trie if use_caption_trie else self.genres_trie
532
+
533
+ # Validate each candidate token: check if prefix + decoded_token is a valid trie prefix
534
+ allowed = set()
535
+ for token_id in candidate_tokens:
536
+ # Use precomputed decoded text (already normalized)
537
+ decoded_normalized = self._token_to_text.get(token_id, "")
538
+
539
+ if not decoded_normalized or not decoded_normalized.strip():
540
+ # Token decodes to empty or only whitespace - allow if space/comma is a valid next char
541
+ if ' ' in valid_next_chars or ',' in valid_next_chars:
542
+ allowed.add(token_id)
543
+ continue
544
+
545
+ # Build new prefix by appending decoded token
546
+ # Handle space-prefixed tokens (e.g., " rock" from "pop rock")
547
+ if decoded_normalized.startswith(' ') or decoded_normalized.startswith(','):
548
+ # Token has leading space/comma - append directly
549
+ new_prefix = current_genre_prefix + decoded_normalized
550
+ else:
551
+ new_prefix = current_genre_prefix + decoded_normalized
552
+
553
+ # Check if new_prefix is a valid prefix in the active trie
554
+ new_node = self._get_trie_node_from_trie(active_trie, new_prefix)
555
+ if new_node is not None:
556
+ allowed.add(token_id)
557
+
558
+ # If current value is a complete genre, also allow newline
559
+ if is_complete and self.newline_token:
560
+ allowed.add(self.newline_token)
561
+
562
+ return list(allowed)
563
 
564
  def reset(self):
565
  """Reset the processor state for a new generation."""
 
567
  self.position_in_state = 0
568
  self.accumulated_value = ""
569
 
570
+ def update_caption(self, caption: Optional[str]):
571
+ """
572
+ Update the caption and rebuild the caption-matched genres trie.
573
+ Call this before each generation to prioritize genres from the new caption.
574
+
575
+ Args:
576
+ caption: User's input caption. If None or empty, clears caption matching.
577
+ """
578
+ # Check for hot reload of genres vocabulary
579
+ self._try_reload_genres_vocab()
580
+
581
+ self.caption = caption
582
+ self.caption_genres_trie = {}
583
+ self.caption_matched_genres = []
584
+
585
+ if caption:
586
+ self._extract_caption_genres(caption)
587
+
588
+ # Also reset FSM state for new generation
589
+ self.reset()
590
+
591
  def _get_allowed_tokens_for_fixed_string(self, fixed_str: str) -> List[int]:
592
  """
593
  Get the token IDs that can continue the fixed string from current position.
 
792
  scores: [batch_size, vocab_size] logits for next token
793
 
794
  Returns:
795
+ Modified scores with invalid tokens masked to -inf and temperature scaling applied
796
  """
797
  if not self.enabled:
798
+ return self._apply_temperature_scaling(scores)
799
 
800
  if self.state == FSMState.COMPLETED or self.state == FSMState.CODES_GENERATION:
801
+ # No constraints in codes generation phase, but still apply temperature
802
+ return self._apply_temperature_scaling(scores)
803
 
804
  batch_size = scores.shape[0]
805
 
 
807
  for b in range(batch_size):
808
  scores[b] = self._process_single_sequence(input_ids[b], scores[b:b+1])
809
 
810
+ # Apply temperature scaling after constraint masking
811
+ return self._apply_temperature_scaling(scores)
812
+
813
+ def _apply_temperature_scaling(self, scores: torch.FloatTensor) -> torch.FloatTensor:
814
+ """
815
+ Apply temperature scaling based on current generation phase.
816
+
817
+ Temperature scaling: logits = logits / temperature
818
+ - Lower temperature (< 1.0) makes distribution sharper (more deterministic)
819
+ - Higher temperature (> 1.0) makes distribution flatter (more diverse)
820
+
821
+ Args:
822
+ scores: [batch_size, vocab_size] logits
823
+
824
+ Returns:
825
+ Temperature-scaled logits
826
+ """
827
+ # Determine which temperature to use based on current state
828
+ if self.state == FSMState.CODES_GENERATION or self.state == FSMState.COMPLETED:
829
+ temperature = self.codes_temperature
830
+ else:
831
+ temperature = self.metadata_temperature
832
+
833
+ # If no temperature is set for this phase, return scores unchanged
834
+ if temperature is None:
835
+ return scores
836
+
837
+ # Avoid division by zero
838
+ if temperature <= 0:
839
+ temperature = 1e-6
840
+
841
+ # Apply temperature scaling
842
+ return scores / temperature
843
 
844
  def _process_single_sequence(
845
  self,
 
907
  scores = scores + mask
908
 
909
  elif self.state == FSMState.GENRES_VALUE:
910
+ # Try to hot-reload genres vocab if file has changed
911
+ self._try_reload_genres_vocab()
912
+
913
+ # Get allowed tokens based on genres vocabulary
914
+ allowed = self._get_allowed_genres_tokens()
915
+
916
+ if allowed:
917
+ # Use vocabulary-constrained decoding
918
+ for t in allowed:
919
+ mask[0, t] = 0
920
+ scores = scores + mask
921
+ elif self.genres_vocab:
922
+ # Vocab is loaded but no valid continuation found
923
+ # Force newline to end the field
924
  if self.newline_token:
925
  mask[0, self.newline_token] = 0
926
+ if self.debug:
927
+ logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
928
  scores = scores + mask
929
  else:
930
+ # Fallback: no vocab loaded, use probability-based ending
931
+ if self._should_end_text_field(scores):
932
  if self.newline_token:
933
+ mask[0, self.newline_token] = 0
934
+ self._transition_to_next_state()
935
+ scores = scores + mask
936
+ else:
937
+ # Allow any token except newline if we don't have content yet
938
+ if not self.accumulated_value.strip():
939
+ if self.newline_token:
940
+ scores[0, self.newline_token] = float('-inf')
941
+ # Otherwise, don't constrain (fallback behavior)
942
 
943
  elif self.state == FSMState.KEYSCALE_VALUE:
944
  if self._is_keyscale_complete():
 
1037
 
1038
  class LLMHandler:
1039
  """5Hz LM Handler for audio code generation"""
1040
+
1041
+ STOP_REASONING_TAG = "</think>"
1042
 
1043
  def __init__(self):
1044
  """Initialize LLMHandler with default values"""
 
1050
  self.device = "cpu"
1051
  self.dtype = torch.float32
1052
  self.offload_to_cpu = False
1053
+
1054
+ # Shared constrained decoding processor (initialized once when LLM is loaded)
1055
+ self.constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None
1056
 
1057
  def get_available_5hz_lm_models(self) -> List[str]:
1058
  """Scan and return all model directory names starting with 'acestep-5Hz-lm-'"""
 
1141
  logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
1142
  self.llm_tokenizer = llm_tokenizer
1143
 
1144
+ # Initialize shared constrained decoding processor (one-time initialization)
1145
+ logger.info("Initializing constrained decoding processor...")
1146
+ processor_start = time.time()
1147
+ self.constrained_processor = MetadataConstrainedLogitsProcessor(
1148
+ tokenizer=self.llm_tokenizer,
1149
+ enabled=True,
1150
+ debug=False,
1151
+ )
1152
+ logger.info(f"Constrained processor initialized in {time.time() - processor_start:.2f} seconds")
1153
+
1154
  # Initialize based on user-selected backend
1155
  if backend == "vllm":
1156
  # Try to initialize with vllm
 
1256
  repetition_penalty: float = 1.0,
1257
  use_constrained_decoding: bool = True,
1258
  constrained_decoding_debug: bool = False,
1259
+ metadata_temperature: Optional[float] = 0.85,
1260
+ codes_temperature: Optional[float] = None,
1261
  ) -> Tuple[Dict[str, Any], str, str]:
1262
  """Generate metadata and audio codes using 5Hz LM with vllm backend
1263
 
1264
  Args:
1265
  caption: Text caption for music generation
1266
  lyrics: Lyrics for music generation
1267
+ temperature: Base sampling temperature (used if phase-specific temps not set)
1268
  cfg_scale: CFG scale (>1.0 enables CFG)
1269
  negative_prompt: Negative prompt for CFG
1270
  top_k: Top-k sampling parameter
 
1272
  repetition_penalty: Repetition penalty
1273
  use_constrained_decoding: Whether to use FSM-based constrained decoding
1274
  constrained_decoding_debug: Whether to print debug info for constrained decoding
1275
+ metadata_temperature: Temperature for metadata generation (lower = more accurate)
1276
+ If None, uses base temperature
1277
+ codes_temperature: Temperature for audio codes generation (higher = more diverse)
1278
+ If None, uses base temperature
1279
  """
1280
  try:
1281
  from nanovllm import SamplingParams
 
1283
  formatted_prompt = self.build_formatted_prompt(caption, lyrics)
1284
  logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
1285
 
1286
+ # Determine effective temperature for sampler
1287
+ # If using phase-specific temperatures, set sampler temp to 1.0 (processor handles it)
1288
+ use_phase_temperatures = metadata_temperature is not None or codes_temperature is not None
1289
+ effective_sampler_temp = 1.0 if use_phase_temperatures else temperature
1290
+
1291
+ # Use shared constrained decoding processor if enabled
1292
  constrained_processor = None
1293
  update_state_fn = None
1294
+ if use_constrained_decoding or use_phase_temperatures:
1295
+ # Use shared processor, just update caption and settings
1296
+ self.constrained_processor.enabled = use_constrained_decoding
1297
+ self.constrained_processor.debug = constrained_decoding_debug
1298
+ self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
1299
+ self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
1300
+ self.constrained_processor.update_caption(caption)
1301
+
1302
+ constrained_processor = self.constrained_processor
1303
  update_state_fn = constrained_processor.update_state
1304
 
1305
  sampling_params = SamplingParams(
1306
  max_tokens=self.max_model_len-64,
1307
+ temperature=effective_sampler_temp,
1308
  cfg_scale=cfg_scale,
1309
  top_k=top_k,
1310
  top_p=top_p,
 
1355
  repetition_penalty: float,
1356
  use_constrained_decoding: bool = True,
1357
  constrained_decoding_debug: bool = False,
1358
+ metadata_temperature: Optional[float] = 0.85,
1359
+ codes_temperature: Optional[float] = None,
1360
  ) -> str:
1361
  """Shared vllm path: accept prebuilt formatted prompt and return text."""
1362
  from nanovllm import SamplingParams
1363
 
1364
+ # Determine effective temperature for sampler
1365
+ use_phase_temperatures = metadata_temperature is not None or codes_temperature is not None
1366
+ effective_sampler_temp = 1.0 if use_phase_temperatures else temperature
1367
+
1368
+ # Use shared constrained processor if enabled
1369
  constrained_processor = None
1370
+ if use_constrained_decoding or use_phase_temperatures:
1371
+ # Use shared processor, just update caption and settings
1372
+ self.constrained_processor.enabled = use_constrained_decoding
1373
+ self.constrained_processor.debug = constrained_decoding_debug
1374
+ self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
1375
+ self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
1376
+ self.constrained_processor.update_caption(formatted_prompt) # Use formatted prompt for genre extraction
1377
+
1378
+ constrained_processor = self.constrained_processor
1379
 
1380
  sampling_params = SamplingParams(
1381
  max_tokens=self.max_model_len - 64,
1382
+ temperature=effective_sampler_temp,
1383
  cfg_scale=cfg_scale,
1384
  top_k=top_k,
1385
  top_p=top_p,
 
1425
  repetition_penalty: float = 1.0,
1426
  use_constrained_decoding: bool = True,
1427
  constrained_decoding_debug: bool = False,
1428
+ metadata_temperature: Optional[float] = 0.85,
1429
+ codes_temperature: Optional[float] = None,
1430
  ) -> Tuple[Dict[str, Any], str, str]:
1431
  """Generate metadata and audio codes using 5Hz LM with PyTorch backend
1432
 
1433
  Args:
1434
  caption: Text caption for music generation
1435
  lyrics: Lyrics for music generation
1436
+ temperature: Base sampling temperature (used if phase-specific temps not set)
1437
  cfg_scale: CFG scale (>1.0 enables CFG)
1438
  negative_prompt: Negative prompt for CFG
1439
  top_k: Top-k sampling parameter
 
1441
  repetition_penalty: Repetition penalty
1442
  use_constrained_decoding: Whether to use FSM-based constrained decoding
1443
  constrained_decoding_debug: Whether to print debug info for constrained decoding
1444
+ metadata_temperature: Temperature for metadata generation (lower = more accurate)
1445
+ If None, uses base temperature
1446
+ codes_temperature: Temperature for audio codes generation (higher = more diverse)
1447
+ If None, uses base temperature
1448
  """
1449
  try:
1450
  formatted_prompt = self.build_formatted_prompt(caption, lyrics)
 
1483
 
1484
  streamer = TqdmTokenStreamer(total=max_new_tokens)
1485
 
1486
+ # Determine if using phase-specific temperatures
1487
+ use_phase_temperatures = metadata_temperature is not None or codes_temperature is not None
1488
+ effective_temperature = 1.0 if use_phase_temperatures else temperature
1489
+
1490
+ # Use shared constrained decoding processor if enabled
1491
  constrained_processor = None
1492
+ if use_constrained_decoding or use_phase_temperatures:
1493
+ # Use shared processor, just update caption and settings
1494
+ self.constrained_processor.enabled = use_constrained_decoding
1495
+ self.constrained_processor.debug = constrained_decoding_debug
1496
+ self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
1497
+ self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
1498
+ self.constrained_processor.update_caption(caption)
1499
+
1500
+ constrained_processor = self.constrained_processor
1501
 
1502
  # Build logits processor list (only for CFG and repetition penalty)
1503
  logits_processor = LogitsProcessorList()
 
1534
  batch_input_ids=batch_input_ids,
1535
  batch_attention_mask=batch_attention_mask,
1536
  max_new_tokens=max_new_tokens,
1537
+ temperature=effective_temperature,
1538
  cfg_scale=cfg_scale,
1539
  top_k=top_k,
1540
  top_p=top_p,
 
1546
 
1547
  # Extract only the conditional output (first in batch)
1548
  outputs = outputs[0:1] # Keep only conditional output
1549
+ elif use_constrained_decoding or use_phase_temperatures:
1550
+ # Use custom generation loop for constrained decoding or phase temperatures (non-CFG)
1551
  input_ids = inputs['input_ids']
1552
  attention_mask = inputs.get('attention_mask', None)
1553
 
 
1555
  input_ids=input_ids,
1556
  attention_mask=attention_mask,
1557
  max_new_tokens=max_new_tokens,
1558
+ temperature=effective_temperature,
1559
  top_k=top_k,
1560
  top_p=top_p,
1561
  repetition_penalty=repetition_penalty,
 
1569
  outputs = self.llm.generate(
1570
  **inputs,
1571
  max_new_tokens=max_new_tokens,
1572
+ temperature=effective_temperature if effective_temperature > 0 else 1.0,
1573
+ do_sample=True if effective_temperature > 0 else False,
1574
  top_k=top_k if top_k is not None and top_k > 0 else None,
1575
  top_p=top_p if top_p is not None and 0.0 < top_p < 1.0 else None,
1576
  logits_processor=logits_processor if len(logits_processor) > 0 else None,
 
1632
  truncation=True,
1633
  )
1634
 
1635
+ # Use shared constrained processor if enabled
1636
  constrained_processor = None
1637
  if use_constrained_decoding:
1638
+ # Use shared processor, just update caption and settings
1639
+ self.constrained_processor.enabled = use_constrained_decoding
1640
+ self.constrained_processor.debug = constrained_decoding_debug
1641
+ self.constrained_processor.update_caption(formatted_prompt) # Use formatted prompt for genre extraction
1642
+
1643
+ constrained_processor = self.constrained_processor
1644
 
1645
  with self._load_model_context():
1646
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
 
1762
  repetition_penalty: float = 1.0,
1763
  use_constrained_decoding: bool = True,
1764
  constrained_decoding_debug: bool = False,
1765
+ metadata_temperature: Optional[float] = 0.85,
1766
+ codes_temperature: Optional[float] = None,
1767
  ) -> Tuple[Dict[str, Any], str, str]:
1768
  """Generate metadata and audio codes using 5Hz LM
1769
 
1770
  Args:
1771
  caption: Text caption for music generation
1772
  lyrics: Lyrics for music generation
1773
+ temperature: Base sampling temperature (used if phase-specific temps not set)
1774
  cfg_scale: CFG scale (>1.0 enables CFG)
1775
  negative_prompt: Negative prompt for CFG
1776
  top_k: Top-k sampling parameter
 
1778
  repetition_penalty: Repetition penalty
1779
  use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
1780
  constrained_decoding_debug: Whether to print debug info for constrained decoding
1781
+ metadata_temperature: Temperature for metadata generation (lower = more accurate)
1782
+ Recommended: 0.3-0.5 for accurate metadata
1783
+ codes_temperature: Temperature for audio codes generation (higher = more diverse)
1784
+ Recommended: 0.7-1.0 for diverse codes
1785
  """
1786
  # Check if 5Hz LM is initialized
1787
  if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
 
1799
 
1800
  if self.llm_backend == "vllm":
1801
  return self.generate_with_5hz_lm_vllm(
1802
+ caption=caption,
1803
+ lyrics=lyrics,
1804
+ temperature=temperature,
1805
+ cfg_scale=cfg_scale,
1806
+ negative_prompt=negative_prompt,
1807
+ top_k=top_k,
1808
+ top_p=top_p,
1809
+ repetition_penalty=repetition_penalty,
1810
+ use_constrained_decoding=use_constrained_decoding,
1811
+ constrained_decoding_debug=constrained_decoding_debug,
1812
+ metadata_temperature=metadata_temperature,
1813
+ codes_temperature=codes_temperature,
1814
  )
1815
  else:
1816
  return self.generate_with_5hz_lm_pt(
1817
+ caption=caption,
1818
+ lyrics=lyrics,
1819
+ temperature=temperature,
1820
+ cfg_scale=cfg_scale,
1821
+ negative_prompt=negative_prompt,
1822
+ top_k=top_k,
1823
+ top_p=top_p,
1824
+ repetition_penalty=repetition_penalty,
1825
+ use_constrained_decoding=use_constrained_decoding,
1826
+ constrained_decoding_debug=constrained_decoding_debug,
1827
+ metadata_temperature=metadata_temperature,
1828
+ codes_temperature=codes_temperature,
1829
  )
1830
 
1831
+ def generate_with_stop_condition(
1832
+ self,
1833
+ caption: str,
1834
+ lyrics: str,
1835
+ infer_type: str,
1836
+ temperature: float = 0.6,
1837
+ cfg_scale: float = 1.0,
1838
+ negative_prompt: str = "NO USER INPUT",
1839
+ top_k: Optional[int] = None,
1840
+ top_p: Optional[float] = None,
1841
+ repetition_penalty: float = 1.0,
1842
+ use_constrained_decoding: bool = True,
1843
+ constrained_decoding_debug: bool = False,
1844
+ metadata_temperature: Optional[float] = 0.85,
1845
+ codes_temperature: Optional[float] = None,
1846
+ ) -> Tuple[Dict[str, Any], str, str]:
1847
+ """Feishu-compatible LM generation.
1848
+
1849
+ - infer_type='dit': stop at </think> and return metas only (no audio codes)
1850
+ - infer_type='llm_dit': normal generation (metas + audio codes)
1851
+ """
1852
+ infer_type = (infer_type or "").strip().lower()
1853
+ if infer_type not in {"dit", "llm_dit"}:
1854
+ return {}, "", f"❌ invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
1855
+
1856
+ if infer_type == "llm_dit":
1857
+ return self.generate_with_5hz_lm(
1858
+ caption=caption,
1859
+ lyrics=lyrics,
1860
+ temperature=temperature,
1861
+ cfg_scale=cfg_scale,
1862
+ negative_prompt=negative_prompt,
1863
+ top_k=top_k,
1864
+ top_p=top_p,
1865
+ repetition_penalty=repetition_penalty,
1866
+ use_constrained_decoding=use_constrained_decoding,
1867
+ constrained_decoding_debug=constrained_decoding_debug,
1868
+ metadata_temperature=metadata_temperature,
1869
+ codes_temperature=codes_temperature,
1870
+ )
1871
+
1872
+ # dit: generate and truncate at reasoning end tag
1873
+ formatted_prompt = self.build_formatted_prompt(caption, lyrics)
1874
+ output_text, status = self.generate_from_formatted_prompt(
1875
+ formatted_prompt,
1876
+ cfg={
1877
+ "temperature": temperature,
1878
+ "cfg_scale": cfg_scale,
1879
+ "negative_prompt": negative_prompt,
1880
+ "top_k": top_k,
1881
+ "top_p": top_p,
1882
+ "repetition_penalty": repetition_penalty,
1883
+ },
1884
+ use_constrained_decoding=use_constrained_decoding,
1885
+ constrained_decoding_debug=constrained_decoding_debug,
1886
+ )
1887
+ if not output_text:
1888
+ return {}, "", status
1889
+
1890
+ if self.STOP_REASONING_TAG in output_text:
1891
+ stop_idx = output_text.find(self.STOP_REASONING_TAG)
1892
+ output_text = output_text[: stop_idx + len(self.STOP_REASONING_TAG)]
1893
+
1894
+ metadata, _audio_codes = self.parse_lm_output(output_text)
1895
+ return metadata, "", status
1896
+
1897
  def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False) -> str:
1898
  """
1899
  Build the chat-formatted prompt for 5Hz LM from caption/lyrics.