ChuxiJ commited on
Commit
1241c80
·
1 Parent(s): 26b4474

cot caption & language LM

Browse files
acestep/api_server.py CHANGED
@@ -35,6 +35,10 @@ from starlette.datastructures import UploadFile as StarletteUploadFile
35
 
36
  from acestep.handler import AceStepHandler
37
  from acestep.llm_inference import LLMHandler
 
 
 
 
38
 
39
 
40
  JobStatus = Literal["queued", "running", "succeeded", "failed"]
@@ -70,7 +74,7 @@ class GenerateMusicRequest(BaseModel):
70
  repainting_start: float = 0.0
71
  repainting_end: Optional[float] = None
72
 
73
- instruction: str = "Fill the audio semantic mask based on the given conditions:"
74
  audio_cover_strength: float = 1.0
75
  task_type: str = "text2music"
76
 
@@ -102,8 +106,8 @@ class GenerateMusicRequest(BaseModel):
102
  _LM_DEFAULT_TEMPERATURE = 0.85
103
  _LM_DEFAULT_CFG_SCALE = 2.0
104
  _LM_DEFAULT_TOP_P = 0.9
105
- _DEFAULT_DIT_INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
106
- _DEFAULT_LM_INSTRUCTION = "Generate audio semantic tokens based on the given conditions:"
107
 
108
 
109
  class CreateJobResponse(BaseModel):
 
35
 
36
  from acestep.handler import AceStepHandler
37
  from acestep.llm_inference import LLMHandler
38
+ from acestep.constants import (
39
+ DEFAULT_DIT_INSTRUCTION,
40
+ DEFAULT_LM_INSTRUCTION,
41
+ )
42
 
43
 
44
  JobStatus = Literal["queued", "running", "succeeded", "failed"]
 
74
  repainting_start: float = 0.0
75
  repainting_end: Optional[float] = None
76
 
77
+ instruction: str = DEFAULT_DIT_INSTRUCTION
78
  audio_cover_strength: float = 1.0
79
  task_type: str = "text2music"
80
 
 
106
  _LM_DEFAULT_TEMPERATURE = 0.85
107
  _LM_DEFAULT_CFG_SCALE = 2.0
108
  _LM_DEFAULT_TOP_P = 0.9
109
+ _DEFAULT_DIT_INSTRUCTION = DEFAULT_DIT_INSTRUCTION
110
+ _DEFAULT_LM_INSTRUCTION = DEFAULT_LM_INSTRUCTION
111
 
112
 
113
  class CreateJobResponse(BaseModel):
acestep/constants.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Constants for ACE-Step
3
+ Centralized constants used across the codebase
4
+ """
5
+
6
+ # ==============================================================================
7
+ # Language Constants
8
+ # ==============================================================================
9
+
10
+ VALID_LANGUAGES = [
11
+ 'ar', 'az', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en',
12
+ 'es', 'fa', 'fi', 'fr', 'he', 'hi', 'hr', 'ht', 'hu', 'id',
13
+ 'is', 'it', 'ja', 'ko', 'la', 'lt', 'ms', 'ne', 'nl', 'no',
14
+ 'pa', 'pl', 'pt', 'ro', 'ru', 'sa', 'sk', 'sr', 'sv', 'sw',
15
+ 'ta', 'te', 'th', 'tl', 'tr', 'uk', 'ur', 'vi', 'yue', 'zh',
16
+ 'unknown'
17
+ ]
18
+
19
+
20
+ # ==============================================================================
21
+ # Keyscale Constants
22
+ # ==============================================================================
23
+
24
+ KEYSCALE_NOTES = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
25
+ KEYSCALE_ACCIDENTALS = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
26
+ KEYSCALE_MODES = ['major', 'minor']
27
+
28
+ # Generate all valid keyscales: 7 notes × 5 accidentals × 2 modes = 70 combinations
29
+ VALID_KEYSCALES = set()
30
+ for note in KEYSCALE_NOTES:
31
+ for acc in KEYSCALE_ACCIDENTALS:
32
+ for mode in KEYSCALE_MODES:
33
+ VALID_KEYSCALES.add(f"{note}{acc} {mode}")
34
+
35
+
36
+ # ==============================================================================
37
+ # Metadata Range Constants
38
+ # ==============================================================================
39
+
40
+ # BPM (Beats Per Minute) range
41
+ BPM_MIN = 30
42
+ BPM_MAX = 300
43
+
44
+ # Duration range (in seconds)
45
+ DURATION_MIN = 10
46
+ DURATION_MAX = 600
47
+
48
+ # Valid time signatures
49
+ VALID_TIME_SIGNATURES = [2, 3, 4, 6]
50
+
51
+
52
+ # ==============================================================================
53
+ # Task Type Constants
54
+ # ==============================================================================
55
+
56
+ TASK_TYPES = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
57
+
58
+ # Task types available for turbo models (subset)
59
+ TASK_TYPES_TURBO = ["text2music", "repaint", "cover"]
60
+
61
+ # Task types available for base models (full set)
62
+ TASK_TYPES_BASE = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
63
+
64
+
65
+ # ==============================================================================
66
+ # Instruction Constants
67
+ # ==============================================================================
68
+
69
+ # Default instructions
70
+ DEFAULT_DIT_INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
71
+ DEFAULT_LM_INSTRUCTION = "Generate audio semantic tokens based on the given conditions:"
72
+
73
+ # Instruction templates for each task type
74
+ # Note: Some instructions use placeholders like {TRACK_NAME} or {TRACK_CLASSES}
75
+ # These should be formatted using .format() or f-strings when used
76
+ TASK_INSTRUCTIONS = {
77
+ "text2music": "Fill the audio semantic mask based on the given conditions:",
78
+ "repaint": "Repaint the mask area based on the given conditions:",
79
+ "cover": "Generate audio semantic tokens based on the given conditions:",
80
+ "extract": "Extract the {TRACK_NAME} track from the audio:",
81
+ "extract_default": "Extract the track from the audio:",
82
+ "lego": "Generate the {TRACK_NAME} track based on the audio context:",
83
+ "lego_default": "Generate the track based on the audio context:",
84
+ "complete": "Complete the input track with {TRACK_CLASSES}:",
85
+ "complete_default": "Complete the input track:",
86
+ }
87
+
88
+
89
+ # ==============================================================================
90
+ # Track/Instrument Constants
91
+ # ==============================================================================
92
+
93
+ TRACK_NAMES = [
94
+ "woodwinds", "brass", "fx", "synth", "strings", "percussion",
95
+ "keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
96
+ ]
97
+
acestep/constrained_logits_processor.py CHANGED
@@ -6,6 +6,18 @@ from transformers import AutoTokenizer
6
  from transformers.generation.logits_process import LogitsProcessor
7
  import os
8
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  # ==============================================================================
@@ -18,6 +30,8 @@ class FSMState(Enum):
18
  BPM_NAME = auto() # Generating "bpm: "
19
  BPM_VALUE = auto() # Generating numeric value 30-300
20
  NEWLINE_AFTER_BPM = auto() # Generating "\n" after bpm value
 
 
21
  DURATION_NAME = auto() # Generating "duration: "
22
  DURATION_VALUE = auto() # Generating numeric value 10-600
23
  NEWLINE_AFTER_DURATION = auto()
@@ -27,6 +41,8 @@ class FSMState(Enum):
27
  KEYSCALE_NAME = auto() # Generating "keyscale: "
28
  KEYSCALE_VALUE = auto() # Generating keyscale pattern
29
  NEWLINE_AFTER_KEYSCALE = auto()
 
 
30
  TIMESIG_NAME = auto() # Generating "timesignature: "
31
  TIMESIG_VALUE = auto() # Generating 2, 3, 4, or 6
32
  NEWLINE_AFTER_TIMESIG = auto()
@@ -42,15 +58,18 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
42
  This processor enforces the following format:
43
  <think>
44
  bpm: [30-300]
 
45
  duration: [10-600]
46
- genres: [any non-empty string]
47
  keyscale: [A-G][#/♭]? [major/minor]
 
48
  timesignature: [2/3/4/6]
49
  </think>
50
 
51
  It uses token masking (setting invalid token logits to -inf) to enforce constraints.
52
  For numeric fields, it uses early-blocking to prevent out-of-range values.
53
  For field transitions (e.g., end of numeric value), it compares P(newline) vs P(digit).
 
 
54
  """
55
 
56
  def __init__(
@@ -80,15 +99,19 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
80
  self.enabled = enabled
81
  self.debug = debug
82
  self.skip_genres = skip_genres
 
 
83
  self.caption: Optional[str] = None # Set via update_caption() before each generation
84
 
85
  # User-provided metadata fields (optional)
86
  # If provided, these fields will be used directly instead of generating
87
- # Format: {"bpm": "120", "duration": "234", "keyscale": "G major", "timesignature": "4", "genres": "Pop Rock"}
88
  self.user_provided_metadata: Dict[str, Optional[str]] = {
89
  "bpm": None,
 
90
  "duration": None,
91
  "keyscale": None,
 
92
  "timesignature": None,
93
  "genres": None,
94
  }
@@ -114,6 +137,10 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
114
  self.accumulated_value = "" # For numeric/text value accumulation (legacy, for compatibility)
115
  self.accumulated_token_ids: List[int] = [] # Token ID sequence for keyscale (and other fields)
116
 
 
 
 
 
117
  # Token queue for user-provided fields (injected directly without generation)
118
  self.user_field_token_queue: List[int] = []
119
  self.current_user_field: Optional[str] = None # Current field being injected
@@ -137,9 +164,9 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
137
 
138
  # Field definitions (needed before building prefix trees)
139
  self.field_specs = {
140
- "bpm": {"min": 30, "max": 300},
141
- "duration": {"min": 10, "max": 600},
142
- "timesignature": {"valid_values": [2, 3, 4, 6]},
143
  }
144
 
145
  # Build valid numeric values for BPM, Duration, Timesignature
@@ -170,6 +197,9 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
170
  context_prefix_for_tokenization="timesignature: "
171
  )
172
 
 
 
 
173
  self._load_genres_vocab()
174
 
175
  # Note: Caption-based genre filtering is initialized via update_caption() before each generation
@@ -182,9 +212,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
182
  FSMState.THINK_TAG: "<think>",
183
  FSMState.NEWLINE_AFTER_THINK: "\n",
184
  FSMState.BPM_NAME: "bpm:",
 
185
  FSMState.DURATION_NAME: "duration:",
186
  FSMState.GENRES_NAME: "genres:",
187
  FSMState.KEYSCALE_NAME: "keyscale:",
 
188
  FSMState.TIMESIG_NAME: "timesignature:",
189
  FSMState.THINK_END_TAG: "</think>",
