Spaces:
Running
on
A100
Running
on
A100
add MetadataConstrainedLogitsProcessor
Browse files- .gitignore +2 -0
- API.md +41 -0
- acestep/genres_vocab.txt +0 -0
- acestep/llm_inference.py +647 -57
.gitignore
CHANGED
|
@@ -218,3 +218,5 @@ checkpoints.7z
|
|
| 218 |
README_old.md
|
| 219 |
discord_bot/
|
| 220 |
feishu_bot/
|
|
|
|
|
|
|
|
|
| 218 |
README_old.md
|
| 219 |
discord_bot/
|
| 220 |
feishu_bot/
|
| 221 |
+
tmp*
|
| 222 |
+
torchinductor_root/
|
API.md
CHANGED
|
@@ -61,6 +61,22 @@ Suitable for passing only text parameters, or referencing audio file paths that
|
|
| 61 |
| `seed` | int | `-1` | Specify seed (when use_random_seed=false) |
|
| 62 |
| `batch_size` | int | null | Batch generation count |
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
**Edit/Reference Audio Parameters** (requires absolute path on server):
|
| 65 |
|
| 66 |
| Parameter Name | Type | Default | Description |
|
|
@@ -108,6 +124,31 @@ curl -X POST http://localhost:8001/v1/music/generate \
|
|
| 108 |
}'
|
| 109 |
```
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
> Note: If you use `curl -d` but **forget** to add `-H 'Content-Type: application/json'`, curl will default to sending `application/x-www-form-urlencoded`, and older server versions will return 415.
|
| 112 |
|
| 113 |
**Form Method (no file upload, application/x-www-form-urlencoded)**:
|
|
|
|
| 61 |
| `seed` | int | `-1` | Specify seed (when use_random_seed=false) |
|
| 62 |
| `batch_size` | int | null | Batch generation count |
|
| 63 |
|
| 64 |
+
**5Hz LM Parameters (Optional, server-side codes generation)**:
|
| 65 |
+
|
| 66 |
+
If you want the server to generate `audio_code_string` using the 5Hz LM (equivalent to Gradio's **Generate LM Hints** button), set `use_5hz_lm=true`.
|
| 67 |
+
|
| 68 |
+
| Parameter Name | Type | Default | Description |
|
| 69 |
+
| :--- | :--- | :--- | :--- |
|
| 70 |
+
| `use_5hz_lm` | bool | `false` | Enable server-side 5Hz LM code generation |
|
| 71 |
+
| `lm_model_path` | string | null | 5Hz LM checkpoint dir name (e.g. `acestep-5Hz-lm-0.6B`) |
|
| 72 |
+
| `lm_backend` | string | `"vllm"` | `vllm` or `pt` |
|
| 73 |
+
| `lm_temperature` | float | `0.6` | Sampling temperature |
|
| 74 |
+
| `lm_cfg_scale` | float | `1.0` | CFG scale (>1 enables CFG) |
|
| 75 |
+
| `lm_negative_prompt` | string | `"NO USER INPUT"` | Negative prompt used by CFG |
|
| 76 |
+
| `lm_top_k` | int | null | Top-k (0/null disables) |
|
| 77 |
+
| `lm_top_p` | float | null | Top-p (>=1/null disables) |
|
| 78 |
+
| `lm_repetition_penalty` | float | `1.0` | Repetition penalty |
|
| 79 |
+
|
| 80 |
**Edit/Reference Audio Parameters** (requires absolute path on server):
|
| 81 |
|
| 82 |
| Parameter Name | Type | Default | Description |
|
|
|
|
| 124 |
}'
|
| 125 |
```
|
| 126 |
|
| 127 |
+
**JSON Method (server-side 5Hz LM)**:
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
curl -X POST http://localhost:8001/v1/music/generate \
|
| 131 |
+
-H 'Content-Type: application/json' \
|
| 132 |
+
-d '{
|
| 133 |
+
"caption": "upbeat pop song",
|
| 134 |
+
"lyrics": "Hello world",
|
| 135 |
+
"use_5hz_lm": true,
|
| 136 |
+
"lm_temperature": 0.6,
|
| 137 |
+
"lm_cfg_scale": 1.0,
|
| 138 |
+
"lm_top_k": 0,
|
| 139 |
+
"lm_top_p": 1.0,
|
| 140 |
+
"lm_repetition_penalty": 1.0
|
| 141 |
+
}'
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
When `use_5hz_lm=true` and the server generates LM codes, the job `result` will also include the following optional fields:
|
| 145 |
+
|
| 146 |
+
- `bpm`
|
| 147 |
+
- `duration`
|
| 148 |
+
- `genres`
|
| 149 |
+
- `keyscale`
|
| 150 |
+
- `timesignature`
|
| 151 |
+
|
| 152 |
> Note: If you use `curl -d` but **forget** to add `-H 'Content-Type: application/json'`, curl will default to sending `application/x-www-form-urlencoded`, and older server versions will return 415.
|
| 153 |
|
| 154 |
**Form Method (no file upload, application/x-www-form-urlencoded)**:
|
acestep/genres_vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/llm_inference.py
CHANGED
|
@@ -72,18 +72,35 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 72 |
tokenizer: AutoTokenizer,
|
| 73 |
enabled: bool = True,
|
| 74 |
debug: bool = False,
|
|
|
|
|
|
|
| 75 |
):
|
| 76 |
"""
|
| 77 |
Initialize the constrained logits processor.
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
Args:
|
| 80 |
tokenizer: The tokenizer to use for encoding/decoding
|
| 81 |
enabled: Whether to enable constrained decoding
|
| 82 |
debug: Whether to print debug information
|
|
|
|
|
|
|
|
|
|
| 83 |
"""
|
| 84 |
self.tokenizer = tokenizer
|
| 85 |
self.enabled = enabled
|
| 86 |
self.debug = debug
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
# Current state
|
| 89 |
self.state = FSMState.THINK_TAG
|
|
@@ -93,6 +110,23 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 93 |
# Pre-compute token IDs for efficiency
|
| 94 |
self._precompute_tokens()
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
# Field definitions
|
| 97 |
self.field_specs = {
|
| 98 |
"bpm": {"min": 30, "max": 300},
|
|
@@ -117,7 +151,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 117 |
FSMState.THINK_END_TAG: "</think>",
|
| 118 |
}
|
| 119 |
|
| 120 |
-
# State transitions
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
self.next_state = {
|
| 122 |
FSMState.THINK_TAG: FSMState.NEWLINE_AFTER_THINK,
|
| 123 |
FSMState.NEWLINE_AFTER_THINK: FSMState.BPM_NAME,
|
|
@@ -126,10 +164,6 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 126 |
FSMState.NEWLINE_AFTER_BPM: FSMState.DURATION_NAME,
|
| 127 |
FSMState.DURATION_NAME: FSMState.DURATION_VALUE,
|
| 128 |
FSMState.DURATION_VALUE: FSMState.NEWLINE_AFTER_DURATION,
|
| 129 |
-
FSMState.NEWLINE_AFTER_DURATION: FSMState.GENRES_NAME,
|
| 130 |
-
FSMState.GENRES_NAME: FSMState.GENRES_VALUE,
|
| 131 |
-
FSMState.GENRES_VALUE: FSMState.NEWLINE_AFTER_GENRES,
|
| 132 |
-
FSMState.NEWLINE_AFTER_GENRES: FSMState.KEYSCALE_NAME,
|
| 133 |
FSMState.KEYSCALE_NAME: FSMState.KEYSCALE_VALUE,
|
| 134 |
FSMState.KEYSCALE_VALUE: FSMState.NEWLINE_AFTER_KEYSCALE,
|
| 135 |
FSMState.NEWLINE_AFTER_KEYSCALE: FSMState.TIMESIG_NAME,
|
|
@@ -139,6 +173,21 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 139 |
FSMState.THINK_END_TAG: FSMState.CODES_GENERATION,
|
| 140 |
FSMState.CODES_GENERATION: FSMState.COMPLETED,
|
| 141 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
def _precompute_tokens(self):
|
| 144 |
"""Pre-compute commonly used token IDs for efficiency."""
|
|
@@ -189,6 +238,328 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 189 |
|
| 190 |
# Vocab size
|
| 191 |
self.vocab_size = len(self.tokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
def reset(self):
|
| 194 |
"""Reset the processor state for a new generation."""
|
|
@@ -196,6 +567,27 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 196 |
self.position_in_state = 0
|
| 197 |
self.accumulated_value = ""
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
def _get_allowed_tokens_for_fixed_string(self, fixed_str: str) -> List[int]:
|
| 200 |
"""
|
| 201 |
Get the token IDs that can continue the fixed string from current position.
|
|
@@ -400,13 +792,14 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 400 |
scores: [batch_size, vocab_size] logits for next token
|
| 401 |
|
| 402 |
Returns:
|
| 403 |
-
Modified scores with invalid tokens masked to -inf
|
| 404 |
"""
|
| 405 |
if not self.enabled:
|
| 406 |
-
return scores
|
| 407 |
|
| 408 |
if self.state == FSMState.COMPLETED or self.state == FSMState.CODES_GENERATION:
|
| 409 |
-
|
|
|
|
| 410 |
|
| 411 |
batch_size = scores.shape[0]
|
| 412 |
|
|
@@ -414,7 +807,39 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 414 |
for b in range(batch_size):
|
| 415 |
scores[b] = self._process_single_sequence(input_ids[b], scores[b:b+1])
|
| 416 |
|
| 417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
|
| 419 |
def _process_single_sequence(
|
| 420 |
self,
|
|
@@ -482,17 +907,38 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 482 |
scores = scores + mask
|
| 483 |
|
| 484 |
elif self.state == FSMState.GENRES_VALUE:
|
| 485 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
if self.newline_token:
|
| 487 |
mask[0, self.newline_token] = 0
|
| 488 |
-
self.
|
|
|
|
| 489 |
scores = scores + mask
|
| 490 |
else:
|
| 491 |
-
#
|
| 492 |
-
if
|
| 493 |
if self.newline_token:
|
| 494 |
-
|
| 495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
|
| 497 |
elif self.state == FSMState.KEYSCALE_VALUE:
|
| 498 |
if self._is_keyscale_complete():
|
|
@@ -591,6 +1037,8 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 591 |
|
| 592 |
class LLMHandler:
|
| 593 |
"""5Hz LM Handler for audio code generation"""
|
|
|
|
|
|
|
| 594 |
|
| 595 |
def __init__(self):
|
| 596 |
"""Initialize LLMHandler with default values"""
|
|
@@ -602,6 +1050,9 @@ class LLMHandler:
|
|
| 602 |
self.device = "cpu"
|
| 603 |
self.dtype = torch.float32
|
| 604 |
self.offload_to_cpu = False
|
|
|
|
|
|
|
|
|
|
| 605 |
|
| 606 |
def get_available_5hz_lm_models(self) -> List[str]:
|
| 607 |
"""Scan and return all model directory names starting with 'acestep-5Hz-lm-'"""
|
|
@@ -690,6 +1141,16 @@ class LLMHandler:
|
|
| 690 |
logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
|
| 691 |
self.llm_tokenizer = llm_tokenizer
|
| 692 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 693 |
# Initialize based on user-selected backend
|
| 694 |
if backend == "vllm":
|
| 695 |
# Try to initialize with vllm
|
|
@@ -795,13 +1256,15 @@ class LLMHandler:
|
|
| 795 |
repetition_penalty: float = 1.0,
|
| 796 |
use_constrained_decoding: bool = True,
|
| 797 |
constrained_decoding_debug: bool = False,
|
|
|
|
|
|
|
| 798 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 799 |
"""Generate metadata and audio codes using 5Hz LM with vllm backend
|
| 800 |
|
| 801 |
Args:
|
| 802 |
caption: Text caption for music generation
|
| 803 |
lyrics: Lyrics for music generation
|
| 804 |
-
temperature:
|
| 805 |
cfg_scale: CFG scale (>1.0 enables CFG)
|
| 806 |
negative_prompt: Negative prompt for CFG
|
| 807 |
top_k: Top-k sampling parameter
|
|
@@ -809,6 +1272,10 @@ class LLMHandler:
|
|
| 809 |
repetition_penalty: Repetition penalty
|
| 810 |
use_constrained_decoding: Whether to use FSM-based constrained decoding
|
| 811 |
constrained_decoding_debug: Whether to print debug info for constrained decoding
|
|
|
|
|
|
|
|
|
|
|
|
|
| 812 |
"""
|
| 813 |
try:
|
| 814 |
from nanovllm import SamplingParams
|
|
@@ -816,20 +1283,28 @@ class LLMHandler:
|
|
| 816 |
formatted_prompt = self.build_formatted_prompt(caption, lyrics)
|
| 817 |
logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
|
| 818 |
|
| 819 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 820 |
constrained_processor = None
|
| 821 |
update_state_fn = None
|
| 822 |
-
if use_constrained_decoding:
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
|
|
|
|
|
|
|
|
|
| 828 |
update_state_fn = constrained_processor.update_state
|
| 829 |
|
| 830 |
sampling_params = SamplingParams(
|
| 831 |
max_tokens=self.max_model_len-64,
|
| 832 |
-
temperature=
|
| 833 |
cfg_scale=cfg_scale,
|
| 834 |
top_k=top_k,
|
| 835 |
top_p=top_p,
|
|
@@ -880,21 +1355,31 @@ class LLMHandler:
|
|
| 880 |
repetition_penalty: float,
|
| 881 |
use_constrained_decoding: bool = True,
|
| 882 |
constrained_decoding_debug: bool = False,
|
|
|
|
|
|
|
| 883 |
) -> str:
|
| 884 |
"""Shared vllm path: accept prebuilt formatted prompt and return text."""
|
| 885 |
from nanovllm import SamplingParams
|
| 886 |
|
| 887 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 888 |
constrained_processor = None
|
| 889 |
-
if use_constrained_decoding:
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 894 |
|
| 895 |
sampling_params = SamplingParams(
|
| 896 |
max_tokens=self.max_model_len - 64,
|
| 897 |
-
temperature=
|
| 898 |
cfg_scale=cfg_scale,
|
| 899 |
top_k=top_k,
|
| 900 |
top_p=top_p,
|
|
@@ -940,13 +1425,15 @@ class LLMHandler:
|
|
| 940 |
repetition_penalty: float = 1.0,
|
| 941 |
use_constrained_decoding: bool = True,
|
| 942 |
constrained_decoding_debug: bool = False,
|
|
|
|
|
|
|
| 943 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 944 |
"""Generate metadata and audio codes using 5Hz LM with PyTorch backend
|
| 945 |
|
| 946 |
Args:
|
| 947 |
caption: Text caption for music generation
|
| 948 |
lyrics: Lyrics for music generation
|
| 949 |
-
temperature:
|
| 950 |
cfg_scale: CFG scale (>1.0 enables CFG)
|
| 951 |
negative_prompt: Negative prompt for CFG
|
| 952 |
top_k: Top-k sampling parameter
|
|
@@ -954,6 +1441,10 @@ class LLMHandler:
|
|
| 954 |
repetition_penalty: Repetition penalty
|
| 955 |
use_constrained_decoding: Whether to use FSM-based constrained decoding
|
| 956 |
constrained_decoding_debug: Whether to print debug info for constrained decoding
|
|
|
|
|
|
|
|
|
|
|
|
|
| 957 |
"""
|
| 958 |
try:
|
| 959 |
formatted_prompt = self.build_formatted_prompt(caption, lyrics)
|
|
@@ -992,14 +1483,21 @@ class LLMHandler:
|
|
| 992 |
|
| 993 |
streamer = TqdmTokenStreamer(total=max_new_tokens)
|
| 994 |
|
| 995 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 996 |
constrained_processor = None
|
| 997 |
-
if use_constrained_decoding:
|
| 998 |
-
|
| 999 |
-
|
| 1000 |
-
|
| 1001 |
-
|
| 1002 |
-
|
|
|
|
|
|
|
|
|
|
| 1003 |
|
| 1004 |
# Build logits processor list (only for CFG and repetition penalty)
|
| 1005 |
logits_processor = LogitsProcessorList()
|
|
@@ -1036,7 +1534,7 @@ class LLMHandler:
|
|
| 1036 |
batch_input_ids=batch_input_ids,
|
| 1037 |
batch_attention_mask=batch_attention_mask,
|
| 1038 |
max_new_tokens=max_new_tokens,
|
| 1039 |
-
temperature=
|
| 1040 |
cfg_scale=cfg_scale,
|
| 1041 |
top_k=top_k,
|
| 1042 |
top_p=top_p,
|
|
@@ -1048,8 +1546,8 @@ class LLMHandler:
|
|
| 1048 |
|
| 1049 |
# Extract only the conditional output (first in batch)
|
| 1050 |
outputs = outputs[0:1] # Keep only conditional output
|
| 1051 |
-
elif use_constrained_decoding:
|
| 1052 |
-
# Use custom generation loop for constrained decoding (non-CFG)
|
| 1053 |
input_ids = inputs['input_ids']
|
| 1054 |
attention_mask = inputs.get('attention_mask', None)
|
| 1055 |
|
|
@@ -1057,7 +1555,7 @@ class LLMHandler:
|
|
| 1057 |
input_ids=input_ids,
|
| 1058 |
attention_mask=attention_mask,
|
| 1059 |
max_new_tokens=max_new_tokens,
|
| 1060 |
-
temperature=
|
| 1061 |
top_k=top_k,
|
| 1062 |
top_p=top_p,
|
| 1063 |
repetition_penalty=repetition_penalty,
|
|
@@ -1071,8 +1569,8 @@ class LLMHandler:
|
|
| 1071 |
outputs = self.llm.generate(
|
| 1072 |
**inputs,
|
| 1073 |
max_new_tokens=max_new_tokens,
|
| 1074 |
-
temperature=
|
| 1075 |
-
do_sample=True if
|
| 1076 |
top_k=top_k if top_k is not None and top_k > 0 else None,
|
| 1077 |
top_p=top_p if top_p is not None and 0.0 < top_p < 1.0 else None,
|
| 1078 |
logits_processor=logits_processor if len(logits_processor) > 0 else None,
|
|
@@ -1134,13 +1632,15 @@ class LLMHandler:
|
|
| 1134 |
truncation=True,
|
| 1135 |
)
|
| 1136 |
|
| 1137 |
-
#
|
| 1138 |
constrained_processor = None
|
| 1139 |
if use_constrained_decoding:
|
| 1140 |
-
|
| 1141 |
-
|
| 1142 |
-
|
| 1143 |
-
)
|
|
|
|
|
|
|
| 1144 |
|
| 1145 |
with self._load_model_context():
|
| 1146 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
@@ -1262,13 +1762,15 @@ class LLMHandler:
|
|
| 1262 |
repetition_penalty: float = 1.0,
|
| 1263 |
use_constrained_decoding: bool = True,
|
| 1264 |
constrained_decoding_debug: bool = False,
|
|
|
|
|
|
|
| 1265 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 1266 |
"""Generate metadata and audio codes using 5Hz LM
|
| 1267 |
|
| 1268 |
Args:
|
| 1269 |
caption: Text caption for music generation
|
| 1270 |
lyrics: Lyrics for music generation
|
| 1271 |
-
temperature:
|
| 1272 |
cfg_scale: CFG scale (>1.0 enables CFG)
|
| 1273 |
negative_prompt: Negative prompt for CFG
|
| 1274 |
top_k: Top-k sampling parameter
|
|
@@ -1276,6 +1778,10 @@ class LLMHandler:
|
|
| 1276 |
repetition_penalty: Repetition penalty
|
| 1277 |
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
|
| 1278 |
constrained_decoding_debug: Whether to print debug info for constrained decoding
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1279 |
"""
|
| 1280 |
# Check if 5Hz LM is initialized
|
| 1281 |
if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
|
|
@@ -1293,17 +1799,101 @@ class LLMHandler:
|
|
| 1293 |
|
| 1294 |
if self.llm_backend == "vllm":
|
| 1295 |
return self.generate_with_5hz_lm_vllm(
|
| 1296 |
-
caption,
|
| 1297 |
-
|
| 1298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1299 |
)
|
| 1300 |
else:
|
| 1301 |
return self.generate_with_5hz_lm_pt(
|
| 1302 |
-
caption,
|
| 1303 |
-
|
| 1304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1305 |
)
|
| 1306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1307 |
def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False) -> str:
|
| 1308 |
"""
|
| 1309 |
Build the chat-formatted prompt for 5Hz LM from caption/lyrics.
|
|
|
|
| 72 |
tokenizer: AutoTokenizer,
|
| 73 |
enabled: bool = True,
|
| 74 |
debug: bool = False,
|
| 75 |
+
genres_vocab_path: Optional[str] = None,
|
| 76 |
+
skip_genres: bool = True,
|
| 77 |
):
|
| 78 |
"""
|
| 79 |
Initialize the constrained logits processor.
|
| 80 |
|
| 81 |
+
This processor should be initialized once when loading the LLM and reused
|
| 82 |
+
for all generations. Use update_caption() before each generation to update
|
| 83 |
+
the caption-based genre filtering.
|
| 84 |
+
|
| 85 |
Args:
|
| 86 |
tokenizer: The tokenizer to use for encoding/decoding
|
| 87 |
enabled: Whether to enable constrained decoding
|
| 88 |
debug: Whether to print debug information
|
| 89 |
+
genres_vocab_path: Path to genres vocabulary file (one genre per line)
|
| 90 |
+
If None, defaults to "acestep/genres_vocab.txt"
|
| 91 |
+
skip_genres: Whether to skip genres generation in metadata (default True)
|
| 92 |
"""
|
| 93 |
self.tokenizer = tokenizer
|
| 94 |
self.enabled = enabled
|
| 95 |
self.debug = debug
|
| 96 |
+
self.skip_genres = skip_genres
|
| 97 |
+
self.caption: Optional[str] = None # Set via update_caption() before each generation
|
| 98 |
+
|
| 99 |
+
# Temperature settings for different generation phases (set per-generation)
|
| 100 |
+
# If set, the processor will apply temperature scaling (divide logits by temperature)
|
| 101 |
+
# Note: Set base sampler temperature to 1.0 when using processor-based temperature
|
| 102 |
+
self.metadata_temperature: Optional[float] = None
|
| 103 |
+
self.codes_temperature: Optional[float] = None
|
| 104 |
|
| 105 |
# Current state
|
| 106 |
self.state = FSMState.THINK_TAG
|
|
|
|
| 110 |
# Pre-compute token IDs for efficiency
|
| 111 |
self._precompute_tokens()
|
| 112 |
|
| 113 |
+
# Genres vocabulary for constrained decoding
|
| 114 |
+
self.genres_vocab_path = genres_vocab_path or os.path.join(
|
| 115 |
+
os.path.dirname(os.path.abspath(__file__)), "genres_vocab.txt"
|
| 116 |
+
)
|
| 117 |
+
self.genres_vocab: List[str] = [] # Full vocab
|
| 118 |
+
self.genres_vocab_mtime: float = 0.0
|
| 119 |
+
self.genres_trie: Dict = {} # Trie for full vocab (fallback)
|
| 120 |
+
self.caption_genres_trie: Dict = {} # Trie for caption-matched genres (priority)
|
| 121 |
+
self.caption_matched_genres: List[str] = [] # Genres matched from caption
|
| 122 |
+
self._char_to_tokens: Dict[str, set] = {} # Precomputed char -> token IDs mapping
|
| 123 |
+
|
| 124 |
+
# Precompute token mappings once (O(vocab_size), runs once at init)
|
| 125 |
+
self._precompute_char_token_mapping()
|
| 126 |
+
self._load_genres_vocab()
|
| 127 |
+
|
| 128 |
+
# Note: Caption-based genre filtering is initialized via update_caption() before each generation
|
| 129 |
+
|
| 130 |
# Field definitions
|
| 131 |
self.field_specs = {
|
| 132 |
"bpm": {"min": 30, "max": 300},
|
|
|
|
| 151 |
FSMState.THINK_END_TAG: "</think>",
|
| 152 |
}
|
| 153 |
|
| 154 |
+
# State transitions - build dynamically based on skip_genres
|
| 155 |
+
self._build_state_transitions()
|
| 156 |
+
|
| 157 |
+
def _build_state_transitions(self):
|
| 158 |
+
"""Build state transition map based on skip_genres setting."""
|
| 159 |
self.next_state = {
|
| 160 |
FSMState.THINK_TAG: FSMState.NEWLINE_AFTER_THINK,
|
| 161 |
FSMState.NEWLINE_AFTER_THINK: FSMState.BPM_NAME,
|
|
|
|
| 164 |
FSMState.NEWLINE_AFTER_BPM: FSMState.DURATION_NAME,
|
| 165 |
FSMState.DURATION_NAME: FSMState.DURATION_VALUE,
|
| 166 |
FSMState.DURATION_VALUE: FSMState.NEWLINE_AFTER_DURATION,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
FSMState.KEYSCALE_NAME: FSMState.KEYSCALE_VALUE,
|
| 168 |
FSMState.KEYSCALE_VALUE: FSMState.NEWLINE_AFTER_KEYSCALE,
|
| 169 |
FSMState.NEWLINE_AFTER_KEYSCALE: FSMState.TIMESIG_NAME,
|
|
|
|
| 173 |
FSMState.THINK_END_TAG: FSMState.CODES_GENERATION,
|
| 174 |
FSMState.CODES_GENERATION: FSMState.COMPLETED,
|
| 175 |
}
|
| 176 |
+
|
| 177 |
+
if self.skip_genres:
|
| 178 |
+
# Skip genres: NEWLINE_AFTER_DURATION -> KEYSCALE_NAME directly
|
| 179 |
+
self.next_state[FSMState.NEWLINE_AFTER_DURATION] = FSMState.KEYSCALE_NAME
|
| 180 |
+
else:
|
| 181 |
+
# Include genres in the flow
|
| 182 |
+
self.next_state[FSMState.NEWLINE_AFTER_DURATION] = FSMState.GENRES_NAME
|
| 183 |
+
self.next_state[FSMState.GENRES_NAME] = FSMState.GENRES_VALUE
|
| 184 |
+
self.next_state[FSMState.GENRES_VALUE] = FSMState.NEWLINE_AFTER_GENRES
|
| 185 |
+
self.next_state[FSMState.NEWLINE_AFTER_GENRES] = FSMState.KEYSCALE_NAME
|
| 186 |
+
|
| 187 |
+
def set_skip_genres(self, skip: bool):
|
| 188 |
+
"""Set whether to skip genres generation and rebuild state transitions."""
|
| 189 |
+
self.skip_genres = skip
|
| 190 |
+
self._build_state_transitions()
|
| 191 |
|
| 192 |
def _precompute_tokens(self):
|
| 193 |
"""Pre-compute commonly used token IDs for efficiency."""
|
|
|
|
| 238 |
|
| 239 |
# Vocab size
|
| 240 |
self.vocab_size = len(self.tokenizer)
|
| 241 |
+
|
| 242 |
+
# Comma token for multi-genre support
|
| 243 |
+
comma_tokens = self.tokenizer.encode(",", add_special_tokens=False)
|
| 244 |
+
self.comma_token = comma_tokens[-1] if comma_tokens else None
|
| 245 |
+
|
| 246 |
+
def _load_genres_vocab(self):
|
| 247 |
+
"""
|
| 248 |
+
Load genres vocabulary from file. Supports hot reload by checking file mtime.
|
| 249 |
+
File format: one genre per line, lines starting with # are comments.
|
| 250 |
+
"""
|
| 251 |
+
if not os.path.exists(self.genres_vocab_path):
|
| 252 |
+
if self.debug:
|
| 253 |
+
logger.debug(f"Genres vocab file not found: {self.genres_vocab_path}")
|
| 254 |
+
return
|
| 255 |
+
|
| 256 |
+
try:
|
| 257 |
+
mtime = os.path.getmtime(self.genres_vocab_path)
|
| 258 |
+
if mtime <= self.genres_vocab_mtime:
|
| 259 |
+
return # File hasn't changed
|
| 260 |
+
|
| 261 |
+
with open(self.genres_vocab_path, 'r', encoding='utf-8') as f:
|
| 262 |
+
genres = []
|
| 263 |
+
for line in f:
|
| 264 |
+
line = line.strip()
|
| 265 |
+
if line and not line.startswith('#'):
|
| 266 |
+
genres.append(line.lower())
|
| 267 |
+
|
| 268 |
+
self.genres_vocab = genres
|
| 269 |
+
self.genres_vocab_mtime = mtime
|
| 270 |
+
self._build_genres_trie()
|
| 271 |
+
|
| 272 |
+
if self.debug:
|
| 273 |
+
logger.debug(f"Loaded {len(self.genres_vocab)} genres from {self.genres_vocab_path}")
|
| 274 |
+
except Exception as e:
|
| 275 |
+
logger.warning(f"Failed to load genres vocab: {e}")
|
| 276 |
+
|
| 277 |
+
def _build_genres_trie(self):
|
| 278 |
+
"""
|
| 279 |
+
Build a trie (prefix tree) from genres vocabulary for efficient prefix matching.
|
| 280 |
+
Each node is a dict with:
|
| 281 |
+
- '_end': True if this node represents a complete genre
|
| 282 |
+
- other keys: next characters in the trie
|
| 283 |
+
"""
|
| 284 |
+
self.genres_trie = {}
|
| 285 |
+
|
| 286 |
+
for genre in self.genres_vocab:
|
| 287 |
+
node = self.genres_trie
|
| 288 |
+
for char in genre:
|
| 289 |
+
if char not in node:
|
| 290 |
+
node[char] = {}
|
| 291 |
+
node = node[char]
|
| 292 |
+
node['_end'] = True # Mark end of a complete genre
|
| 293 |
+
|
| 294 |
+
if self.debug:
|
| 295 |
+
logger.debug(f"Built genres trie with {len(self.genres_vocab)} entries")
|
| 296 |
+
|
| 297 |
+
def _extract_caption_genres(self, caption: str):
|
| 298 |
+
"""
|
| 299 |
+
Extract genres from the user's caption that match entries in the vocabulary.
|
| 300 |
+
This creates a smaller trie for faster and more relevant genre generation.
|
| 301 |
+
|
| 302 |
+
Strategy (optimized - O(words * max_genre_len) instead of O(vocab_size)):
|
| 303 |
+
1. Extract words/phrases from caption
|
| 304 |
+
2. For each word, use trie to find all vocab entries that START with this word
|
| 305 |
+
3. Build a separate trie from matched genres
|
| 306 |
+
"""
|
| 307 |
+
if not caption or not self.genres_vocab:
|
| 308 |
+
return
|
| 309 |
+
|
| 310 |
+
caption_lower = caption.lower()
|
| 311 |
+
matched_genres = set()
|
| 312 |
+
|
| 313 |
+
# Extract words from caption (split by common delimiters)
|
| 314 |
+
import re
|
| 315 |
+
words = re.split(r'[,\s\-_/\\|]+', caption_lower)
|
| 316 |
+
words = [w.strip() for w in words if w.strip() and len(w.strip()) >= 2]
|
| 317 |
+
|
| 318 |
+
# For each word, find genres in trie that start with this word
|
| 319 |
+
for word in words:
|
| 320 |
+
# Find all genres starting with this word using trie traversal
|
| 321 |
+
node = self._get_genres_trie_node(word)
|
| 322 |
+
if node is not None:
|
| 323 |
+
# Collect all complete genres under this node
|
| 324 |
+
self._collect_complete_genres(node, word, matched_genres)
|
| 325 |
+
|
| 326 |
+
# Also check if any word appears as a substring in short genres (< 20 chars)
|
| 327 |
+
# This is a quick check for common single-word genres
|
| 328 |
+
genres_set = set(self.genres_vocab)
|
| 329 |
+
for word in words:
|
| 330 |
+
if word in genres_set:
|
| 331 |
+
matched_genres.add(word)
|
| 332 |
+
|
| 333 |
+
if not matched_genres:
|
| 334 |
+
if self.debug:
|
| 335 |
+
logger.debug(f"No genres matched in caption, using full vocab")
|
| 336 |
+
return
|
| 337 |
+
|
| 338 |
+
# Build a trie from matched genres
|
| 339 |
+
self.caption_matched_genres = list(matched_genres)
|
| 340 |
+
self.caption_genres_trie = {}
|
| 341 |
+
|
| 342 |
+
for genre in matched_genres:
|
| 343 |
+
node = self.caption_genres_trie
|
| 344 |
+
for char in genre:
|
| 345 |
+
if char not in node:
|
| 346 |
+
node[char] = {}
|
| 347 |
+
node = node[char]
|
| 348 |
+
node['_end'] = True
|
| 349 |
+
|
| 350 |
+
if self.debug:
|
| 351 |
+
logger.debug(f"Matched {len(matched_genres)} genres from caption: {list(matched_genres)[:5]}...")
|
| 352 |
+
|
| 353 |
+
def _collect_complete_genres(self, node: Dict, prefix: str, result: set, max_depth: int = 50):
|
| 354 |
+
"""
|
| 355 |
+
Recursively collect all complete genres under a trie node.
|
| 356 |
+
Limited depth to avoid too many matches.
|
| 357 |
+
"""
|
| 358 |
+
if max_depth <= 0:
|
| 359 |
+
return
|
| 360 |
+
|
| 361 |
+
if node.get('_end', False):
|
| 362 |
+
result.add(prefix)
|
| 363 |
+
|
| 364 |
+
# Limit total collected genres to avoid slowdown
|
| 365 |
+
if len(result) >= 100:
|
| 366 |
+
return
|
| 367 |
+
|
| 368 |
+
for char, child_node in node.items():
|
| 369 |
+
if char not in ('_end', '_tokens'):
|
| 370 |
+
self._collect_complete_genres(child_node, prefix + char, result, max_depth - 1)
|
| 371 |
+
|
| 372 |
+
def _precompute_char_token_mapping(self):
|
| 373 |
+
"""
|
| 374 |
+
Precompute mapping from characters to token IDs and token decoded texts.
|
| 375 |
+
This allows O(1) lookup instead of calling tokenizer.encode()/decode() at runtime.
|
| 376 |
+
|
| 377 |
+
Time complexity: O(vocab_size) - runs once during initialization
|
| 378 |
+
|
| 379 |
+
Note: Many subword tokenizers (like Qwen) add space prefixes to tokens.
|
| 380 |
+
We need to handle both the raw first char and the first non-space char.
|
| 381 |
+
"""
|
| 382 |
+
self._char_to_tokens: Dict[str, set] = {}
|
| 383 |
+
self._token_to_text: Dict[int, str] = {} # Precomputed decoded text for each token
|
| 384 |
+
|
| 385 |
+
# For each token in vocabulary, get its decoded text
|
| 386 |
+
for token_id in range(self.vocab_size):
|
| 387 |
+
try:
|
| 388 |
+
text = self.tokenizer.decode([token_id])
|
| 389 |
+
|
| 390 |
+
if not text:
|
| 391 |
+
continue
|
| 392 |
+
|
| 393 |
+
# Store the decoded text (normalized to lowercase)
|
| 394 |
+
# Keep leading spaces for proper concatenation (e.g., " rock" in "pop rock")
|
| 395 |
+
# Only rstrip trailing whitespace, unless it's a pure whitespace token
|
| 396 |
+
text_lower = text.lower()
|
| 397 |
+
if text_lower.strip(): # Has non-whitespace content
|
| 398 |
+
normalized_text = text_lower.rstrip()
|
| 399 |
+
else: # Pure whitespace token
|
| 400 |
+
normalized_text = " " # Normalize to single space
|
| 401 |
+
self._token_to_text[token_id] = normalized_text
|
| 402 |
+
|
| 403 |
+
# Map first character (including space) to this token
|
| 404 |
+
first_char = text[0].lower()
|
| 405 |
+
if first_char not in self._char_to_tokens:
|
| 406 |
+
self._char_to_tokens[first_char] = set()
|
| 407 |
+
self._char_to_tokens[first_char].add(token_id)
|
| 408 |
+
|
| 409 |
+
# Also map first non-space character to this token
|
| 410 |
+
# This handles tokenizers that add space prefixes (e.g., " pop" -> maps to 'p')
|
| 411 |
+
stripped_text = text.lstrip()
|
| 412 |
+
if stripped_text and stripped_text != text:
|
| 413 |
+
first_nonspace_char = stripped_text[0].lower()
|
| 414 |
+
if first_nonspace_char not in self._char_to_tokens:
|
| 415 |
+
self._char_to_tokens[first_nonspace_char] = set()
|
| 416 |
+
self._char_to_tokens[first_nonspace_char].add(token_id)
|
| 417 |
+
|
| 418 |
+
except Exception:
|
| 419 |
+
continue
|
| 420 |
+
|
| 421 |
+
if self.debug:
|
| 422 |
+
logger.debug(f"Precomputed char->token mapping for {len(self._char_to_tokens)} unique characters")
|
| 423 |
+
|
| 424 |
+
def _try_reload_genres_vocab(self):
|
| 425 |
+
"""Check if genres vocab file has been updated and reload if necessary."""
|
| 426 |
+
if not os.path.exists(self.genres_vocab_path):
|
| 427 |
+
return
|
| 428 |
+
|
| 429 |
+
try:
|
| 430 |
+
mtime = os.path.getmtime(self.genres_vocab_path)
|
| 431 |
+
if mtime > self.genres_vocab_mtime:
|
| 432 |
+
self._load_genres_vocab()
|
| 433 |
+
except Exception:
|
| 434 |
+
pass # Ignore errors during hot reload check
|
| 435 |
+
|
| 436 |
+
def _get_genres_trie_node(self, prefix: str) -> Optional[Dict]:
|
| 437 |
+
"""
|
| 438 |
+
Get the trie node for a given prefix.
|
| 439 |
+
Returns None if the prefix is not valid (no genres start with this prefix).
|
| 440 |
+
"""
|
| 441 |
+
node = self.genres_trie
|
| 442 |
+
for char in prefix.lower():
|
| 443 |
+
if char not in node:
|
| 444 |
+
return None
|
| 445 |
+
node = node[char]
|
| 446 |
+
return node
|
| 447 |
+
|
| 448 |
+
def _is_complete_genre(self, text: str) -> bool:
|
| 449 |
+
"""Check if the given text is a complete genre in the vocabulary."""
|
| 450 |
+
node = self._get_genres_trie_node(text.strip())
|
| 451 |
+
return node is not None and node.get('_end', False)
|
| 452 |
+
|
| 453 |
+
def _get_trie_node_from_trie(self, trie: Dict, prefix: str) -> Optional[Dict]:
|
| 454 |
+
"""Get a trie node from a specific trie (helper for caption vs full trie)."""
|
| 455 |
+
node = trie
|
| 456 |
+
for char in prefix.lower():
|
| 457 |
+
if char not in node:
|
| 458 |
+
return None
|
| 459 |
+
node = node[char]
|
| 460 |
+
return node
|
| 461 |
+
|
| 462 |
+
def _get_allowed_genres_tokens(self) -> List[int]:
|
| 463 |
+
"""
|
| 464 |
+
Get allowed tokens for genres field based on trie matching.
|
| 465 |
+
|
| 466 |
+
The entire genres string (including commas) must match a complete entry in the vocab.
|
| 467 |
+
For example, if vocab contains "pop, rock, jazz", the generated string must exactly
|
| 468 |
+
match that entry - we don't treat commas as separators for individual genres.
|
| 469 |
+
|
| 470 |
+
Strategy:
|
| 471 |
+
1. If caption-matched genres exist, use that smaller trie first (faster + more relevant)
|
| 472 |
+
2. If no caption matches or prefix not in caption trie, fallback to full vocab trie
|
| 473 |
+
3. Get valid next characters from current trie node
|
| 474 |
+
4. For each candidate token, verify the full decoded text forms a valid trie prefix
|
| 475 |
+
"""
|
| 476 |
+
if not self.genres_vocab:
|
| 477 |
+
# No vocab loaded, allow all except newline if empty
|
| 478 |
+
return []
|
| 479 |
+
|
| 480 |
+
# Use the full accumulated value (don't split by comma - treat as single entry)
|
| 481 |
+
accumulated = self.accumulated_value.lower()
|
| 482 |
+
current_genre_prefix = accumulated.strip()
|
| 483 |
+
|
| 484 |
+
# Determine which trie to use: caption-matched (priority) or full vocab (fallback)
|
| 485 |
+
use_caption_trie = False
|
| 486 |
+
current_node = None
|
| 487 |
+
|
| 488 |
+
# Try caption-matched trie first if available
|
| 489 |
+
if self.caption_genres_trie:
|
| 490 |
+
if current_genre_prefix == "":
|
| 491 |
+
current_node = self.caption_genres_trie
|
| 492 |
+
use_caption_trie = True
|
| 493 |
+
else:
|
| 494 |
+
current_node = self._get_trie_node_from_trie(self.caption_genres_trie, current_genre_prefix)
|
| 495 |
+
if current_node is not None:
|
| 496 |
+
use_caption_trie = True
|
| 497 |
+
|
| 498 |
+
# Fallback to full vocab trie
|
| 499 |
+
if current_node is None:
|
| 500 |
+
if current_genre_prefix == "":
|
| 501 |
+
current_node = self.genres_trie
|
| 502 |
+
else:
|
| 503 |
+
current_node = self._get_genres_trie_node(current_genre_prefix)
|
| 504 |
+
|
| 505 |
+
if current_node is None:
|
| 506 |
+
# Invalid prefix, force newline to end
|
| 507 |
+
if self.newline_token:
|
| 508 |
+
return [self.newline_token]
|
| 509 |
+
return []
|
| 510 |
+
|
| 511 |
+
# Get valid next characters from trie node
|
| 512 |
+
valid_next_chars = set(k for k in current_node.keys() if k not in ('_end', '_tokens'))
|
| 513 |
+
|
| 514 |
+
# If current value is a complete genre, allow newline to end
|
| 515 |
+
is_complete = current_node.get('_end', False)
|
| 516 |
+
|
| 517 |
+
if not valid_next_chars:
|
| 518 |
+
# No more characters to match, only allow newline if complete
|
| 519 |
+
allowed = set()
|
| 520 |
+
if is_complete and self.newline_token:
|
| 521 |
+
allowed.add(self.newline_token)
|
| 522 |
+
return list(allowed)
|
| 523 |
+
|
| 524 |
+
# Collect candidate tokens based on first character
|
| 525 |
+
candidate_tokens = set()
|
| 526 |
+
for char in valid_next_chars:
|
| 527 |
+
if char in self._char_to_tokens:
|
| 528 |
+
candidate_tokens.update(self._char_to_tokens[char])
|
| 529 |
+
|
| 530 |
+
# Select the appropriate trie for validation
|
| 531 |
+
active_trie = self.caption_genres_trie if use_caption_trie else self.genres_trie
|
| 532 |
+
|
| 533 |
+
# Validate each candidate token: check if prefix + decoded_token is a valid trie prefix
|
| 534 |
+
allowed = set()
|
| 535 |
+
for token_id in candidate_tokens:
|
| 536 |
+
# Use precomputed decoded text (already normalized)
|
| 537 |
+
decoded_normalized = self._token_to_text.get(token_id, "")
|
| 538 |
+
|
| 539 |
+
if not decoded_normalized or not decoded_normalized.strip():
|
| 540 |
+
# Token decodes to empty or only whitespace - allow if space/comma is a valid next char
|
| 541 |
+
if ' ' in valid_next_chars or ',' in valid_next_chars:
|
| 542 |
+
allowed.add(token_id)
|
| 543 |
+
continue
|
| 544 |
+
|
| 545 |
+
# Build new prefix by appending decoded token
|
| 546 |
+
# Handle space-prefixed tokens (e.g., " rock" from "pop rock")
|
| 547 |
+
if decoded_normalized.startswith(' ') or decoded_normalized.startswith(','):
|
| 548 |
+
# Token has leading space/comma - append directly
|
| 549 |
+
new_prefix = current_genre_prefix + decoded_normalized
|
| 550 |
+
else:
|
| 551 |
+
new_prefix = current_genre_prefix + decoded_normalized
|
| 552 |
+
|
| 553 |
+
# Check if new_prefix is a valid prefix in the active trie
|
| 554 |
+
new_node = self._get_trie_node_from_trie(active_trie, new_prefix)
|
| 555 |
+
if new_node is not None:
|
| 556 |
+
allowed.add(token_id)
|
| 557 |
+
|
| 558 |
+
# If current value is a complete genre, also allow newline
|
| 559 |
+
if is_complete and self.newline_token:
|
| 560 |
+
allowed.add(self.newline_token)
|
| 561 |
+
|
| 562 |
+
return list(allowed)
|
| 563 |
|
| 564 |
def reset(self):
|
| 565 |
"""Reset the processor state for a new generation."""
|
|
|
|
| 567 |
self.position_in_state = 0
|
| 568 |
self.accumulated_value = ""
|
| 569 |
|
| 570 |
+
def update_caption(self, caption: Optional[str]):
|
| 571 |
+
"""
|
| 572 |
+
Update the caption and rebuild the caption-matched genres trie.
|
| 573 |
+
Call this before each generation to prioritize genres from the new caption.
|
| 574 |
+
|
| 575 |
+
Args:
|
| 576 |
+
caption: User's input caption. If None or empty, clears caption matching.
|
| 577 |
+
"""
|
| 578 |
+
# Check for hot reload of genres vocabulary
|
| 579 |
+
self._try_reload_genres_vocab()
|
| 580 |
+
|
| 581 |
+
self.caption = caption
|
| 582 |
+
self.caption_genres_trie = {}
|
| 583 |
+
self.caption_matched_genres = []
|
| 584 |
+
|
| 585 |
+
if caption:
|
| 586 |
+
self._extract_caption_genres(caption)
|
| 587 |
+
|
| 588 |
+
# Also reset FSM state for new generation
|
| 589 |
+
self.reset()
|
| 590 |
+
|
| 591 |
def _get_allowed_tokens_for_fixed_string(self, fixed_str: str) -> List[int]:
|
| 592 |
"""
|
| 593 |
Get the token IDs that can continue the fixed string from current position.
|
|
|
|
| 792 |
scores: [batch_size, vocab_size] logits for next token
|
| 793 |
|
| 794 |
Returns:
|
| 795 |
+
Modified scores with invalid tokens masked to -inf and temperature scaling applied
|
| 796 |
"""
|
| 797 |
if not self.enabled:
|
| 798 |
+
return self._apply_temperature_scaling(scores)
|
| 799 |
|
| 800 |
if self.state == FSMState.COMPLETED or self.state == FSMState.CODES_GENERATION:
|
| 801 |
+
# No constraints in codes generation phase, but still apply temperature
|
| 802 |
+
return self._apply_temperature_scaling(scores)
|
| 803 |
|
| 804 |
batch_size = scores.shape[0]
|
| 805 |
|
|
|
|
| 807 |
for b in range(batch_size):
|
| 808 |
scores[b] = self._process_single_sequence(input_ids[b], scores[b:b+1])
|
| 809 |
|
| 810 |
+
# Apply temperature scaling after constraint masking
|
| 811 |
+
return self._apply_temperature_scaling(scores)
|
| 812 |
+
|
| 813 |
+
def _apply_temperature_scaling(self, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 814 |
+
"""
|
| 815 |
+
Apply temperature scaling based on current generation phase.
|
| 816 |
+
|
| 817 |
+
Temperature scaling: logits = logits / temperature
|
| 818 |
+
- Lower temperature (< 1.0) makes distribution sharper (more deterministic)
|
| 819 |
+
- Higher temperature (> 1.0) makes distribution flatter (more diverse)
|
| 820 |
+
|
| 821 |
+
Args:
|
| 822 |
+
scores: [batch_size, vocab_size] logits
|
| 823 |
+
|
| 824 |
+
Returns:
|
| 825 |
+
Temperature-scaled logits
|
| 826 |
+
"""
|
| 827 |
+
# Determine which temperature to use based on current state
|
| 828 |
+
if self.state == FSMState.CODES_GENERATION or self.state == FSMState.COMPLETED:
|
| 829 |
+
temperature = self.codes_temperature
|
| 830 |
+
else:
|
| 831 |
+
temperature = self.metadata_temperature
|
| 832 |
+
|
| 833 |
+
# If no temperature is set for this phase, return scores unchanged
|
| 834 |
+
if temperature is None:
|
| 835 |
+
return scores
|
| 836 |
+
|
| 837 |
+
# Avoid division by zero
|
| 838 |
+
if temperature <= 0:
|
| 839 |
+
temperature = 1e-6
|
| 840 |
+
|
| 841 |
+
# Apply temperature scaling
|
| 842 |
+
return scores / temperature
|
| 843 |
|
| 844 |
def _process_single_sequence(
|
| 845 |
self,
|
|
|
|
| 907 |
scores = scores + mask
|
| 908 |
|
| 909 |
elif self.state == FSMState.GENRES_VALUE:
|
| 910 |
+
# Try to hot-reload genres vocab if file has changed
|
| 911 |
+
self._try_reload_genres_vocab()
|
| 912 |
+
|
| 913 |
+
# Get allowed tokens based on genres vocabulary
|
| 914 |
+
allowed = self._get_allowed_genres_tokens()
|
| 915 |
+
|
| 916 |
+
if allowed:
|
| 917 |
+
# Use vocabulary-constrained decoding
|
| 918 |
+
for t in allowed:
|
| 919 |
+
mask[0, t] = 0
|
| 920 |
+
scores = scores + mask
|
| 921 |
+
elif self.genres_vocab:
|
| 922 |
+
# Vocab is loaded but no valid continuation found
|
| 923 |
+
# Force newline to end the field
|
| 924 |
if self.newline_token:
|
| 925 |
mask[0, self.newline_token] = 0
|
| 926 |
+
if self.debug:
|
| 927 |
+
logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
|
| 928 |
scores = scores + mask
|
| 929 |
else:
|
| 930 |
+
# Fallback: no vocab loaded, use probability-based ending
|
| 931 |
+
if self._should_end_text_field(scores):
|
| 932 |
if self.newline_token:
|
| 933 |
+
mask[0, self.newline_token] = 0
|
| 934 |
+
self._transition_to_next_state()
|
| 935 |
+
scores = scores + mask
|
| 936 |
+
else:
|
| 937 |
+
# Allow any token except newline if we don't have content yet
|
| 938 |
+
if not self.accumulated_value.strip():
|
| 939 |
+
if self.newline_token:
|
| 940 |
+
scores[0, self.newline_token] = float('-inf')
|
| 941 |
+
# Otherwise, don't constrain (fallback behavior)
|
| 942 |
|
| 943 |
elif self.state == FSMState.KEYSCALE_VALUE:
|
| 944 |
if self._is_keyscale_complete():
|
|
|
|
| 1037 |
|
| 1038 |
class LLMHandler:
|
| 1039 |
"""5Hz LM Handler for audio code generation"""
|
| 1040 |
+
|
| 1041 |
+
STOP_REASONING_TAG = "</think>"
|
| 1042 |
|
| 1043 |
def __init__(self):
|
| 1044 |
"""Initialize LLMHandler with default values"""
|
|
|
|
| 1050 |
self.device = "cpu"
|
| 1051 |
self.dtype = torch.float32
|
| 1052 |
self.offload_to_cpu = False
|
| 1053 |
+
|
| 1054 |
+
# Shared constrained decoding processor (initialized once when LLM is loaded)
|
| 1055 |
+
self.constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None
|
| 1056 |
|
| 1057 |
def get_available_5hz_lm_models(self) -> List[str]:
|
| 1058 |
"""Scan and return all model directory names starting with 'acestep-5Hz-lm-'"""
|
|
|
|
| 1141 |
logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
|
| 1142 |
self.llm_tokenizer = llm_tokenizer
|
| 1143 |
|
| 1144 |
+
# Initialize shared constrained decoding processor (one-time initialization)
|
| 1145 |
+
logger.info("Initializing constrained decoding processor...")
|
| 1146 |
+
processor_start = time.time()
|
| 1147 |
+
self.constrained_processor = MetadataConstrainedLogitsProcessor(
|
| 1148 |
+
tokenizer=self.llm_tokenizer,
|
| 1149 |
+
enabled=True,
|
| 1150 |
+
debug=False,
|
| 1151 |
+
)
|
| 1152 |
+
logger.info(f"Constrained processor initialized in {time.time() - processor_start:.2f} seconds")
|
| 1153 |
+
|
| 1154 |
# Initialize based on user-selected backend
|
| 1155 |
if backend == "vllm":
|
| 1156 |
# Try to initialize with vllm
|
|
|
|
| 1256 |
repetition_penalty: float = 1.0,
|
| 1257 |
use_constrained_decoding: bool = True,
|
| 1258 |
constrained_decoding_debug: bool = False,
|
| 1259 |
+
metadata_temperature: Optional[float] = 0.85,
|
| 1260 |
+
codes_temperature: Optional[float] = None,
|
| 1261 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 1262 |
"""Generate metadata and audio codes using 5Hz LM with vllm backend
|
| 1263 |
|
| 1264 |
Args:
|
| 1265 |
caption: Text caption for music generation
|
| 1266 |
lyrics: Lyrics for music generation
|
| 1267 |
+
temperature: Base sampling temperature (used if phase-specific temps not set)
|
| 1268 |
cfg_scale: CFG scale (>1.0 enables CFG)
|
| 1269 |
negative_prompt: Negative prompt for CFG
|
| 1270 |
top_k: Top-k sampling parameter
|
|
|
|
| 1272 |
repetition_penalty: Repetition penalty
|
| 1273 |
use_constrained_decoding: Whether to use FSM-based constrained decoding
|
| 1274 |
constrained_decoding_debug: Whether to print debug info for constrained decoding
|
| 1275 |
+
metadata_temperature: Temperature for metadata generation (lower = more accurate)
|
| 1276 |
+
If None, uses base temperature
|
| 1277 |
+
codes_temperature: Temperature for audio codes generation (higher = more diverse)
|
| 1278 |
+
If None, uses base temperature
|
| 1279 |
"""
|
| 1280 |
try:
|
| 1281 |
from nanovllm import SamplingParams
|
|
|
|
| 1283 |
formatted_prompt = self.build_formatted_prompt(caption, lyrics)
|
| 1284 |
logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
|
| 1285 |
|
| 1286 |
+
# Determine effective temperature for sampler
|
| 1287 |
+
# If using phase-specific temperatures, set sampler temp to 1.0 (processor handles it)
|
| 1288 |
+
use_phase_temperatures = metadata_temperature is not None or codes_temperature is not None
|
| 1289 |
+
effective_sampler_temp = 1.0 if use_phase_temperatures else temperature
|
| 1290 |
+
|
| 1291 |
+
# Use shared constrained decoding processor if enabled
|
| 1292 |
constrained_processor = None
|
| 1293 |
update_state_fn = None
|
| 1294 |
+
if use_constrained_decoding or use_phase_temperatures:
|
| 1295 |
+
# Use shared processor, just update caption and settings
|
| 1296 |
+
self.constrained_processor.enabled = use_constrained_decoding
|
| 1297 |
+
self.constrained_processor.debug = constrained_decoding_debug
|
| 1298 |
+
self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
|
| 1299 |
+
self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
|
| 1300 |
+
self.constrained_processor.update_caption(caption)
|
| 1301 |
+
|
| 1302 |
+
constrained_processor = self.constrained_processor
|
| 1303 |
update_state_fn = constrained_processor.update_state
|
| 1304 |
|
| 1305 |
sampling_params = SamplingParams(
|
| 1306 |
max_tokens=self.max_model_len-64,
|
| 1307 |
+
temperature=effective_sampler_temp,
|
| 1308 |
cfg_scale=cfg_scale,
|
| 1309 |
top_k=top_k,
|
| 1310 |
top_p=top_p,
|
|
|
|
| 1355 |
repetition_penalty: float,
|
| 1356 |
use_constrained_decoding: bool = True,
|
| 1357 |
constrained_decoding_debug: bool = False,
|
| 1358 |
+
metadata_temperature: Optional[float] = 0.85,
|
| 1359 |
+
codes_temperature: Optional[float] = None,
|
| 1360 |
) -> str:
|
| 1361 |
"""Shared vllm path: accept prebuilt formatted prompt and return text."""
|
| 1362 |
from nanovllm import SamplingParams
|
| 1363 |
|
| 1364 |
+
# Determine effective temperature for sampler
|
| 1365 |
+
use_phase_temperatures = metadata_temperature is not None or codes_temperature is not None
|
| 1366 |
+
effective_sampler_temp = 1.0 if use_phase_temperatures else temperature
|
| 1367 |
+
|
| 1368 |
+
# Use shared constrained processor if enabled
|
| 1369 |
constrained_processor = None
|
| 1370 |
+
if use_constrained_decoding or use_phase_temperatures:
|
| 1371 |
+
# Use shared processor, just update caption and settings
|
| 1372 |
+
self.constrained_processor.enabled = use_constrained_decoding
|
| 1373 |
+
self.constrained_processor.debug = constrained_decoding_debug
|
| 1374 |
+
self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
|
| 1375 |
+
self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
|
| 1376 |
+
self.constrained_processor.update_caption(formatted_prompt) # Use formatted prompt for genre extraction
|
| 1377 |
+
|
| 1378 |
+
constrained_processor = self.constrained_processor
|
| 1379 |
|
| 1380 |
sampling_params = SamplingParams(
|
| 1381 |
max_tokens=self.max_model_len - 64,
|
| 1382 |
+
temperature=effective_sampler_temp,
|
| 1383 |
cfg_scale=cfg_scale,
|
| 1384 |
top_k=top_k,
|
| 1385 |
top_p=top_p,
|
|
|
|
| 1425 |
repetition_penalty: float = 1.0,
|
| 1426 |
use_constrained_decoding: bool = True,
|
| 1427 |
constrained_decoding_debug: bool = False,
|
| 1428 |
+
metadata_temperature: Optional[float] = 0.85,
|
| 1429 |
+
codes_temperature: Optional[float] = None,
|
| 1430 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 1431 |
"""Generate metadata and audio codes using 5Hz LM with PyTorch backend
|
| 1432 |
|
| 1433 |
Args:
|
| 1434 |
caption: Text caption for music generation
|
| 1435 |
lyrics: Lyrics for music generation
|
| 1436 |
+
temperature: Base sampling temperature (used if phase-specific temps not set)
|
| 1437 |
cfg_scale: CFG scale (>1.0 enables CFG)
|
| 1438 |
negative_prompt: Negative prompt for CFG
|
| 1439 |
top_k: Top-k sampling parameter
|
|
|
|
| 1441 |
repetition_penalty: Repetition penalty
|
| 1442 |
use_constrained_decoding: Whether to use FSM-based constrained decoding
|
| 1443 |
constrained_decoding_debug: Whether to print debug info for constrained decoding
|
| 1444 |
+
metadata_temperature: Temperature for metadata generation (lower = more accurate)
|
| 1445 |
+
If None, uses base temperature
|
| 1446 |
+
codes_temperature: Temperature for audio codes generation (higher = more diverse)
|
| 1447 |
+
If None, uses base temperature
|
| 1448 |
"""
|
| 1449 |
try:
|
| 1450 |
formatted_prompt = self.build_formatted_prompt(caption, lyrics)
|
|
|
|
| 1483 |
|
| 1484 |
streamer = TqdmTokenStreamer(total=max_new_tokens)
|
| 1485 |
|
| 1486 |
+
# Determine if using phase-specific temperatures
|
| 1487 |
+
use_phase_temperatures = metadata_temperature is not None or codes_temperature is not None
|
| 1488 |
+
effective_temperature = 1.0 if use_phase_temperatures else temperature
|
| 1489 |
+
|
| 1490 |
+
# Use shared constrained decoding processor if enabled
|
| 1491 |
constrained_processor = None
|
| 1492 |
+
if use_constrained_decoding or use_phase_temperatures:
|
| 1493 |
+
# Use shared processor, just update caption and settings
|
| 1494 |
+
self.constrained_processor.enabled = use_constrained_decoding
|
| 1495 |
+
self.constrained_processor.debug = constrained_decoding_debug
|
| 1496 |
+
self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
|
| 1497 |
+
self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
|
| 1498 |
+
self.constrained_processor.update_caption(caption)
|
| 1499 |
+
|
| 1500 |
+
constrained_processor = self.constrained_processor
|
| 1501 |
|
| 1502 |
# Build logits processor list (only for CFG and repetition penalty)
|
| 1503 |
logits_processor = LogitsProcessorList()
|
|
|
|
| 1534 |
batch_input_ids=batch_input_ids,
|
| 1535 |
batch_attention_mask=batch_attention_mask,
|
| 1536 |
max_new_tokens=max_new_tokens,
|
| 1537 |
+
temperature=effective_temperature,
|
| 1538 |
cfg_scale=cfg_scale,
|
| 1539 |
top_k=top_k,
|
| 1540 |
top_p=top_p,
|
|
|
|
| 1546 |
|
| 1547 |
# Extract only the conditional output (first in batch)
|
| 1548 |
outputs = outputs[0:1] # Keep only conditional output
|
| 1549 |
+
elif use_constrained_decoding or use_phase_temperatures:
|
| 1550 |
+
# Use custom generation loop for constrained decoding or phase temperatures (non-CFG)
|
| 1551 |
input_ids = inputs['input_ids']
|
| 1552 |
attention_mask = inputs.get('attention_mask', None)
|
| 1553 |
|
|
|
|
| 1555 |
input_ids=input_ids,
|
| 1556 |
attention_mask=attention_mask,
|
| 1557 |
max_new_tokens=max_new_tokens,
|
| 1558 |
+
temperature=effective_temperature,
|
| 1559 |
top_k=top_k,
|
| 1560 |
top_p=top_p,
|
| 1561 |
repetition_penalty=repetition_penalty,
|
|
|
|
| 1569 |
outputs = self.llm.generate(
|
| 1570 |
**inputs,
|
| 1571 |
max_new_tokens=max_new_tokens,
|
| 1572 |
+
temperature=effective_temperature if effective_temperature > 0 else 1.0,
|
| 1573 |
+
do_sample=True if effective_temperature > 0 else False,
|
| 1574 |
top_k=top_k if top_k is not None and top_k > 0 else None,
|
| 1575 |
top_p=top_p if top_p is not None and 0.0 < top_p < 1.0 else None,
|
| 1576 |
logits_processor=logits_processor if len(logits_processor) > 0 else None,
|
|
|
|
| 1632 |
truncation=True,
|
| 1633 |
)
|
| 1634 |
|
| 1635 |
+
# Use shared constrained processor if enabled
|
| 1636 |
constrained_processor = None
|
| 1637 |
if use_constrained_decoding:
|
| 1638 |
+
# Use shared processor, just update caption and settings
|
| 1639 |
+
self.constrained_processor.enabled = use_constrained_decoding
|
| 1640 |
+
self.constrained_processor.debug = constrained_decoding_debug
|
| 1641 |
+
self.constrained_processor.update_caption(formatted_prompt) # Use formatted prompt for genre extraction
|
| 1642 |
+
|
| 1643 |
+
constrained_processor = self.constrained_processor
|
| 1644 |
|
| 1645 |
with self._load_model_context():
|
| 1646 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
| 1762 |
repetition_penalty: float = 1.0,
|
| 1763 |
use_constrained_decoding: bool = True,
|
| 1764 |
constrained_decoding_debug: bool = False,
|
| 1765 |
+
metadata_temperature: Optional[float] = 0.85,
|
| 1766 |
+
codes_temperature: Optional[float] = None,
|
| 1767 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 1768 |
"""Generate metadata and audio codes using 5Hz LM
|
| 1769 |
|
| 1770 |
Args:
|
| 1771 |
caption: Text caption for music generation
|
| 1772 |
lyrics: Lyrics for music generation
|
| 1773 |
+
temperature: Base sampling temperature (used if phase-specific temps not set)
|
| 1774 |
cfg_scale: CFG scale (>1.0 enables CFG)
|
| 1775 |
negative_prompt: Negative prompt for CFG
|
| 1776 |
top_k: Top-k sampling parameter
|
|
|
|
| 1778 |
repetition_penalty: Repetition penalty
|
| 1779 |
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
|
| 1780 |
constrained_decoding_debug: Whether to print debug info for constrained decoding
|
| 1781 |
+
metadata_temperature: Temperature for metadata generation (lower = more accurate)
|
| 1782 |
+
Recommended: 0.3-0.5 for accurate metadata
|
| 1783 |
+
codes_temperature: Temperature for audio codes generation (higher = more diverse)
|
| 1784 |
+
Recommended: 0.7-1.0 for diverse codes
|
| 1785 |
"""
|
| 1786 |
# Check if 5Hz LM is initialized
|
| 1787 |
if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
|
|
|
|
| 1799 |
|
| 1800 |
if self.llm_backend == "vllm":
|
| 1801 |
return self.generate_with_5hz_lm_vllm(
|
| 1802 |
+
caption=caption,
|
| 1803 |
+
lyrics=lyrics,
|
| 1804 |
+
temperature=temperature,
|
| 1805 |
+
cfg_scale=cfg_scale,
|
| 1806 |
+
negative_prompt=negative_prompt,
|
| 1807 |
+
top_k=top_k,
|
| 1808 |
+
top_p=top_p,
|
| 1809 |
+
repetition_penalty=repetition_penalty,
|
| 1810 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 1811 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1812 |
+
metadata_temperature=metadata_temperature,
|
| 1813 |
+
codes_temperature=codes_temperature,
|
| 1814 |
)
|
| 1815 |
else:
|
| 1816 |
return self.generate_with_5hz_lm_pt(
|
| 1817 |
+
caption=caption,
|
| 1818 |
+
lyrics=lyrics,
|
| 1819 |
+
temperature=temperature,
|
| 1820 |
+
cfg_scale=cfg_scale,
|
| 1821 |
+
negative_prompt=negative_prompt,
|
| 1822 |
+
top_k=top_k,
|
| 1823 |
+
top_p=top_p,
|
| 1824 |
+
repetition_penalty=repetition_penalty,
|
| 1825 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 1826 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1827 |
+
metadata_temperature=metadata_temperature,
|
| 1828 |
+
codes_temperature=codes_temperature,
|
| 1829 |
)
|
| 1830 |
|
| 1831 |
+
def generate_with_stop_condition(
|
| 1832 |
+
self,
|
| 1833 |
+
caption: str,
|
| 1834 |
+
lyrics: str,
|
| 1835 |
+
infer_type: str,
|
| 1836 |
+
temperature: float = 0.6,
|
| 1837 |
+
cfg_scale: float = 1.0,
|
| 1838 |
+
negative_prompt: str = "NO USER INPUT",
|
| 1839 |
+
top_k: Optional[int] = None,
|
| 1840 |
+
top_p: Optional[float] = None,
|
| 1841 |
+
repetition_penalty: float = 1.0,
|
| 1842 |
+
use_constrained_decoding: bool = True,
|
| 1843 |
+
constrained_decoding_debug: bool = False,
|
| 1844 |
+
metadata_temperature: Optional[float] = 0.85,
|
| 1845 |
+
codes_temperature: Optional[float] = None,
|
| 1846 |
+
) -> Tuple[Dict[str, Any], str, str]:
|
| 1847 |
+
"""Feishu-compatible LM generation.
|
| 1848 |
+
|
| 1849 |
+
- infer_type='dit': stop at </think> and return metas only (no audio codes)
|
| 1850 |
+
- infer_type='llm_dit': normal generation (metas + audio codes)
|
| 1851 |
+
"""
|
| 1852 |
+
infer_type = (infer_type or "").strip().lower()
|
| 1853 |
+
if infer_type not in {"dit", "llm_dit"}:
|
| 1854 |
+
return {}, "", f"❌ invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
|
| 1855 |
+
|
| 1856 |
+
if infer_type == "llm_dit":
|
| 1857 |
+
return self.generate_with_5hz_lm(
|
| 1858 |
+
caption=caption,
|
| 1859 |
+
lyrics=lyrics,
|
| 1860 |
+
temperature=temperature,
|
| 1861 |
+
cfg_scale=cfg_scale,
|
| 1862 |
+
negative_prompt=negative_prompt,
|
| 1863 |
+
top_k=top_k,
|
| 1864 |
+
top_p=top_p,
|
| 1865 |
+
repetition_penalty=repetition_penalty,
|
| 1866 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 1867 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1868 |
+
metadata_temperature=metadata_temperature,
|
| 1869 |
+
codes_temperature=codes_temperature,
|
| 1870 |
+
)
|
| 1871 |
+
|
| 1872 |
+
# dit: generate and truncate at reasoning end tag
|
| 1873 |
+
formatted_prompt = self.build_formatted_prompt(caption, lyrics)
|
| 1874 |
+
output_text, status = self.generate_from_formatted_prompt(
|
| 1875 |
+
formatted_prompt,
|
| 1876 |
+
cfg={
|
| 1877 |
+
"temperature": temperature,
|
| 1878 |
+
"cfg_scale": cfg_scale,
|
| 1879 |
+
"negative_prompt": negative_prompt,
|
| 1880 |
+
"top_k": top_k,
|
| 1881 |
+
"top_p": top_p,
|
| 1882 |
+
"repetition_penalty": repetition_penalty,
|
| 1883 |
+
},
|
| 1884 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 1885 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1886 |
+
)
|
| 1887 |
+
if not output_text:
|
| 1888 |
+
return {}, "", status
|
| 1889 |
+
|
| 1890 |
+
if self.STOP_REASONING_TAG in output_text:
|
| 1891 |
+
stop_idx = output_text.find(self.STOP_REASONING_TAG)
|
| 1892 |
+
output_text = output_text[: stop_idx + len(self.STOP_REASONING_TAG)]
|
| 1893 |
+
|
| 1894 |
+
metadata, _audio_codes = self.parse_lm_output(output_text)
|
| 1895 |
+
return metadata, "", status
|
| 1896 |
+
|
| 1897 |
def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False) -> str:
|
| 1898 |
"""
|
| 1899 |
Build the chat-formatted prompt for 5Hz LM from caption/lyrics.
|