Spaces:
Running
on
A100
Running
on
A100
add 5hz llm
Browse files
acestep/gradio_ui.py
CHANGED
|
@@ -204,7 +204,6 @@ def create_generation_section(handler) -> dict:
|
|
| 204 |
label="Initialize 5Hz LM",
|
| 205 |
value=False,
|
| 206 |
info="Check to initialize 5Hz LM during service initialization",
|
| 207 |
-
interactive=False
|
| 208 |
)
|
| 209 |
|
| 210 |
with gr.Row():
|
|
@@ -298,7 +297,7 @@ def create_generation_section(handler) -> dict:
|
|
| 298 |
)
|
| 299 |
|
| 300 |
# 5Hz LM
|
| 301 |
-
with gr.Row(visible=
|
| 302 |
use_5hz_lm_btn = gr.Button(
|
| 303 |
"Generate LM Hints",
|
| 304 |
variant="secondary",
|
|
@@ -748,9 +747,36 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
|
|
| 748 |
def generate_lm_hints_wrapper(caption, lyrics, temperature):
|
| 749 |
"""Wrapper for 5Hz LM generation"""
|
| 750 |
metadata, audio_codes, status = handler.generate_with_5hz_lm(caption, lyrics, temperature)
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 754 |
|
| 755 |
generation_section["use_5hz_lm_btn"].click(
|
| 756 |
fn=generate_lm_hints_wrapper,
|
|
@@ -759,7 +785,13 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
|
|
| 759 |
generation_section["lyrics"],
|
| 760 |
generation_section["lm_temperature"]
|
| 761 |
],
|
| 762 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
)
|
| 764 |
|
| 765 |
# Update instruction and UI visibility based on task type
|
|
|
|
| 204 |
label="Initialize 5Hz LM",
|
| 205 |
value=False,
|
| 206 |
info="Check to initialize 5Hz LM during service initialization",
|
|
|
|
| 207 |
)
|
| 208 |
|
| 209 |
with gr.Row():
|
|
|
|
| 297 |
)
|
| 298 |
|
| 299 |
# 5Hz LM
|
| 300 |
+
with gr.Row(visible=True) as use_5hz_lm_row:
|
| 301 |
use_5hz_lm_btn = gr.Button(
|
| 302 |
"Generate LM Hints",
|
| 303 |
variant="secondary",
|
|
|
|
| 747 |
def generate_lm_hints_wrapper(caption, lyrics, temperature):
|
| 748 |
"""Wrapper for 5Hz LM generation"""
|
| 749 |
metadata, audio_codes, status = handler.generate_with_5hz_lm(caption, lyrics, temperature)
|
| 750 |
+
|
| 751 |
+
# Extract metadata values and map to UI fields
|
| 752 |
+
# Handle bpm
|
| 753 |
+
bpm_value = metadata.get('bpm', None)
|
| 754 |
+
if bpm_value == "N/A" or bpm_value == "":
|
| 755 |
+
bpm_value = None
|
| 756 |
+
|
| 757 |
+
# Handle key_scale (metadata uses 'keyscale')
|
| 758 |
+
key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
|
| 759 |
+
if key_scale_value == "N/A":
|
| 760 |
+
key_scale_value = ""
|
| 761 |
+
|
| 762 |
+
# Handle time_signature (metadata uses 'timesignature')
|
| 763 |
+
time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
|
| 764 |
+
if time_signature_value == "N/A":
|
| 765 |
+
time_signature_value = ""
|
| 766 |
+
|
| 767 |
+
# Handle audio_duration (metadata uses 'duration')
|
| 768 |
+
audio_duration_value = metadata.get('duration', -1)
|
| 769 |
+
if audio_duration_value == "N/A" or audio_duration_value == "":
|
| 770 |
+
audio_duration_value = -1
|
| 771 |
+
|
| 772 |
+
# Return audio codes and all metadata fields
|
| 773 |
+
return (
|
| 774 |
+
audio_codes, # text2music_audio_code_string
|
| 775 |
+
bpm_value, # bpm
|
| 776 |
+
key_scale_value, # key_scale
|
| 777 |
+
time_signature_value, # time_signature
|
| 778 |
+
audio_duration_value, # audio_duration
|
| 779 |
+
)
|
| 780 |
|
| 781 |
generation_section["use_5hz_lm_btn"].click(
|
| 782 |
fn=generate_lm_hints_wrapper,
|
|
|
|
| 785 |
generation_section["lyrics"],
|
| 786 |
generation_section["lm_temperature"]
|
| 787 |
],
|
| 788 |
+
outputs=[
|
| 789 |
+
generation_section["text2music_audio_code_string"],
|
| 790 |
+
generation_section["bpm"],
|
| 791 |
+
generation_section["key_scale"],
|
| 792 |
+
generation_section["time_signature"],
|
| 793 |
+
generation_section["audio_duration"],
|
| 794 |
+
]
|
| 795 |
)
|
| 796 |
|
| 797 |
# Update instruction and UI visibility based on task type
|
acestep/handler.py
CHANGED
|
@@ -4,6 +4,7 @@ Encapsulates all data processing and business logic as a bridge between model an
|
|
| 4 |
"""
|
| 5 |
import os
|
| 6 |
import math
|
|
|
|
| 7 |
import tempfile
|
| 8 |
import traceback
|
| 9 |
import re
|
|
@@ -58,9 +59,10 @@ class AceStepHandler:
|
|
| 58 |
self.sample_rate = 48000
|
| 59 |
|
| 60 |
# 5Hz LM related
|
| 61 |
-
self.
|
| 62 |
-
self.
|
| 63 |
-
self.
|
|
|
|
| 64 |
|
| 65 |
# Reward model (temporarily disabled)
|
| 66 |
self.reward_model = None
|
|
@@ -218,12 +220,43 @@ class AceStepHandler:
|
|
| 218 |
if init_llm:
|
| 219 |
full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
|
| 220 |
if os.path.exists(full_lm_model_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
if device == "cuda":
|
| 222 |
status_msg = self._initialize_5hz_lm_cuda(full_lm_model_path)
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
else:
|
| 228 |
# 5Hz LM path not found
|
| 229 |
return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
|
|
@@ -266,6 +299,10 @@ class AceStepHandler:
|
|
| 266 |
reserved_mem_bytes = torch.cuda.memory_reserved(device)
|
| 267 |
|
| 268 |
total_gpu = total_gpu_mem_bytes / 1024**3
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
allocated_gpu = allocated_mem_bytes / 1024**3
|
| 270 |
reserved_gpu = reserved_mem_bytes / 1024**3
|
| 271 |
available_gpu = total_gpu - reserved_gpu
|
|
@@ -275,54 +312,64 @@ class AceStepHandler:
|
|
| 275 |
else:
|
| 276 |
ratio = min(max_ratio, max(min_ratio, (available_gpu * 0.8) / total_gpu))
|
| 277 |
|
| 278 |
-
return ratio
|
| 279 |
except Exception as e:
|
| 280 |
-
return 0.9
|
| 281 |
|
| 282 |
def _initialize_5hz_lm_cuda(self, model_path: str) -> str:
|
| 283 |
"""Initialize 5Hz LM model"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
try:
|
| 285 |
from nanovllm import LLM, SamplingParams
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
|
|
|
|
|
|
| 290 |
current_device = torch.cuda.current_device()
|
| 291 |
device_name = torch.cuda.get_device_name(current_device)
|
| 292 |
|
| 293 |
torch.cuda.empty_cache()
|
| 294 |
-
gpu_memory_utilization = self.get_gpu_memory_utilization(
|
| 295 |
minimal_gpu=8,
|
| 296 |
min_ratio=0.2,
|
| 297 |
max_ratio=0.9
|
| 298 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
|
|
|
|
|
|
| 300 |
self.llm = LLM(
|
| 301 |
model=model_path,
|
| 302 |
enforce_eager=False,
|
| 303 |
tensor_parallel_size=1,
|
| 304 |
-
max_model_len=
|
| 305 |
gpu_memory_utilization=gpu_memory_utilization,
|
| 306 |
)
|
| 307 |
-
|
|
|
|
| 308 |
self.llm_initialized = True
|
|
|
|
| 309 |
return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
|
| 310 |
except Exception as e:
|
| 311 |
self.llm_initialized = False
|
| 312 |
error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 313 |
return error_msg
|
| 314 |
-
|
| 315 |
-
def
|
| 316 |
-
"""Generate metadata and audio codes using 5Hz LM"""
|
| 317 |
-
if not self.lm_initialized or self.llm is None:
|
| 318 |
-
return {}, "", "❌ 5Hz LM not initialized. Please initialize it first."
|
| 319 |
-
|
| 320 |
try:
|
| 321 |
from nanovllm import SamplingParams
|
| 322 |
|
| 323 |
prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
|
| 324 |
|
| 325 |
-
formatted_prompt = self.
|
| 326 |
[
|
| 327 |
{"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
|
| 328 |
{"role": "user", "content": prompt}
|
|
@@ -330,10 +377,10 @@ class AceStepHandler:
|
|
| 330 |
tokenize=False,
|
| 331 |
add_generation_prompt=True,
|
| 332 |
)
|
|
|
|
| 333 |
|
| 334 |
-
sampling_params = SamplingParams(max_tokens=
|
| 335 |
outputs = self.llm.generate([formatted_prompt], sampling_params)
|
| 336 |
-
|
| 337 |
if isinstance(outputs, list) and len(outputs) > 0:
|
| 338 |
if hasattr(outputs[0], 'outputs') and len(outputs[0].outputs) > 0:
|
| 339 |
output_text = outputs[0].outputs[0].text
|
|
@@ -351,22 +398,113 @@ class AceStepHandler:
|
|
| 351 |
except Exception as e:
|
| 352 |
error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 353 |
return {}, "", error_msg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]:
|
| 356 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
metadata = {}
|
| 358 |
audio_codes = ""
|
| 359 |
|
| 360 |
import re
|
| 361 |
|
| 362 |
-
# Extract audio codes
|
| 363 |
code_pattern = r'<\|audio_code_\d+\|>'
|
| 364 |
code_matches = re.findall(code_pattern, output_text)
|
| 365 |
if code_matches:
|
| 366 |
audio_codes = "".join(code_matches)
|
| 367 |
|
| 368 |
-
# Extract metadata
|
|
|
|
| 369 |
reasoning_patterns = [
|
|
|
|
| 370 |
r'<think>(.*?)</think>',
|
| 371 |
r'<reasoning>(.*?)</reasoning>',
|
| 372 |
]
|
|
@@ -378,7 +516,9 @@ class AceStepHandler:
|
|
| 378 |
reasoning_text = match.group(1).strip()
|
| 379 |
break
|
| 380 |
|
|
|
|
| 381 |
if not reasoning_text:
|
|
|
|
| 382 |
lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text
|
| 383 |
reasoning_text = lines_before_codes.strip()
|
| 384 |
|
|
@@ -402,8 +542,12 @@ class AceStepHandler:
|
|
| 402 |
metadata['duration'] = int(value)
|
| 403 |
except:
|
| 404 |
metadata['duration'] = value
|
| 405 |
-
elif key
|
| 406 |
-
metadata[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
|
| 408 |
return metadata, audio_codes
|
| 409 |
|
|
|
|
| 4 |
"""
|
| 5 |
import os
|
| 6 |
import math
|
| 7 |
+
from copy import deepcopy
|
| 8 |
import tempfile
|
| 9 |
import traceback
|
| 10 |
import re
|
|
|
|
| 59 |
self.sample_rate = 48000
|
| 60 |
|
| 61 |
# 5Hz LM related
|
| 62 |
+
self.llm = None
|
| 63 |
+
self.llm_tokenizer = None
|
| 64 |
+
self.llm_initialized = False
|
| 65 |
+
self.llm_backend = None
|
| 66 |
|
| 67 |
# Reward model (temporarily disabled)
|
| 68 |
self.reward_model = None
|
|
|
|
| 220 |
if init_llm:
|
| 221 |
full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
|
| 222 |
if os.path.exists(full_lm_model_path):
|
| 223 |
+
print("loading 5Hz LM tokenizer...")
|
| 224 |
+
start_time = time.time()
|
| 225 |
+
llm_tokenizer = deepcopy(self.text_tokenizer)
|
| 226 |
+
max_audio_length = 2**16 - 1
|
| 227 |
+
semantic_tokens = [f"<|audio_code_{i}|>" for i in range(max_audio_length)]
|
| 228 |
+
# 217204
|
| 229 |
+
llm_tokenizer.add_special_tokens({"additional_special_tokens": semantic_tokens})
|
| 230 |
+
print(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
|
| 231 |
+
self.llm_tokenizer = llm_tokenizer
|
| 232 |
if device == "cuda":
|
| 233 |
status_msg = self._initialize_5hz_lm_cuda(full_lm_model_path)
|
| 234 |
+
print(f"5Hz LM status message: {status_msg}")
|
| 235 |
+
# Check if initialization failed (status_msg starts with ❌)
|
| 236 |
+
if status_msg.startswith("❌"):
|
| 237 |
+
# vllm initialization failed, fallback to PyTorch
|
| 238 |
+
if not self.llm_initialized:
|
| 239 |
+
try:
|
| 240 |
+
self.llm = AutoModel.from_pretrained(full_lm_model_path)
|
| 241 |
+
self.llm = self.llm.to(device).to(self.dtype)
|
| 242 |
+
self.llm.eval()
|
| 243 |
+
self.llm_backend = "pt"
|
| 244 |
+
self.llm_initialized = True
|
| 245 |
+
except Exception as e:
|
| 246 |
+
return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
|
| 247 |
+
# If vllm initialization succeeded, self.llm_initialized should already be True
|
| 248 |
+
else:
|
| 249 |
+
# For CPU or other devices, use PyTorch backend
|
| 250 |
+
try:
|
| 251 |
+
self.llm = AutoModel.from_pretrained(full_lm_model_path)
|
| 252 |
+
self.llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True)
|
| 253 |
+
self.llm = self.llm.to(device).to(self.dtype)
|
| 254 |
+
self.llm.eval()
|
| 255 |
+
self.llm_backend = "pt"
|
| 256 |
+
self.llm_initialized = True
|
| 257 |
+
except Exception as e:
|
| 258 |
+
return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
|
| 259 |
+
|
| 260 |
else:
|
| 261 |
# 5Hz LM path not found
|
| 262 |
return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
|
|
|
|
| 299 |
reserved_mem_bytes = torch.cuda.memory_reserved(device)
|
| 300 |
|
| 301 |
total_gpu = total_gpu_mem_bytes / 1024**3
|
| 302 |
+
low_gpu_memory_mode = False
|
| 303 |
+
if total_gpu < minimal_gpu:
|
| 304 |
+
minimal_gpu = 0.5 * total_gpu
|
| 305 |
+
low_gpu_memory_mode = True
|
| 306 |
allocated_gpu = allocated_mem_bytes / 1024**3
|
| 307 |
reserved_gpu = reserved_mem_bytes / 1024**3
|
| 308 |
available_gpu = total_gpu - reserved_gpu
|
|
|
|
| 312 |
else:
|
| 313 |
ratio = min(max_ratio, max(min_ratio, (available_gpu * 0.8) / total_gpu))
|
| 314 |
|
| 315 |
+
return ratio, low_gpu_memory_mode
|
| 316 |
except Exception as e:
|
| 317 |
+
return 0.9, low_gpu_memory_mode
|
| 318 |
|
| 319 |
def _initialize_5hz_lm_cuda(self, model_path: str) -> str:
|
| 320 |
"""Initialize 5Hz LM model"""
|
| 321 |
+
if not torch.cuda.is_available():
|
| 322 |
+
self.llm_initialized = False
|
| 323 |
+
print("CUDA is not available. Please check your GPU setup.")
|
| 324 |
+
return "❌ CUDA is not available. Please check your GPU setup."
|
| 325 |
try:
|
| 326 |
from nanovllm import LLM, SamplingParams
|
| 327 |
+
except ImportError:
|
| 328 |
+
self.llm_initialized = False
|
| 329 |
+
print("nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install .")
|
| 330 |
+
return "❌ nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install ."
|
| 331 |
+
|
| 332 |
+
try:
|
| 333 |
current_device = torch.cuda.current_device()
|
| 334 |
device_name = torch.cuda.get_device_name(current_device)
|
| 335 |
|
| 336 |
torch.cuda.empty_cache()
|
| 337 |
+
gpu_memory_utilization, low_gpu_memory_mode = self.get_gpu_memory_utilization(
|
| 338 |
minimal_gpu=8,
|
| 339 |
min_ratio=0.2,
|
| 340 |
max_ratio=0.9
|
| 341 |
)
|
| 342 |
+
if low_gpu_memory_mode:
|
| 343 |
+
self.max_model_len = 1024
|
| 344 |
+
else:
|
| 345 |
+
self.max_model_len = 2048
|
| 346 |
|
| 347 |
+
print(f"Initializing 5Hz LM with model: {model_path}, enforce_eager: False, tensor_parallel_size: 1, max_model_len: {self.max_model_len}, gpu_memory_utilization: {gpu_memory_utilization}")
|
| 348 |
+
start_time = time.time()
|
| 349 |
self.llm = LLM(
|
| 350 |
model=model_path,
|
| 351 |
enforce_eager=False,
|
| 352 |
tensor_parallel_size=1,
|
| 353 |
+
max_model_len=self.max_model_len,
|
| 354 |
gpu_memory_utilization=gpu_memory_utilization,
|
| 355 |
)
|
| 356 |
+
print(f"5Hz LM initialized successfully in {time.time() - start_time:.2f} seconds")
|
| 357 |
+
self.llm.tokenizer = self.llm_tokenizer
|
| 358 |
self.llm_initialized = True
|
| 359 |
+
self.llm_backend = "vllm"
|
| 360 |
return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
|
| 361 |
except Exception as e:
|
| 362 |
self.llm_initialized = False
|
| 363 |
error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 364 |
return error_msg
|
| 365 |
+
|
| 366 |
+
def generate_with_5hz_lm_vllm(self, caption: str, lyrics: str, temperature: float = 0.6) -> Tuple[Dict[str, Any], str, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
try:
|
| 368 |
from nanovllm import SamplingParams
|
| 369 |
|
| 370 |
prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
|
| 371 |
|
| 372 |
+
formatted_prompt = self.llm_tokenizer.apply_chat_template(
|
| 373 |
[
|
| 374 |
{"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
|
| 375 |
{"role": "user", "content": prompt}
|
|
|
|
| 377 |
tokenize=False,
|
| 378 |
add_generation_prompt=True,
|
| 379 |
)
|
| 380 |
+
print("[debug] formatted_prompt: ", formatted_prompt)
|
| 381 |
|
| 382 |
+
sampling_params = SamplingParams(max_tokens=self.max_model_len, temperature=temperature)
|
| 383 |
outputs = self.llm.generate([formatted_prompt], sampling_params)
|
|
|
|
| 384 |
if isinstance(outputs, list) and len(outputs) > 0:
|
| 385 |
if hasattr(outputs[0], 'outputs') and len(outputs[0].outputs) > 0:
|
| 386 |
output_text = outputs[0].outputs[0].text
|
|
|
|
| 398 |
except Exception as e:
|
| 399 |
error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 400 |
return {}, "", error_msg
|
| 401 |
+
|
| 402 |
+
def generate_with_5hz_lm_pt(self, caption: str, lyrics: str, temperature: float = 0.6) -> Tuple[Dict[str, Any], str, str]:
|
| 403 |
+
try:
|
| 404 |
+
prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
|
| 405 |
+
|
| 406 |
+
formatted_prompt = self.llm_tokenizer.apply_chat_template(
|
| 407 |
+
[
|
| 408 |
+
{"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
|
| 409 |
+
{"role": "user", "content": prompt}
|
| 410 |
+
],
|
| 411 |
+
tokenize=False,
|
| 412 |
+
add_generation_prompt=True,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# Tokenize the prompt
|
| 416 |
+
inputs = self.llm_tokenizer(
|
| 417 |
+
formatted_prompt,
|
| 418 |
+
return_tensors="pt",
|
| 419 |
+
padding=False,
|
| 420 |
+
truncation=True,
|
| 421 |
+
)
|
| 422 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 423 |
+
|
| 424 |
+
# Generate with the model
|
| 425 |
+
with torch.no_grad():
|
| 426 |
+
# Get max_new_tokens from model config or use a default
|
| 427 |
+
max_new_tokens = getattr(self.llm.config, 'max_new_tokens', 4096)
|
| 428 |
+
if hasattr(self, 'max_model_len'):
|
| 429 |
+
max_new_tokens = min(max_new_tokens, self.max_model_len)
|
| 430 |
+
|
| 431 |
+
outputs = self.llm.generate(
|
| 432 |
+
**inputs,
|
| 433 |
+
max_new_tokens=max_new_tokens,
|
| 434 |
+
temperature=temperature,
|
| 435 |
+
do_sample=True if temperature > 0 else False,
|
| 436 |
+
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# Decode the generated tokens
|
| 440 |
+
# Only decode the newly generated tokens (skip the input prompt)
|
| 441 |
+
generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
|
| 442 |
+
output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
|
| 443 |
+
|
| 444 |
+
metadata, audio_codes = self.parse_lm_output(output_text)
|
| 445 |
+
codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
|
| 446 |
+
return metadata, audio_codes, f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"
|
| 447 |
+
|
| 448 |
+
except Exception as e:
|
| 449 |
+
error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 450 |
+
return {}, "", error_msg
|
| 451 |
+
|
| 452 |
+
def generate_with_5hz_lm(self, caption: str, lyrics: str, temperature: float = 0.6) -> Tuple[Dict[str, Any], str, str]:
|
| 453 |
+
"""Generate metadata and audio codes using 5Hz LM"""
|
| 454 |
+
# Check if 5Hz LM is initialized
|
| 455 |
+
if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
|
| 456 |
+
debug_info = f"llm_initialized={getattr(self, 'llm_initialized', 'not set')}, "
|
| 457 |
+
debug_info += f"has_llm={hasattr(self, 'llm')}, "
|
| 458 |
+
debug_info += f"llm_is_none={getattr(self, 'llm', None) is None}, "
|
| 459 |
+
debug_info += f"llm_backend={getattr(self, 'llm_backend', 'not set')}"
|
| 460 |
+
return {}, "", f"❌ 5Hz LM not initialized. Please initialize it first. Debug: {debug_info}"
|
| 461 |
+
|
| 462 |
+
if not hasattr(self, 'llm') or self.llm is None:
|
| 463 |
+
return {}, "", "❌ 5Hz LM model not loaded. Please initialize it first."
|
| 464 |
+
|
| 465 |
+
if not hasattr(self, 'llm_backend'):
|
| 466 |
+
return {}, "", "❌ 5Hz LM backend not set. Please initialize it first."
|
| 467 |
+
|
| 468 |
+
if self.llm_backend == "vllm":
|
| 469 |
+
return self.generate_with_5hz_lm_vllm(caption, lyrics, temperature)
|
| 470 |
+
else:
|
| 471 |
+
return self.generate_with_5hz_lm_pt(caption, lyrics, temperature)
|
| 472 |
|
| 473 |
def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]:
|
| 474 |
+
"""
|
| 475 |
+
Parse LM output to extract metadata and audio codes.
|
| 476 |
+
|
| 477 |
+
Expected format:
|
| 478 |
+
<think>
|
| 479 |
+
bpm: 73
|
| 480 |
+
duration: 273
|
| 481 |
+
genres: Chinese folk
|
| 482 |
+
keyscale: G major
|
| 483 |
+
timesignature: 4
|
| 484 |
+
</think>
|
| 485 |
+
|
| 486 |
+
<|audio_code_56535|><|audio_code_62918|>...
|
| 487 |
+
|
| 488 |
+
Returns:
|
| 489 |
+
Tuple of (metadata_dict, audio_codes_string)
|
| 490 |
+
"""
|
| 491 |
+
debug_output_text = output_text.split("</think>")[0]
|
| 492 |
+
print(f"Debug output text: {debug_output_text}")
|
| 493 |
metadata = {}
|
| 494 |
audio_codes = ""
|
| 495 |
|
| 496 |
import re
|
| 497 |
|
| 498 |
+
# Extract audio codes - find all <|audio_code_XXX|> patterns
|
| 499 |
code_pattern = r'<\|audio_code_\d+\|>'
|
| 500 |
code_matches = re.findall(code_pattern, output_text)
|
| 501 |
if code_matches:
|
| 502 |
audio_codes = "".join(code_matches)
|
| 503 |
|
| 504 |
+
# Extract metadata from reasoning section
|
| 505 |
+
# Try different reasoning tag patterns
|
| 506 |
reasoning_patterns = [
|
| 507 |
+
r'<think>(.*?)</think>',
|
| 508 |
r'<think>(.*?)</think>',
|
| 509 |
r'<reasoning>(.*?)</reasoning>',
|
| 510 |
]
|
|
|
|
| 516 |
reasoning_text = match.group(1).strip()
|
| 517 |
break
|
| 518 |
|
| 519 |
+
# If no reasoning tags found, try to parse metadata from the beginning of output
|
| 520 |
if not reasoning_text:
|
| 521 |
+
# Look for metadata lines before audio codes
|
| 522 |
lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text
|
| 523 |
reasoning_text = lines_before_codes.strip()
|
| 524 |
|
|
|
|
| 542 |
metadata['duration'] = int(value)
|
| 543 |
except:
|
| 544 |
metadata['duration'] = value
|
| 545 |
+
elif key == 'genres':
|
| 546 |
+
metadata['genres'] = value
|
| 547 |
+
elif key == 'keyscale':
|
| 548 |
+
metadata['keyscale'] = value
|
| 549 |
+
elif key == 'timesignature':
|
| 550 |
+
metadata['timesignature'] = value
|
| 551 |
|
| 552 |
return metadata, audio_codes
|
| 553 |
|
acestep/third_parts/nano-vllm/nanovllm/config.py
CHANGED
|
@@ -1,8 +1,35 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from transformers import AutoConfig
|
| 4 |
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
@dataclass
|
| 7 |
class Config:
|
| 8 |
model: str
|
|
@@ -13,9 +40,10 @@ class Config:
|
|
| 13 |
tensor_parallel_size: int = 1
|
| 14 |
enforce_eager: bool = False
|
| 15 |
hf_config: AutoConfig | None = None
|
| 16 |
-
eos: int =
|
| 17 |
kvcache_block_size: int = 256
|
| 18 |
num_kvcache_blocks: int = -1
|
|
|
|
| 19 |
|
| 20 |
def __post_init__(self):
|
| 21 |
assert os.path.isdir(self.model)
|
|
@@ -24,3 +52,6 @@ class Config:
|
|
| 24 |
self.hf_config = AutoConfig.from_pretrained(self.model)
|
| 25 |
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
| 26 |
assert self.max_num_batched_tokens >= self.max_model_len
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import socket
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from transformers import AutoConfig
|
| 5 |
|
| 6 |
|
| 7 |
+
def find_available_port(start_port: int = 2333, max_attempts: int = 100) -> int:
|
| 8 |
+
"""Find an available port starting from start_port.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
start_port: The starting port number to check
|
| 12 |
+
max_attempts: Maximum number of ports to try
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
An available port number
|
| 16 |
+
|
| 17 |
+
Raises:
|
| 18 |
+
RuntimeError: If no available port is found within max_attempts
|
| 19 |
+
"""
|
| 20 |
+
for i in range(max_attempts):
|
| 21 |
+
port = start_port + i
|
| 22 |
+
try:
|
| 23 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 24 |
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| 25 |
+
s.bind(('localhost', port))
|
| 26 |
+
return port
|
| 27 |
+
except OSError:
|
| 28 |
+
# Port is in use, try next one
|
| 29 |
+
continue
|
| 30 |
+
raise RuntimeError(f"Could not find an available port starting from {start_port} after {max_attempts} attempts")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
@dataclass
|
| 34 |
class Config:
|
| 35 |
model: str
|
|
|
|
| 40 |
tensor_parallel_size: int = 1
|
| 41 |
enforce_eager: bool = False
|
| 42 |
hf_config: AutoConfig | None = None
|
| 43 |
+
eos: int = 151643
|
| 44 |
kvcache_block_size: int = 256
|
| 45 |
num_kvcache_blocks: int = -1
|
| 46 |
+
dist_port: int | None = None
|
| 47 |
|
| 48 |
def __post_init__(self):
|
| 49 |
assert os.path.isdir(self.model)
|
|
|
|
| 52 |
self.hf_config = AutoConfig.from_pretrained(self.model)
|
| 53 |
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
| 54 |
assert self.max_num_batched_tokens >= self.max_model_len
|
| 55 |
+
# Auto-find available port if not specified
|
| 56 |
+
if self.dist_port is None:
|
| 57 |
+
self.dist_port = find_available_port()
|
acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py
CHANGED
|
@@ -21,6 +21,28 @@ class LLMEngine:
|
|
| 21 |
self.ps = []
|
| 22 |
self.events = []
|
| 23 |
ctx = mp.get_context("spawn")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
for i in range(1, config.tensor_parallel_size):
|
| 25 |
event = ctx.Event()
|
| 26 |
process = ctx.Process(target=ModelRunner, args=(config, i, event))
|
|
@@ -28,8 +50,7 @@ class LLMEngine:
|
|
| 28 |
self.ps.append(process)
|
| 29 |
self.events.append(event)
|
| 30 |
self.model_runner = ModelRunner(config, 0, self.events)
|
| 31 |
-
self.tokenizer =
|
| 32 |
-
config.eos = self.tokenizer.eos_token_id
|
| 33 |
self.scheduler = Scheduler(config)
|
| 34 |
atexit.register(self.exit)
|
| 35 |
|
|
|
|
| 21 |
self.ps = []
|
| 22 |
self.events = []
|
| 23 |
ctx = mp.get_context("spawn")
|
| 24 |
+
|
| 25 |
+
# Pre-validate port availability by attempting to bind to it
|
| 26 |
+
# This helps avoid race conditions when multiple LLMEngine instances start simultaneously
|
| 27 |
+
import socket
|
| 28 |
+
from nanovllm.config import find_available_port
|
| 29 |
+
max_port_retries = 10
|
| 30 |
+
for port_attempt in range(max_port_retries):
|
| 31 |
+
try:
|
| 32 |
+
# Test if port is actually available by binding to it
|
| 33 |
+
test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 34 |
+
test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| 35 |
+
test_socket.bind(('localhost', config.dist_port))
|
| 36 |
+
test_socket.close()
|
| 37 |
+
# Port is available, break
|
| 38 |
+
break
|
| 39 |
+
except OSError:
|
| 40 |
+
# Port is in use, find next available
|
| 41 |
+
if port_attempt < max_port_retries - 1:
|
| 42 |
+
config.dist_port = find_available_port(start_port=config.dist_port + 1, max_attempts=10)
|
| 43 |
+
else:
|
| 44 |
+
raise RuntimeError(f"Failed to find available port after {max_port_retries} attempts")
|
| 45 |
+
|
| 46 |
for i in range(1, config.tensor_parallel_size):
|
| 47 |
event = ctx.Event()
|
| 48 |
process = ctx.Process(target=ModelRunner, args=(config, i, event))
|
|
|
|
| 50 |
self.ps.append(process)
|
| 51 |
self.events.append(event)
|
| 52 |
self.model_runner = ModelRunner(config, 0, self.events)
|
| 53 |
+
self.tokenizer = None
|
|
|
|
| 54 |
self.scheduler = Scheduler(config)
|
| 55 |
atexit.register(self.exit)
|
| 56 |
|
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
import pickle
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.distributed as dist
|
| 4 |
from multiprocessing.synchronize import Event
|
| 5 |
from multiprocessing.shared_memory import SharedMemory
|
| 6 |
|
| 7 |
-
from nanovllm.config import Config
|
| 8 |
from nanovllm.engine.sequence import Sequence
|
| 9 |
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
| 10 |
from nanovllm.layers.sampler import Sampler
|
|
@@ -23,7 +24,32 @@ class ModelRunner:
|
|
| 23 |
self.rank = rank
|
| 24 |
self.event = event
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
torch.cuda.set_device(rank)
|
| 28 |
default_dtype = torch.get_default_dtype()
|
| 29 |
torch.set_default_dtype(hf_config.torch_dtype)
|
|
@@ -118,9 +144,15 @@ class ModelRunner:
|
|
| 118 |
layer_id += 1
|
| 119 |
|
| 120 |
def prepare_block_tables(self, seqs: list[Sequence]):
|
| 121 |
-
max_len = max(len(seq.block_table) for seq in seqs)
|
|
|
|
|
|
|
|
|
|
| 122 |
block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
|
| 123 |
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
|
|
|
|
|
|
|
|
|
| 124 |
return block_tables
|
| 125 |
|
| 126 |
def prepare_prefill(self, seqs: list[Sequence]):
|
|
@@ -215,7 +247,29 @@ class ModelRunner:
|
|
| 215 |
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
| 216 |
graph_vars["context_lens"].zero_()
|
| 217 |
graph_vars["context_lens"][:bs] = context.context_lens
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
graph.replay()
|
| 220 |
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
| 221 |
|
|
|
|
| 1 |
import pickle
|
| 2 |
+
import socket
|
| 3 |
import torch
|
| 4 |
import torch.distributed as dist
|
| 5 |
from multiprocessing.synchronize import Event
|
| 6 |
from multiprocessing.shared_memory import SharedMemory
|
| 7 |
|
| 8 |
+
from nanovllm.config import Config, find_available_port
|
| 9 |
from nanovllm.engine.sequence import Sequence
|
| 10 |
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
| 11 |
from nanovllm.layers.sampler import Sampler
|
|
|
|
| 24 |
self.rank = rank
|
| 25 |
self.event = event
|
| 26 |
|
| 27 |
+
# Try to initialize process group with retry logic for port conflicts
|
| 28 |
+
# Only rank 0 binds to the port, so only rank 0 needs retry logic
|
| 29 |
+
dist_port = self.config.dist_port
|
| 30 |
+
max_retries = 10
|
| 31 |
+
for attempt in range(max_retries):
|
| 32 |
+
try:
|
| 33 |
+
dist.init_process_group("nccl", f"tcp://localhost:{dist_port}", world_size=self.world_size, rank=rank)
|
| 34 |
+
break
|
| 35 |
+
except RuntimeError as e:
|
| 36 |
+
if ("EADDRINUSE" in str(e) or "address already in use" in str(e).lower()) and rank == 0:
|
| 37 |
+
# Port is in use, try next port (only for rank 0)
|
| 38 |
+
if attempt < max_retries - 1:
|
| 39 |
+
# Find next available port
|
| 40 |
+
dist_port = find_available_port(start_port=dist_port + 1, max_attempts=10)
|
| 41 |
+
self.config.dist_port = dist_port
|
| 42 |
+
# If we had a previous failed attempt, destroy any partial process group
|
| 43 |
+
if dist.is_initialized():
|
| 44 |
+
try:
|
| 45 |
+
dist.destroy_process_group()
|
| 46 |
+
except:
|
| 47 |
+
pass
|
| 48 |
+
else:
|
| 49 |
+
raise RuntimeError(f"Failed to find available port after {max_retries} attempts. Last error: {e}")
|
| 50 |
+
else:
|
| 51 |
+
# Other error or non-rank-0 process, re-raise
|
| 52 |
+
raise
|
| 53 |
torch.cuda.set_device(rank)
|
| 54 |
default_dtype = torch.get_default_dtype()
|
| 55 |
torch.set_default_dtype(hf_config.torch_dtype)
|
|
|
|
| 144 |
layer_id += 1
|
| 145 |
|
| 146 |
def prepare_block_tables(self, seqs: list[Sequence]):
|
| 147 |
+
max_len = max(len(seq.block_table) for seq in seqs) if seqs else 0
|
| 148 |
+
if max_len == 0:
|
| 149 |
+
# Return empty 2D tensor with correct shape
|
| 150 |
+
return torch.zeros((len(seqs), 0), dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 151 |
block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
|
| 152 |
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 153 |
+
# Ensure it's 2D: if only one sequence, shape should be [1, max_len]
|
| 154 |
+
if block_tables.dim() == 1:
|
| 155 |
+
block_tables = block_tables.unsqueeze(0)
|
| 156 |
return block_tables
|
| 157 |
|
| 158 |
def prepare_prefill(self, seqs: list[Sequence]):
|
|
|
|
| 247 |
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
| 248 |
graph_vars["context_lens"].zero_()
|
| 249 |
graph_vars["context_lens"][:bs] = context.context_lens
|
| 250 |
+
# Handle block_tables: ensure it's 2D and size matches
|
| 251 |
+
if context.block_tables is not None and context.block_tables.numel() > 0:
|
| 252 |
+
# Ensure block_tables is 2D
|
| 253 |
+
if context.block_tables.dim() == 1:
|
| 254 |
+
# Reshape 1D to 2D: [num_blocks] -> [1, num_blocks]
|
| 255 |
+
block_tables_2d = context.block_tables.unsqueeze(0)
|
| 256 |
+
else:
|
| 257 |
+
block_tables_2d = context.block_tables
|
| 258 |
+
|
| 259 |
+
# Get dimensions
|
| 260 |
+
context_bs = block_tables_2d.size(0)
|
| 261 |
+
context_num_blocks = block_tables_2d.size(1)
|
| 262 |
+
graph_num_blocks = graph_vars["block_tables"].size(1)
|
| 263 |
+
|
| 264 |
+
# Use minimum to avoid size mismatch
|
| 265 |
+
num_blocks_to_copy = min(context_num_blocks, graph_num_blocks)
|
| 266 |
+
actual_bs = min(bs, context_bs)
|
| 267 |
+
|
| 268 |
+
# Copy block_tables with size matching
|
| 269 |
+
graph_vars["block_tables"][:actual_bs, :num_blocks_to_copy] = block_tables_2d[:actual_bs, :num_blocks_to_copy]
|
| 270 |
+
# Fill remaining with -1 if needed
|
| 271 |
+
if num_blocks_to_copy < graph_num_blocks:
|
| 272 |
+
graph_vars["block_tables"][:actual_bs, num_blocks_to_copy:] = -1
|
| 273 |
graph.replay()
|
| 274 |
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
| 275 |
|
requirements.txt
CHANGED
|
@@ -4,4 +4,5 @@ diffusers
|
|
| 4 |
gradio
|
| 5 |
soundfile
|
| 6 |
loguru
|
| 7 |
-
einops
|
|
|
|
|
|
| 4 |
gradio
|
| 5 |
soundfile
|
| 6 |
loguru
|
| 7 |
+
einops
|
| 8 |
+
accelerator
|