190
  }
@@ -198,17 +230,21 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
198
  even if the field is user-provided (we still need to generate the field name).
199
 
200
  Args:
201
- current_field: Current field name ("bpm", "duration", "genres", "keyscale", "timesignature")
202
 
203
  Returns:
204
  Next FSMState (NAME state of next field), or THINK_END_TAG if no more fields
205
  """
206
- field_order = ["bpm", "duration", "genres", "keyscale", "timesignature"]
 
 
207
  field_to_state = {
208
  "bpm": FSMState.BPM_NAME,
 
209
  "duration": FSMState.DURATION_NAME,
210
  "genres": FSMState.GENRES_NAME,
211
  "keyscale": FSMState.KEYSCALE_NAME,
 
212
  "timesignature": FSMState.TIMESIG_NAME,
213
  }
214
 
@@ -221,9 +257,13 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
221
  for i in range(current_idx + 1, len(field_order)):
222
  field = field_order[i]
223
 
224
- # Skip genres if skip_genres is True
225
  if field == "genres" and self.skip_genres:
226
  continue
 
 
 
 
227
 
228
  # Return the next field's NAME state (even if user-provided, we still generate field name)
229
  return field_to_state[field]
@@ -241,12 +281,17 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
241
  }
242
 
243
  # Build transitions for all fields (even if user-provided, we still need to generate field name)
244
- # Field order: bpm -> duration -> genres -> keyscale -> timesignature
245
 
246
- # BPM field: NAME -> VALUE -> next field
247
  self.next_state[FSMState.BPM_NAME] = FSMState.BPM_VALUE
248
  self.next_state[FSMState.BPM_VALUE] = self._get_next_field_state("bpm")
249
 
 
 
 
 
 
250
  # Duration field: NAME -> VALUE -> next field
251
  self.next_state[FSMState.DURATION_NAME] = FSMState.DURATION_VALUE
252
  self.next_state[FSMState.DURATION_VALUE] = self._get_next_field_state("duration")
@@ -256,10 +301,15 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
256
  self.next_state[FSMState.GENRES_NAME] = FSMState.GENRES_VALUE
257
  self.next_state[FSMState.GENRES_VALUE] = self._get_next_field_state("genres")
258
 
259
- # Keyscale field: NAME -> VALUE -> next field
260
  self.next_state[FSMState.KEYSCALE_NAME] = FSMState.KEYSCALE_VALUE
261
  self.next_state[FSMState.KEYSCALE_VALUE] = self._get_next_field_state("keyscale")
262
 
 
 
 
 
 
263
  # Timesignature field: NAME -> VALUE -> THINK_END_TAG
264
  self.next_state[FSMState.TIMESIG_NAME] = FSMState.TIMESIG_VALUE
265
  self.next_state[FSMState.TIMESIG_VALUE] = FSMState.THINK_END_TAG
@@ -269,6 +319,49 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
269
  self.skip_genres = skip
270
  self._build_state_transitions()
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  def set_stop_at_reasoning(self, stop: bool):
273
  """
274
  Set whether to stop generation after </think> tag.
@@ -287,8 +380,10 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
287
  Args:
288
  metadata: Dictionary with optional fields:
289
  - "bpm": Optional[str] - e.g., "120"
 
290
  - "duration": Optional[str] - e.g., "234"
291
  - "keyscale": Optional[str] - e.g., "G major"
 
292
  - "timesignature": Optional[str] - e.g., "4"
293
  - "genres": Optional[str] - e.g., "Pop Rock"
294
  If None, clears all user-provided metadata.
@@ -297,7 +392,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
297
  metadata = {}
298
 
299
  # Update user-provided metadata
300
- for field in ["bpm", "duration", "keyscale", "timesignature", "genres"]:
301
  if field in metadata:
302
  self.user_provided_metadata[field] = metadata[field]
303
  else:
@@ -328,7 +423,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
328
 
329
  # Note tokens for keyscale (A-G)
330
  self.note_tokens = {}
331
- for note in "ABCDEFG":
332
  tokens = self.tokenizer.encode(note, add_special_tokens=False)
333
  if tokens:
334
  self.note_tokens[note] = tokens[-1]
@@ -370,21 +465,80 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
370
  # EOS token for duration-constrained codes generation
371
  self.eos_token_id = self.tokenizer.eos_token_id
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  # Build valid keyscales set (prefix tree will be built after _char_to_tokens is initialized)
374
  # 7 notes × 5 accidentals (none, #, b, ♯, ♭) × 2 modes = 70 valid combinations
375
- notes = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
376
- accidentals = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
377
- modes = ['major', 'minor']
378
-
379
- self.valid_keyscales = set()
380
- for note in notes:
381
- for acc in accidentals:
382
- for mode in modes:
383
- self.valid_keyscales.add(f"{note}{acc} {mode}")
384
 
385
  # keyscale_prefix_tree will be built in _precompute_char_token_mapping() after _char_to_tokens is ready
386
  # Numeric prefix trees will be built after field_specs is defined
387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
389
  """
390
  Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
@@ -560,6 +714,68 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
560
 
561
  return prefix_to_tokens
562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
  def diagnose_keyscale_prefix_tree(self):
564
  """
565
  Diagnose the keyscale prefix tree to help debug generation bias.
@@ -926,6 +1142,8 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
926
  self.codes_count = 0 # Reset codes counter
927
  self.user_field_token_queue = [] # Reset user field token queue
928
  self.current_user_field = None # Reset current user field
 
 
929
 
930
  def set_target_duration(self, duration: Optional[float]):
931
  """
@@ -1170,6 +1388,20 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1170
  return self.newline_token in self.keyscale_prefix_tree[token_prefix]
1171
  return False
1172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1173
  def _get_allowed_timesig_tokens(self) -> List[int]:
1174
  """
1175
  Get allowed tokens for timesignature field using the precomputed prefix tree.
@@ -1269,7 +1501,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1269
  Uses the same tokenization logic as prefix tree building.
1270
 
1271
  Args:
1272
- field_name: Field name ("bpm", "duration", "keyscale", "timesignature", "genres")
1273
 
1274
  Returns:
1275
  List of token IDs for the complete field, or None if field is not provided
@@ -1281,8 +1513,10 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1281
  # Build full field string with space (matching prefix tree tokenization)
1282
  field_to_prefix = {
1283
  "bpm": "bpm: ",
 
1284
  "duration": "duration: ",
1285
  "keyscale": "keyscale: ",
 
1286
  "timesignature": "timesignature: ",
1287
  "genres": "genres: ",
1288
  }
@@ -1410,6 +1644,67 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1410
 
1411
  scores = scores + mask
1412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1413
  elif self.state == FSMState.DURATION_VALUE:
1414
  # Check if field is user-provided and we haven't started injecting yet
