Spaces:
Running
on
A100
Running
on
A100
cot caption & language LM
Browse files- acestep/api_server.py +7 -3
- acestep/constants.py +97 -0
- acestep/constrained_logits_processor.py +385 -23
- acestep/gradio_ui.py +48 -17
- acestep/handler.py +84 -33
- acestep/llm_inference.py +86 -24
acestep/api_server.py
CHANGED
|
@@ -35,6 +35,10 @@ from starlette.datastructures import UploadFile as StarletteUploadFile
|
|
| 35 |
|
| 36 |
from acestep.handler import AceStepHandler
|
| 37 |
from acestep.llm_inference import LLMHandler
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
JobStatus = Literal["queued", "running", "succeeded", "failed"]
|
|
@@ -70,7 +74,7 @@ class GenerateMusicRequest(BaseModel):
|
|
| 70 |
repainting_start: float = 0.0
|
| 71 |
repainting_end: Optional[float] = None
|
| 72 |
|
| 73 |
-
instruction: str =
|
| 74 |
audio_cover_strength: float = 1.0
|
| 75 |
task_type: str = "text2music"
|
| 76 |
|
|
@@ -102,8 +106,8 @@ class GenerateMusicRequest(BaseModel):
|
|
| 102 |
_LM_DEFAULT_TEMPERATURE = 0.85
|
| 103 |
_LM_DEFAULT_CFG_SCALE = 2.0
|
| 104 |
_LM_DEFAULT_TOP_P = 0.9
|
| 105 |
-
_DEFAULT_DIT_INSTRUCTION =
|
| 106 |
-
_DEFAULT_LM_INSTRUCTION =
|
| 107 |
|
| 108 |
|
| 109 |
class CreateJobResponse(BaseModel):
|
|
|
|
| 35 |
|
| 36 |
from acestep.handler import AceStepHandler
|
| 37 |
from acestep.llm_inference import LLMHandler
|
| 38 |
+
from acestep.constants import (
|
| 39 |
+
DEFAULT_DIT_INSTRUCTION,
|
| 40 |
+
DEFAULT_LM_INSTRUCTION,
|
| 41 |
+
)
|
| 42 |
|
| 43 |
|
| 44 |
JobStatus = Literal["queued", "running", "succeeded", "failed"]
|
|
|
|
| 74 |
repainting_start: float = 0.0
|
| 75 |
repainting_end: Optional[float] = None
|
| 76 |
|
| 77 |
+
instruction: str = DEFAULT_DIT_INSTRUCTION
|
| 78 |
audio_cover_strength: float = 1.0
|
| 79 |
task_type: str = "text2music"
|
| 80 |
|
|
|
|
| 106 |
_LM_DEFAULT_TEMPERATURE = 0.85
|
| 107 |
_LM_DEFAULT_CFG_SCALE = 2.0
|
| 108 |
_LM_DEFAULT_TOP_P = 0.9
|
| 109 |
+
_DEFAULT_DIT_INSTRUCTION = DEFAULT_DIT_INSTRUCTION
|
| 110 |
+
_DEFAULT_LM_INSTRUCTION = DEFAULT_LM_INSTRUCTION
|
| 111 |
|
| 112 |
|
| 113 |
class CreateJobResponse(BaseModel):
|
acestep/constants.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Constants for ACE-Step
|
| 3 |
+
Centralized constants used across the codebase
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# ==============================================================================
|
| 7 |
+
# Language Constants
|
| 8 |
+
# ==============================================================================
|
| 9 |
+
|
| 10 |
+
VALID_LANGUAGES = [
|
| 11 |
+
'ar', 'az', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en',
|
| 12 |
+
'es', 'fa', 'fi', 'fr', 'he', 'hi', 'hr', 'ht', 'hu', 'id',
|
| 13 |
+
'is', 'it', 'ja', 'ko', 'la', 'lt', 'ms', 'ne', 'nl', 'no',
|
| 14 |
+
'pa', 'pl', 'pt', 'ro', 'ru', 'sa', 'sk', 'sr', 'sv', 'sw',
|
| 15 |
+
'ta', 'te', 'th', 'tl', 'tr', 'uk', 'ur', 'vi', 'yue', 'zh',
|
| 16 |
+
'unknown'
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ==============================================================================
|
| 21 |
+
# Keyscale Constants
|
| 22 |
+
# ==============================================================================
|
| 23 |
+
|
| 24 |
+
KEYSCALE_NOTES = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
|
| 25 |
+
KEYSCALE_ACCIDENTALS = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
|
| 26 |
+
KEYSCALE_MODES = ['major', 'minor']
|
| 27 |
+
|
| 28 |
+
# Generate all valid keyscales: 7 notes × 5 accidentals × 2 modes = 70 combinations
|
| 29 |
+
VALID_KEYSCALES = set()
|
| 30 |
+
for note in KEYSCALE_NOTES:
|
| 31 |
+
for acc in KEYSCALE_ACCIDENTALS:
|
| 32 |
+
for mode in KEYSCALE_MODES:
|
| 33 |
+
VALID_KEYSCALES.add(f"{note}{acc} {mode}")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ==============================================================================
|
| 37 |
+
# Metadata Range Constants
|
| 38 |
+
# ==============================================================================
|
| 39 |
+
|
| 40 |
+
# BPM (Beats Per Minute) range
|
| 41 |
+
BPM_MIN = 30
|
| 42 |
+
BPM_MAX = 300
|
| 43 |
+
|
| 44 |
+
# Duration range (in seconds)
|
| 45 |
+
DURATION_MIN = 10
|
| 46 |
+
DURATION_MAX = 600
|
| 47 |
+
|
| 48 |
+
# Valid time signatures
|
| 49 |
+
VALID_TIME_SIGNATURES = [2, 3, 4, 6]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ==============================================================================
|
| 53 |
+
# Task Type Constants
|
| 54 |
+
# ==============================================================================
|
| 55 |
+
|
| 56 |
+
TASK_TYPES = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
|
| 57 |
+
|
| 58 |
+
# Task types available for turbo models (subset)
|
| 59 |
+
TASK_TYPES_TURBO = ["text2music", "repaint", "cover"]
|
| 60 |
+
|
| 61 |
+
# Task types available for base models (full set)
|
| 62 |
+
TASK_TYPES_BASE = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# ==============================================================================
|
| 66 |
+
# Instruction Constants
|
| 67 |
+
# ==============================================================================
|
| 68 |
+
|
| 69 |
+
# Default instructions
|
| 70 |
+
DEFAULT_DIT_INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
|
| 71 |
+
DEFAULT_LM_INSTRUCTION = "Generate audio semantic tokens based on the given conditions:"
|
| 72 |
+
|
| 73 |
+
# Instruction templates for each task type
|
| 74 |
+
# Note: Some instructions use placeholders like {TRACK_NAME} or {TRACK_CLASSES}
|
| 75 |
+
# These should be formatted using .format() or f-strings when used
|
| 76 |
+
TASK_INSTRUCTIONS = {
|
| 77 |
+
"text2music": "Fill the audio semantic mask based on the given conditions:",
|
| 78 |
+
"repaint": "Repaint the mask area based on the given conditions:",
|
| 79 |
+
"cover": "Generate audio semantic tokens based on the given conditions:",
|
| 80 |
+
"extract": "Extract the {TRACK_NAME} track from the audio:",
|
| 81 |
+
"extract_default": "Extract the track from the audio:",
|
| 82 |
+
"lego": "Generate the {TRACK_NAME} track based on the audio context:",
|
| 83 |
+
"lego_default": "Generate the track based on the audio context:",
|
| 84 |
+
"complete": "Complete the input track with {TRACK_CLASSES}:",
|
| 85 |
+
"complete_default": "Complete the input track:",
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ==============================================================================
|
| 90 |
+
# Track/Instrument Constants
|
| 91 |
+
# ==============================================================================
|
| 92 |
+
|
| 93 |
+
TRACK_NAMES = [
|
| 94 |
+
"woodwinds", "brass", "fx", "synth", "strings", "percussion",
|
| 95 |
+
"keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
|
| 96 |
+
]
|
| 97 |
+
|
acestep/constrained_logits_processor.py
CHANGED
|
@@ -6,6 +6,18 @@ from transformers import AutoTokenizer
|
|
| 6 |
from transformers.generation.logits_process import LogitsProcessor
|
| 7 |
import os
|
| 8 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
# ==============================================================================
|
|
@@ -18,6 +30,8 @@ class FSMState(Enum):
|
|
| 18 |
BPM_NAME = auto() # Generating "bpm: "
|
| 19 |
BPM_VALUE = auto() # Generating numeric value 30-300
|
| 20 |
NEWLINE_AFTER_BPM = auto() # Generating "\n" after bpm value
|
|
|
|
|
|
|
| 21 |
DURATION_NAME = auto() # Generating "duration: "
|
| 22 |
DURATION_VALUE = auto() # Generating numeric value 10-600
|
| 23 |
NEWLINE_AFTER_DURATION = auto()
|
|
@@ -27,6 +41,8 @@ class FSMState(Enum):
|
|
| 27 |
KEYSCALE_NAME = auto() # Generating "keyscale: "
|
| 28 |
KEYSCALE_VALUE = auto() # Generating keyscale pattern
|
| 29 |
NEWLINE_AFTER_KEYSCALE = auto()
|
|
|
|
|
|
|
| 30 |
TIMESIG_NAME = auto() # Generating "timesignature: "
|
| 31 |
TIMESIG_VALUE = auto() # Generating 2, 3, 4, or 6
|
| 32 |
NEWLINE_AFTER_TIMESIG = auto()
|
|
@@ -42,15 +58,18 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 42 |
This processor enforces the following format:
|
| 43 |
<think>
|
| 44 |
bpm: [30-300]
|
|
|
|
| 45 |
duration: [10-600]
|
| 46 |
-
genres: [any non-empty string]
|
| 47 |
keyscale: [A-G][#/♭]? [major/minor]
|
|
|
|
| 48 |
timesignature: [2/3/4/6]
|
| 49 |
</think>
|
| 50 |
|
| 51 |
It uses token masking (setting invalid token logits to -inf) to enforce constraints.
|
| 52 |
For numeric fields, it uses early-blocking to prevent out-of-range values.
|
| 53 |
For field transitions (e.g., end of numeric value), it compares P(newline) vs P(digit).
|
|
|
|
|
|
|
| 54 |
"""
|
| 55 |
|
| 56 |
def __init__(
|
|
@@ -80,15 +99,19 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 80 |
self.enabled = enabled
|
| 81 |
self.debug = debug
|
| 82 |
self.skip_genres = skip_genres
|
|
|
|
|
|
|
| 83 |
self.caption: Optional[str] = None # Set via update_caption() before each generation
|
| 84 |
|
| 85 |
# User-provided metadata fields (optional)
|
| 86 |
# If provided, these fields will be used directly instead of generating
|
| 87 |
-
# Format: {"bpm": "120", "duration": "234", "keyscale": "G major", "timesignature": "4", "genres": "Pop Rock"}
|
| 88 |
self.user_provided_metadata: Dict[str, Optional[str]] = {
|
| 89 |
"bpm": None,
|
|
|
|
| 90 |
"duration": None,
|
| 91 |
"keyscale": None,
|
|
|
|
| 92 |
"timesignature": None,
|
| 93 |
"genres": None,
|
| 94 |
}
|
|
@@ -114,6 +137,10 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 114 |
self.accumulated_value = "" # For numeric/text value accumulation (legacy, for compatibility)
|
| 115 |
self.accumulated_token_ids: List[int] = [] # Token ID sequence for keyscale (and other fields)
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
# Token queue for user-provided fields (injected directly without generation)
|
| 118 |
self.user_field_token_queue: List[int] = []
|
| 119 |
self.current_user_field: Optional[str] = None # Current field being injected
|
|
@@ -137,9 +164,9 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 137 |
|
| 138 |
# Field definitions (needed before building prefix trees)
|
| 139 |
self.field_specs = {
|
| 140 |
-
"bpm": {"min":
|
| 141 |
-
"duration": {"min":
|
| 142 |
-
"timesignature": {"valid_values":
|
| 143 |
}
|
| 144 |
|
| 145 |
# Build valid numeric values for BPM, Duration, Timesignature
|
|
@@ -170,6 +197,9 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 170 |
context_prefix_for_tokenization="timesignature: "
|
| 171 |
)
|
| 172 |
|
|
|
|
|
|
|
|
|
|
| 173 |
self._load_genres_vocab()
|
| 174 |
|
| 175 |
# Note: Caption-based genre filtering is initialized via update_caption() before each generation
|
|
@@ -182,9 +212,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 182 |
FSMState.THINK_TAG: "<think>",
|
| 183 |
FSMState.NEWLINE_AFTER_THINK: "\n",
|
| 184 |
FSMState.BPM_NAME: "bpm:",
|
|
|
|
| 185 |
FSMState.DURATION_NAME: "duration:",
|
| 186 |
FSMState.GENRES_NAME: "genres:",
|
| 187 |
FSMState.KEYSCALE_NAME: "keyscale:",
|
|
|
|
| 188 |
FSMState.TIMESIG_NAME: "timesignature:",
|
| 189 |
FSMState.THINK_END_TAG: "</think>",
|
| 190 |
}
|
|
@@ -198,17 +230,21 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 198 |
even if the field is user-provided (we still need to generate the field name).
|
| 199 |
|
| 200 |
Args:
|
| 201 |
-
current_field: Current field name ("bpm", "duration", "genres", "keyscale", "timesignature")
|
| 202 |
|
| 203 |
Returns:
|
| 204 |
Next FSMState (NAME state of next field), or THINK_END_TAG if no more fields
|
| 205 |
"""
|
| 206 |
-
|
|
|
|
|
|
|
| 207 |
field_to_state = {
|
| 208 |
"bpm": FSMState.BPM_NAME,
|
|
|
|
| 209 |
"duration": FSMState.DURATION_NAME,
|
| 210 |
"genres": FSMState.GENRES_NAME,
|
| 211 |
"keyscale": FSMState.KEYSCALE_NAME,
|
|
|
|
| 212 |
"timesignature": FSMState.TIMESIG_NAME,
|
| 213 |
}
|
| 214 |
|
|
@@ -221,9 +257,13 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 221 |
for i in range(current_idx + 1, len(field_order)):
|
| 222 |
field = field_order[i]
|
| 223 |
|
| 224 |
-
# Skip
|
| 225 |
if field == "genres" and self.skip_genres:
|
| 226 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
# Return the next field's NAME state (even if user-provided, we still generate field name)
|
| 229 |
return field_to_state[field]
|
|
@@ -241,12 +281,17 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 241 |
}
|
| 242 |
|
| 243 |
# Build transitions for all fields (even if user-provided, we still need to generate field name)
|
| 244 |
-
# Field order: bpm -> duration -> genres -> keyscale -> timesignature
|
| 245 |
|
| 246 |
-
# BPM field: NAME -> VALUE -> next field
|
| 247 |
self.next_state[FSMState.BPM_NAME] = FSMState.BPM_VALUE
|
| 248 |
self.next_state[FSMState.BPM_VALUE] = self._get_next_field_state("bpm")
|
| 249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
# Duration field: NAME -> VALUE -> next field
|
| 251 |
self.next_state[FSMState.DURATION_NAME] = FSMState.DURATION_VALUE
|
| 252 |
self.next_state[FSMState.DURATION_VALUE] = self._get_next_field_state("duration")
|
|
@@ -256,10 +301,15 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 256 |
self.next_state[FSMState.GENRES_NAME] = FSMState.GENRES_VALUE
|
| 257 |
self.next_state[FSMState.GENRES_VALUE] = self._get_next_field_state("genres")
|
| 258 |
|
| 259 |
-
# Keyscale field: NAME -> VALUE -> next field
|
| 260 |
self.next_state[FSMState.KEYSCALE_NAME] = FSMState.KEYSCALE_VALUE
|
| 261 |
self.next_state[FSMState.KEYSCALE_VALUE] = self._get_next_field_state("keyscale")
|
| 262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
# Timesignature field: NAME -> VALUE -> THINK_END_TAG
|
| 264 |
self.next_state[FSMState.TIMESIG_NAME] = FSMState.TIMESIG_VALUE
|
| 265 |
self.next_state[FSMState.TIMESIG_VALUE] = FSMState.THINK_END_TAG
|
|
@@ -269,6 +319,49 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 269 |
self.skip_genres = skip
|
| 270 |
self._build_state_transitions()
|
| 271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
def set_stop_at_reasoning(self, stop: bool):
|
| 273 |
"""
|
| 274 |
Set whether to stop generation after </think> tag.
|
|
@@ -287,8 +380,10 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 287 |
Args:
|
| 288 |
metadata: Dictionary with optional fields:
|
| 289 |
- "bpm": Optional[str] - e.g., "120"
|
|
|
|
| 290 |
- "duration": Optional[str] - e.g., "234"
|
| 291 |
- "keyscale": Optional[str] - e.g., "G major"
|
|
|
|
| 292 |
- "timesignature": Optional[str] - e.g., "4"
|
| 293 |
- "genres": Optional[str] - e.g., "Pop Rock"
|
| 294 |
If None, clears all user-provided metadata.
|
|
@@ -297,7 +392,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 297 |
metadata = {}
|
| 298 |
|
| 299 |
# Update user-provided metadata
|
| 300 |
-
for field in ["bpm", "duration", "keyscale", "timesignature", "genres"]:
|
| 301 |
if field in metadata:
|
| 302 |
self.user_provided_metadata[field] = metadata[field]
|
| 303 |
else:
|
|
@@ -328,7 +423,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 328 |
|
| 329 |
# Note tokens for keyscale (A-G)
|
| 330 |
self.note_tokens = {}
|
| 331 |
-
for note in
|
| 332 |
tokens = self.tokenizer.encode(note, add_special_tokens=False)
|
| 333 |
if tokens:
|
| 334 |
self.note_tokens[note] = tokens[-1]
|
|
@@ -370,21 +465,80 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 370 |
# EOS token for duration-constrained codes generation
|
| 371 |
self.eos_token_id = self.tokenizer.eos_token_id
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
# Build valid keyscales set (prefix tree will be built after _char_to_tokens is initialized)
|
| 374 |
# 7 notes × 5 accidentals (none, #, b, ♯, ♭) × 2 modes = 70 valid combinations
|
| 375 |
-
|
| 376 |
-
accidentals = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
|
| 377 |
-
modes = ['major', 'minor']
|
| 378 |
-
|
| 379 |
-
self.valid_keyscales = set()
|
| 380 |
-
for note in notes:
|
| 381 |
-
for acc in accidentals:
|
| 382 |
-
for mode in modes:
|
| 383 |
-
self.valid_keyscales.add(f"{note}{acc} {mode}")
|
| 384 |
|
| 385 |
# keyscale_prefix_tree will be built in _precompute_char_token_mapping() after _char_to_tokens is ready
|
| 386 |
# Numeric prefix trees will be built after field_specs is defined
|
| 387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
|
| 389 |
"""
|
| 390 |
Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
|
|
@@ -560,6 +714,68 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 560 |
|
| 561 |
return prefix_to_tokens
|
| 562 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
def diagnose_keyscale_prefix_tree(self):
|
| 564 |
"""
|
| 565 |
Diagnose the keyscale prefix tree to help debug generation bias.
|
|
@@ -926,6 +1142,8 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 926 |
self.codes_count = 0 # Reset codes counter
|
| 927 |
self.user_field_token_queue = [] # Reset user field token queue
|
| 928 |
self.current_user_field = None # Reset current user field
|
|
|
|
|
|
|
| 929 |
|
| 930 |
def set_target_duration(self, duration: Optional[float]):
|
| 931 |
"""
|
|
@@ -1170,6 +1388,20 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1170 |
return self.newline_token in self.keyscale_prefix_tree[token_prefix]
|
| 1171 |
return False
|
| 1172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1173 |
def _get_allowed_timesig_tokens(self) -> List[int]:
|
| 1174 |
"""
|
| 1175 |
Get allowed tokens for timesignature field using the precomputed prefix tree.
|
|
@@ -1269,7 +1501,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1269 |
Uses the same tokenization logic as prefix tree building.
|
| 1270 |
|
| 1271 |
Args:
|
| 1272 |
-
field_name: Field name ("bpm", "duration", "keyscale", "timesignature", "genres")
|
| 1273 |
|
| 1274 |
Returns:
|
| 1275 |
List of token IDs for the complete field, or None if field is not provided
|
|
@@ -1281,8 +1513,10 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1281 |
# Build full field string with space (matching prefix tree tokenization)
|
| 1282 |
field_to_prefix = {
|
| 1283 |
"bpm": "bpm: ",
|
|
|
|
| 1284 |
"duration": "duration: ",
|
| 1285 |
"keyscale": "keyscale: ",
|
|
|
|
| 1286 |
"timesignature": "timesignature: ",
|
| 1287 |
"genres": "genres: ",
|
| 1288 |
}
|
|
@@ -1410,6 +1644,67 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1410 |
|
| 1411 |
scores = scores + mask
|
| 1412 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1413 |
elif self.state == FSMState.DURATION_VALUE:
|
| 1414 |
# Check if field is user-provided and we haven't started injecting yet
|
| 1415 |
if self.user_provided_metadata["duration"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
|
|
@@ -1539,6 +1834,43 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1539 |
mask[0, self.newline_token] = 0
|
| 1540 |
scores = scores + mask
|
| 1541 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1542 |
elif self.state == FSMState.TIMESIG_VALUE:
|
| 1543 |
# Check if field is user-provided and we haven't started injecting yet
|
| 1544 |
if self.user_provided_metadata["timesignature"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
|
|
@@ -1587,6 +1919,8 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1587 |
self.position_in_state = 0
|
| 1588 |
self.accumulated_value = "" # Legacy, kept for compatibility
|
| 1589 |
self.accumulated_token_ids = [] # Reset token ID sequence for new field
|
|
|
|
|
|
|
| 1590 |
if self.debug:
|
| 1591 |
logger.debug(f"FSM transition: {old_state.name} -> {self.state.name}")
|
| 1592 |
|
|
@@ -1703,6 +2037,22 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1703 |
# Genres still uses string-based trie, so keep accumulated_value
|
| 1704 |
self.accumulated_value += token_str
|
| 1705 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1706 |
elif self.state == FSMState.KEYSCALE_VALUE:
|
| 1707 |
if generated_token_id == self.newline_token:
|
| 1708 |
# Newline ends the field
|
|
@@ -1718,4 +2068,16 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1718 |
self.accumulated_token_ids.append(generated_token_id)
|
| 1719 |
# Also update legacy accumulated_value for compatibility
|
| 1720 |
self.accumulated_value += token_str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1721 |
|
|
|
|
| 6 |
from transformers.generation.logits_process import LogitsProcessor
|
| 7 |
import os
|
| 8 |
import torch
|
| 9 |
+
from acestep.constants import (
|
| 10 |
+
VALID_LANGUAGES,
|
| 11 |
+
KEYSCALE_NOTES,
|
| 12 |
+
KEYSCALE_ACCIDENTALS,
|
| 13 |
+
KEYSCALE_MODES,
|
| 14 |
+
VALID_KEYSCALES,
|
| 15 |
+
BPM_MIN,
|
| 16 |
+
BPM_MAX,
|
| 17 |
+
DURATION_MIN,
|
| 18 |
+
DURATION_MAX,
|
| 19 |
+
VALID_TIME_SIGNATURES,
|
| 20 |
+
)
|
| 21 |
|
| 22 |
|
| 23 |
# ==============================================================================
|
|
|
|
| 30 |
BPM_NAME = auto() # Generating "bpm: "
|
| 31 |
BPM_VALUE = auto() # Generating numeric value 30-300
|
| 32 |
NEWLINE_AFTER_BPM = auto() # Generating "\n" after bpm value
|
| 33 |
+
CAPTION_NAME = auto() # Generating "caption: "
|
| 34 |
+
CAPTION_VALUE = auto() # Generating caption text (no code blocks/newlines)
|
| 35 |
DURATION_NAME = auto() # Generating "duration: "
|
| 36 |
DURATION_VALUE = auto() # Generating numeric value 10-600
|
| 37 |
NEWLINE_AFTER_DURATION = auto()
|
|
|
|
| 41 |
KEYSCALE_NAME = auto() # Generating "keyscale: "
|
| 42 |
KEYSCALE_VALUE = auto() # Generating keyscale pattern
|
| 43 |
NEWLINE_AFTER_KEYSCALE = auto()
|
| 44 |
+
LANGUAGE_NAME = auto() # Generating "language: "
|
| 45 |
+
LANGUAGE_VALUE = auto() # Generating language code (en, zh, ja, etc.)
|
| 46 |
TIMESIG_NAME = auto() # Generating "timesignature: "
|
| 47 |
TIMESIG_VALUE = auto() # Generating 2, 3, 4, or 6
|
| 48 |
NEWLINE_AFTER_TIMESIG = auto()
|
|
|
|
| 58 |
This processor enforces the following format:
|
| 59 |
<think>
|
| 60 |
bpm: [30-300]
|
| 61 |
+
caption: [text without code blocks, ends with period + newline]
|
| 62 |
duration: [10-600]
|
|
|
|
| 63 |
keyscale: [A-G][#/♭]? [major/minor]
|
| 64 |
+
language: [en/zh/ja/ko/es/fr/de/uk/ru/...]
|
| 65 |
timesignature: [2/3/4/6]
|
| 66 |
</think>
|
| 67 |
|
| 68 |
It uses token masking (setting invalid token logits to -inf) to enforce constraints.
|
| 69 |
For numeric fields, it uses early-blocking to prevent out-of-range values.
|
| 70 |
For field transitions (e.g., end of numeric value), it compares P(newline) vs P(digit).
|
| 71 |
+
For caption field, it blocks code blocks and newlines, and only transitions when
|
| 72 |
+
the previous token was a period and newline has the highest probability.
|
| 73 |
"""
|
| 74 |
|
| 75 |
def __init__(
|
|
|
|
| 99 |
self.enabled = enabled
|
| 100 |
self.debug = debug
|
| 101 |
self.skip_genres = skip_genres
|
| 102 |
+
self.skip_caption = False # Set to True to skip caption field generation
|
| 103 |
+
self.skip_language = False # Set to True to skip language field generation
|
| 104 |
self.caption: Optional[str] = None # Set via update_caption() before each generation
|
| 105 |
|
| 106 |
# User-provided metadata fields (optional)
|
| 107 |
# If provided, these fields will be used directly instead of generating
|
| 108 |
+
# Format: {"bpm": "120", "caption": "...", "duration": "234", "keyscale": "G major", "language": "en", "timesignature": "4", "genres": "Pop Rock"}
|
| 109 |
self.user_provided_metadata: Dict[str, Optional[str]] = {
|
| 110 |
"bpm": None,
|
| 111 |
+
"caption": None,
|
| 112 |
"duration": None,
|
| 113 |
"keyscale": None,
|
| 114 |
+
"language": None,
|
| 115 |
"timesignature": None,
|
| 116 |
"genres": None,
|
| 117 |
}
|
|
|
|
| 137 |
self.accumulated_value = "" # For numeric/text value accumulation (legacy, for compatibility)
|
| 138 |
self.accumulated_token_ids: List[int] = [] # Token ID sequence for keyscale (and other fields)
|
| 139 |
|
| 140 |
+
# Caption generation state tracking
|
| 141 |
+
self.caption_after_newline = False # Track if we're right after a newline in caption
|
| 142 |
+
self.caption_token_count = 0 # Track token count for caption (max 512)
|
| 143 |
+
|
| 144 |
# Token queue for user-provided fields (injected directly without generation)
|
| 145 |
self.user_field_token_queue: List[int] = []
|
| 146 |
self.current_user_field: Optional[str] = None # Current field being injected
|
|
|
|
| 164 |
|
| 165 |
# Field definitions (needed before building prefix trees)
|
| 166 |
self.field_specs = {
|
| 167 |
+
"bpm": {"min": BPM_MIN, "max": BPM_MAX},
|
| 168 |
+
"duration": {"min": DURATION_MIN, "max": DURATION_MAX},
|
| 169 |
+
"timesignature": {"valid_values": VALID_TIME_SIGNATURES},
|
| 170 |
}
|
| 171 |
|
| 172 |
# Build valid numeric values for BPM, Duration, Timesignature
|
|
|
|
| 197 |
context_prefix_for_tokenization="timesignature: "
|
| 198 |
)
|
| 199 |
|
| 200 |
+
# Build language prefix tree (similar to keyscale but for language codes)
|
| 201 |
+
self.language_prefix_tree = self._build_language_prefix_tree()
|
| 202 |
+
|
| 203 |
self._load_genres_vocab()
|
| 204 |
|
| 205 |
# Note: Caption-based genre filtering is initialized via update_caption() before each generation
|
|
|
|
| 212 |
FSMState.THINK_TAG: "<think>",
|
| 213 |
FSMState.NEWLINE_AFTER_THINK: "\n",
|
| 214 |
FSMState.BPM_NAME: "bpm:",
|
| 215 |
+
FSMState.CAPTION_NAME: "caption:",
|
| 216 |
FSMState.DURATION_NAME: "duration:",
|
| 217 |
FSMState.GENRES_NAME: "genres:",
|
| 218 |
FSMState.KEYSCALE_NAME: "keyscale:",
|
| 219 |
+
FSMState.LANGUAGE_NAME: "language:",
|
| 220 |
FSMState.TIMESIG_NAME: "timesignature:",
|
| 221 |
FSMState.THINK_END_TAG: "</think>",
|
| 222 |
}
|
|
|
|
| 230 |
even if the field is user-provided (we still need to generate the field name).
|
| 231 |
|
| 232 |
Args:
|
| 233 |
+
current_field: Current field name ("bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature")
|
| 234 |
|
| 235 |
Returns:
|
| 236 |
Next FSMState (NAME state of next field), or THINK_END_TAG if no more fields
|
| 237 |
"""
|
| 238 |
+
# New field order: bpm -> caption -> duration -> keyscale -> language -> timesignature
|
| 239 |
+
# genres is optional and can be skipped
|
| 240 |
+
field_order = ["bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature"]
|
| 241 |
field_to_state = {
|
| 242 |
"bpm": FSMState.BPM_NAME,
|
| 243 |
+
"caption": FSMState.CAPTION_NAME,
|
| 244 |
"duration": FSMState.DURATION_NAME,
|
| 245 |
"genres": FSMState.GENRES_NAME,
|
| 246 |
"keyscale": FSMState.KEYSCALE_NAME,
|
| 247 |
+
"language": FSMState.LANGUAGE_NAME,
|
| 248 |
"timesignature": FSMState.TIMESIG_NAME,
|
| 249 |
}
|
| 250 |
|
|
|
|
| 257 |
for i in range(current_idx + 1, len(field_order)):
|
| 258 |
field = field_order[i]
|
| 259 |
|
| 260 |
+
# Skip fields based on flags
|
| 261 |
if field == "genres" and self.skip_genres:
|
| 262 |
continue
|
| 263 |
+
if field == "caption" and self.skip_caption:
|
| 264 |
+
continue
|
| 265 |
+
if field == "language" and self.skip_language:
|
| 266 |
+
continue
|
| 267 |
|
| 268 |
# Return the next field's NAME state (even if user-provided, we still generate field name)
|
| 269 |
return field_to_state[field]
|
|
|
|
| 281 |
}
|
| 282 |
|
| 283 |
# Build transitions for all fields (even if user-provided, we still need to generate field name)
|
| 284 |
+
# Field order: bpm -> caption -> duration -> genres -> keyscale -> language -> timesignature
|
| 285 |
|
| 286 |
+
# BPM field: NAME -> VALUE -> next field (caption or duration)
|
| 287 |
self.next_state[FSMState.BPM_NAME] = FSMState.BPM_VALUE
|
| 288 |
self.next_state[FSMState.BPM_VALUE] = self._get_next_field_state("bpm")
|
| 289 |
|
| 290 |
+
# Caption field (only if not skipped): NAME -> VALUE -> next field (duration)
|
| 291 |
+
if not self.skip_caption:
|
| 292 |
+
self.next_state[FSMState.CAPTION_NAME] = FSMState.CAPTION_VALUE
|
| 293 |
+
self.next_state[FSMState.CAPTION_VALUE] = self._get_next_field_state("caption")
|
| 294 |
+
|
| 295 |
# Duration field: NAME -> VALUE -> next field
|
| 296 |
self.next_state[FSMState.DURATION_NAME] = FSMState.DURATION_VALUE
|
| 297 |
self.next_state[FSMState.DURATION_VALUE] = self._get_next_field_state("duration")
|
|
|
|
| 301 |
self.next_state[FSMState.GENRES_NAME] = FSMState.GENRES_VALUE
|
| 302 |
self.next_state[FSMState.GENRES_VALUE] = self._get_next_field_state("genres")
|
| 303 |
|
| 304 |
+
# Keyscale field: NAME -> VALUE -> next field (language or timesignature)
|
| 305 |
self.next_state[FSMState.KEYSCALE_NAME] = FSMState.KEYSCALE_VALUE
|
| 306 |
self.next_state[FSMState.KEYSCALE_VALUE] = self._get_next_field_state("keyscale")
|
| 307 |
|
| 308 |
+
# Language field (only if not skipped): NAME -> VALUE -> next field (timesignature)
|
| 309 |
+
if not self.skip_language:
|
| 310 |
+
self.next_state[FSMState.LANGUAGE_NAME] = FSMState.LANGUAGE_VALUE
|
| 311 |
+
self.next_state[FSMState.LANGUAGE_VALUE] = self._get_next_field_state("language")
|
| 312 |
+
|
| 313 |
# Timesignature field: NAME -> VALUE -> THINK_END_TAG
|
| 314 |
self.next_state[FSMState.TIMESIG_NAME] = FSMState.TIMESIG_VALUE
|
| 315 |
self.next_state[FSMState.TIMESIG_VALUE] = FSMState.THINK_END_TAG
|
|
|
|
| 319 |
self.skip_genres = skip
|
| 320 |
self._build_state_transitions()
|
| 321 |
|
| 322 |
+
def set_skip_caption(self, skip: bool):
|
| 323 |
+
"""Set whether to skip caption generation and rebuild state transitions."""
|
| 324 |
+
self.skip_caption = skip
|
| 325 |
+
self._build_state_transitions()
|
| 326 |
+
|
| 327 |
+
def set_skip_language(self, skip: bool):
|
| 328 |
+
"""Set whether to skip language generation and rebuild state transitions."""
|
| 329 |
+
self.skip_language = skip
|
| 330 |
+
self._build_state_transitions()
|
| 331 |
+
|
| 332 |
+
@staticmethod
|
| 333 |
+
def postprocess_caption(caption: str) -> str:
|
| 334 |
+
"""
|
| 335 |
+
Post-process caption to remove YAML multi-line formatting.
|
| 336 |
+
Converts YAML-style multi-line text (with newlines and leading spaces)
|
| 337 |
+
to a single-line string.
|
| 338 |
+
|
| 339 |
+
Example:
|
| 340 |
+
Input: "An emotional ballad.\\n The track opens with piano.\\n More text."
|
| 341 |
+
Output: "An emotional ballad. The track opens with piano. More text."
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
caption: Raw caption text with possible YAML formatting
|
| 345 |
+
|
| 346 |
+
Returns:
|
| 347 |
+
Clean single-line caption
|
| 348 |
+
"""
|
| 349 |
+
if not caption:
|
| 350 |
+
return caption
|
| 351 |
+
|
| 352 |
+
# Split by newlines
|
| 353 |
+
lines = caption.split('\n')
|
| 354 |
+
|
| 355 |
+
# Process each line: strip leading/trailing whitespace
|
| 356 |
+
cleaned_lines = []
|
| 357 |
+
for line in lines:
|
| 358 |
+
stripped = line.strip()
|
| 359 |
+
if stripped:
|
| 360 |
+
cleaned_lines.append(stripped)
|
| 361 |
+
|
| 362 |
+
# Join with single space
|
| 363 |
+
return ' '.join(cleaned_lines)
|
| 364 |
+
|
| 365 |
def set_stop_at_reasoning(self, stop: bool):
|
| 366 |
"""
|
| 367 |
Set whether to stop generation after </think> tag.
|
|
|
|
| 380 |
Args:
|
| 381 |
metadata: Dictionary with optional fields:
|
| 382 |
- "bpm": Optional[str] - e.g., "120"
|
| 383 |
+
- "caption": Optional[str] - e.g., "A melodic piano piece..."
|
| 384 |
- "duration": Optional[str] - e.g., "234"
|
| 385 |
- "keyscale": Optional[str] - e.g., "G major"
|
| 386 |
+
- "language": Optional[str] - e.g., "en"
|
| 387 |
- "timesignature": Optional[str] - e.g., "4"
|
| 388 |
- "genres": Optional[str] - e.g., "Pop Rock"
|
| 389 |
If None, clears all user-provided metadata.
|
|
|
|
| 392 |
metadata = {}
|
| 393 |
|
| 394 |
# Update user-provided metadata
|
| 395 |
+
for field in ["bpm", "caption", "duration", "keyscale", "language", "timesignature", "genres"]:
|
| 396 |
if field in metadata:
|
| 397 |
self.user_provided_metadata[field] = metadata[field]
|
| 398 |
else:
|
|
|
|
| 423 |
|
| 424 |
# Note tokens for keyscale (A-G)
|
| 425 |
self.note_tokens = {}
|
| 426 |
+
for note in KEYSCALE_NOTES:
|
| 427 |
tokens = self.tokenizer.encode(note, add_special_tokens=False)
|
| 428 |
if tokens:
|
| 429 |
self.note_tokens[note] = tokens[-1]
|
|
|
|
| 465 |
# EOS token for duration-constrained codes generation
|
| 466 |
self.eos_token_id = self.tokenizer.eos_token_id
|
| 467 |
|
| 468 |
+
# Period token for caption field transition logic
|
| 469 |
+
period_tokens = self.tokenizer.encode(".", add_special_tokens=False)
|
| 470 |
+
self.period_token = period_tokens[-1] if period_tokens else None
|
| 471 |
+
|
| 472 |
+
# Backtick tokens for blocking code blocks in caption
|
| 473 |
+
backtick_tokens = self.tokenizer.encode("`", add_special_tokens=False)
|
| 474 |
+
self.backtick_token = backtick_tokens[-1] if backtick_tokens else None
|
| 475 |
+
|
| 476 |
+
# Valid language codes (ISO 639-1 and common variants)
|
| 477 |
+
self.valid_languages = VALID_LANGUAGES
|
| 478 |
+
|
| 479 |
+
# Precompute audio code token IDs (tokens matching <|audio_code_\d+|>)
|
| 480 |
+
# These should be blocked during caption generation
|
| 481 |
+
self.audio_code_token_ids: Set[int] = set()
|
| 482 |
+
self._precompute_audio_code_tokens()
|
| 483 |
+
|
| 484 |
+
# Precompute audio code mask for efficient blocking (O(1) instead of O(n))
|
| 485 |
+
# This mask will be added to scores during caption generation
|
| 486 |
+
self.audio_code_mask: Optional[torch.Tensor] = None
|
| 487 |
+
self._build_audio_code_mask()
|
| 488 |
+
|
| 489 |
# Build valid keyscales set (prefix tree will be built after _char_to_tokens is initialized)
|
| 490 |
# 7 notes × 5 accidentals (none, #, b, ♯, ♭) × 2 modes = 70 valid combinations
|
| 491 |
+
self.valid_keyscales = VALID_KEYSCALES.copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
|
| 493 |
# keyscale_prefix_tree will be built in _precompute_char_token_mapping() after _char_to_tokens is ready
|
| 494 |
# Numeric prefix trees will be built after field_specs is defined
|
| 495 |
|
| 496 |
+
def _precompute_audio_code_tokens(self):
|
| 497 |
+
"""
|
| 498 |
+
Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
|
| 499 |
+
These tokens should be blocked during caption generation.
|
| 500 |
+
"""
|
| 501 |
+
import re
|
| 502 |
+
audio_code_pattern = re.compile(r'^<\|audio_code_\d+\|>$')
|
| 503 |
+
|
| 504 |
+
# Iterate through vocabulary to find audio code tokens
|
| 505 |
+
for token_id in range(self.vocab_size):
|
| 506 |
+
try:
|
| 507 |
+
token_text = self.tokenizer.decode([token_id])
|
| 508 |
+
if audio_code_pattern.match(token_text):
|
| 509 |
+
self.audio_code_token_ids.add(token_id)
|
| 510 |
+
except Exception:
|
| 511 |
+
continue
|
| 512 |
+
|
| 513 |
+
if self.debug:
|
| 514 |
+
logger.debug(f"Found {len(self.audio_code_token_ids)} audio code tokens")
|
| 515 |
+
|
| 516 |
+
def _build_audio_code_mask(self):
|
| 517 |
+
"""
|
| 518 |
+
Build a precomputed mask tensor for blocking audio code tokens.
|
| 519 |
+
This mask can be added to scores in O(1) time instead of O(n) loop.
|
| 520 |
+
|
| 521 |
+
The mask is [1, vocab_size] tensor with -inf at audio code token positions.
|
| 522 |
+
"""
|
| 523 |
+
if not self.audio_code_token_ids:
|
| 524 |
+
self.audio_code_mask = None
|
| 525 |
+
return
|
| 526 |
+
|
| 527 |
+
# Create mask tensor: 0 everywhere, -inf at audio code positions
|
| 528 |
+
# Use float32 for compatibility with most model dtypes
|
| 529 |
+
mask = torch.zeros(1, self.vocab_size, dtype=torch.float32)
|
| 530 |
+
|
| 531 |
+
# Convert set to list for indexing
|
| 532 |
+
audio_code_indices = list(self.audio_code_token_ids)
|
| 533 |
+
|
| 534 |
+
# Set -inf at audio code token positions
|
| 535 |
+
mask[0, audio_code_indices] = float('-inf')
|
| 536 |
+
|
| 537 |
+
self.audio_code_mask = mask
|
| 538 |
+
|
| 539 |
+
if self.debug:
|
| 540 |
+
logger.debug(f"Built audio code mask for {len(self.audio_code_token_ids)} tokens")
|
| 541 |
+
|
| 542 |
def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
|
| 543 |
"""
|
| 544 |
Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
|
|
|
|
| 714 |
|
| 715 |
return prefix_to_tokens
|
| 716 |
|
| 717 |
+
def _build_language_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
|
| 718 |
+
"""
|
| 719 |
+
Build language prefix to allowed tokens mapping based on ACTUAL tokenization.
|
| 720 |
+
Similar to keyscale prefix tree but for language codes.
|
| 721 |
+
|
| 722 |
+
Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches.
|
| 723 |
+
"""
|
| 724 |
+
prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {}
|
| 725 |
+
|
| 726 |
+
context_prefix_for_matching = "language:"
|
| 727 |
+
context_prefix_for_tokenization = "language: "
|
| 728 |
+
|
| 729 |
+
context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False)
|
| 730 |
+
|
| 731 |
+
if self.debug:
|
| 732 |
+
context_tokens_str = [self.tokenizer.decode([t]) for t in context_token_ids]
|
| 733 |
+
logger.debug(f"Context for matching 'language:' tokenizes to {context_token_ids} -> {context_tokens_str}")
|
| 734 |
+
|
| 735 |
+
for lang in self.valid_languages:
|
| 736 |
+
full_text = context_prefix_for_tokenization + lang
|
| 737 |
+
full_token_ids = self.tokenizer.encode(full_text, add_special_tokens=False)
|
| 738 |
+
|
| 739 |
+
context_end_idx = None
|
| 740 |
+
if len(full_token_ids) >= len(context_token_ids):
|
| 741 |
+
if full_token_ids[:len(context_token_ids)] == context_token_ids:
|
| 742 |
+
context_end_idx = len(context_token_ids)
|
| 743 |
+
|
| 744 |
+
if context_end_idx is None:
|
| 745 |
+
if self.debug:
|
| 746 |
+
logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping")
|
| 747 |
+
continue
|
| 748 |
+
|
| 749 |
+
lang_token_ids = full_token_ids[context_end_idx:]
|
| 750 |
+
|
| 751 |
+
if not lang_token_ids:
|
| 752 |
+
if self.debug:
|
| 753 |
+
logger.warning(f"No tokens extracted for language '{lang}', skipping")
|
| 754 |
+
continue
|
| 755 |
+
|
| 756 |
+
for i in range(len(lang_token_ids) + 1):
|
| 757 |
+
token_prefix = tuple(lang_token_ids[:i])
|
| 758 |
+
|
| 759 |
+
if token_prefix not in prefix_to_tokens:
|
| 760 |
+
prefix_to_tokens[token_prefix] = set()
|
| 761 |
+
|
| 762 |
+
if i < len(lang_token_ids):
|
| 763 |
+
next_token_id = lang_token_ids[i]
|
| 764 |
+
prefix_to_tokens[token_prefix].add(next_token_id)
|
| 765 |
+
else:
|
| 766 |
+
if self.newline_token:
|
| 767 |
+
prefix_to_tokens[token_prefix].add(self.newline_token)
|
| 768 |
+
|
| 769 |
+
if self.debug:
|
| 770 |
+
logger.debug(f"Built language prefix tree with {len(prefix_to_tokens)} token sequence prefixes")
|
| 771 |
+
empty_prefix = tuple()
|
| 772 |
+
if empty_prefix in prefix_to_tokens:
|
| 773 |
+
first_tokens = prefix_to_tokens[empty_prefix]
|
| 774 |
+
decoded_first = [(t, repr(self.tokenizer.decode([t]))) for t in sorted(first_tokens)]
|
| 775 |
+
logger.debug(f"First tokens allowed for language (empty prefix): {decoded_first}")
|
| 776 |
+
|
| 777 |
+
return prefix_to_tokens
|
| 778 |
+
|
| 779 |
def diagnose_keyscale_prefix_tree(self):
|
| 780 |
"""
|
| 781 |
Diagnose the keyscale prefix tree to help debug generation bias.
|
|
|
|
| 1142 |
self.codes_count = 0 # Reset codes counter
|
| 1143 |
self.user_field_token_queue = [] # Reset user field token queue
|
| 1144 |
self.current_user_field = None # Reset current user field
|
| 1145 |
+
self.caption_after_newline = False # Reset caption newline tracking
|
| 1146 |
+
self.caption_token_count = 0 # Reset caption token count
|
| 1147 |
|
| 1148 |
def set_target_duration(self, duration: Optional[float]):
|
| 1149 |
"""
|
|
|
|
| 1388 |
return self.newline_token in self.keyscale_prefix_tree[token_prefix]
|
| 1389 |
return False
|
| 1390 |
|
| 1391 |
+
def _get_allowed_language_tokens(self) -> List[int]:
|
| 1392 |
+
"""
|
| 1393 |
+
Get allowed tokens for language field using the precomputed prefix tree.
|
| 1394 |
+
Uses token ID sequence as key (not string) to avoid tokenization mismatches.
|
| 1395 |
+
Similar to keyscale.
|
| 1396 |
+
"""
|
| 1397 |
+
token_prefix = tuple(self.accumulated_token_ids)
|
| 1398 |
+
|
| 1399 |
+
if token_prefix in self.language_prefix_tree:
|
| 1400 |
+
return list(self.language_prefix_tree[token_prefix])
|
| 1401 |
+
|
| 1402 |
+
# Fallback: no valid continuation found
|
| 1403 |
+
return []
|
| 1404 |
+
|
| 1405 |
def _get_allowed_timesig_tokens(self) -> List[int]:
|
| 1406 |
"""
|
| 1407 |
Get allowed tokens for timesignature field using the precomputed prefix tree.
|
|
|
|
| 1501 |
Uses the same tokenization logic as prefix tree building.
|
| 1502 |
|
| 1503 |
Args:
|
| 1504 |
+
field_name: Field name ("bpm", "caption", "duration", "keyscale", "language", "timesignature", "genres")
|
| 1505 |
|
| 1506 |
Returns:
|
| 1507 |
List of token IDs for the complete field, or None if field is not provided
|
|
|
|
| 1513 |
# Build full field string with space (matching prefix tree tokenization)
|
| 1514 |
field_to_prefix = {
|
| 1515 |
"bpm": "bpm: ",
|
| 1516 |
+
"caption": "caption: ",
|
| 1517 |
"duration": "duration: ",
|
| 1518 |
"keyscale": "keyscale: ",
|
| 1519 |
+
"language": "language: ",
|
| 1520 |
"timesignature": "timesignature: ",
|
| 1521 |
"genres": "genres: ",
|
| 1522 |
}
|
|
|
|
| 1644 |
|
| 1645 |
scores = scores + mask
|
| 1646 |
|
| 1647 |
+
elif self.state == FSMState.CAPTION_VALUE:
|
| 1648 |
+
# Caption field generation with YAML format support:
|
| 1649 |
+
# - Allow newlines and spaces (YAML multi-line formatting)
|
| 1650 |
+
# - Block audio codes and backticks
|
| 1651 |
+
# - Max 512 tokens
|
| 1652 |
+
# - Transition when model wants to generate next field (non-indented line)
|
| 1653 |
+
|
| 1654 |
+
# Check if field is user-provided and we haven't started injecting yet
|
| 1655 |
+
if self.user_provided_metadata["caption"] is not None and not self.user_field_token_queue and not self.accumulated_value:
|
| 1656 |
+
# Initialize token queue with field value tokens (value + newline)
|
| 1657 |
+
value = self.user_provided_metadata["caption"]
|
| 1658 |
+
value_text = f" {value}\n"
|
| 1659 |
+
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
|
| 1660 |
+
if value_tokens:
|
| 1661 |
+
self.user_field_token_queue = value_tokens
|
| 1662 |
+
self.current_user_field = "caption"
|
| 1663 |
+
# Inject first token
|
| 1664 |
+
mask[0, value_tokens[0]] = 0
|
| 1665 |
+
scores = scores + mask
|
| 1666 |
+
return scores
|
| 1667 |
+
|
| 1668 |
+
# Check if we should transition after a newline (non-indented line = new field)
|
| 1669 |
+
if self.caption_after_newline:
|
| 1670 |
+
# Get top token from current scores
|
| 1671 |
+
top_token_id = torch.argmax(scores[0]).item()
|
| 1672 |
+
top_token_text = self.tokenizer.decode([top_token_id])
|
| 1673 |
+
|
| 1674 |
+
# If top token does NOT start with space/tab, it's a new field (like "duration:")
|
| 1675 |
+
if len(top_token_text) > 0 and top_token_text[0] not in ' \t':
|
| 1676 |
+
# Caption is ending, transition to next field
|
| 1677 |
+
self.caption_after_newline = False
|
| 1678 |
+
self._transition_to_next_state()
|
| 1679 |
+
# Process with new state (DURATION_NAME)
|
| 1680 |
+
return self._process_single_sequence(input_ids, scores)
|
| 1681 |
+
else:
|
| 1682 |
+
# It's indentation, continue caption
|
| 1683 |
+
self.caption_after_newline = False
|
| 1684 |
+
|
| 1685 |
+
# Block backticks (code blocks)
|
| 1686 |
+
if self.backtick_token is not None:
|
| 1687 |
+
scores[0, self.backtick_token] = float('-inf')
|
| 1688 |
+
|
| 1689 |
+
# Block ALL audio code tokens (critical - these should never appear in caption)
|
| 1690 |
+
# Use precomputed mask for O(1) performance instead of O(n) loop
|
| 1691 |
+
if self.audio_code_mask is not None:
|
| 1692 |
+
# Move mask to same device/dtype as scores if needed
|
| 1693 |
+
if self.audio_code_mask.device != scores.device or self.audio_code_mask.dtype != scores.dtype:
|
| 1694 |
+
self.audio_code_mask = self.audio_code_mask.to(device=scores.device, dtype=scores.dtype)
|
| 1695 |
+
scores = scores + self.audio_code_mask
|
| 1696 |
+
|
| 1697 |
+
# Enforce 512 token limit for caption
|
| 1698 |
+
if self.caption_token_count >= 512:
|
| 1699 |
+
# Force end by only allowing newline
|
| 1700 |
+
if self.newline_token is not None:
|
| 1701 |
+
mask[0, self.newline_token] = 0
|
| 1702 |
+
scores = scores + mask
|
| 1703 |
+
return scores
|
| 1704 |
+
|
| 1705 |
+
# Allow natural generation (with blocked audio codes and backticks)
|
| 1706 |
+
return scores
|
| 1707 |
+
|
| 1708 |
elif self.state == FSMState.DURATION_VALUE:
|
| 1709 |
# Check if field is user-provided and we haven't started injecting yet
|
| 1710 |
if self.user_provided_metadata["duration"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
|
|
|
|
| 1834 |
mask[0, self.newline_token] = 0
|
| 1835 |
scores = scores + mask
|
| 1836 |
|
| 1837 |
+
elif self.state == FSMState.LANGUAGE_VALUE:
|
| 1838 |
+
# Language field: similar to keyscale, uses prefix tree
|
| 1839 |
+
|
| 1840 |
+
# Check if field is user-provided and we haven't started injecting yet
|
| 1841 |
+
if self.user_provided_metadata["language"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
|
| 1842 |
+
# Initialize token queue with field value tokens (value + newline)
|
| 1843 |
+
value = self.user_provided_metadata["language"]
|
| 1844 |
+
value_text = f" {value}\n"
|
| 1845 |
+
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
|
| 1846 |
+
if value_tokens:
|
| 1847 |
+
self.user_field_token_queue = value_tokens
|
| 1848 |
+
self.current_user_field = "language"
|
| 1849 |
+
# Inject first token
|
| 1850 |
+
mask[0, value_tokens[0]] = 0
|
| 1851 |
+
scores = scores + mask
|
| 1852 |
+
return scores
|
| 1853 |
+
|
| 1854 |
+
# Check if current token sequence is complete (allows newline)
|
| 1855 |
+
token_prefix = tuple(self.accumulated_token_ids)
|
| 1856 |
+
if token_prefix in self.language_prefix_tree and self.newline_token in self.language_prefix_tree[token_prefix]:
|
| 1857 |
+
# Complete language, allow newline
|
| 1858 |
+
if self.newline_token:
|
| 1859 |
+
mask[0, self.newline_token] = 0
|
| 1860 |
+
scores = scores + mask
|
| 1861 |
+
else:
|
| 1862 |
+
# Not complete, allow valid continuation tokens
|
| 1863 |
+
allowed = self._get_allowed_language_tokens()
|
| 1864 |
+
if allowed:
|
| 1865 |
+
for t in allowed:
|
| 1866 |
+
mask[0, t] = 0
|
| 1867 |
+
scores = scores + mask
|
| 1868 |
+
else:
|
| 1869 |
+
# No valid tokens found - force newline to end field
|
| 1870 |
+
if self.newline_token:
|
| 1871 |
+
mask[0, self.newline_token] = 0
|
| 1872 |
+
scores = scores + mask
|
| 1873 |
+
|
| 1874 |
elif self.state == FSMState.TIMESIG_VALUE:
|
| 1875 |
# Check if field is user-provided and we haven't started injecting yet
|
| 1876 |
if self.user_provided_metadata["timesignature"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
|
|
|
|
| 1919 |
self.position_in_state = 0
|
| 1920 |
self.accumulated_value = "" # Legacy, kept for compatibility
|
| 1921 |
self.accumulated_token_ids = [] # Reset token ID sequence for new field
|
| 1922 |
+
self.caption_after_newline = False # Reset caption newline tracking
|
| 1923 |
+
self.caption_token_count = 0 # Reset caption token count
|
| 1924 |
if self.debug:
|
| 1925 |
logger.debug(f"FSM transition: {old_state.name} -> {self.state.name}")
|
| 1926 |
|
|
|
|
| 2037 |
# Genres still uses string-based trie, so keep accumulated_value
|
| 2038 |
self.accumulated_value += token_str
|
| 2039 |
|
| 2040 |
+
elif self.state == FSMState.CAPTION_VALUE:
|
| 2041 |
+
# Track token count for 512 limit
|
| 2042 |
+
self.caption_token_count += 1
|
| 2043 |
+
|
| 2044 |
+
# Accumulate caption text
|
| 2045 |
+
self.accumulated_value += token_str
|
| 2046 |
+
|
| 2047 |
+
# Track if this token is a newline (for transition detection)
|
| 2048 |
+
if generated_token_id == self.newline_token:
|
| 2049 |
+
# Mark that we need to check next token for field transition
|
| 2050 |
+
self.caption_after_newline = True
|
| 2051 |
+
else:
|
| 2052 |
+
# Not a newline - if we were after newline and this is not space,
|
| 2053 |
+
# transition already happened in _process_single_sequence
|
| 2054 |
+
self.caption_after_newline = False
|
| 2055 |
+
|
| 2056 |
elif self.state == FSMState.KEYSCALE_VALUE:
|
| 2057 |
if generated_token_id == self.newline_token:
|
| 2058 |
# Newline ends the field
|
|
|
|
| 2068 |
self.accumulated_token_ids.append(generated_token_id)
|
| 2069 |
# Also update legacy accumulated_value for compatibility
|
| 2070 |
self.accumulated_value += token_str
|
| 2071 |
+
|
| 2072 |
+
elif self.state == FSMState.LANGUAGE_VALUE:
|
| 2073 |
+
if generated_token_id == self.newline_token:
|
| 2074 |
+
# Newline ends the field
|
| 2075 |
+
self._transition_to_next_state()
|
| 2076 |
+
if self.state in self.fixed_strings:
|
| 2077 |
+
return
|
| 2078 |
+
else:
|
| 2079 |
+
# Add token ID to sequence (for prefix tree lookup)
|
| 2080 |
+
self.accumulated_token_ids.append(generated_token_id)
|
| 2081 |
+
# Also update legacy accumulated_value for compatibility
|
| 2082 |
+
self.accumulated_value += token_str
|
| 2083 |
|
acestep/gradio_ui.py
CHANGED
|
@@ -8,6 +8,14 @@ import random
|
|
| 8 |
import glob
|
| 9 |
import gradio as gr
|
| 10 |
from typing import Callable, Optional, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None) -> gr.Blocks:
|
|
@@ -296,9 +304,9 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 296 |
# Determine initial task_type choices based on default model
|
| 297 |
default_model_lower = (default_model or "").lower()
|
| 298 |
if "turbo" in default_model_lower:
|
| 299 |
-
initial_task_choices =
|
| 300 |
else:
|
| 301 |
-
initial_task_choices =
|
| 302 |
|
| 303 |
with gr.Row():
|
| 304 |
with gr.Column(scale=2):
|
|
@@ -311,15 +319,14 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 311 |
with gr.Column(scale=8):
|
| 312 |
instruction_display_gen = gr.Textbox(
|
| 313 |
label="Instruction",
|
| 314 |
-
value=
|
| 315 |
interactive=False,
|
| 316 |
lines=1,
|
| 317 |
info="Instruction is automatically generated based on task type",
|
| 318 |
)
|
| 319 |
|
| 320 |
track_name = gr.Dropdown(
|
| 321 |
-
choices=
|
| 322 |
-
"keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"],
|
| 323 |
value=None,
|
| 324 |
label="Track Name",
|
| 325 |
info="Select track name for lego/extract tasks",
|
|
@@ -327,8 +334,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 327 |
)
|
| 328 |
|
| 329 |
complete_track_classes = gr.CheckboxGroup(
|
| 330 |
-
choices=
|
| 331 |
-
"keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"],
|
| 332 |
label="Track Names",
|
| 333 |
info="Select multiple track classes for complete task",
|
| 334 |
visible=False
|
|
@@ -410,8 +416,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 410 |
with gr.Accordion("⚙️ Optional Parameters", open=True):
|
| 411 |
with gr.Row():
|
| 412 |
vocal_language = gr.Dropdown(
|
| 413 |
-
choices=
|
| 414 |
-
value="
|
| 415 |
label="Vocal Language (optional)",
|
| 416 |
allow_custom_value=True,
|
| 417 |
info="use `unknown` for inst"
|
|
@@ -567,6 +573,20 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 567 |
scale=2,
|
| 568 |
)
|
| 569 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
with gr.Row():
|
| 571 |
audio_cover_strength = gr.Slider(
|
| 572 |
minimum=0.0,
|
|
@@ -625,6 +645,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 625 |
"lm_top_k": lm_top_k,
|
| 626 |
"lm_top_p": lm_top_p,
|
| 627 |
"lm_negative_prompt": lm_negative_prompt,
|
|
|
|
|
|
|
| 628 |
"repainting_group": repainting_group,
|
| 629 |
"repainting_start": repainting_start,
|
| 630 |
"repainting_end": repainting_end,
|
|
@@ -824,7 +846,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 824 |
gr.update(visible=False), # use_adg
|
| 825 |
gr.update(visible=False), # cfg_interval_start
|
| 826 |
gr.update(visible=False), # cfg_interval_end
|
| 827 |
-
gr.update(choices=
|
| 828 |
)
|
| 829 |
elif "base" in config_path_lower:
|
| 830 |
# Base model: max 100 steps, show CFG/ADG, show all task types
|
|
@@ -834,7 +856,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 834 |
gr.update(visible=True), # use_adg
|
| 835 |
gr.update(visible=True), # cfg_interval_start
|
| 836 |
gr.update(visible=True), # cfg_interval_end
|
| 837 |
-
gr.update(choices=
|
| 838 |
)
|
| 839 |
else:
|
| 840 |
# Default to turbo settings
|
|
@@ -844,7 +866,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 844 |
gr.update(visible=False),
|
| 845 |
gr.update(visible=False),
|
| 846 |
gr.update(visible=False),
|
| 847 |
-
gr.update(choices=
|
| 848 |
)
|
| 849 |
|
| 850 |
generation_section["config_path"].change(
|
|
@@ -965,6 +987,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 965 |
instruction_display_gen, audio_cover_strength, task_type,
|
| 966 |
use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
|
| 967 |
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
|
|
|
| 968 |
progress=gr.Progress(track_tqdm=True)
|
| 969 |
):
|
| 970 |
# If think is enabled (llm_dit mode), generate audio codes using LM first
|
|
@@ -1019,6 +1042,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1019 |
top_k=top_k_value,
|
| 1020 |
top_p=top_p_value,
|
| 1021 |
user_metadata=user_metadata_to_pass,
|
|
|
|
|
|
|
| 1022 |
)
|
| 1023 |
|
| 1024 |
# Store LM-generated metadata and audio codes for display
|
|
@@ -1076,14 +1101,18 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1076 |
metadata_lines = []
|
| 1077 |
if lm_generated_metadata.get('bpm'):
|
| 1078 |
metadata_lines.append(f"- **BPM:** {lm_generated_metadata['bpm']}")
|
| 1079 |
-
if lm_generated_metadata.get('
|
| 1080 |
-
metadata_lines.append(f"- **
|
| 1081 |
-
if lm_generated_metadata.get('timesignature'):
|
| 1082 |
-
metadata_lines.append(f"- **Time Signature:** {lm_generated_metadata['timesignature']}")
|
| 1083 |
if lm_generated_metadata.get('duration'):
|
| 1084 |
metadata_lines.append(f"- **Duration:** {lm_generated_metadata['duration']} seconds")
|
| 1085 |
if lm_generated_metadata.get('genres'):
|
| 1086 |
metadata_lines.append(f"- **Genres:** {lm_generated_metadata['genres']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1087 |
|
| 1088 |
if metadata_lines:
|
| 1089 |
metadata_section = "\n\n**🤖 LM-Generated Metadata:**\n" + "\n\n".join(metadata_lines)
|
|
@@ -1140,7 +1169,9 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1140 |
generation_section["lm_cfg_scale"],
|
| 1141 |
generation_section["lm_top_k"],
|
| 1142 |
generation_section["lm_top_p"],
|
| 1143 |
-
generation_section["lm_negative_prompt"]
|
|
|
|
|
|
|
| 1144 |
],
|
| 1145 |
outputs=[
|
| 1146 |
results_section["generated_audio_1"],
|
|
|
|
| 8 |
import glob
|
| 9 |
import gradio as gr
|
| 10 |
from typing import Callable, Optional, Tuple
|
| 11 |
+
from acestep.constants import (
|
| 12 |
+
VALID_LANGUAGES,
|
| 13 |
+
TRACK_NAMES,
|
| 14 |
+
TASK_TYPES,
|
| 15 |
+
TASK_TYPES_TURBO,
|
| 16 |
+
TASK_TYPES_BASE,
|
| 17 |
+
DEFAULT_DIT_INSTRUCTION,
|
| 18 |
+
)
|
| 19 |
|
| 20 |
|
| 21 |
def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None) -> gr.Blocks:
|
|
|
|
| 304 |
# Determine initial task_type choices based on default model
|
| 305 |
default_model_lower = (default_model or "").lower()
|
| 306 |
if "turbo" in default_model_lower:
|
| 307 |
+
initial_task_choices = TASK_TYPES_TURBO
|
| 308 |
else:
|
| 309 |
+
initial_task_choices = TASK_TYPES_BASE
|
| 310 |
|
| 311 |
with gr.Row():
|
| 312 |
with gr.Column(scale=2):
|
|
|
|
| 319 |
with gr.Column(scale=8):
|
| 320 |
instruction_display_gen = gr.Textbox(
|
| 321 |
label="Instruction",
|
| 322 |
+
value=DEFAULT_DIT_INSTRUCTION,
|
| 323 |
interactive=False,
|
| 324 |
lines=1,
|
| 325 |
info="Instruction is automatically generated based on task type",
|
| 326 |
)
|
| 327 |
|
| 328 |
track_name = gr.Dropdown(
|
| 329 |
+
choices=TRACK_NAMES,
|
|
|
|
| 330 |
value=None,
|
| 331 |
label="Track Name",
|
| 332 |
info="Select track name for lego/extract tasks",
|
|
|
|
| 334 |
)
|
| 335 |
|
| 336 |
complete_track_classes = gr.CheckboxGroup(
|
| 337 |
+
choices=TRACK_NAMES,
|
|
|
|
| 338 |
label="Track Names",
|
| 339 |
info="Select multiple track classes for complete task",
|
| 340 |
visible=False
|
|
|
|
| 416 |
with gr.Accordion("⚙️ Optional Parameters", open=True):
|
| 417 |
with gr.Row():
|
| 418 |
vocal_language = gr.Dropdown(
|
| 419 |
+
choices=VALID_LANGUAGES,
|
| 420 |
+
value="unknown",
|
| 421 |
label="Vocal Language (optional)",
|
| 422 |
allow_custom_value=True,
|
| 423 |
info="use `unknown` for inst"
|
|
|
|
| 573 |
scale=2,
|
| 574 |
)
|
| 575 |
|
| 576 |
+
with gr.Row():
|
| 577 |
+
use_cot_caption = gr.Checkbox(
|
| 578 |
+
label="CoT Caption",
|
| 579 |
+
value=True,
|
| 580 |
+
info="Generate caption in CoT (chain-of-thought)",
|
| 581 |
+
scale=1,
|
| 582 |
+
)
|
| 583 |
+
use_cot_language = gr.Checkbox(
|
| 584 |
+
label="CoT Language",
|
| 585 |
+
value=True,
|
| 586 |
+
info="Generate language in CoT (chain-of-thought)",
|
| 587 |
+
scale=1,
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
with gr.Row():
|
| 591 |
audio_cover_strength = gr.Slider(
|
| 592 |
minimum=0.0,
|
|
|
|
| 645 |
"lm_top_k": lm_top_k,
|
| 646 |
"lm_top_p": lm_top_p,
|
| 647 |
"lm_negative_prompt": lm_negative_prompt,
|
| 648 |
+
"use_cot_caption": use_cot_caption,
|
| 649 |
+
"use_cot_language": use_cot_language,
|
| 650 |
"repainting_group": repainting_group,
|
| 651 |
"repainting_start": repainting_start,
|
| 652 |
"repainting_end": repainting_end,
|
|
|
|
| 846 |
gr.update(visible=False), # use_adg
|
| 847 |
gr.update(visible=False), # cfg_interval_start
|
| 848 |
gr.update(visible=False), # cfg_interval_end
|
| 849 |
+
gr.update(choices=TASK_TYPES_TURBO), # task_type
|
| 850 |
)
|
| 851 |
elif "base" in config_path_lower:
|
| 852 |
# Base model: max 100 steps, show CFG/ADG, show all task types
|
|
|
|
| 856 |
gr.update(visible=True), # use_adg
|
| 857 |
gr.update(visible=True), # cfg_interval_start
|
| 858 |
gr.update(visible=True), # cfg_interval_end
|
| 859 |
+
gr.update(choices=TASK_TYPES_BASE), # task_type
|
| 860 |
)
|
| 861 |
else:
|
| 862 |
# Default to turbo settings
|
|
|
|
| 866 |
gr.update(visible=False),
|
| 867 |
gr.update(visible=False),
|
| 868 |
gr.update(visible=False),
|
| 869 |
+
gr.update(choices=TASK_TYPES_TURBO), # task_type
|
| 870 |
)
|
| 871 |
|
| 872 |
generation_section["config_path"].change(
|
|
|
|
| 987 |
instruction_display_gen, audio_cover_strength, task_type,
|
| 988 |
use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
|
| 989 |
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 990 |
+
use_cot_caption, use_cot_language,
|
| 991 |
progress=gr.Progress(track_tqdm=True)
|
| 992 |
):
|
| 993 |
# If think is enabled (llm_dit mode), generate audio codes using LM first
|
|
|
|
| 1042 |
top_k=top_k_value,
|
| 1043 |
top_p=top_p_value,
|
| 1044 |
user_metadata=user_metadata_to_pass,
|
| 1045 |
+
use_cot_caption=use_cot_caption,
|
| 1046 |
+
use_cot_language=use_cot_language,
|
| 1047 |
)
|
| 1048 |
|
| 1049 |
# Store LM-generated metadata and audio codes for display
|
|
|
|
| 1101 |
metadata_lines = []
|
| 1102 |
if lm_generated_metadata.get('bpm'):
|
| 1103 |
metadata_lines.append(f"- **BPM:** {lm_generated_metadata['bpm']}")
|
| 1104 |
+
if lm_generated_metadata.get('caption'):
|
| 1105 |
+
metadata_lines.append(f"- **User Query Rewritten Caption:** {lm_generated_metadata['caption']}")
|
|
|
|
|
|
|
| 1106 |
if lm_generated_metadata.get('duration'):
|
| 1107 |
metadata_lines.append(f"- **Duration:** {lm_generated_metadata['duration']} seconds")
|
| 1108 |
if lm_generated_metadata.get('genres'):
|
| 1109 |
metadata_lines.append(f"- **Genres:** {lm_generated_metadata['genres']}")
|
| 1110 |
+
if lm_generated_metadata.get('keyscale'):
|
| 1111 |
+
metadata_lines.append(f"- **KeyScale:** {lm_generated_metadata['keyscale']}")
|
| 1112 |
+
if lm_generated_metadata.get('language'):
|
| 1113 |
+
metadata_lines.append(f"- **Language:** {lm_generated_metadata['language']}")
|
| 1114 |
+
if lm_generated_metadata.get('timesignature'):
|
| 1115 |
+
metadata_lines.append(f"- **Time Signature:** {lm_generated_metadata['timesignature']}")
|
| 1116 |
|
| 1117 |
if metadata_lines:
|
| 1118 |
metadata_section = "\n\n**🤖 LM-Generated Metadata:**\n" + "\n\n".join(metadata_lines)
|
|
|
|
| 1169 |
generation_section["lm_cfg_scale"],
|
| 1170 |
generation_section["lm_top_k"],
|
| 1171 |
generation_section["lm_top_p"],
|
| 1172 |
+
generation_section["lm_negative_prompt"],
|
| 1173 |
+
generation_section["use_cot_caption"],
|
| 1174 |
+
generation_section["use_cot_language"]
|
| 1175 |
],
|
| 1176 |
outputs=[
|
| 1177 |
results_section["generated_audio_1"],
|
acestep/handler.py
CHANGED
|
@@ -23,6 +23,11 @@ import warnings
|
|
| 23 |
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
|
| 24 |
from transformers.generation.streamers import BaseStreamer
|
| 25 |
from diffusers.models import AutoencoderOobleck
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
warnings.filterwarnings("ignore")
|
|
@@ -519,10 +524,11 @@ class AceStepHandler:
|
|
| 519 |
Args:
|
| 520 |
task: Task name (e.g., text2music, cover, repaint); kept for logging/future branching.
|
| 521 |
instruction: Instruction text; default fallback matches service_generate behavior.
|
| 522 |
-
caption: Caption string.
|
| 523 |
lyrics: Lyrics string.
|
| 524 |
metas: Metadata (str or dict); follows _parse_metas formatting.
|
| 525 |
-
|
|
|
|
| 526 |
|
| 527 |
Returns:
|
| 528 |
(caption_input_text, lyrics_input_text)
|
|
@@ -533,18 +539,45 @@ class AceStepHandler:
|
|
| 533 |
instruction=None,
|
| 534 |
caption="A calm piano melody",
|
| 535 |
lyrics="la la la",
|
| 536 |
-
metas={"bpm": 90, "duration": 45},
|
| 537 |
vocal_language="en",
|
| 538 |
)
|
| 539 |
"""
|
| 540 |
# Align instruction formatting with _prepare_batch
|
| 541 |
-
final_instruction = instruction or
|
| 542 |
if not final_instruction.endswith(":"):
|
| 543 |
final_instruction = final_instruction + ":"
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
parsed_meta = self._parse_metas([metas])[0]
|
| 546 |
-
caption_input = SFT_GEN_PROMPT.format(final_instruction,
|
| 547 |
-
lyrics_input = f"# Languages\n{
|
| 548 |
return caption_input, lyrics_input
|
| 549 |
|
| 550 |
def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
@@ -679,41 +712,36 @@ class AceStepHandler:
|
|
| 679 |
track_name: Optional[str] = None,
|
| 680 |
complete_track_classes: Optional[List[str]] = None
|
| 681 |
) -> str:
|
| 682 |
-
TRACK_NAMES = [
|
| 683 |
-
"woodwinds", "brass", "fx", "synth", "strings", "percussion",
|
| 684 |
-
"keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
|
| 685 |
-
]
|
| 686 |
-
|
| 687 |
if task_type == "text2music":
|
| 688 |
-
return "
|
| 689 |
elif task_type == "repaint":
|
| 690 |
-
return "
|
| 691 |
elif task_type == "cover":
|
| 692 |
-
return "
|
| 693 |
elif task_type == "extract":
|
| 694 |
if track_name:
|
| 695 |
# Convert to uppercase
|
| 696 |
track_name_upper = track_name.upper()
|
| 697 |
-
return
|
| 698 |
else:
|
| 699 |
-
return "
|
| 700 |
elif task_type == "lego":
|
| 701 |
if track_name:
|
| 702 |
# Convert to uppercase
|
| 703 |
track_name_upper = track_name.upper()
|
| 704 |
-
return
|
| 705 |
else:
|
| 706 |
-
return "
|
| 707 |
elif task_type == "complete":
|
| 708 |
if complete_track_classes and len(complete_track_classes) > 0:
|
| 709 |
# Convert to uppercase and join with " | "
|
| 710 |
track_classes_upper = [t.upper() for t in complete_track_classes]
|
| 711 |
complete_track_classes_str = " | ".join(track_classes_upper)
|
| 712 |
-
return
|
| 713 |
else:
|
| 714 |
-
return "
|
| 715 |
else:
|
| 716 |
-
return "
|
| 717 |
|
| 718 |
def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]:
|
| 719 |
if audio_file is None:
|
|
@@ -1247,7 +1275,7 @@ class AceStepHandler:
|
|
| 1247 |
# Process instructions early so we can use them for task type detection
|
| 1248 |
# Use custom instructions if provided, otherwise use default
|
| 1249 |
if instructions is None:
|
| 1250 |
-
instructions = [
|
| 1251 |
|
| 1252 |
# Ensure instructions list has the same length as batch_size
|
| 1253 |
if len(instructions) != batch_size:
|
|
@@ -1257,7 +1285,7 @@ class AceStepHandler:
|
|
| 1257 |
# Pad or truncate to match batch_size
|
| 1258 |
instructions = instructions[:batch_size]
|
| 1259 |
while len(instructions) < batch_size:
|
| 1260 |
-
instructions.append(
|
| 1261 |
|
| 1262 |
# Generate chunk_masks and spans based on repainting parameters
|
| 1263 |
# Also determine if this is a cover task (target audio provided without repainting)
|
|
@@ -1415,13 +1443,29 @@ class AceStepHandler:
|
|
| 1415 |
|
| 1416 |
for i in range(batch_size):
|
| 1417 |
# Use custom instruction for this batch item
|
| 1418 |
-
instruction = instructions[i] if i < len(instructions) else
|
| 1419 |
# Ensure instruction ends with ":"
|
| 1420 |
if not instruction.endswith(":"):
|
| 1421 |
instruction = instruction + ":"
|
| 1422 |
|
| 1423 |
-
#
|
| 1424 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1425 |
|
| 1426 |
# Tokenize text
|
| 1427 |
text_inputs_dict = self.text_tokenizer(
|
|
@@ -1434,8 +1478,8 @@ class AceStepHandler:
|
|
| 1434 |
text_token_ids = text_inputs_dict.input_ids[0]
|
| 1435 |
text_attention_mask = text_inputs_dict.attention_mask[0].bool()
|
| 1436 |
|
| 1437 |
-
# Format and tokenize lyrics
|
| 1438 |
-
lyrics_text = f"# Languages\n{
|
| 1439 |
lyrics_inputs_dict = self.text_tokenizer(
|
| 1440 |
lyrics_text,
|
| 1441 |
padding="longest",
|
|
@@ -1495,10 +1539,17 @@ class AceStepHandler:
|
|
| 1495 |
non_cover_text_attention_masks = []
|
| 1496 |
for i in range(batch_size):
|
| 1497 |
# Use custom instruction for this batch item
|
| 1498 |
-
instruction =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1499 |
|
| 1500 |
-
# Format text prompt with custom instruction
|
| 1501 |
-
text_prompt = SFT_GEN_PROMPT.format(instruction,
|
| 1502 |
|
| 1503 |
# Tokenize text
|
| 1504 |
text_inputs_dict = self.text_tokenizer(
|
|
@@ -1991,7 +2042,7 @@ class AceStepHandler:
|
|
| 1991 |
audio_code_string: Union[str, List[str]] = "",
|
| 1992 |
repainting_start: float = 0.0,
|
| 1993 |
repainting_end: Optional[float] = None,
|
| 1994 |
-
instruction: str =
|
| 1995 |
audio_cover_strength: float = 1.0,
|
| 1996 |
task_type: str = "text2music",
|
| 1997 |
use_adg: bool = False,
|
|
@@ -2030,7 +2081,7 @@ class AceStepHandler:
|
|
| 2030 |
# User has provided audio codes, switch to cover task
|
| 2031 |
task_type = "cover"
|
| 2032 |
# Update instruction for cover task
|
| 2033 |
-
instruction = "
|
| 2034 |
|
| 2035 |
logger.info("[generate_music] Starting generation...")
|
| 2036 |
if progress:
|
|
|
|
| 23 |
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
|
| 24 |
from transformers.generation.streamers import BaseStreamer
|
| 25 |
from diffusers.models import AutoencoderOobleck
|
| 26 |
+
from acestep.constants import (
|
| 27 |
+
TASK_INSTRUCTIONS,
|
| 28 |
+
TRACK_NAMES,
|
| 29 |
+
DEFAULT_DIT_INSTRUCTION,
|
| 30 |
+
)
|
| 31 |
|
| 32 |
|
| 33 |
warnings.filterwarnings("ignore")
|
|
|
|
| 524 |
Args:
|
| 525 |
task: Task name (e.g., text2music, cover, repaint); kept for logging/future branching.
|
| 526 |
instruction: Instruction text; default fallback matches service_generate behavior.
|
| 527 |
+
caption: Caption string (fallback if not in metas).
|
| 528 |
lyrics: Lyrics string.
|
| 529 |
metas: Metadata (str or dict); follows _parse_metas formatting.
|
| 530 |
+
May contain 'caption' and 'language' fields from LM CoT output.
|
| 531 |
+
vocal_language: Language code for lyrics section (fallback if not in metas).
|
| 532 |
|
| 533 |
Returns:
|
| 534 |
(caption_input_text, lyrics_input_text)
|
|
|
|
| 539 |
instruction=None,
|
| 540 |
caption="A calm piano melody",
|
| 541 |
lyrics="la la la",
|
| 542 |
+
metas={"bpm": 90, "duration": 45, "caption": "LM generated caption", "language": "en"},
|
| 543 |
vocal_language="en",
|
| 544 |
)
|
| 545 |
"""
|
| 546 |
# Align instruction formatting with _prepare_batch
|
| 547 |
+
final_instruction = instruction or DEFAULT_DIT_INSTRUCTION
|
| 548 |
if not final_instruction.endswith(":"):
|
| 549 |
final_instruction = final_instruction + ":"
|
| 550 |
|
| 551 |
+
# Extract caption and language from metas if available (from LM CoT output)
|
| 552 |
+
# Fallback to user-provided values if not in metas
|
| 553 |
+
actual_caption = caption
|
| 554 |
+
actual_language = vocal_language
|
| 555 |
+
|
| 556 |
+
if metas is not None:
|
| 557 |
+
# Parse metas to dict if it's a string
|
| 558 |
+
if isinstance(metas, str):
|
| 559 |
+
# Try to parse as dict-like string or use as-is
|
| 560 |
+
parsed_metas = self._parse_metas([metas])
|
| 561 |
+
if parsed_metas and isinstance(parsed_metas[0], dict):
|
| 562 |
+
meta_dict = parsed_metas[0]
|
| 563 |
+
else:
|
| 564 |
+
meta_dict = {}
|
| 565 |
+
elif isinstance(metas, dict):
|
| 566 |
+
meta_dict = metas
|
| 567 |
+
else:
|
| 568 |
+
meta_dict = {}
|
| 569 |
+
|
| 570 |
+
# Extract caption from metas if available
|
| 571 |
+
if 'caption' in meta_dict and meta_dict['caption']:
|
| 572 |
+
actual_caption = str(meta_dict['caption'])
|
| 573 |
+
|
| 574 |
+
# Extract language from metas if available
|
| 575 |
+
if 'language' in meta_dict and meta_dict['language']:
|
| 576 |
+
actual_language = str(meta_dict['language'])
|
| 577 |
+
|
| 578 |
parsed_meta = self._parse_metas([metas])[0]
|
| 579 |
+
caption_input = SFT_GEN_PROMPT.format(final_instruction, actual_caption, parsed_meta)
|
| 580 |
+
lyrics_input = f"# Languages\n{actual_language}\n\n# Lyric\n{lyrics}<|endoftext|>"
|
| 581 |
return caption_input, lyrics_input
|
| 582 |
|
| 583 |
def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
| 712 |
track_name: Optional[str] = None,
|
| 713 |
complete_track_classes: Optional[List[str]] = None
|
| 714 |
) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 715 |
if task_type == "text2music":
|
| 716 |
+
return TASK_INSTRUCTIONS["text2music"]
|
| 717 |
elif task_type == "repaint":
|
| 718 |
+
return TASK_INSTRUCTIONS["repaint"]
|
| 719 |
elif task_type == "cover":
|
| 720 |
+
return TASK_INSTRUCTIONS["cover"]
|
| 721 |
elif task_type == "extract":
|
| 722 |
if track_name:
|
| 723 |
# Convert to uppercase
|
| 724 |
track_name_upper = track_name.upper()
|
| 725 |
+
return TASK_INSTRUCTIONS["extract"].format(TRACK_NAME=track_name_upper)
|
| 726 |
else:
|
| 727 |
+
return TASK_INSTRUCTIONS["extract_default"]
|
| 728 |
elif task_type == "lego":
|
| 729 |
if track_name:
|
| 730 |
# Convert to uppercase
|
| 731 |
track_name_upper = track_name.upper()
|
| 732 |
+
return TASK_INSTRUCTIONS["lego"].format(TRACK_NAME=track_name_upper)
|
| 733 |
else:
|
| 734 |
+
return TASK_INSTRUCTIONS["lego_default"]
|
| 735 |
elif task_type == "complete":
|
| 736 |
if complete_track_classes and len(complete_track_classes) > 0:
|
| 737 |
# Convert to uppercase and join with " | "
|
| 738 |
track_classes_upper = [t.upper() for t in complete_track_classes]
|
| 739 |
complete_track_classes_str = " | ".join(track_classes_upper)
|
| 740 |
+
return TASK_INSTRUCTIONS["complete"].format(TRACK_CLASSES=complete_track_classes_str)
|
| 741 |
else:
|
| 742 |
+
return TASK_INSTRUCTIONS["complete_default"]
|
| 743 |
else:
|
| 744 |
+
return TASK_INSTRUCTIONS["text2music"]
|
| 745 |
|
| 746 |
def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]:
|
| 747 |
if audio_file is None:
|
|
|
|
| 1275 |
# Process instructions early so we can use them for task type detection
|
| 1276 |
# Use custom instructions if provided, otherwise use default
|
| 1277 |
if instructions is None:
|
| 1278 |
+
instructions = [DEFAULT_DIT_INSTRUCTION] * batch_size
|
| 1279 |
|
| 1280 |
# Ensure instructions list has the same length as batch_size
|
| 1281 |
if len(instructions) != batch_size:
|
|
|
|
| 1285 |
# Pad or truncate to match batch_size
|
| 1286 |
instructions = instructions[:batch_size]
|
| 1287 |
while len(instructions) < batch_size:
|
| 1288 |
+
instructions.append(DEFAULT_DIT_INSTRUCTION)
|
| 1289 |
|
| 1290 |
# Generate chunk_masks and spans based on repainting parameters
|
| 1291 |
# Also determine if this is a cover task (target audio provided without repainting)
|
|
|
|
| 1443 |
|
| 1444 |
for i in range(batch_size):
|
| 1445 |
# Use custom instruction for this batch item
|
| 1446 |
+
instruction = instructions[i] if i < len(instructions) else DEFAULT_DIT_INSTRUCTION
|
| 1447 |
# Ensure instruction ends with ":"
|
| 1448 |
if not instruction.endswith(":"):
|
| 1449 |
instruction = instruction + ":"
|
| 1450 |
|
| 1451 |
+
# Extract caption and language from metas if available (from LM CoT output)
|
| 1452 |
+
# Fallback to user-provided values if not in metas
|
| 1453 |
+
actual_caption = captions[i]
|
| 1454 |
+
actual_language = vocal_languages[i]
|
| 1455 |
+
|
| 1456 |
+
# Check if metas contains caption/language from LM CoT
|
| 1457 |
+
if i < len(parsed_metas) and parsed_metas[i]:
|
| 1458 |
+
meta_dict = parsed_metas[i]
|
| 1459 |
+
if isinstance(meta_dict, dict):
|
| 1460 |
+
# Extract caption from metas if available
|
| 1461 |
+
if 'caption' in meta_dict and meta_dict['caption']:
|
| 1462 |
+
actual_caption = str(meta_dict['caption'])
|
| 1463 |
+
# Extract language from metas if available
|
| 1464 |
+
if 'language' in meta_dict and meta_dict['language']:
|
| 1465 |
+
actual_language = str(meta_dict['language'])
|
| 1466 |
+
|
| 1467 |
+
# Format text prompt with custom instruction (using LM-generated caption if available)
|
| 1468 |
+
text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
|
| 1469 |
|
| 1470 |
# Tokenize text
|
| 1471 |
text_inputs_dict = self.text_tokenizer(
|
|
|
|
| 1478 |
text_token_ids = text_inputs_dict.input_ids[0]
|
| 1479 |
text_attention_mask = text_inputs_dict.attention_mask[0].bool()
|
| 1480 |
|
| 1481 |
+
# Format and tokenize lyrics (using LM-generated language if available)
|
| 1482 |
+
lyrics_text = f"# Languages\n{actual_language}\n\n# Lyric\n{lyrics[i]}<|endoftext|>"
|
| 1483 |
lyrics_inputs_dict = self.text_tokenizer(
|
| 1484 |
lyrics_text,
|
| 1485 |
padding="longest",
|
|
|
|
| 1539 |
non_cover_text_attention_masks = []
|
| 1540 |
for i in range(batch_size):
|
| 1541 |
# Use custom instruction for this batch item
|
| 1542 |
+
instruction = DEFAULT_DIT_INSTRUCTION
|
| 1543 |
+
|
| 1544 |
+
# Extract caption from metas if available (from LM CoT output)
|
| 1545 |
+
actual_caption = captions[i]
|
| 1546 |
+
if i < len(parsed_metas) and parsed_metas[i]:
|
| 1547 |
+
meta_dict = parsed_metas[i]
|
| 1548 |
+
if isinstance(meta_dict, dict) and 'caption' in meta_dict and meta_dict['caption']:
|
| 1549 |
+
actual_caption = str(meta_dict['caption'])
|
| 1550 |
|
| 1551 |
+
# Format text prompt with custom instruction (using LM-generated caption if available)
|
| 1552 |
+
text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
|
| 1553 |
|
| 1554 |
# Tokenize text
|
| 1555 |
text_inputs_dict = self.text_tokenizer(
|
|
|
|
| 2042 |
audio_code_string: Union[str, List[str]] = "",
|
| 2043 |
repainting_start: float = 0.0,
|
| 2044 |
repainting_end: Optional[float] = None,
|
| 2045 |
+
instruction: str = DEFAULT_DIT_INSTRUCTION,
|
| 2046 |
audio_cover_strength: float = 1.0,
|
| 2047 |
task_type: str = "text2music",
|
| 2048 |
use_adg: bool = False,
|
|
|
|
| 2081 |
# User has provided audio codes, switch to cover task
|
| 2082 |
task_type = "cover"
|
| 2083 |
# Update instruction for cover task
|
| 2084 |
+
instruction = TASK_INSTRUCTIONS["cover"]
|
| 2085 |
|
| 2086 |
logger.info("[generate_music] Starting generation...")
|
| 2087 |
if progress:
|
acestep/llm_inference.py
CHANGED
|
@@ -17,6 +17,7 @@ from transformers.generation.logits_process import (
|
|
| 17 |
RepetitionPenaltyLogitsProcessor,
|
| 18 |
)
|
| 19 |
from acestep.constrained_logits_processor import MetadataConstrainedLogitsProcessor
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
class LLMHandler:
|
|
@@ -244,6 +245,8 @@ class LLMHandler:
|
|
| 244 |
target_duration: Optional[float] = None,
|
| 245 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 246 |
stop_at_reasoning: bool = False,
|
|
|
|
|
|
|
| 247 |
) -> str:
|
| 248 |
"""Shared vllm path: accept prebuilt formatted prompt and return text."""
|
| 249 |
from nanovllm import SamplingParams
|
|
@@ -265,6 +268,9 @@ class LLMHandler:
|
|
| 265 |
# Always call set_user_metadata to ensure previous settings are cleared if None
|
| 266 |
self.constrained_processor.set_user_metadata(user_metadata)
|
| 267 |
self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
constrained_processor = self.constrained_processor
|
| 270 |
|
|
@@ -318,6 +324,8 @@ class LLMHandler:
|
|
| 318 |
target_duration: Optional[float] = None,
|
| 319 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 320 |
stop_at_reasoning: bool = False,
|
|
|
|
|
|
|
| 321 |
) -> str:
|
| 322 |
"""Shared PyTorch path: accept prebuilt formatted prompt and return text."""
|
| 323 |
inputs = self.llm_tokenizer(
|
|
@@ -338,6 +346,9 @@ class LLMHandler:
|
|
| 338 |
# Always call set_user_metadata to ensure previous settings are cleared if None
|
| 339 |
self.constrained_processor.set_user_metadata(user_metadata)
|
| 340 |
self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
|
|
|
|
|
|
|
|
|
|
| 341 |
|
| 342 |
constrained_processor = self.constrained_processor
|
| 343 |
|
|
@@ -472,6 +483,8 @@ class LLMHandler:
|
|
| 472 |
constrained_decoding_debug: bool = False,
|
| 473 |
target_duration: Optional[float] = None,
|
| 474 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
|
|
|
|
|
|
| 475 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 476 |
"""Feishu-compatible LM generation.
|
| 477 |
|
|
@@ -483,6 +496,8 @@ class LLMHandler:
|
|
| 483 |
5 codes = 1 second. If specified, blocks EOS until target reached.
|
| 484 |
user_metadata: User-provided metadata fields (e.g. bpm/duration/keyscale/timesignature).
|
| 485 |
If specified, constrained decoding will inject these values directly.
|
|
|
|
|
|
|
| 486 |
"""
|
| 487 |
infer_type = (infer_type or "").strip().lower()
|
| 488 |
if infer_type not in {"dit", "llm_dit"}:
|
|
@@ -509,6 +524,8 @@ class LLMHandler:
|
|
| 509 |
"repetition_penalty": repetition_penalty,
|
| 510 |
"target_duration": target_duration,
|
| 511 |
"user_metadata": user_metadata,
|
|
|
|
|
|
|
| 512 |
},
|
| 513 |
use_constrained_decoding=use_constrained_decoding,
|
| 514 |
constrained_decoding_debug=constrained_decoding_debug,
|
|
@@ -540,7 +557,7 @@ class LLMHandler:
|
|
| 540 |
prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
|
| 541 |
return self.llm_tokenizer.apply_chat_template(
|
| 542 |
[
|
| 543 |
-
{"role": "system", "content": "# Instruction\
|
| 544 |
{"role": "user", "content": prompt},
|
| 545 |
],
|
| 546 |
tokenize=False,
|
|
@@ -591,6 +608,8 @@ class LLMHandler:
|
|
| 591 |
repetition_penalty = cfg.get("repetition_penalty", 1.0)
|
| 592 |
target_duration = cfg.get("target_duration")
|
| 593 |
user_metadata = cfg.get("user_metadata") # User-provided metadata fields
|
|
|
|
|
|
|
| 594 |
|
| 595 |
try:
|
| 596 |
if self.llm_backend == "vllm":
|
|
@@ -607,6 +626,8 @@ class LLMHandler:
|
|
| 607 |
target_duration=target_duration,
|
| 608 |
user_metadata=user_metadata,
|
| 609 |
stop_at_reasoning=stop_at_reasoning,
|
|
|
|
|
|
|
| 610 |
)
|
| 611 |
return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
|
| 612 |
|
|
@@ -624,6 +645,8 @@ class LLMHandler:
|
|
| 624 |
target_duration=target_duration,
|
| 625 |
user_metadata=user_metadata,
|
| 626 |
stop_at_reasoning=stop_at_reasoning,
|
|
|
|
|
|
|
| 627 |
)
|
| 628 |
return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}"
|
| 629 |
|
|
@@ -928,9 +951,11 @@ class LLMHandler:
|
|
| 928 |
Expected format:
|
| 929 |
<think>
|
| 930 |
bpm: 73
|
|
|
|
| 931 |
duration: 273
|
| 932 |
genres: Chinese folk
|
| 933 |
keyscale: G major
|
|
|
|
| 934 |
timesignature: 4
|
| 935 |
</think>
|
| 936 |
|
|
@@ -973,32 +998,69 @@ class LLMHandler:
|
|
| 973 |
lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text
|
| 974 |
reasoning_text = lines_before_codes.strip()
|
| 975 |
|
| 976 |
-
# Parse metadata fields
|
| 977 |
if reasoning_text:
|
| 978 |
-
|
| 979 |
-
|
| 980 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 981 |
parts = line.split(':', 1)
|
| 982 |
if len(parts) == 2:
|
| 983 |
-
|
| 984 |
-
value
|
| 985 |
-
|
| 986 |
-
if
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
| 995 |
-
metadata['duration'] = value
|
| 996 |
-
elif key == 'genres':
|
| 997 |
-
metadata['genres'] = value
|
| 998 |
-
elif key == 'keyscale':
|
| 999 |
-
metadata['keyscale'] = value
|
| 1000 |
-
elif key == 'timesignature':
|
| 1001 |
-
metadata['timesignature'] = value
|
| 1002 |
|
| 1003 |
return metadata, audio_codes
|
| 1004 |
|
|
|
|
| 17 |
RepetitionPenaltyLogitsProcessor,
|
| 18 |
)
|
| 19 |
from acestep.constrained_logits_processor import MetadataConstrainedLogitsProcessor
|
| 20 |
+
from acestep.constants import DEFAULT_LM_INSTRUCTION
|
| 21 |
|
| 22 |
|
| 23 |
class LLMHandler:
|
|
|
|
| 245 |
target_duration: Optional[float] = None,
|
| 246 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 247 |
stop_at_reasoning: bool = False,
|
| 248 |
+
skip_caption: bool = False,
|
| 249 |
+
skip_language: bool = False,
|
| 250 |
) -> str:
|
| 251 |
"""Shared vllm path: accept prebuilt formatted prompt and return text."""
|
| 252 |
from nanovllm import SamplingParams
|
|
|
|
| 268 |
# Always call set_user_metadata to ensure previous settings are cleared if None
|
| 269 |
self.constrained_processor.set_user_metadata(user_metadata)
|
| 270 |
self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
|
| 271 |
+
# Set skip_caption and skip_language based on flags
|
| 272 |
+
self.constrained_processor.set_skip_caption(skip_caption)
|
| 273 |
+
self.constrained_processor.set_skip_language(skip_language)
|
| 274 |
|
| 275 |
constrained_processor = self.constrained_processor
|
| 276 |
|
|
|
|
| 324 |
target_duration: Optional[float] = None,
|
| 325 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 326 |
stop_at_reasoning: bool = False,
|
| 327 |
+
skip_caption: bool = False,
|
| 328 |
+
skip_language: bool = False,
|
| 329 |
) -> str:
|
| 330 |
"""Shared PyTorch path: accept prebuilt formatted prompt and return text."""
|
| 331 |
inputs = self.llm_tokenizer(
|
|
|
|
| 346 |
# Always call set_user_metadata to ensure previous settings are cleared if None
|
| 347 |
self.constrained_processor.set_user_metadata(user_metadata)
|
| 348 |
self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
|
| 349 |
+
# Set skip_caption and skip_language based on flags
|
| 350 |
+
self.constrained_processor.set_skip_caption(skip_caption)
|
| 351 |
+
self.constrained_processor.set_skip_language(skip_language)
|
| 352 |
|
| 353 |
constrained_processor = self.constrained_processor
|
| 354 |
|
|
|
|
| 483 |
constrained_decoding_debug: bool = False,
|
| 484 |
target_duration: Optional[float] = None,
|
| 485 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 486 |
+
use_cot_caption: bool = True,
|
| 487 |
+
use_cot_language: bool = True,
|
| 488 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 489 |
"""Feishu-compatible LM generation.
|
| 490 |
|
|
|
|
| 496 |
5 codes = 1 second. If specified, blocks EOS until target reached.
|
| 497 |
user_metadata: User-provided metadata fields (e.g. bpm/duration/keyscale/timesignature).
|
| 498 |
If specified, constrained decoding will inject these values directly.
|
| 499 |
+
use_cot_caption: Whether to generate caption in CoT (default True).
|
| 500 |
+
use_cot_language: Whether to generate language in CoT (default True).
|
| 501 |
"""
|
| 502 |
infer_type = (infer_type or "").strip().lower()
|
| 503 |
if infer_type not in {"dit", "llm_dit"}:
|
|
|
|
| 524 |
"repetition_penalty": repetition_penalty,
|
| 525 |
"target_duration": target_duration,
|
| 526 |
"user_metadata": user_metadata,
|
| 527 |
+
"skip_caption": not use_cot_caption,
|
| 528 |
+
"skip_language": not use_cot_language,
|
| 529 |
},
|
| 530 |
use_constrained_decoding=use_constrained_decoding,
|
| 531 |
constrained_decoding_debug=constrained_decoding_debug,
|
|
|
|
| 557 |
prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
|
| 558 |
return self.llm_tokenizer.apply_chat_template(
|
| 559 |
[
|
| 560 |
+
{"role": "system", "content": f"# Instruction\n{DEFAULT_LM_INSTRUCTION}\n\n"},
|
| 561 |
{"role": "user", "content": prompt},
|
| 562 |
],
|
| 563 |
tokenize=False,
|
|
|
|
| 608 |
repetition_penalty = cfg.get("repetition_penalty", 1.0)
|
| 609 |
target_duration = cfg.get("target_duration")
|
| 610 |
user_metadata = cfg.get("user_metadata") # User-provided metadata fields
|
| 611 |
+
skip_caption = cfg.get("skip_caption", False) # Skip caption generation in CoT
|
| 612 |
+
skip_language = cfg.get("skip_language", False) # Skip language generation in CoT
|
| 613 |
|
| 614 |
try:
|
| 615 |
if self.llm_backend == "vllm":
|
|
|
|
| 626 |
target_duration=target_duration,
|
| 627 |
user_metadata=user_metadata,
|
| 628 |
stop_at_reasoning=stop_at_reasoning,
|
| 629 |
+
skip_caption=skip_caption,
|
| 630 |
+
skip_language=skip_language,
|
| 631 |
)
|
| 632 |
return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
|
| 633 |
|
|
|
|
| 645 |
target_duration=target_duration,
|
| 646 |
user_metadata=user_metadata,
|
| 647 |
stop_at_reasoning=stop_at_reasoning,
|
| 648 |
+
skip_caption=skip_caption,
|
| 649 |
+
skip_language=skip_language,
|
| 650 |
)
|
| 651 |
return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}"
|
| 652 |
|
|
|
|
| 951 |
Expected format:
|
| 952 |
<think>
|
| 953 |
bpm: 73
|
| 954 |
+
caption: A calm piano melody
|
| 955 |
duration: 273
|
| 956 |
genres: Chinese folk
|
| 957 |
keyscale: G major
|
| 958 |
+
language: en
|
| 959 |
timesignature: 4
|
| 960 |
</think>
|
| 961 |
|
|
|
|
| 998 |
lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text
|
| 999 |
reasoning_text = lines_before_codes.strip()
|
| 1000 |
|
| 1001 |
+
# Parse metadata fields with YAML multi-line value support
|
| 1002 |
if reasoning_text:
|
| 1003 |
+
lines = reasoning_text.split('\n')
|
| 1004 |
+
current_key = None
|
| 1005 |
+
current_value_lines = []
|
| 1006 |
+
|
| 1007 |
+
def save_current_field():
|
| 1008 |
+
"""Save the accumulated field value"""
|
| 1009 |
+
nonlocal current_key, current_value_lines
|
| 1010 |
+
if current_key and current_value_lines:
|
| 1011 |
+
# Join multi-line value
|
| 1012 |
+
value = '\n'.join(current_value_lines)
|
| 1013 |
+
|
| 1014 |
+
if current_key == 'bpm':
|
| 1015 |
+
try:
|
| 1016 |
+
metadata['bpm'] = int(value.strip())
|
| 1017 |
+
except:
|
| 1018 |
+
metadata['bpm'] = value.strip()
|
| 1019 |
+
elif current_key == 'caption':
|
| 1020 |
+
# Post-process caption to remove YAML multi-line formatting
|
| 1021 |
+
metadata['caption'] = MetadataConstrainedLogitsProcessor.postprocess_caption(value)
|
| 1022 |
+
elif current_key == 'duration':
|
| 1023 |
+
try:
|
| 1024 |
+
metadata['duration'] = int(value.strip())
|
| 1025 |
+
except:
|
| 1026 |
+
metadata['duration'] = value.strip()
|
| 1027 |
+
elif current_key == 'genres':
|
| 1028 |
+
metadata['genres'] = value.strip()
|
| 1029 |
+
elif current_key == 'keyscale':
|
| 1030 |
+
metadata['keyscale'] = value.strip()
|
| 1031 |
+
elif current_key == 'language':
|
| 1032 |
+
metadata['language'] = value.strip()
|
| 1033 |
+
elif current_key == 'timesignature':
|
| 1034 |
+
metadata['timesignature'] = value.strip()
|
| 1035 |
+
|
| 1036 |
+
current_key = None
|
| 1037 |
+
current_value_lines = []
|
| 1038 |
+
|
| 1039 |
+
for line in lines:
|
| 1040 |
+
# Skip lines starting with '<' (tags)
|
| 1041 |
+
if line.strip().startswith('<'):
|
| 1042 |
+
continue
|
| 1043 |
+
|
| 1044 |
+
# Check if this is a new field (no leading spaces and contains ':')
|
| 1045 |
+
if line and not line[0].isspace() and ':' in line:
|
| 1046 |
+
# Save previous field if any
|
| 1047 |
+
save_current_field()
|
| 1048 |
+
|
| 1049 |
+
# Parse new field
|
| 1050 |
parts = line.split(':', 1)
|
| 1051 |
if len(parts) == 2:
|
| 1052 |
+
current_key = parts[0].strip().lower()
|
| 1053 |
+
# First line of value (after colon)
|
| 1054 |
+
first_value = parts[1]
|
| 1055 |
+
if first_value.strip():
|
| 1056 |
+
current_value_lines.append(first_value)
|
| 1057 |
+
elif line.startswith(' ') or line.startswith('\t'):
|
| 1058 |
+
# Continuation line (YAML multi-line value)
|
| 1059 |
+
if current_key:
|
| 1060 |
+
current_value_lines.append(line)
|
| 1061 |
+
|
| 1062 |
+
# Don't forget to save the last field
|
| 1063 |
+
save_current_field()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1064 |
|
| 1065 |
return metadata, audio_codes
|
| 1066 |
|