1415
  if self.user_provided_metadata["duration"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
@@ -1539,6 +1834,43 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1539
  mask[0, self.newline_token] = 0
1540
  scores = scores + mask
1541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1542
  elif self.state == FSMState.TIMESIG_VALUE:
1543
  # Check if field is user-provided and we haven't started injecting yet
1544
  if self.user_provided_metadata["timesignature"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
@@ -1587,6 +1919,8 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1587
  self.position_in_state = 0
1588
  self.accumulated_value = "" # Legacy, kept for compatibility
1589
  self.accumulated_token_ids = [] # Reset token ID sequence for new field
 
 
1590
  if self.debug:
1591
  logger.debug(f"FSM transition: {old_state.name} -> {self.state.name}")
1592
 
@@ -1703,6 +2037,22 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1703
  # Genres still uses string-based trie, so keep accumulated_value
1704
  self.accumulated_value += token_str
1705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1706
  elif self.state == FSMState.KEYSCALE_VALUE:
1707
  if generated_token_id == self.newline_token:
1708
  # Newline ends the field
@@ -1718,4 +2068,16 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1718
  self.accumulated_token_ids.append(generated_token_id)
1719
  # Also update legacy accumulated_value for compatibility
1720
  self.accumulated_value += token_str
 
 
 
 
 
 
 
 
 
 
 
 
1721
 
 
6
  from transformers.generation.logits_process import LogitsProcessor
7
  import os
8
  import torch
9
+ from acestep.constants import (
10
+ VALID_LANGUAGES,
11
+ KEYSCALE_NOTES,
12
+ KEYSCALE_ACCIDENTALS,
13
+ KEYSCALE_MODES,
14
+ VALID_KEYSCALES,
15
+ BPM_MIN,
16
+ BPM_MAX,
17
+ DURATION_MIN,
18
+ DURATION_MAX,
19
+ VALID_TIME_SIGNATURES,
20
+ )
21
 
22
 
23
  # ==============================================================================
 
30
  BPM_NAME = auto() # Generating "bpm: "
31
  BPM_VALUE = auto() # Generating numeric value 30-300
32
  NEWLINE_AFTER_BPM = auto() # Generating "\n" after bpm value
33
+ CAPTION_NAME = auto() # Generating "caption: "
34
+ CAPTION_VALUE = auto() # Generating caption text (no code blocks/newlines)
35
  DURATION_NAME = auto() # Generating "duration: "
36
  DURATION_VALUE = auto() # Generating numeric value 10-600
37
  NEWLINE_AFTER_DURATION = auto()
 
41
  KEYSCALE_NAME = auto() # Generating "keyscale: "
42
  KEYSCALE_VALUE = auto() # Generating keyscale pattern
43
  NEWLINE_AFTER_KEYSCALE = auto()
44
+ LANGUAGE_NAME = auto() # Generating "language: "
45
+ LANGUAGE_VALUE = auto() # Generating language code (en, zh, ja, etc.)
46
  TIMESIG_NAME = auto() # Generating "timesignature: "
47
  TIMESIG_VALUE = auto() # Generating 2, 3, 4, or 6
48
  NEWLINE_AFTER_TIMESIG = auto()
 
58
  This processor enforces the following format:
59
  <think>
60
  bpm: [30-300]
61
+ caption: [text without code blocks, ends with period + newline]
62
  duration: [10-600]
 
63
  keyscale: [A-G][#/♭]? [major/minor]
64
+ language: [en/zh/ja/ko/es/fr/de/uk/ru/...]
65
  timesignature: [2/3/4/6]
66
  </think>
67
 
68
  It uses token masking (setting invalid token logits to -inf) to enforce constraints.
69
  For numeric fields, it uses early-blocking to prevent out-of-range values.
70
  For field transitions (e.g., end of numeric value), it compares P(newline) vs P(digit).
71
+ For caption field, it blocks code blocks and newlines, and only transitions when
72
+ the previous token was a period and newline has the highest probability.
73
  """
74
 
75
  def __init__(
 
99
  self.enabled = enabled
100
  self.debug = debug
101
  self.skip_genres = skip_genres
102
+ self.skip_caption = False # Set to True to skip caption field generation
103
+ self.skip_language = False # Set to True to skip language field generation
104
  self.caption: Optional[str] = None # Set via update_caption() before each generation
105
 
106
  # User-provided metadata fields (optional)
107
  # If provided, these fields will be used directly instead of generating
108
+ # Format: {"bpm": "120", "caption": "...", "duration": "234", "keyscale": "G major", "language": "en", "timesignature": "4", "genres": "Pop Rock"}
109
  self.user_provided_metadata: Dict[str, Optional[str]] = {
110
  "bpm": None,
111
+ "caption": None,
112
  "duration": None,
113
  "keyscale": None,
114
+ "language": None,
115
  "timesignature": None,
116
  "genres": None,
117
  }
 
137
  self.accumulated_value = "" # For numeric/text value accumulation (legacy, for compatibility)
138
  self.accumulated_token_ids: List[int] = [] # Token ID sequence for keyscale (and other fields)
139
 
140
+ # Caption generation state tracking
141
+ self.caption_after_newline = False # Track if we're right after a newline in caption
142
+ self.caption_token_count = 0 # Track token count for caption (max 512)
143
+
144
  # Token queue for user-provided fields (injected directly without generation)
145
  self.user_field_token_queue: List[int] = []
146
  self.current_user_field: Optional[str] = None # Current field being injected
 
164
 
165
  # Field definitions (needed before building prefix trees)
166
  self.field_specs = {
167
+ "bpm": {"min": BPM_MIN, "max": BPM_MAX},
168
+ "duration": {"min": DURATION_MIN, "max": DURATION_MAX},
169
+ "timesignature": {"valid_values": VALID_TIME_SIGNATURES},
170
  }
171
 
172
  # Build valid numeric values for BPM, Duration, Timesignature
 
197
  context_prefix_for_tokenization="timesignature: "
198
  )
199
 
200
+ # Build language prefix tree (similar to keyscale but for language codes)
201
+ self.language_prefix_tree = self._build_language_prefix_tree()
202
+
203
  self._load_genres_vocab()
204
 
205
  # Note: Caption-based genre filtering is initialized via update_caption() before each generation
 
212
  FSMState.THINK_TAG: "<think>",
213
  FSMState.NEWLINE_AFTER_THINK: "\n",
214
  FSMState.BPM_NAME: "bpm:",
215
+ FSMState.CAPTION_NAME: "caption:",
216
  FSMState.DURATION_NAME: "duration:",
217
  FSMState.GENRES_NAME: "genres:",
218
  FSMState.KEYSCALE_NAME: "keyscale:",
219
+ FSMState.LANGUAGE_NAME: "language:",
220
  FSMState.TIMESIG_NAME: "timesignature:",
221
  FSMState.THINK_END_TAG: "</think>",
222
  }
 
230
  even if the field is user-provided (we still need to generate the field name).
231
 
232
  Args:
233
+ current_field: Current field name ("bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature")
234
 
235
  Returns:
236
  Next FSMState (NAME state of next field), or THINK_END_TAG if no more fields
237
  """
238
+ # New field order: bpm -> caption -> duration -> keyscale -> language -> timesignature
239
+ # genres is optional and can be skipped
240
+ field_order = ["bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature"]
241
  field_to_state = {
242
  "bpm": FSMState.BPM_NAME,
243
+ "caption": FSMState.CAPTION_NAME,
244
  "duration": FSMState.DURATION_NAME,
245
  "genres": FSMState.GENRES_NAME,
246
  "keyscale": FSMState.KEYSCALE_NAME,
247
+ "language": FSMState.LANGUAGE_NAME,
248
  "timesignature": FSMState.TIMESIG_NAME,
249
  }
250
 
 
257
  for i in range(current_idx + 1, len(field_order)):
258
  field = field_order[i]
259
 
260
+ # Skip fields based on flags
261
  if field == "genres" and self.skip_genres:
262
  continue
263
+ if field == "caption" and self.skip_caption:
264
+ continue
265
+ if field == "language" and self.skip_language:
266
+ continue
267
 
268
  # Return the next field's NAME state (even if user-provided, we still generate field name)
269
  return field_to_state[field]
 
281
  }
282
 
283
  # Build transitions for all fields (even if user-provided, we still need to generate field name)
284
+ # Field order: bpm -> caption -> duration -> genres -> keyscale -> language -> timesignature
285
 
286
+ # BPM field: NAME -> VALUE -> next field (caption or duration)
287
  self.next_state[FSMState.BPM_NAME] = FSMState.BPM_VALUE
288
  self.next_state[FSMState.BPM_VALUE] = self._get_next_field_state("bpm")
289
 
290
+ # Caption field (only if not skipped): NAME -> VALUE -> next field (duration)
291
+ if not self.skip_caption:
292
+ self.next_state[FSMState.CAPTION_NAME] = FSMState.CAPTION_VALUE
293
+ self.next_state[FSMState.CAPTION_VALUE] = self._get_next_field_state("caption")
294
+
295
  # Duration field: NAME -> VALUE -> next field
296
  self.next_state[FSMState.DURATION_NAME] = FSMState.DURATION_VALUE
297
  self.next_state[FSMState.DURATION_VALUE] = self._get_next_field_state("duration")
 
301
  self.next_state[FSMState.GENRES_NAME] = FSMState.GENRES_VALUE
302
  self.next_state[FSMState.GENRES_VALUE] = self._get_next_field_state("genres")
303
 
304
+ # Keyscale field: NAME -> VALUE -> next field (language or timesignature)
305
  self.next_state[FSMState.KEYSCALE_NAME] = FSMState.KEYSCALE_VALUE
306
  self.next_state[FSMState.KEYSCALE_VALUE] = self._get_next_field_state("keyscale")
307
 
308
+ # Language field (only if not skipped): NAME -> VALUE -> next field (timesignature)
309
+ if not self.skip_language:
310
+ self.next_state[FSMState.LANGUAGE_NAME] = FSMState.LANGUAGE_VALUE
311
+ self.next_state[FSMState.LANGUAGE_VALUE] = self._get_next_field_state("language")
312
+
313
  # Timesignature field: NAME -> VALUE -> THINK_END_TAG
314
  self.next_state[FSMState.TIMESIG_NAME] = FSMState.TIMESIG_VALUE
315
  self.next_state[FSMState.TIMESIG_VALUE] = FSMState.THINK_END_TAG
 
319
  self.skip_genres = skip
320
  self._build_state_transitions()
321
 
322
+ def set_skip_caption(self, skip: bool):
323
+ """Set whether to skip caption generation and rebuild state transitions."""
324
+ self.skip_caption = skip
325
+ self._build_state_transitions()
326
+
327
+ def set_skip_language(self, skip: bool):
328
+ """Set whether to skip language generation and rebuild state transitions."""
329
+ self.skip_language = skip
330
+ self._build_state_transitions()
331
+
332
+ @staticmethod
333
+ def postprocess_caption(caption: str) -> str:
334
+ """
335
+ Post-process caption to remove YAML multi-line formatting.
336
+ Converts YAML-style multi-line text (with newlines and leading spaces)
337
+ to a single-line string.
338
+
339
+ Example:
340
+ Input: "An emotional ballad.\\n The track opens with piano.\\n More text."
341
+ Output: "An emotional ballad. The track opens with piano. More text."
342
+
343
+ Args:
344
+ caption: Raw caption text with possible YAML formatting
345
+
346
+ Returns:
347
+ Clean single-line caption
348
+ """
349
+ if not caption:
350
+ return caption
351
+
352
+ # Split by newlines
353
+ lines = caption.split('\n')
354
+
355
+ # Process each line: strip leading/trailing whitespace
356
+ cleaned_lines = []
357
+ for line in lines:
358
+ stripped = line.strip()
359
+ if stripped:
360
+ cleaned_lines.append(stripped)
361
+
362
+ # Join with single space
363
+ return ' '.join(cleaned_lines)
364
+
365
  def set_stop_at_reasoning(self, stop: bool):
366
  """
367
  Set whether to stop generation after </think> tag.
 
380
  Args:
381
  metadata: Dictionary with optional fields:
382
  - "bpm": Optional[str] - e.g., "120"
383
+ - "caption": Optional[str] - e.g., "A melodic piano piece..."
384
  - "duration": Optional[str] - e.g., "234"
385
  - "keyscale": Optional[str] - e.g., "G major"
386
+ - "language": Optional[str] - e.g., "en"
387
  - "timesignature": Optional[str] - e.g., "4"
388
  - "genres": Optional[str] - e.g., "Pop Rock"
389
  If None, clears all user-provided metadata.
 
392
  metadata = {}
393
 
394
  # Update user-provided metadata
395
+ for field in ["bpm", "caption", "duration", "keyscale", "language", "timesignature", "genres"]:
396
  if field in metadata:
397
  self.user_provided_metadata[field] = metadata[field]
398
  else:
 
423
 
424
  # Note tokens for keyscale (A-G)
425
  self.note_tokens = {}
426
+ for note in KEYSCALE_NOTES:
427
  tokens = self.tokenizer.encode(note, add_special_tokens=False)
428
  if tokens:
429
  self.note_tokens[note] = tokens[-1]
 
465
  # EOS token for duration-constrained codes generation
466
  self.eos_token_id = self.tokenizer.eos_token_id
467
 
468
+ # Period token for caption field transition logic
469
+ period_tokens = self.tokenizer.encode(".", add_special_tokens=False)
470
+ self.period_token = period_tokens[-1] if period_tokens else None
471
+
472
+ # Backtick tokens for blocking code blocks in caption
473
+ backtick_tokens = self.tokenizer.encode("`", add_special_tokens=False)
474
+ self.backtick_token = backtick_tokens[-1] if backtick_tokens else None
475
+
476
+ # Valid language codes (ISO 639-1 and common variants)
477
+ self.valid_languages = VALID_LANGUAGES
478
+
479
+ # Precompute audio code token IDs (tokens matching <|audio_code_\d+|>)
480
+ # These should be blocked during caption generation
481
+ self.audio_code_token_ids: Set[int] = set()
482
+ self._precompute_audio_code_tokens()
483
+
484
+ # Precompute audio code mask for efficient blocking (O(1) instead of O(n))
485
+ # This mask will be added to scores during caption generation
486
+ self.audio_code_mask: Optional[torch.Tensor] = None
487
+ self._build_audio_code_mask()
488
+
489
  # Build valid keyscales set (prefix tree will be built after _char_to_tokens is initialized)
490
  # 7 notes × 5 accidentals (none, #, b, ♯, ♭) × 2 modes = 70 valid combinations
491
+ self.valid_keyscales = VALID_KEYSCALES.copy()
 
 
 
 
 
 
 
 
492
 
493
  # keyscale_prefix_tree will be built in _precompute_char_token_mapping() after _char_to_tokens is ready
494
  # Numeric prefix trees will be built after field_specs is defined
495
 
496
+ def _precompute_audio_code_tokens(self):
497
+ """
498
+ Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
499
+ These tokens should be blocked during caption generation.
500
+ """
501
+ import re
502
+ audio_code_pattern = re.compile(r'^<\|audio_code_\d+\|>$')
503
+
504
+ # Iterate through vocabulary to find audio code tokens
505
+ for token_id in range(self.vocab_size):
506
+ try:
507
+ token_text = self.tokenizer.decode([token_id])
508
+ if audio_code_pattern.match(token_text):
509
+ self.audio_code_token_ids.add(token_id)
510
+ except Exception:
511
+ continue
512
+
513
+ if self.debug:
514
+ logger.debug(f"Found {len(self.audio_code_token_ids)} audio code tokens")
515
+
516
+ def _build_audio_code_mask(self):
517
+ """
518
+ Build a precomputed mask tensor for blocking audio code tokens.
519
+ This mask can be added to scores in O(1) time instead of O(n) loop.
520
+
521
+ The mask is [1, vocab_size] tensor with -inf at audio code token positions.
522
+ """
523
+ if not self.audio_code_token_ids:
524
+ self.audio_code_mask = None
525
+ return
526
+
527
+ # Create mask tensor: 0 everywhere, -inf at audio code positions
528
+ # Use float32 for compatibility with most model dtypes
529
+ mask = torch.zeros(1, self.vocab_size, dtype=torch.float32)
530
+
531
+ # Convert set to list for indexing
532
+ audio_code_indices = list(self.audio_code_token_ids)
533
+
534
+ # Set -inf at audio code token positions
535
+ mask[0, audio_code_indices] = float('-inf')
536
+
537
+ self.audio_code_mask = mask
538
+
539
+ if self.debug:
540
+ logger.debug(f"Built audio code mask for {len(self.audio_code_token_ids)} tokens")
541
+
542
  def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
543
  """
544
  Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
 
714
 
715
  return prefix_to_tokens
716
 
717
+ def _build_language_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
718
+ """
719
+ Build language prefix to allowed tokens mapping based on ACTUAL tokenization.
720
+ Similar to keyscale prefix tree but for language codes.
721
+
722
+ Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches.
723
+ """
724
+ prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {}
725
+
726
+ context_prefix_for_matching = "language:"
727
+ context_prefix_for_tokenization = "language: "
728
+
729
+ context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False)
730
+
731
+ if self.debug:
732
+ context_tokens_str = [self.tokenizer.decode([t]) for t in context_token_ids]
733
+ logger.debug(f"Context for matching 'language:' tokenizes to {context_token_ids} -> {context_tokens_str}")
734
+
735
+ for lang in self.valid_languages:
736
+ full_text = context_prefix_for_tokenization + lang
737
+ full_token_ids = self.tokenizer.encode(full_text, add_special_tokens=False)
738
+
739
+ context_end_idx = None
740
+ if len(full_token_ids) >= len(context_token_ids):
741
+ if full_token_ids[:len(context_token_ids)] == context_token_ids:
742
+ context_end_idx = len(context_token_ids)
743
+
744
+ if context_end_idx is None:
745
+ if self.debug:
746
+ logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping")
747
+ continue
748
+
749
+ lang_token_ids = full_token_ids[context_end_idx:]
750
+
751
+ if not lang_token_ids:
752
+ if self.debug:
753
+ logger.warning(f"No tokens extracted for language '{lang}', skipping")
754
+ continue
755
+
756
+ for i in range(len(lang_token_ids) + 1):
757
+ token_prefix = tuple(lang_token_ids[:i])
758
+
759
+ if token_prefix not in prefix_to_tokens:
760
+ prefix_to_tokens[token_prefix] = set()
761
+
762
+ if i < len(lang_token_ids):
763
+ next_token_id = lang_token_ids[i]
764
+ prefix_to_tokens[token_prefix].add(next_token_id)
765
+ else:
766
+ if self.newline_token:
767
+ prefix_to_tokens[token_prefix].add(self.newline_token)
768
+
769
+ if self.debug:
770
+ logger.debug(f"Built language prefix tree with {len(prefix_to_tokens)} token sequence prefixes")
771
+ empty_prefix = tuple()
772
+ if empty_prefix in prefix_to_tokens:
773
+ first_tokens = prefix_to_tokens[empty_prefix]
774
+ decoded_first = [(t, repr(self.tokenizer.decode([t]))) for t in sorted(first_tokens)]
775
+ logger.debug(f"First tokens allowed for language (empty prefix): {decoded_first}")
776
+
777
+ return prefix_to_tokens
778
+
779
  def diagnose_keyscale_prefix_tree(self):
780
  """
781
  Diagnose the keyscale prefix tree to help debug generation bias.
 
1142
  self.codes_count = 0 # Reset codes counter
1143
  self.user_field_token_queue = [] # Reset user field token queue
1144
  self.current_user_field = None # Reset current user field
1145
+ self.caption_after_newline = False # Reset caption newline tracking
1146
+ self.caption_token_count = 0 # Reset caption token count
1147
 
1148
  def set_target_duration(self, duration: Optional[float]):
1149
  """
 
1388
  return self.newline_token in self.keyscale_prefix_tree[token_prefix]
1389
  return False
1390
 
1391
+ def _get_allowed_language_tokens(self) -> List[int]:
1392
+ """
1393
+ Get allowed tokens for language field using the precomputed prefix tree.
1394
+ Uses token ID sequence as key (not string) to avoid tokenization mismatches.
1395
+ Similar to keyscale.
1396
+ """
1397
+ token_prefix = tuple(self.accumulated_token_ids)
1398
+
1399
+ if token_prefix in self.language_prefix_tree:
1400
+ return list(self.language_prefix_tree[token_prefix])
1401
+
1402
+ # Fallback: no valid continuation found
1403
+ return []
1404
+
1405
  def _get_allowed_timesig_tokens(self) -> List[int]:
1406
  """
1407
  Get allowed tokens for timesignature field using the precomputed prefix tree.
 
1501
  Uses the same tokenization logic as prefix tree building.
1502
 
1503
  Args:
1504
+ field_name: Field name ("bpm", "caption", "duration", "keyscale", "language", "timesignature", "genres")
1505
 
1506
  Returns:
1507
  List of token IDs for the complete field, or None if field is not provided
 
1513
  # Build full field string with space (matching prefix tree tokenization)
1514
  field_to_prefix = {
1515
  "bpm": "bpm: ",
1516
+ "caption": "caption: ",
1517
  "duration": "duration: ",
1518
  "keyscale": "keyscale: ",
1519
+ "language": "language: ",
1520
  "timesignature": "timesignature: ",
1521
  "genres": "genres: ",
1522
  }
 
1644
 
1645
  scores = scores + mask
1646
 
1647
+ elif self.state == FSMState.CAPTION_VALUE:
1648
+ # Caption field generation with YAML format support:
1649
+ # - Allow newlines and spaces (YAML multi-line formatting)
1650
+ # - Block audio codes and backticks
1651
+ # - Max 512 tokens
1652
+ # - Transition when model wants to generate next field (non-indented line)
1653
+
1654
+ # Check if field is user-provided and we haven't started injecting yet
1655
+ if self.user_provided_metadata["caption"] is not None and not self.user_field_token_queue and not self.accumulated_value:
1656
+ # Initialize token queue with field value tokens (value + newline)
1657
+ value = self.user_provided_metadata["caption"]
1658
+ value_text = f" {value}\n"
1659
+ value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
1660
+ if value_tokens:
1661
+ self.user_field_token_queue = value_tokens
1662
+ self.current_user_field = "caption"
1663
+ # Inject first token
1664
+ mask[0, value_tokens[0]] = 0
1665
+ scores = scores + mask
1666
+ return scores
1667
+
1668
+ # Check if we should transition after a newline (non-indented line = new field)
1669
+ if self.caption_after_newline:
1670
+ # Get top token from current scores
1671
+ top_token_id = torch.argmax(scores[0]).item()
1672
+ top_token_text = self.tokenizer.decode([top_token_id])
1673
+
1674
+ # If top token does NOT start with space/tab, it's a new field (like "duration:")
1675
+ if len(top_token_text) > 0 and top_token_text[0] not in ' \t':
1676
+ # Caption is ending, transition to next field
1677
+ self.caption_after_newline = False
1678
+ self._transition_to_next_state()
1679
+ # Process with new state (DURATION_NAME)
1680
+ return self._process_single_sequence(input_ids, scores)
1681
+ else:
1682
+ # It's indentation, continue caption
1683
+ self.caption_after_newline = False
1684
+
1685
+ # Block backticks (code blocks)
1686
+ if self.backtick_token is not None:
1687
+ scores[0, self.backtick_token] = float('-inf')
1688
+
1689
+ # Block ALL audio code tokens (critical - these should never appear in caption)
1690
+ # Use precomputed mask for O(1) performance instead of O(n) loop
1691
+ if self.audio_code_mask is not None:
1692
+ # Move mask to same device/dtype as scores if needed
1693
+ if self.audio_code_mask.device != scores.device or self.audio_code_mask.dtype != scores.dtype:
1694
+ self.audio_code_mask = self.audio_code_mask.to(device=scores.device, dtype=scores.dtype)
1695
+ scores = scores + self.audio_code_mask
1696
+
1697
+ # Enforce 512 token limit for caption
1698
+ if self.caption_token_count >= 512:
1699
+ # Force end by only allowing newline
1700
+ if self.newline_token is not None:
1701
+ mask[0, self.newline_token] = 0
1702
+ scores = scores + mask
1703
+ return scores
1704
+
1705
+ # Allow natural generation (with blocked audio codes and backticks)
1706
+ return scores
1707
+
1708
  elif self.state == FSMState.DURATION_VALUE:
1709
  # Check if field is user-provided and we haven't started injecting yet
1710
  if self.user_provided_metadata["duration"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
 
1834
  mask[0, self.newline_token] = 0
1835
  scores = scores + mask
1836
 
1837
+ elif self.state == FSMState.LANGUAGE_VALUE:
1838
+ # Language field: similar to keyscale, uses prefix tree
1839
+
1840
+ # Check if field is user-provided and we haven't started injecting yet
1841
+ if self.user_provided_metadata["language"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
1842
+ # Initialize token queue with field value tokens (value + newline)
1843
+ value = self.user_provided_metadata["language"]
1844
+ value_text = f" {value}\n"
1845
+ value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
1846
+ if value_tokens:
1847
+ self.user_field_token_queue = value_tokens
1848
+ self.current_user_field = "language"
1849
+ # Inject first token
1850
+ mask[0, value_tokens[0]] = 0
1851
+ scores = scores + mask
1852
+ return scores
1853
+
1854
+ # Check if current token sequence is complete (allows newline)
1855
+ token_prefix = tuple(self.accumulated_token_ids)
1856
+ if token_prefix in self.language_prefix_tree and self.newline_token in self.language_prefix_tree[token_prefix]:
1857
+ # Complete language, allow newline
1858
+ if self.newline_token:
1859
+ mask[0, self.newline_token] = 0
1860
+ scores = scores + mask
1861
+ else:
1862
+ # Not complete, allow valid continuation tokens
1863
+ allowed = self._get_allowed_language_tokens()
1864
+ if allowed:
1865
+ for t in allowed:
1866
+ mask[0, t] = 0
1867
+ scores = scores + mask
1868
+ else:
1869
+ # No valid tokens found - force newline to end field
1870
+ if self.newline_token:
1871
+ mask[0, self.newline_token] = 0
1872
+ scores = scores + mask
1873
+
1874
  elif self.state == FSMState.TIMESIG_VALUE:
1875
  # Check if field is user-provided and we haven't started injecting yet
1876
  if self.user_provided_metadata["timesignature"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
 
1919
  self.position_in_state = 0
1920
  self.accumulated_value = "" # Legacy, kept for compatibility
1921
  self.accumulated_token_ids = [] # Reset token ID sequence for new field
1922
+ self.caption_after_newline = False # Reset caption newline tracking
1923
+ self.caption_token_count = 0 # Reset caption token count
1924
  if self.debug:
1925
  logger.debug(f"FSM transition: {old_state.name} -> {self.state.name}")
1926
 
 
2037
  # Genres still uses string-based trie, so keep accumulated_value
2038
  self.accumulated_value += token_str
2039
 
2040
+ elif self.state == FSMState.CAPTION_VALUE:
2041
+ # Track token count for 512 limit
2042
+ self.caption_token_count += 1
2043
+
2044
+ # Accumulate caption text
2045
+ self.accumulated_value += token_str
2046
+
2047
+ # Track if this token is a newline (for transition detection)
2048
+ if generated_token_id == self.newline_token:
2049
+ # Mark that we need to check next token for field transition
2050
+ self.caption_after_newline = True
2051
+ else:
2052
+ # Not a newline - if we were after newline and this is not space,
2053
+ # transition already happened in _process_single_sequence
2054
+ self.caption_after_newline = False
2055
+
2056
  elif self.state == FSMState.KEYSCALE_VALUE:
2057
  if generated_token_id == self.newline_token:
2058
  # Newline ends the field
 
2068
  self.accumulated_token_ids.append(generated_token_id)
2069
  # Also update legacy accumulated_value for compatibility
2070
  self.accumulated_value += token_str
2071
+
2072
+ elif self.state == FSMState.LANGUAGE_VALUE:
2073
+ if generated_token_id == self.newline_token:
2074
+ # Newline ends the field
2075
+ self._transition_to_next_state()
2076
+ if self.state in self.fixed_strings:
2077
+ return
2078
+ else:
2079
+ # Add token ID to sequence (for prefix tree lookup)
2080
+ self.accumulated_token_ids.append(generated_token_id)
2081
+ # Also update legacy accumulated_value for compatibility
2082
+ self.accumulated_value += token_str
2083
 
acestep/gradio_ui.py CHANGED
@@ -8,6 +8,14 @@ import random
8
  import glob
9
  import gradio as gr
10
  from typing import Callable, Optional, Tuple
 
 
 
 
 
 
 
 
11
 
12
 
13
  def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None) -> gr.Blocks:
@@ -296,9 +304,9 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
296
  # Determine initial task_type choices based on default model
297
  default_model_lower = (default_model or "").lower()
298
  if "turbo" in default_model_lower:
299
- initial_task_choices = ["text2music", "repaint", "cover"]
300
  else:
301
- initial_task_choices = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
302
 
303
  with gr.Row():
304
  with gr.Column(scale=2):
@@ -311,15 +319,14 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
311
  with gr.Column(scale=8):
312
  instruction_display_gen = gr.Textbox(
313
  label="Instruction",
314
- value="Fill the audio semantic mask based on the given conditions:",
315
  interactive=False,
316
  lines=1,
317
  info="Instruction is automatically generated based on task type",
318
  )
319
 
320
  track_name = gr.Dropdown(
321
- choices=["woodwinds", "brass", "fx", "synth", "strings", "percussion",
322
- "keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"],
323
  value=None,
324
  label="Track Name",
325
  info="Select track name for lego/extract tasks",
@@ -327,8 +334,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
327
  )
328
 
329
  complete_track_classes = gr.CheckboxGroup(
330
- choices=["woodwinds", "brass", "fx", "synth", "strings", "percussion",
331
- "keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"],
332
  label="Track Names",
333
  info="Select multiple track classes for complete task",
334
  visible=False
@@ -410,8 +416,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
410
  with gr.Accordion("⚙️ Optional Parameters", open=True):
411
  with gr.Row():
412
  vocal_language = gr.Dropdown(
413
- choices=["en", "zh", "ja", "ko", "es", "fr", "de"],
414
- value="en",
415
  label="Vocal Language (optional)",
416
  allow_custom_value=True,
417
  info="use `unknown` for inst"
@@ -567,6 +573,20 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
567
  scale=2,
568
  )
569
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
  with gr.Row():
571
  audio_cover_strength = gr.Slider(
572
  minimum=0.0,
@@ -625,6 +645,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
625
  "lm_top_k": lm_top_k,
626
  "lm_top_p": lm_top_p,
627
  "lm_negative_prompt": lm_negative_prompt,
 
 
628
  "repainting_group": repainting_group,
629
  "repainting_start": repainting_start,
630
  "repainting_end": repainting_end,
@@ -824,7 +846,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
824
  gr.update(visible=False), # use_adg
825
  gr.update(visible=False), # cfg_interval_start
826
  gr.update(visible=False), # cfg_interval_end
827
- gr.update(choices=["text2music", "repaint", "cover"]), # task_type
828
  )
829
  elif "base" in config_path_lower:
830
  # Base model: max 100 steps, show CFG/ADG, show all task types
@@ -834,7 +856,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
834
  gr.update(visible=True), # use_adg
835
  gr.update(visible=True), # cfg_interval_start
836
  gr.update(visible=True), # cfg_interval_end
837
- gr.update(choices=["text2music", "repaint", "cover", "extract", "lego", "complete"]), # task_type
838
  )
839
  else:
840
  # Default to turbo settings
@@ -844,7 +866,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
844
  gr.update(visible=False),
845
  gr.update(visible=False),
846
  gr.update(visible=False),
847
- gr.update(choices=["text2music", "repaint", "cover"]), # task_type
848
  )
849
 
850
  generation_section["config_path"].change(
@@ -965,6 +987,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
965
  instruction_display_gen, audio_cover_strength, task_type,
966
  use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
967
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
 
968
  progress=gr.Progress(track_tqdm=True)
969
  ):
970
  # If think is enabled (llm_dit mode), generate audio codes using LM first
@@ -1019,6 +1042,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1019
  top_k=top_k_value,
1020
  top_p=top_p_value,
1021
  user_metadata=user_metadata_to_pass,
 
 
1022
  )
1023
 
1024
  # Store LM-generated metadata and audio codes for display
@@ -1076,14 +1101,18 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1076
  metadata_lines = []
1077
  if lm_generated_metadata.get('bpm'):
1078
  metadata_lines.append(f"- **BPM:** {lm_generated_metadata['bpm']}")
1079
- if lm_generated_metadata.get('keyscale'):
1080
- metadata_lines.append(f"- **KeyScale:** {lm_generated_metadata['keyscale']}")
1081
- if lm_generated_metadata.get('timesignature'):
1082
- metadata_lines.append(f"- **Time Signature:** {lm_generated_metadata['timesignature']}")
1083
  if lm_generated_metadata.get('duration'):
1084
  metadata_lines.append(f"- **Duration:** {lm_generated_metadata['duration']} seconds")
1085
  if lm_generated_metadata.get('genres'):
1086
  metadata_lines.append(f"- **Genres:** {lm_generated_metadata['genres']}")
 
 
 
 
 
 
1087
 
1088
  if metadata_lines:
1089
  metadata_section = "\n\n**🤖 LM-Generated Metadata:**\n" + "\n\n".join(metadata_lines)
@@ -1140,7 +1169,9 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1140
  generation_section["lm_cfg_scale"],
1141
  generation_section["lm_top_k"],
1142
  generation_section["lm_top_p"],
1143
- generation_section["lm_negative_prompt"]
 
 
1144
  ],
1145
  outputs=[
1146
  results_section["generated_audio_1"],
 
8
  import glob
9
  import gradio as gr
10
  from typing import Callable, Optional, Tuple
11
+ from acestep.constants import (
12
+ VALID_LANGUAGES,
13
+ TRACK_NAMES,
14
+ TASK_TYPES,
15
+ TASK_TYPES_TURBO,
16
+ TASK_TYPES_BASE,
17
+ DEFAULT_DIT_INSTRUCTION,
18
+ )
19
 
20
 
21
  def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None) -> gr.Blocks:
 
304
  # Determine initial task_type choices based on default model
305
  default_model_lower = (default_model or "").lower()
306
  if "turbo" in default_model_lower:
307
+ initial_task_choices = TASK_TYPES_TURBO
308
  else:
309
+ initial_task_choices = TASK_TYPES_BASE
310
 
311
  with gr.Row():
312
  with gr.Column(scale=2):
 
319
  with gr.Column(scale=8):
320
  instruction_display_gen = gr.Textbox(
321
  label="Instruction",
322
+ value=DEFAULT_DIT_INSTRUCTION,
323
  interactive=False,
324
  lines=1,
325
  info="Instruction is automatically generated based on task type",
326
  )
327
 
328
  track_name = gr.Dropdown(
329
+ choices=TRACK_NAMES,
 
330
  value=None,
331
  label="Track Name",
332
  info="Select track name for lego/extract tasks",
 
334
  )
335
 
336
  complete_track_classes = gr.CheckboxGroup(
337
+ choices=TRACK_NAMES,
 
338
  label="Track Names",
339
  info="Select multiple track classes for complete task",
340
  visible=False
 
416
  with gr.Accordion("⚙️ Optional Parameters", open=True):
417
  with gr.Row():
418
  vocal_language = gr.Dropdown(
419
+ choices=VALID_LANGUAGES,
420
+ value="unknown",
421
  label="Vocal Language (optional)",
422
  allow_custom_value=True,
423
  info="use `unknown` for inst"
 
573
  scale=2,
574
  )
575
 
576
+ with gr.Row():
577
+ use_cot_caption = gr.Checkbox(
578
+ label="CoT Caption",
579
+ value=True,
580
+ info="Generate caption in CoT (chain-of-thought)",
581
+ scale=1,
582
+ )
583
+ use_cot_language = gr.Checkbox(
584
+ label="CoT Language",
585
+ value=True,
586
+ info="Generate language in CoT (chain-of-thought)",
587
+ scale=1,
588
+ )
589
+
590
  with gr.Row():
591
  audio_cover_strength = gr.Slider(
592
  minimum=0.0,
 
645
  "lm_top_k": lm_top_k,
646
  "lm_top_p": lm_top_p,
647
  "lm_negative_prompt": lm_negative_prompt,
648
+ "use_cot_caption": use_cot_caption,
649
+ "use_cot_language": use_cot_language,
650
  "repainting_group": repainting_group,
651
  "repainting_start": repainting_start,
652
  "repainting_end": repainting_end,
 
846
  gr.update(visible=False), # use_adg
847
  gr.update(visible=False), # cfg_interval_start
848
  gr.update(visible=False), # cfg_interval_end
849
+ gr.update(choices=TASK_TYPES_TURBO), # task_type
850
  )
851
  elif "base" in config_path_lower:
852
  # Base model: max 100 steps, show CFG/ADG, show all task types
 
856
  gr.update(visible=True), # use_adg
857
  gr.update(visible=True), # cfg_interval_start
858
  gr.update(visible=True), # cfg_interval_end
859
+ gr.update(choices=TASK_TYPES_BASE), # task_type
860
  )
861
  else:
862
  # Default to turbo settings
 
866
  gr.update(visible=False),
867
  gr.update(visible=False),
868
  gr.update(visible=False),
869
+ gr.update(choices=TASK_TYPES_TURBO), # task_type
870
  )
871
 
872
  generation_section["config_path"].change(
 
987
  instruction_display_gen, audio_cover_strength, task_type,
988
  use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
989
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
990
+ use_cot_caption, use_cot_language,
991
  progress=gr.Progress(track_tqdm=True)
992
  ):
993
  # If think is enabled (llm_dit mode), generate audio codes using LM first
 
1042
  top_k=top_k_value,
1043
  top_p=top_p_value,
1044
  user_metadata=user_metadata_to_pass,
1045
+ use_cot_caption=use_cot_caption,
1046
+ use_cot_language=use_cot_language,
1047
  )
1048
 
1049
  # Store LM-generated metadata and audio codes for display
 
1101
  metadata_lines = []
1102
  if lm_generated_metadata.get('bpm'):
1103
  metadata_lines.append(f"- **BPM:** {lm_generated_metadata['bpm']}")
1104
+ if lm_generated_metadata.get('caption'):
1105
+ metadata_lines.append(f"- **User Query Rewritten Caption:** {lm_generated_metadata['caption']}")
 
 
1106
  if lm_generated_metadata.get('duration'):
1107
  metadata_lines.append(f"- **Duration:** {lm_generated_metadata['duration']} seconds")
1108
  if lm_generated_metadata.get('genres'):
1109
  metadata_lines.append(f"- **Genres:** {lm_generated_metadata['genres']}")
1110
+ if lm_generated_metadata.get('keyscale'):
1111
+ metadata_lines.append(f"- **KeyScale:** {lm_generated_metadata['keyscale']}")
1112
+ if lm_generated_metadata.get('language'):
1113
+ metadata_lines.append(f"- **Language:** {lm_generated_metadata['language']}")
1114
+ if lm_generated_metadata.get('timesignature'):
1115
+ metadata_lines.append(f"- **Time Signature:** {lm_generated_metadata['timesignature']}")
1116
 
1117
  if metadata_lines:
1118
  metadata_section = "\n\n**🤖 LM-Generated Metadata:**\n" + "\n\n".join(metadata_lines)
 
1169
  generation_section["lm_cfg_scale"],
1170
  generation_section["lm_top_k"],
1171
  generation_section["lm_top_p"],
1172
+ generation_section["lm_negative_prompt"],
1173
+ generation_section["use_cot_caption"],
1174
+ generation_section["use_cot_language"]
1175
  ],
1176
  outputs=[
1177
  results_section["generated_audio_1"],
acestep/handler.py CHANGED
@@ -23,6 +23,11 @@ import warnings
23
  from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
24
  from transformers.generation.streamers import BaseStreamer
25
  from diffusers.models import AutoencoderOobleck
 
 
 
 
 
26
 
27
 
28
  warnings.filterwarnings("ignore")
@@ -519,10 +524,11 @@ class AceStepHandler:
519
  Args:
520
  task: Task name (e.g., text2music, cover, repaint); kept for logging/future branching.
521
  instruction: Instruction text; default fallback matches service_generate behavior.
522
- caption: Caption string.
523
  lyrics: Lyrics string.
524
  metas: Metadata (str or dict); follows _parse_metas formatting.
525
- vocal_language: Language code for lyrics section.
 
526
 
527
  Returns:
528
  (caption_input_text, lyrics_input_text)
@@ -533,18 +539,45 @@ class AceStepHandler:
533
  instruction=None,
534
  caption="A calm piano melody",
535
  lyrics="la la la",
536
- metas={"bpm": 90, "duration": 45},
537
  vocal_language="en",
538
  )
539
  """
540
  # Align instruction formatting with _prepare_batch
541
- final_instruction = instruction or "Fill the audio semantic mask based on the given conditions:"
542
  if not final_instruction.endswith(":"):
543
  final_instruction = final_instruction + ":"
544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  parsed_meta = self._parse_metas([metas])[0]
546
- caption_input = SFT_GEN_PROMPT.format(final_instruction, caption, parsed_meta)
547
- lyrics_input = f"# Languages\n{vocal_language}\n\n# Lyric\n{lyrics}<|endoftext|>"
548
  return caption_input, lyrics_input
549
 
550
  def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -679,41 +712,36 @@ class AceStepHandler:
679
  track_name: Optional[str] = None,
680
  complete_track_classes: Optional[List[str]] = None
681
  ) -> str:
682
- TRACK_NAMES = [
683
- "woodwinds", "brass", "fx", "synth", "strings", "percussion",
684
- "keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
685
- ]
686
-
687
  if task_type == "text2music":
688
- return "Fill the audio semantic mask based on the given conditions:"
689
  elif task_type == "repaint":
690
- return "Repaint the mask area based on the given conditions:"
691
  elif task_type == "cover":
692
- return "Generate audio semantic tokens based on the given conditions:"
693
  elif task_type == "extract":
694
  if track_name:
695
  # Convert to uppercase
696
  track_name_upper = track_name.upper()
697
- return f"Extract the {track_name_upper} track from the audio:"
698
  else:
699
- return "Extract the track from the audio:"
700
  elif task_type == "lego":
701
  if track_name:
702
  # Convert to uppercase
703
  track_name_upper = track_name.upper()
704
- return f"Generate the {track_name_upper} track based on the audio context:"
705
  else:
706
- return "Generate the track based on the audio context:"
707
  elif task_type == "complete":
708
  if complete_track_classes and len(complete_track_classes) > 0:
709
  # Convert to uppercase and join with " | "
710
  track_classes_upper = [t.upper() for t in complete_track_classes]
711
  complete_track_classes_str = " | ".join(track_classes_upper)
712
- return f"Complete the input track with {complete_track_classes_str}:"
713
  else:
714
- return "Complete the input track:"
715
  else:
716
- return "Fill the audio semantic mask based on the given conditions:"
717
 
718
  def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]:
719
  if audio_file is None:
@@ -1247,7 +1275,7 @@ class AceStepHandler:
1247
  # Process instructions early so we can use them for task type detection
1248
  # Use custom instructions if provided, otherwise use default
1249
  if instructions is None:
1250
- instructions = ["Fill the audio semantic mask based on the given conditions:"] * batch_size
1251
 
1252
  # Ensure instructions list has the same length as batch_size
1253
  if len(instructions) != batch_size:
@@ -1257,7 +1285,7 @@ class AceStepHandler:
1257
  # Pad or truncate to match batch_size
1258
  instructions = instructions[:batch_size]
1259
  while len(instructions) < batch_size:
1260
- instructions.append("Fill the audio semantic mask based on the given conditions:")
1261
 
1262
  # Generate chunk_masks and spans based on repainting parameters
1263
  # Also determine if this is a cover task (target audio provided without repainting)
@@ -1415,13 +1443,29 @@ class AceStepHandler:
1415
 
1416
  for i in range(batch_size):
1417
  # Use custom instruction for this batch item
1418
- instruction = instructions[i] if i < len(instructions) else "Fill the audio semantic mask based on the given conditions:"
1419
  # Ensure instruction ends with ":"
1420
  if not instruction.endswith(":"):
1421
  instruction = instruction + ":"
1422
 
1423
- # Format text prompt with custom instruction
1424
- text_prompt = SFT_GEN_PROMPT.format(instruction, captions[i], parsed_metas[i])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1425
 
1426
  # Tokenize text
1427
  text_inputs_dict = self.text_tokenizer(
@@ -1434,8 +1478,8 @@ class AceStepHandler:
1434
  text_token_ids = text_inputs_dict.input_ids[0]
1435
  text_attention_mask = text_inputs_dict.attention_mask[0].bool()
1436
 
1437
- # Format and tokenize lyrics
1438
- lyrics_text = f"# Languages\n{vocal_languages[i]}\n\n# Lyric\n{lyrics[i]}<|endoftext|>"
1439
  lyrics_inputs_dict = self.text_tokenizer(
1440
  lyrics_text,
1441
  padding="longest",
@@ -1495,10 +1539,17 @@ class AceStepHandler:
1495
  non_cover_text_attention_masks = []
1496
  for i in range(batch_size):
1497
  # Use custom instruction for this batch item
1498
- instruction = "Fill the audio semantic mask based on the given conditions:"
 
 
 
 
 
 
 
1499
 
1500
- # Format text prompt with custom instruction
1501
- text_prompt = SFT_GEN_PROMPT.format(instruction, captions[i], parsed_metas[i])
1502
 
1503
  # Tokenize text
1504
  text_inputs_dict = self.text_tokenizer(
@@ -1991,7 +2042,7 @@ class AceStepHandler:
1991
  audio_code_string: Union[str, List[str]] = "",
1992
  repainting_start: float = 0.0,
1993
  repainting_end: Optional[float] = None,
1994
- instruction: str = "Fill the audio semantic mask based on the given conditions:",
1995
  audio_cover_strength: float = 1.0,
1996
  task_type: str = "text2music",
1997
  use_adg: bool = False,
@@ -2030,7 +2081,7 @@ class AceStepHandler:
2030
  # User has provided audio codes, switch to cover task
2031
  task_type = "cover"
2032
  # Update instruction for cover task
2033
- instruction = "Generate audio semantic tokens based on the given conditions:"
2034
 
2035
  logger.info("[generate_music] Starting generation...")
2036
  if progress:
 
23
  from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
24
  from transformers.generation.streamers import BaseStreamer
25
  from diffusers.models import AutoencoderOobleck
26
+ from acestep.constants import (
27
+ TASK_INSTRUCTIONS,
28
+ TRACK_NAMES,
29
+ DEFAULT_DIT_INSTRUCTION,
30
+ )
31
 
32
 
33
  warnings.filterwarnings("ignore")
 
524
  Args:
525
  task: Task name (e.g., text2music, cover, repaint); kept for logging/future branching.
526
  instruction: Instruction text; default fallback matches service_generate behavior.
527
+ caption: Caption string (fallback if not in metas).
528
  lyrics: Lyrics string.
529
  metas: Metadata (str or dict); follows _parse_metas formatting.
530
+ May contain 'caption' and 'language' fields from LM CoT output.
531
+ vocal_language: Language code for lyrics section (fallback if not in metas).
532
 
533
  Returns:
534
  (caption_input_text, lyrics_input_text)
 
539
  instruction=None,
540
  caption="A calm piano melody",
541
  lyrics="la la la",
542
+ metas={"bpm": 90, "duration": 45, "caption": "LM generated caption", "language": "en"},
543
  vocal_language="en",
544
  )
545
  """
546
  # Align instruction formatting with _prepare_batch
547
+ final_instruction = instruction or DEFAULT_DIT_INSTRUCTION
548
  if not final_instruction.endswith(":"):
549
  final_instruction = final_instruction + ":"
550
 
551
+ # Extract caption and language from metas if available (from LM CoT output)
552
+ # Fallback to user-provided values if not in metas
553
+ actual_caption = caption
554
+ actual_language = vocal_language
555
+
556
+ if metas is not None:
557
+ # Parse metas to dict if it's a string
558
+ if isinstance(metas, str):
559
+ # Try to parse as dict-like string or use as-is
560
+ parsed_metas = self._parse_metas([metas])
561
+ if parsed_metas and isinstance(parsed_metas[0], dict):
562
+ meta_dict = parsed_metas[0]
563
+ else:
564
+ meta_dict = {}
565
+ elif isinstance(metas, dict):
566
+ meta_dict = metas
567
+ else:
568
+ meta_dict = {}
569
+
570
+ # Extract caption from metas if available
571
+ if 'caption' in meta_dict and meta_dict['caption']:
572
+ actual_caption = str(meta_dict['caption'])
573
+
574
+ # Extract language from metas if available
575
+ if 'language' in meta_dict and meta_dict['language']:
576
+ actual_language = str(meta_dict['language'])
577
+
578
  parsed_meta = self._parse_metas([metas])[0]
579
+ caption_input = SFT_GEN_PROMPT.format(final_instruction, actual_caption, parsed_meta)
580
+ lyrics_input = f"# Languages\n{actual_language}\n\n# Lyric\n{lyrics}<|endoftext|>"
581
  return caption_input, lyrics_input
582
 
583
  def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
 
712
  track_name: Optional[str] = None,
713
  complete_track_classes: Optional[List[str]] = None
714
  ) -> str:
 
 
 
 
 
715
  if task_type == "text2music":
716
+ return TASK_INSTRUCTIONS["text2music"]
717
  elif task_type == "repaint":
718
+ return TASK_INSTRUCTIONS["repaint"]
719
  elif task_type == "cover":
720
+ return TASK_INSTRUCTIONS["cover"]
721
  elif task_type == "extract":
722
  if track_name:
723
  # Convert to uppercase
724
  track_name_upper = track_name.upper()
725
+ return TASK_INSTRUCTIONS["extract"].format(TRACK_NAME=track_name_upper)
726
  else:
727
+ return TASK_INSTRUCTIONS["extract_default"]
728
  elif task_type == "lego":
729
  if track_name:
730
  # Convert to uppercase
731
  track_name_upper = track_name.upper()
732
+ return TASK_INSTRUCTIONS["lego"].format(TRACK_NAME=track_name_upper)
733
  else:
734
+ return TASK_INSTRUCTIONS["lego_default"]
735
  elif task_type == "complete":
736
  if complete_track_classes and len(complete_track_classes) > 0:
737
  # Convert to uppercase and join with " | "
738
  track_classes_upper = [t.upper() for t in complete_track_classes]
739
  complete_track_classes_str = " | ".join(track_classes_upper)
740
+ return TASK_INSTRUCTIONS["complete"].format(TRACK_CLASSES=complete_track_classes_str)
741
  else:
742
+ return TASK_INSTRUCTIONS["complete_default"]
743
  else:
744
+ return TASK_INSTRUCTIONS["text2music"]
745
 
746
  def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]:
747
  if audio_file is None:
 
1275
  # Process instructions early so we can use them for task type detection
1276
  # Use custom instructions if provided, otherwise use default
1277
  if instructions is None:
1278
+ instructions = [DEFAULT_DIT_INSTRUCTION] * batch_size
1279
 
1280
  # Ensure instructions list has the same length as batch_size
1281
  if len(instructions) != batch_size:
 
1285
  # Pad or truncate to match batch_size
1286
  instructions = instructions[:batch_size]
1287
  while len(instructions) < batch_size:
1288
+ instructions.append(DEFAULT_DIT_INSTRUCTION)
1289
 
1290
  # Generate chunk_masks and spans based on repainting parameters
1291
  # Also determine if this is a cover task (target audio provided without repainting)
 
1443
 
1444
  for i in range(batch_size):
1445
  # Use custom instruction for this batch item
1446
+ instruction = instructions[i] if i < len(instructions) else DEFAULT_DIT_INSTRUCTION
1447
  # Ensure instruction ends with ":"
1448
  if not instruction.endswith(":"):
1449
  instruction = instruction + ":"
1450
 
1451
+ # Extract caption and language from metas if available (from LM CoT output)
1452
+ # Fallback to user-provided values if not in metas
1453
+ actual_caption = captions[i]
1454
+ actual_language = vocal_languages[i]
1455
+
1456
+ # Check if metas contains caption/language from LM CoT
1457
+ if i < len(parsed_metas) and parsed_metas[i]:
1458
+ meta_dict = parsed_metas[i]
1459
+ if isinstance(meta_dict, dict):
1460
+ # Extract caption from metas if available
1461
+ if 'caption' in meta_dict and meta_dict['caption']:
1462
+ actual_caption = str(meta_dict['caption'])
1463
+ # Extract language from metas if available
1464
+ if 'language' in meta_dict and meta_dict['language']:
1465
+ actual_language = str(meta_dict['language'])
1466
+
1467
+ # Format text prompt with custom instruction (using LM-generated caption if available)
1468
+ text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
1469
 
1470
  # Tokenize text
1471
  text_inputs_dict = self.text_tokenizer(
 
1478
  text_token_ids = text_inputs_dict.input_ids[0]
1479
  text_attention_mask = text_inputs_dict.attention_mask[0].bool()
1480
 
1481
+ # Format and tokenize lyrics (using LM-generated language if available)
1482
+ lyrics_text = f"# Languages\n{actual_language}\n\n# Lyric\n{lyrics[i]}<|endoftext|>"
1483
  lyrics_inputs_dict = self.text_tokenizer(
1484
  lyrics_text,
1485
  padding="longest",
 
1539
  non_cover_text_attention_masks = []
1540
  for i in range(batch_size):
1541
  # Use custom instruction for this batch item
1542
+ instruction = DEFAULT_DIT_INSTRUCTION
1543
+
1544
+ # Extract caption from metas if available (from LM CoT output)
1545
+ actual_caption = captions[i]
1546
+ if i < len(parsed_metas) and parsed_metas[i]:
1547
+ meta_dict = parsed_metas[i]
1548
+ if isinstance(meta_dict, dict) and 'caption' in meta_dict and meta_dict['caption']:
1549
+ actual_caption = str(meta_dict['caption'])
1550
 
1551
+ # Format text prompt with custom instruction (using LM-generated caption if available)
1552
+ text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
1553
 
1554
  # Tokenize text
1555
  text_inputs_dict = self.text_tokenizer(
 
2042
  audio_code_string: Union[str, List[str]] = "",
2043
  repainting_start: float = 0.0,
2044
  repainting_end: Optional[float] = None,
2045
+ instruction: str = DEFAULT_DIT_INSTRUCTION,
2046
  audio_cover_strength: float = 1.0,
2047
  task_type: str = "text2music",
2048
  use_adg: bool = False,
 
2081
  # User has provided audio codes, switch to cover task
2082
  task_type = "cover"
2083
  # Update instruction for cover task
2084
+ instruction = TASK_INSTRUCTIONS["cover"]
2085
 
2086
  logger.info("[generate_music] Starting generation...")
2087
  if progress:
acestep/llm_inference.py CHANGED
@@ -17,6 +17,7 @@ from transformers.generation.logits_process import (
17
  RepetitionPenaltyLogitsProcessor,
18
  )
19
  from acestep.constrained_logits_processor import MetadataConstrainedLogitsProcessor
 
20
 
21
 
22
  class LLMHandler:
@@ -244,6 +245,8 @@ class LLMHandler:
244
  target_duration: Optional[float] = None,
245
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
246
  stop_at_reasoning: bool = False,
 
 
247
  ) -> str:
248
  """Shared vllm path: accept prebuilt formatted prompt and return text."""
249
  from nanovllm import SamplingParams
@@ -265,6 +268,9 @@ class LLMHandler:
265
  # Always call set_user_metadata to ensure previous settings are cleared if None
266
  self.constrained_processor.set_user_metadata(user_metadata)
267
  self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
 
 
 
268
 
269
  constrained_processor = self.constrained_processor
270
 
@@ -318,6 +324,8 @@ class LLMHandler:
318
  target_duration: Optional[float] = None,
319
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
320
  stop_at_reasoning: bool = False,
 
 
321
  ) -> str:
322
  """Shared PyTorch path: accept prebuilt formatted prompt and return text."""
323
  inputs = self.llm_tokenizer(
@@ -338,6 +346,9 @@ class LLMHandler:
338
  # Always call set_user_metadata to ensure previous settings are cleared if None
339
  self.constrained_processor.set_user_metadata(user_metadata)
340
  self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
 
 
 
341
 
342
  constrained_processor = self.constrained_processor
343
 
@@ -472,6 +483,8 @@ class LLMHandler:
472
  constrained_decoding_debug: bool = False,
473
  target_duration: Optional[float] = None,
474
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
 
 
475
  ) -> Tuple[Dict[str, Any], str, str]:
476
  """Feishu-compatible LM generation.
477
 
@@ -483,6 +496,8 @@ class LLMHandler:
483
  5 codes = 1 second. If specified, blocks EOS until target reached.
484
  user_metadata: User-provided metadata fields (e.g. bpm/duration/keyscale/timesignature).
485
  If specified, constrained decoding will inject these values directly.
 
 
486
  """
487
  infer_type = (infer_type or "").strip().lower()
488
  if infer_type not in {"dit", "llm_dit"}:
@@ -509,6 +524,8 @@ class LLMHandler:
509
  "repetition_penalty": repetition_penalty,
510
  "target_duration": target_duration,
511
  "user_metadata": user_metadata,
 
 
512
  },
513
  use_constrained_decoding=use_constrained_decoding,
514
  constrained_decoding_debug=constrained_decoding_debug,
@@ -540,7 +557,7 @@ class LLMHandler:
540
  prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
541
  return self.llm_tokenizer.apply_chat_template(
542
  [
543
- {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
544
  {"role": "user", "content": prompt},
545
  ],
546
  tokenize=False,
@@ -591,6 +608,8 @@ class LLMHandler:
591
  repetition_penalty = cfg.get("repetition_penalty", 1.0)
592
  target_duration = cfg.get("target_duration")
593
  user_metadata = cfg.get("user_metadata") # User-provided metadata fields
 
 
594
 
595
  try:
596
  if self.llm_backend == "vllm":
@@ -607,6 +626,8 @@ class LLMHandler:
607
  target_duration=target_duration,
608
  user_metadata=user_metadata,
609
  stop_at_reasoning=stop_at_reasoning,
 
 
610
  )
611
  return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
612
 
@@ -624,6 +645,8 @@ class LLMHandler:
624
  target_duration=target_duration,
625
  user_metadata=user_metadata,
626
  stop_at_reasoning=stop_at_reasoning,
 
 
627
  )
628
  return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}"
629
 
@@ -928,9 +951,11 @@ class LLMHandler:
928
  Expected format:
929
  <think>
930
  bpm: 73
 
931
  duration: 273
932
  genres: Chinese folk
933
  keyscale: G major
 
934
  timesignature: 4
935
  </think>
936
 
@@ -973,32 +998,69 @@ class LLMHandler:
973
  lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text
974
  reasoning_text = lines_before_codes.strip()
975
 
976
- # Parse metadata fields
977
  if reasoning_text:
978
- for line in reasoning_text.split('\n'):
979
- line = line.strip()
980
- if ':' in line and not line.startswith('<'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
981
  parts = line.split(':', 1)
982
  if len(parts) == 2:
983
- key = parts[0].strip().lower()
984
- value = parts[1].strip()
985
-
986
- if key == 'bpm':
987
- try:
988
- metadata['bpm'] = int(value)
989
- except:
990
- metadata['bpm'] = value
991
- elif key == 'duration':
992
- try:
993
- metadata['duration'] = int(value)
994
- except:
995
- metadata['duration'] = value
996
- elif key == 'genres':
997
- metadata['genres'] = value
998
- elif key == 'keyscale':
999
- metadata['keyscale'] = value
1000
- elif key == 'timesignature':
1001
- metadata['timesignature'] = value
1002
 
1003
  return metadata, audio_codes
1004
 
 
17
  RepetitionPenaltyLogitsProcessor,
18
  )
19
  from acestep.constrained_logits_processor import MetadataConstrainedLogitsProcessor
20
+ from acestep.constants import DEFAULT_LM_INSTRUCTION
21
 
22
 
23
  class LLMHandler:
 
245
  target_duration: Optional[float] = None,
246
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
247
  stop_at_reasoning: bool = False,
248
+ skip_caption: bool = False,
249
+ skip_language: bool = False,
250
  ) -> str:
251
  """Shared vllm path: accept prebuilt formatted prompt and return text."""
252
  from nanovllm import SamplingParams
 
268
  # Always call set_user_metadata to ensure previous settings are cleared if None
269
  self.constrained_processor.set_user_metadata(user_metadata)
270
  self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
271
+ # Set skip_caption and skip_language based on flags
272
+ self.constrained_processor.set_skip_caption(skip_caption)
273
+ self.constrained_processor.set_skip_language(skip_language)
274
 
275
  constrained_processor = self.constrained_processor
276
 
 
324
  target_duration: Optional[float] = None,
325
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
326
  stop_at_reasoning: bool = False,
327
+ skip_caption: bool = False,
328
+ skip_language: bool = False,
329
  ) -> str:
330
  """Shared PyTorch path: accept prebuilt formatted prompt and return text."""
331
  inputs = self.llm_tokenizer(
 
346
  # Always call set_user_metadata to ensure previous settings are cleared if None
347
  self.constrained_processor.set_user_metadata(user_metadata)
348
  self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
349
+ # Set skip_caption and skip_language based on flags
350
+ self.constrained_processor.set_skip_caption(skip_caption)
351
+ self.constrained_processor.set_skip_language(skip_language)
352
 
353
  constrained_processor = self.constrained_processor
354
 
 
483
  constrained_decoding_debug: bool = False,
484
  target_duration: Optional[float] = None,
485
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
486
+ use_cot_caption: bool = True,
487
+ use_cot_language: bool = True,
488
  ) -> Tuple[Dict[str, Any], str, str]:
489
  """Feishu-compatible LM generation.
490
 
 
496
  5 codes = 1 second. If specified, blocks EOS until target reached.
497
  user_metadata: User-provided metadata fields (e.g. bpm/duration/keyscale/timesignature).
498
  If specified, constrained decoding will inject these values directly.
499
+ use_cot_caption: Whether to generate caption in CoT (default True).
500
+ use_cot_language: Whether to generate language in CoT (default True).
501
  """
502
  infer_type = (infer_type or "").strip().lower()
503
  if infer_type not in {"dit", "llm_dit"}:
 
524
  "repetition_penalty": repetition_penalty,
525
  "target_duration": target_duration,
526
  "user_metadata": user_metadata,
527
+ "skip_caption": not use_cot_caption,
528
+ "skip_language": not use_cot_language,
529
  },
530
  use_constrained_decoding=use_constrained_decoding,
531
  constrained_decoding_debug=constrained_decoding_debug,
 
557
  prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
558
  return self.llm_tokenizer.apply_chat_template(
559
  [
560
+ {"role": "system", "content": f"# Instruction\n{DEFAULT_LM_INSTRUCTION}\n\n"},
561
  {"role": "user", "content": prompt},
562
  ],
563
  tokenize=False,
 
608
  repetition_penalty = cfg.get("repetition_penalty", 1.0)
609
  target_duration = cfg.get("target_duration")
610
  user_metadata = cfg.get("user_metadata") # User-provided metadata fields
611
+ skip_caption = cfg.get("skip_caption", False) # Skip caption generation in CoT
612
+ skip_language = cfg.get("skip_language", False) # Skip language generation in CoT
613
 
614
  try:
615
  if self.llm_backend == "vllm":
 
626
  target_duration=target_duration,
627
  user_metadata=user_metadata,
628
  stop_at_reasoning=stop_at_reasoning,
629
+ skip_caption=skip_caption,
630
+ skip_language=skip_language,
631
  )
632
  return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
633
 
 
645
  target_duration=target_duration,
646
  user_metadata=user_metadata,
647
  stop_at_reasoning=stop_at_reasoning,
648
+ skip_caption=skip_caption,
649
+ skip_language=skip_language,
650
  )
651
  return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}"
652
 
 
951
  Expected format:
952
  <think>
953
  bpm: 73
954
+ caption: A calm piano melody
955
  duration: 273
956
  genres: Chinese folk
957
  keyscale: G major
958
+ language: en
959
  timesignature: 4
960
  </think>
961
 
 
998
  lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text
999
  reasoning_text = lines_before_codes.strip()
1000
 
1001
+ # Parse metadata fields with YAML multi-line value support
1002
  if reasoning_text:
1003
+ lines = reasoning_text.split('\n')
1004
+ current_key = None
1005
+ current_value_lines = []
1006
+
1007
+ def save_current_field():
1008
+ """Save the accumulated field value"""
1009
+ nonlocal current_key, current_value_lines
1010
+ if current_key and current_value_lines:
1011
+ # Join multi-line value
1012
+ value = '\n'.join(current_value_lines)
1013
+
1014
+ if current_key == 'bpm':
1015
+ try:
1016
+ metadata['bpm'] = int(value.strip())
1017
+ except:
1018
+ metadata['bpm'] = value.strip()
1019
+ elif current_key == 'caption':
1020
+ # Post-process caption to remove YAML multi-line formatting
1021
+ metadata['caption'] = MetadataConstrainedLogitsProcessor.postprocess_caption(value)
1022
+ elif current_key == 'duration':
1023
+ try:
1024
+ metadata['duration'] = int(value.strip())
1025
+ except:
1026
+ metadata['duration'] = value.strip()
1027
+ elif current_key == 'genres':
1028
+ metadata['genres'] = value.strip()
1029
+ elif current_key == 'keyscale':
1030
+ metadata['keyscale'] = value.strip()
1031
+ elif current_key == 'language':
1032
+ metadata['language'] = value.strip()
1033
+ elif current_key == 'timesignature':
1034
+ metadata['timesignature'] = value.strip()
1035
+
1036
+ current_key = None
1037
+ current_value_lines = []
1038
+
1039
+ for line in lines:
1040
+ # Skip lines starting with '<' (tags)
1041
+ if line.strip().startswith('<'):
1042
+ continue
1043
+
1044
+ # Check if this is a new field (no leading spaces and contains ':')
1045
+ if line and not line[0].isspace() and ':' in line:
1046
+ # Save previous field if any
1047
+ save_current_field()
1048
+
1049
+ # Parse new field
1050
  parts = line.split(':', 1)
1051
  if len(parts) == 2:
1052
+ current_key = parts[0].strip().lower()
1053
+ # First line of value (after colon)
1054
+ first_value = parts[1]
1055
+ if first_value.strip():
1056
+ current_value_lines.append(first_value)
1057
+ elif line.startswith(' ') or line.startswith('\t'):
1058
+ # Continuation line (YAML multi-line value)
1059
+ if current_key:
1060
+ current_value_lines.append(line)
1061
+
1062
+ # Don't forget to save the last field
1063
+ save_current_field()
 
 
 
 
 
 
 
1064
 
1065
  return metadata, audio_codes
1066