Spaces:
Running
on
A100
Running
on
A100
fix bugs and test profile
Browse files- acestep/audio_utils.py +3 -76
- acestep/gradio_ui/events/__init__.py +3 -2
- acestep/gradio_ui/events/results_handlers.py +345 -381
- acestep/gradio_ui/interfaces/result.py +16 -8
- acestep/handler.py +3 -25
- acestep/inference.py +234 -354
- acestep/llm_inference.py +117 -26
- acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py +35 -23
- acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py +10 -24
- profile_inference.py +621 -162
acestep/audio_utils.py
CHANGED
|
@@ -98,7 +98,6 @@ class AudioSaver:
|
|
| 98 |
channels_first=True,
|
| 99 |
backend='ffmpeg',
|
| 100 |
compression=config,
|
| 101 |
-
buffer_size=65536
|
| 102 |
)
|
| 103 |
elif format in ["flac", "wav"]:
|
| 104 |
# FLAC and WAV use soundfile backend (fastest)
|
|
@@ -107,8 +106,7 @@ class AudioSaver:
|
|
| 107 |
audio_tensor,
|
| 108 |
sample_rate,
|
| 109 |
channels_first=True,
|
| 110 |
-
backend='
|
| 111 |
-
buffer_size=65536
|
| 112 |
)
|
| 113 |
else:
|
| 114 |
# Other formats use default backend
|
|
@@ -117,7 +115,6 @@ class AudioSaver:
|
|
| 117 |
audio_tensor,
|
| 118 |
sample_rate,
|
| 119 |
channels_first=True,
|
| 120 |
-
buffer_size=65536
|
| 121 |
)
|
| 122 |
|
| 123 |
logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
|
|
@@ -247,87 +244,17 @@ def get_audio_file_hash(audio_file) -> str:
|
|
| 247 |
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
|
| 248 |
|
| 249 |
|
| 250 |
-
def generate_uuid_from_params(
|
| 251 |
-
captions: str,
|
| 252 |
-
lyrics: str,
|
| 253 |
-
bpm: Optional[int],
|
| 254 |
-
key_scale: str,
|
| 255 |
-
time_signature: str,
|
| 256 |
-
vocal_language: str,
|
| 257 |
-
inference_steps: int,
|
| 258 |
-
guidance_scale: float,
|
| 259 |
-
seed: Union[str, float, int],
|
| 260 |
-
audio_duration: Optional[float],
|
| 261 |
-
audio_code_string: Union[str, List[str]],
|
| 262 |
-
repainting_start: float,
|
| 263 |
-
repainting_end: Optional[float],
|
| 264 |
-
instruction: str,
|
| 265 |
-
audio_cover_strength: float,
|
| 266 |
-
task_type: str,
|
| 267 |
-
use_adg: bool,
|
| 268 |
-
cfg_interval_start: float,
|
| 269 |
-
cfg_interval_end: float,
|
| 270 |
-
audio_format: str,
|
| 271 |
-
reference_audio=None,
|
| 272 |
-
src_audio=None,
|
| 273 |
-
batch_index: int = 0,
|
| 274 |
-
) -> str:
|
| 275 |
"""
|
| 276 |
Generate deterministic UUID from generation parameters.
|
| 277 |
Same parameters will always generate the same UUID.
|
| 278 |
|
| 279 |
Args:
|
| 280 |
-
|
| 281 |
-
lyrics: Lyrics text
|
| 282 |
-
bpm: BPM value
|
| 283 |
-
key_scale: Musical key and scale
|
| 284 |
-
time_signature: Time signature
|
| 285 |
-
vocal_language: Vocal language code
|
| 286 |
-
inference_steps: Number of inference steps
|
| 287 |
-
guidance_scale: Guidance scale
|
| 288 |
-
seed: Random seed
|
| 289 |
-
audio_duration: Audio duration in seconds
|
| 290 |
-
audio_code_string: Audio code string or list
|
| 291 |
-
repainting_start: Repainting start time
|
| 292 |
-
repainting_end: Repainting end time
|
| 293 |
-
instruction: Task instruction
|
| 294 |
-
audio_cover_strength: Audio cover strength
|
| 295 |
-
task_type: Task type
|
| 296 |
-
use_adg: Whether to use ADG
|
| 297 |
-
cfg_interval_start: CFG interval start
|
| 298 |
-
cfg_interval_end: CFG interval end
|
| 299 |
-
audio_format: Audio format
|
| 300 |
-
reference_audio: Reference audio file path
|
| 301 |
-
src_audio: Source audio file path
|
| 302 |
-
batch_index: Index in batch (for audio_code_string list access)
|
| 303 |
|
| 304 |
Returns:
|
| 305 |
UUID string
|
| 306 |
"""
|
| 307 |
-
params_dict = {
|
| 308 |
-
"captions": captions or "",
|
| 309 |
-
"lyrics": lyrics or "",
|
| 310 |
-
"bpm": bpm,
|
| 311 |
-
"key_scale": key_scale or "",
|
| 312 |
-
"time_signature": time_signature or "",
|
| 313 |
-
"vocal_language": vocal_language or "",
|
| 314 |
-
"inference_steps": inference_steps,
|
| 315 |
-
"guidance_scale": guidance_scale,
|
| 316 |
-
"seed": seed,
|
| 317 |
-
"audio_duration": audio_duration,
|
| 318 |
-
"audio_code_string": audio_code_string if isinstance(audio_code_string, str) else (audio_code_string[batch_index] if isinstance(audio_code_string, list) and batch_index < len(audio_code_string) else ""),
|
| 319 |
-
"repainting_start": repainting_start,
|
| 320 |
-
"repainting_end": repainting_end,
|
| 321 |
-
"instruction": instruction or "",
|
| 322 |
-
"audio_cover_strength": audio_cover_strength,
|
| 323 |
-
"task_type": task_type or "",
|
| 324 |
-
"use_adg": use_adg,
|
| 325 |
-
"cfg_interval_start": cfg_interval_start,
|
| 326 |
-
"cfg_interval_end": cfg_interval_end,
|
| 327 |
-
"audio_format": audio_format or "",
|
| 328 |
-
"reference_audio_hash": get_audio_file_hash(reference_audio),
|
| 329 |
-
"src_audio_hash": get_audio_file_hash(src_audio),
|
| 330 |
-
}
|
| 331 |
|
| 332 |
params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
|
| 333 |
hash_obj = hashlib.sha256(params_json.encode('utf-8'))
|
|
|
|
| 98 |
channels_first=True,
|
| 99 |
backend='ffmpeg',
|
| 100 |
compression=config,
|
|
|
|
| 101 |
)
|
| 102 |
elif format in ["flac", "wav"]:
|
| 103 |
# FLAC and WAV use soundfile backend (fastest)
|
|
|
|
| 106 |
audio_tensor,
|
| 107 |
sample_rate,
|
| 108 |
channels_first=True,
|
| 109 |
+
backend='ffmpeg',
|
|
|
|
| 110 |
)
|
| 111 |
else:
|
| 112 |
# Other formats use default backend
|
|
|
|
| 115 |
audio_tensor,
|
| 116 |
sample_rate,
|
| 117 |
channels_first=True,
|
|
|
|
| 118 |
)
|
| 119 |
|
| 120 |
logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
|
|
|
|
| 244 |
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
|
| 245 |
|
| 246 |
|
| 247 |
+
def generate_uuid_from_params(params_dict) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
"""
|
| 249 |
Generate deterministic UUID from generation parameters.
|
| 250 |
Same parameters will always generate the same UUID.
|
| 251 |
|
| 252 |
Args:
|
| 253 |
+
params_dict: Dictionary of parameters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
Returns:
|
| 256 |
UUID string
|
| 257 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
|
| 260 |
hash_obj = hashlib.sha256(params_json.encode('utf-8'))
|
acestep/gradio_ui/events/__init__.py
CHANGED
|
@@ -331,10 +331,11 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 331 |
],
|
| 332 |
outputs=[results_section[f"score_display_{btn_idx}"], results_section["batch_queue"]]
|
| 333 |
)
|
| 334 |
-
|
|
|
|
| 335 |
# ========== Generation Handler ==========
|
| 336 |
generation_section["generate_btn"].click(
|
| 337 |
-
fn=
|
| 338 |
inputs=[
|
| 339 |
generation_section["captions"],
|
| 340 |
generation_section["lyrics"],
|
|
|
|
| 331 |
],
|
| 332 |
outputs=[results_section[f"score_display_{btn_idx}"], results_section["batch_queue"]]
|
| 333 |
)
|
| 334 |
+
def generation_wrapper(*args):
|
| 335 |
+
yield from res_h.generate_with_batch_management(dit_handler, llm_handler, *args)
|
| 336 |
# ========== Generation Handler ==========
|
| 337 |
generation_section["generate_btn"].click(
|
| 338 |
+
fn=generation_wrapper,
|
| 339 |
inputs=[
|
| 340 |
generation_section["captions"],
|
| 341 |
generation_section["lyrics"],
|
acestep/gradio_ui/events/results_handlers.py
CHANGED
|
@@ -10,9 +10,123 @@ import tempfile
|
|
| 10 |
import shutil
|
| 11 |
import zipfile
|
| 12 |
import time as time_module
|
|
|
|
| 13 |
import gradio as gr
|
| 14 |
from loguru import logger
|
| 15 |
from acestep.gradio_ui.i18n import t
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def store_batch_in_queue(
|
|
@@ -254,383 +368,205 @@ def generate_with_progress(
|
|
| 254 |
auto_score,
|
| 255 |
score_scale,
|
| 256 |
lm_batch_chunk_size,
|
| 257 |
-
progress=gr.Progress(track_tqdm=True)
|
| 258 |
):
|
| 259 |
"""Generate audio with progress tracking"""
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
think_checkbox
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
)
|
| 274 |
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
for chunk_idx in range(num_chunks):
|
| 327 |
-
chunk_start = chunk_idx * max_inference_batch_size
|
| 328 |
-
chunk_end = min(chunk_start + max_inference_batch_size, batch_size_input)
|
| 329 |
-
chunk_size = chunk_end - chunk_start
|
| 330 |
-
chunk_seeds = actual_seed_list[chunk_start:chunk_end]
|
| 331 |
-
|
| 332 |
-
logger.info(f"Generating LM batch chunk {chunk_idx+1}/{num_chunks} (size: {chunk_size}, seeds: {chunk_seeds})...")
|
| 333 |
-
|
| 334 |
-
# Generate batch
|
| 335 |
-
metadata_list, audio_codes_list, status = llm_handler.generate_with_stop_condition(
|
| 336 |
-
caption=captions or "",
|
| 337 |
-
lyrics=lyrics or "",
|
| 338 |
-
infer_type="llm_dit",
|
| 339 |
-
temperature=lm_temperature,
|
| 340 |
-
cfg_scale=lm_cfg_scale,
|
| 341 |
-
negative_prompt=lm_negative_prompt,
|
| 342 |
-
top_k=top_k_value,
|
| 343 |
-
top_p=top_p_value,
|
| 344 |
-
user_metadata=user_metadata_to_pass,
|
| 345 |
-
use_cot_caption=use_cot_caption,
|
| 346 |
-
use_cot_language=use_cot_language,
|
| 347 |
-
is_format_caption=is_format_caption,
|
| 348 |
-
constrained_decoding_debug=constrained_decoding_debug,
|
| 349 |
-
batch_size=chunk_size,
|
| 350 |
-
seeds=chunk_seeds,
|
| 351 |
-
)
|
| 352 |
-
|
| 353 |
-
all_metadata_list.extend(metadata_list)
|
| 354 |
-
all_audio_codes_list.extend(audio_codes_list)
|
| 355 |
-
|
| 356 |
-
# Use first metadata as representative (all are same)
|
| 357 |
-
lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
|
| 358 |
-
|
| 359 |
-
# Store audio codes list for later use
|
| 360 |
-
lm_generated_audio_codes_list = all_audio_codes_list
|
| 361 |
|
| 362 |
-
|
| 363 |
-
|
| 364 |
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
if not key_scale and lm_generated_metadata.get('keyscale'):
|
| 375 |
-
key_scale_value = lm_generated_metadata.get('keyscale', lm_generated_metadata.get('key_scale', ""))
|
| 376 |
-
if key_scale_value != "N/A":
|
| 377 |
-
key_scale = key_scale_value
|
| 378 |
-
if not time_signature and lm_generated_metadata.get('timesignature'):
|
| 379 |
-
time_signature_value = lm_generated_metadata.get('timesignature', lm_generated_metadata.get('time_signature', ""))
|
| 380 |
-
if time_signature_value != "N/A":
|
| 381 |
-
time_signature = time_signature_value
|
| 382 |
-
if audio_duration is None or audio_duration <= 0:
|
| 383 |
-
audio_duration_value = lm_generated_metadata.get('duration', -1)
|
| 384 |
-
if audio_duration_value != "N/A" and audio_duration_value != "":
|
| 385 |
-
try:
|
| 386 |
-
audio_duration = float(audio_duration_value)
|
| 387 |
-
except:
|
| 388 |
-
pass
|
| 389 |
-
else:
|
| 390 |
-
# SEQUENTIAL LM GENERATION (current behavior, when allow_lm_batch is False)
|
| 391 |
-
# Phase 1: Generate CoT metadata
|
| 392 |
-
phase1_start = time_module.time()
|
| 393 |
-
metadata, _, status = llm_handler.generate_with_stop_condition(
|
| 394 |
-
caption=captions or "",
|
| 395 |
-
lyrics=lyrics or "",
|
| 396 |
-
infer_type="dit", # Only generate metadata in Phase 1
|
| 397 |
-
temperature=lm_temperature,
|
| 398 |
-
cfg_scale=lm_cfg_scale,
|
| 399 |
-
negative_prompt=lm_negative_prompt,
|
| 400 |
-
top_k=top_k_value,
|
| 401 |
-
top_p=top_p_value,
|
| 402 |
-
user_metadata=user_metadata_to_pass,
|
| 403 |
-
use_cot_caption=use_cot_caption,
|
| 404 |
-
use_cot_language=use_cot_language,
|
| 405 |
-
is_format_caption=is_format_caption,
|
| 406 |
-
constrained_decoding_debug=constrained_decoding_debug,
|
| 407 |
-
)
|
| 408 |
-
lm_phase1_time = time_module.time() - phase1_start
|
| 409 |
-
logger.info(f"LM Phase 1 (CoT) completed in {lm_phase1_time:.2f}s")
|
| 410 |
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
)
|
| 428 |
-
lm_phase2_time = time_module.time() - phase2_start
|
| 429 |
-
logger.info(f"LM Phase 2 (Codes) completed in {lm_phase2_time:.2f}s")
|
| 430 |
-
|
| 431 |
-
# Store LM-generated metadata and audio codes for display
|
| 432 |
-
lm_generated_metadata = metadata
|
| 433 |
-
if audio_codes:
|
| 434 |
-
audio_code_string_to_use = audio_codes
|
| 435 |
-
lm_generated_audio_codes = audio_codes
|
| 436 |
-
# Update metadata fields only if they are empty/None (user didn't provide them)
|
| 437 |
-
if bpm is None and metadata.get('bpm'):
|
| 438 |
-
bpm_value = metadata.get('bpm')
|
| 439 |
-
if bpm_value != "N/A" and bpm_value != "":
|
| 440 |
-
try:
|
| 441 |
-
bpm = int(bpm_value)
|
| 442 |
-
except:
|
| 443 |
-
pass
|
| 444 |
-
if not key_scale and metadata.get('keyscale'):
|
| 445 |
-
key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
|
| 446 |
-
if key_scale_value != "N/A":
|
| 447 |
-
key_scale = key_scale_value
|
| 448 |
-
if not time_signature and metadata.get('timesignature'):
|
| 449 |
-
time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
|
| 450 |
-
if time_signature_value != "N/A":
|
| 451 |
-
time_signature = time_signature_value
|
| 452 |
-
if audio_duration is None or audio_duration <= 0:
|
| 453 |
-
audio_duration_value = metadata.get('duration', -1)
|
| 454 |
-
if audio_duration_value != "N/A" and audio_duration_value != "":
|
| 455 |
-
try:
|
| 456 |
-
audio_duration = float(audio_duration_value)
|
| 457 |
-
except:
|
| 458 |
-
pass
|
| 459 |
-
|
| 460 |
-
# Call generate_music and get results
|
| 461 |
-
result = dit_handler.generate_music(
|
| 462 |
-
captions=captions, lyrics=lyrics, bpm=bpm, key_scale=key_scale,
|
| 463 |
-
time_signature=time_signature, vocal_language=vocal_language,
|
| 464 |
-
inference_steps=inference_steps, guidance_scale=guidance_scale,
|
| 465 |
-
use_random_seed=random_seed_checkbox, seed=seed,
|
| 466 |
-
reference_audio=reference_audio, audio_duration=audio_duration,
|
| 467 |
-
batch_size=batch_size_input, src_audio=src_audio,
|
| 468 |
-
audio_code_string=audio_code_string_to_use,
|
| 469 |
-
repainting_start=repainting_start, repainting_end=repainting_end,
|
| 470 |
-
instruction=instruction_display_gen, audio_cover_strength=audio_cover_strength,
|
| 471 |
-
task_type=task_type, use_adg=use_adg,
|
| 472 |
-
cfg_interval_start=cfg_interval_start, cfg_interval_end=cfg_interval_end,
|
| 473 |
-
audio_format=audio_format, lm_temperature=lm_temperature,
|
| 474 |
-
progress=progress
|
| 475 |
-
)
|
| 476 |
-
|
| 477 |
-
# Extract results from new dict structure
|
| 478 |
-
if not isinstance(result, dict):
|
| 479 |
-
# Fallback for old tuple format (should not happen)
|
| 480 |
-
first_audio, second_audio, all_audio_paths, generation_info, status_message, seed_value_for_ui, \
|
| 481 |
-
align_score_1, align_text_1, align_plot_1, align_score_2, align_text_2, align_plot_2 = result
|
| 482 |
-
else:
|
| 483 |
-
audios = result.get("audios", [])
|
| 484 |
-
all_audio_paths = [audio.get("path") for audio in audios]
|
| 485 |
-
first_audio = all_audio_paths[0] if len(all_audio_paths) > 0 else None
|
| 486 |
-
second_audio = all_audio_paths[1] if len(all_audio_paths) > 1 else None
|
| 487 |
-
generation_info = result.get("generation_info", "")
|
| 488 |
-
status_message = result.get("status_message", "")
|
| 489 |
-
seed_value_for_ui = result.get("extra_outputs", {}).get("seed_value", "")
|
| 490 |
-
# Legacy alignment fields (no longer used)
|
| 491 |
-
align_score_1 = ""
|
| 492 |
-
align_text_1 = ""
|
| 493 |
-
align_plot_1 = None
|
| 494 |
-
align_score_2 = ""
|
| 495 |
-
align_text_2 = ""
|
| 496 |
-
align_plot_2 = None
|
| 497 |
-
|
| 498 |
-
# Extract LM timing from status if available and prepend to generation_info
|
| 499 |
-
if status:
|
| 500 |
-
import re
|
| 501 |
-
# Try to extract timing info from status using regex
|
| 502 |
-
# Expected format: "Phase1: X.XXs" and "Phase2: X.XXs"
|
| 503 |
-
phase1_match = re.search(r'Phase1:\s*([\d.]+)s', status)
|
| 504 |
-
phase2_match = re.search(r'Phase2:\s*([\d.]+)s', status)
|
| 505 |
-
|
| 506 |
-
if phase1_match or phase2_match:
|
| 507 |
-
lm_timing_section = "\n\n**š¤ LM Timing:**\n"
|
| 508 |
-
lm_total = 0.0
|
| 509 |
-
if phase1_match:
|
| 510 |
-
phase1_time = float(phase1_match.group(1))
|
| 511 |
-
lm_timing_section += f" - Phase 1 (CoT Metadata): {phase1_time:.2f}s\n"
|
| 512 |
-
lm_total += phase1_time
|
| 513 |
-
if phase2_match:
|
| 514 |
-
phase2_time = float(phase2_match.group(1))
|
| 515 |
-
lm_timing_section += f" - Phase 2 (Audio Codes): {phase2_time:.2f}s\n"
|
| 516 |
-
lm_total += phase2_time
|
| 517 |
-
if lm_total > 0:
|
| 518 |
-
lm_timing_section += f" - Total LM Time: {lm_total:.2f}s\n"
|
| 519 |
-
generation_info = lm_timing_section + "\n" + generation_info
|
| 520 |
-
|
| 521 |
-
# Append LM-generated metadata to generation_info if available
|
| 522 |
-
if lm_generated_metadata:
|
| 523 |
-
metadata_lines = []
|
| 524 |
-
if lm_generated_metadata.get('bpm'):
|
| 525 |
-
metadata_lines.append(f"- **BPM:** {lm_generated_metadata['bpm']}")
|
| 526 |
-
if lm_generated_metadata.get('caption'):
|
| 527 |
-
metadata_lines.append(f"- **User Query Rewritten Caption:** {lm_generated_metadata['caption']}")
|
| 528 |
-
if lm_generated_metadata.get('duration'):
|
| 529 |
-
metadata_lines.append(f"- **Duration:** {lm_generated_metadata['duration']} seconds")
|
| 530 |
-
if lm_generated_metadata.get('keyscale'):
|
| 531 |
-
metadata_lines.append(f"- **KeyScale:** {lm_generated_metadata['keyscale']}")
|
| 532 |
-
if lm_generated_metadata.get('language'):
|
| 533 |
-
metadata_lines.append(f"- **Language:** {lm_generated_metadata['language']}")
|
| 534 |
-
if lm_generated_metadata.get('timesignature'):
|
| 535 |
-
metadata_lines.append(f"- **Time Signature:** {lm_generated_metadata['timesignature']}")
|
| 536 |
-
|
| 537 |
-
if metadata_lines:
|
| 538 |
-
metadata_section = "\n\n**š¤ LM-Generated Metadata:**\n" + "\n\n".join(metadata_lines)
|
| 539 |
-
generation_info = metadata_section + "\n\n" + generation_info
|
| 540 |
-
|
| 541 |
-
# Update audio codes in UI if LM generated them
|
| 542 |
-
codes_outputs = [""] * 8 # Codes for 8 components
|
| 543 |
-
if should_use_lm_batch and lm_generated_audio_codes_list:
|
| 544 |
-
# Batch mode: update individual codes inputs
|
| 545 |
-
for idx in range(min(len(lm_generated_audio_codes_list), 8)):
|
| 546 |
-
codes_outputs[idx] = lm_generated_audio_codes_list[idx]
|
| 547 |
-
# For single codes input, show first one
|
| 548 |
-
updated_audio_codes = lm_generated_audio_codes_list[0] if lm_generated_audio_codes_list else text2music_audio_code_string
|
| 549 |
-
else:
|
| 550 |
-
# Single mode: update main codes input
|
| 551 |
-
updated_audio_codes = lm_generated_audio_codes if lm_generated_audio_codes else text2music_audio_code_string
|
| 552 |
-
|
| 553 |
-
# AUTO-SCORING
|
| 554 |
-
score_displays = [""] * 8 # Scores for 8 components
|
| 555 |
-
if auto_score and all_audio_paths:
|
| 556 |
-
logger.info(f"Auto-scoring enabled, calculating quality scores for {batch_size_input} generated audios...")
|
| 557 |
-
|
| 558 |
-
# Determine which audio codes to use for scoring
|
| 559 |
-
if should_use_lm_batch and lm_generated_audio_codes_list:
|
| 560 |
-
codes_list = lm_generated_audio_codes_list
|
| 561 |
-
elif audio_code_string_to_use and isinstance(audio_code_string_to_use, list):
|
| 562 |
-
codes_list = audio_code_string_to_use
|
| 563 |
else:
|
| 564 |
-
#
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
for idx in range(min(len(all_audio_paths), 8)):
|
| 591 |
-
audio_outputs[idx] = all_audio_paths[idx]
|
| 592 |
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
audio_outputs[3], # generated_audio_4
|
| 598 |
-
audio_outputs[4], # generated_audio_5
|
| 599 |
-
audio_outputs[5], # generated_audio_6
|
| 600 |
-
audio_outputs[6], # generated_audio_7
|
| 601 |
-
audio_outputs[7], # generated_audio_8
|
| 602 |
-
all_audio_paths, # generated_audio_batch
|
| 603 |
generation_info,
|
| 604 |
-
|
| 605 |
seed_value_for_ui,
|
| 606 |
-
align_score_1,
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
score_displays[2], # score_display_3
|
| 615 |
-
score_displays[3], # score_display_4
|
| 616 |
-
score_displays[4], # score_display_5
|
| 617 |
-
score_displays[5], # score_display_6
|
| 618 |
-
score_displays[6], # score_display_7
|
| 619 |
-
score_displays[7], # score_display_8
|
| 620 |
-
updated_audio_codes, # Update main audio codes in UI
|
| 621 |
-
codes_outputs[0], # text2music_audio_code_string_1
|
| 622 |
-
codes_outputs[1], # text2music_audio_code_string_2
|
| 623 |
-
codes_outputs[2], # text2music_audio_code_string_3
|
| 624 |
-
codes_outputs[3], # text2music_audio_code_string_4
|
| 625 |
-
codes_outputs[4], # text2music_audio_code_string_5
|
| 626 |
-
codes_outputs[5], # text2music_audio_code_string_6
|
| 627 |
-
codes_outputs[6], # text2music_audio_code_string_7
|
| 628 |
-
codes_outputs[7], # text2music_audio_code_string_8
|
| 629 |
-
lm_generated_metadata, # Store metadata for "Send to src audio" buttons
|
| 630 |
-
is_format_caption, # Keep is_format_caption unchanged
|
| 631 |
)
|
| 632 |
|
| 633 |
|
|
|
|
| 634 |
def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale):
|
| 635 |
"""
|
| 636 |
Calculate PMI-based quality score for generated audio.
|
|
@@ -773,7 +709,9 @@ def calculate_score_handler_with_selection(llm_handler, sample_idx, score_scale,
|
|
| 773 |
if stored_allow_lm_batch and isinstance(stored_codes, list):
|
| 774 |
# Batch mode: use specific sample's codes
|
| 775 |
if 0 <= sample_idx - 1 < len(stored_codes):
|
| 776 |
-
|
|
|
|
|
|
|
| 777 |
else:
|
| 778 |
# Single mode: all samples use same codes
|
| 779 |
audio_codes_str = stored_codes if isinstance(stored_codes, str) else ""
|
|
@@ -885,7 +823,7 @@ def generate_with_batch_management(
|
|
| 885 |
Wrapper for generate_with_progress that adds batch queue management
|
| 886 |
"""
|
| 887 |
# Call the original generation function
|
| 888 |
-
|
| 889 |
dit_handler, llm_handler,
|
| 890 |
captions, lyrics, bpm, key_scale, time_signature, vocal_language,
|
| 891 |
inference_steps, guidance_scale, random_seed_checkbox, seed,
|
|
@@ -902,23 +840,41 @@ def generate_with_batch_management(
|
|
| 902 |
lm_batch_chunk_size,
|
| 903 |
progress
|
| 904 |
)
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 908 |
generation_info = result[9]
|
| 909 |
seed_value_for_ui = result[11]
|
| 910 |
-
lm_generated_metadata = result[
|
| 911 |
|
| 912 |
# Extract codes
|
| 913 |
generated_codes_single = result[26]
|
| 914 |
generated_codes_batch = [result[27], result[28], result[29], result[30], result[31], result[32], result[33], result[34]]
|
| 915 |
-
|
| 916 |
# Determine which codes to store based on mode
|
| 917 |
if allow_lm_batch and batch_size_input >= 2:
|
| 918 |
codes_to_store = generated_codes_batch[:int(batch_size_input)]
|
| 919 |
else:
|
| 920 |
codes_to_store = generated_codes_single
|
| 921 |
-
|
| 922 |
# Save parameters for history
|
| 923 |
saved_params = {
|
| 924 |
"captions": captions,
|
|
@@ -964,6 +920,7 @@ def generate_with_batch_management(
|
|
| 964 |
}
|
| 965 |
|
| 966 |
# Next batch parameters (with cleared codes & random seed)
|
|
|
|
| 967 |
next_params = saved_params.copy()
|
| 968 |
next_params["text2music_audio_code_string"] = ""
|
| 969 |
next_params["random_seed_checkbox"] = True
|
|
@@ -996,9 +953,10 @@ def generate_with_batch_management(
|
|
| 996 |
next_batch_status_text = ""
|
| 997 |
if autogen_checkbox:
|
| 998 |
next_batch_status_text = t("messages.autogen_enabled")
|
| 999 |
-
|
| 1000 |
-
#
|
| 1001 |
-
|
|
|
|
| 1002 |
current_batch_index,
|
| 1003 |
total_batches,
|
| 1004 |
batch_queue,
|
|
@@ -1114,7 +1072,8 @@ def generate_next_batch_background(
|
|
| 1114 |
params.setdefault("complete_track_classes", [])
|
| 1115 |
|
| 1116 |
# Call generate_with_progress with the saved parameters
|
| 1117 |
-
|
|
|
|
| 1118 |
dit_handler,
|
| 1119 |
llm_handler,
|
| 1120 |
captions=params.get("captions"),
|
|
@@ -1159,15 +1118,20 @@ def generate_next_batch_background(
|
|
| 1159 |
progress=progress
|
| 1160 |
)
|
| 1161 |
|
| 1162 |
-
#
|
| 1163 |
-
|
| 1164 |
-
|
| 1165 |
-
|
| 1166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1167 |
|
| 1168 |
# Extract codes
|
| 1169 |
-
generated_codes_single =
|
| 1170 |
-
generated_codes_batch = [
|
| 1171 |
|
| 1172 |
# Determine which codes to store
|
| 1173 |
batch_size = params.get("batch_size_input", 2)
|
|
|
|
| 10 |
import shutil
|
| 11 |
import zipfile
|
| 12 |
import time as time_module
|
| 13 |
+
from typing import Dict, Any, Optional
|
| 14 |
import gradio as gr
|
| 15 |
from loguru import logger
|
| 16 |
from acestep.gradio_ui.i18n import t
|
| 17 |
+
from acestep.inference import generate_music, GenerationParams, GenerationConfig
|
| 18 |
+
from acestep.audio_utils import save_audio
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _build_generation_info(
|
| 22 |
+
lm_metadata: Optional[Dict[str, Any]],
|
| 23 |
+
time_costs: Dict[str, float],
|
| 24 |
+
seed_value: str,
|
| 25 |
+
inference_steps: int,
|
| 26 |
+
num_audios: int,
|
| 27 |
+
) -> str:
|
| 28 |
+
"""Build generation info string from result data.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
lm_metadata: LM-generated metadata dictionary
|
| 32 |
+
time_costs: Unified time costs dictionary
|
| 33 |
+
seed_value: Seed value string
|
| 34 |
+
inference_steps: Number of inference steps
|
| 35 |
+
num_audios: Number of generated audios
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Formatted generation info string
|
| 39 |
+
"""
|
| 40 |
+
info_parts = []
|
| 41 |
+
|
| 42 |
+
# Part 1: LM-generated metadata (if available)
|
| 43 |
+
if lm_metadata:
|
| 44 |
+
metadata_lines = []
|
| 45 |
+
if lm_metadata.get('bpm'):
|
| 46 |
+
metadata_lines.append(f"- **BPM:** {lm_metadata['bpm']}")
|
| 47 |
+
if lm_metadata.get('caption'):
|
| 48 |
+
metadata_lines.append(f"- **Refined Caption:** {lm_metadata['caption']}")
|
| 49 |
+
if lm_metadata.get('lyrics'):
|
| 50 |
+
metadata_lines.append(f"- **Refined Lyrics:** {lm_metadata['lyrics']}")
|
| 51 |
+
if lm_metadata.get('duration'):
|
| 52 |
+
metadata_lines.append(f"- **Duration:** {lm_metadata['duration']} seconds")
|
| 53 |
+
if lm_metadata.get('keyscale'):
|
| 54 |
+
metadata_lines.append(f"- **Key Scale:** {lm_metadata['keyscale']}")
|
| 55 |
+
if lm_metadata.get('language'):
|
| 56 |
+
metadata_lines.append(f"- **Language:** {lm_metadata['language']}")
|
| 57 |
+
if lm_metadata.get('timesignature'):
|
| 58 |
+
metadata_lines.append(f"- **Time Signature:** {lm_metadata['timesignature']}")
|
| 59 |
+
|
| 60 |
+
if metadata_lines:
|
| 61 |
+
metadata_section = "**š¤ LM-Generated Metadata:**\n" + "\n".join(metadata_lines)
|
| 62 |
+
info_parts.append(metadata_section)
|
| 63 |
+
|
| 64 |
+
# Part 2: Time costs (formatted and beautified)
|
| 65 |
+
if time_costs:
|
| 66 |
+
time_lines = []
|
| 67 |
+
|
| 68 |
+
# LM time costs
|
| 69 |
+
lm_phase1 = time_costs.get('lm_phase1_time', 0.0)
|
| 70 |
+
lm_phase2 = time_costs.get('lm_phase2_time', 0.0)
|
| 71 |
+
lm_total = time_costs.get('lm_total_time', 0.0)
|
| 72 |
+
|
| 73 |
+
if lm_total > 0:
|
| 74 |
+
time_lines.append("**š§ LM Time:**")
|
| 75 |
+
if lm_phase1 > 0:
|
| 76 |
+
time_lines.append(f" - Phase 1 (CoT): {lm_phase1:.2f}s")
|
| 77 |
+
if lm_phase2 > 0:
|
| 78 |
+
time_lines.append(f" - Phase 2 (Codes): {lm_phase2:.2f}s")
|
| 79 |
+
time_lines.append(f" - Total: {lm_total:.2f}s")
|
| 80 |
+
|
| 81 |
+
# DiT time costs
|
| 82 |
+
dit_encoder = time_costs.get('dit_encoder_time_cost', 0.0)
|
| 83 |
+
dit_model = time_costs.get('dit_model_time_cost', 0.0)
|
| 84 |
+
dit_vae_decode = time_costs.get('dit_vae_decode_time_cost', 0.0)
|
| 85 |
+
dit_offload = time_costs.get('dit_offload_time_cost', 0.0)
|
| 86 |
+
dit_total = time_costs.get('dit_total_time_cost', 0.0)
|
| 87 |
+
if dit_total > 0:
|
| 88 |
+
time_lines.append("\n**šµ DiT Time:**")
|
| 89 |
+
if dit_encoder > 0:
|
| 90 |
+
time_lines.append(f" - Encoder: {dit_encoder:.2f}s")
|
| 91 |
+
if dit_model > 0:
|
| 92 |
+
time_lines.append(f" - Model: {dit_model:.2f}s")
|
| 93 |
+
if dit_vae_decode > 0:
|
| 94 |
+
time_lines.append(f" - VAE Decode: {dit_vae_decode:.2f}s")
|
| 95 |
+
if dit_offload > 0:
|
| 96 |
+
time_lines.append(f" - Offload: {dit_offload:.2f}s")
|
| 97 |
+
time_lines.append(f" - Total: {dit_total:.2f}s")
|
| 98 |
+
|
| 99 |
+
# Post-processing time costs
|
| 100 |
+
audio_conversion_time = time_costs.get('audio_conversion_time', 0.0)
|
| 101 |
+
auto_score_time = time_costs.get('auto_score_time', 0.0)
|
| 102 |
+
|
| 103 |
+
if audio_conversion_time > 0 or auto_score_time > 0:
|
| 104 |
+
time_lines.append("\n**š§ Post-processing Time:**")
|
| 105 |
+
if audio_conversion_time > 0:
|
| 106 |
+
time_lines.append(f" - Audio Conversion: {audio_conversion_time:.2f}s")
|
| 107 |
+
if auto_score_time > 0:
|
| 108 |
+
time_lines.append(f" - Auto Score: {auto_score_time:.2f}s")
|
| 109 |
+
|
| 110 |
+
# Pipeline total
|
| 111 |
+
pipeline_total = time_costs.get('pipeline_total_time', 0.0)
|
| 112 |
+
if pipeline_total > 0:
|
| 113 |
+
time_lines.append(f"\n**ā±ļø Pipeline Total: {pipeline_total:.2f}s**")
|
| 114 |
+
|
| 115 |
+
if time_lines:
|
| 116 |
+
time_section = "\n".join(time_lines)
|
| 117 |
+
info_parts.append(time_section)
|
| 118 |
+
|
| 119 |
+
# Part 3: Generation summary
|
| 120 |
+
summary_lines = [
|
| 121 |
+
"**šµ Generation Complete**",
|
| 122 |
+
f" - **Seeds:** {seed_value}",
|
| 123 |
+
f" - **Steps:** {inference_steps}",
|
| 124 |
+
f" - **Audio Count:** {num_audios} audio(s)",
|
| 125 |
+
]
|
| 126 |
+
info_parts.append("\n".join(summary_lines))
|
| 127 |
+
|
| 128 |
+
# Combine all parts
|
| 129 |
+
return "\n\n".join(info_parts)
|
| 130 |
|
| 131 |
|
| 132 |
def store_batch_in_queue(
|
|
|
|
| 368 |
auto_score,
|
| 369 |
score_scale,
|
| 370 |
lm_batch_chunk_size,
|
| 371 |
+
progress=gr.Progress(track_tqdm=True),
|
| 372 |
):
|
| 373 |
"""Generate audio with progress tracking"""
|
| 374 |
+
|
| 375 |
+
# step 1: prepare inputs
|
| 376 |
+
# generate_music, GenerationParams, GenerationConfig
|
| 377 |
+
gen_params = GenerationParams(
|
| 378 |
+
task_type=task_type,
|
| 379 |
+
instruction=instruction_display_gen,
|
| 380 |
+
reference_audio=reference_audio,
|
| 381 |
+
src_audio=src_audio,
|
| 382 |
+
audio_codes=text2music_audio_code_string if not think_checkbox else "",
|
| 383 |
+
caption=captions or "",
|
| 384 |
+
lyrics=lyrics or "",
|
| 385 |
+
instrumental=False,
|
| 386 |
+
vocal_language=vocal_language,
|
| 387 |
+
bpm=bpm,
|
| 388 |
+
keyscale=key_scale,
|
| 389 |
+
timesignature=time_signature,
|
| 390 |
+
duration=audio_duration,
|
| 391 |
+
inference_steps=inference_steps,
|
| 392 |
+
guidance_scale=guidance_scale,
|
| 393 |
+
use_adg=use_adg,
|
| 394 |
+
cfg_interval_start=cfg_interval_start,
|
| 395 |
+
cfg_interval_end=cfg_interval_end,
|
| 396 |
+
repainting_start=repainting_start,
|
| 397 |
+
repainting_end=repainting_end,
|
| 398 |
+
audio_cover_strength=audio_cover_strength,
|
| 399 |
+
thinking=think_checkbox,
|
| 400 |
+
lm_temperature=lm_temperature,
|
| 401 |
+
lm_cfg_scale=lm_cfg_scale,
|
| 402 |
+
lm_top_k=lm_top_k,
|
| 403 |
+
lm_top_p=lm_top_p,
|
| 404 |
+
lm_negative_prompt=lm_negative_prompt,
|
| 405 |
+
use_cot_metas=use_cot_metas,
|
| 406 |
+
use_cot_caption=use_cot_caption,
|
| 407 |
+
use_cot_language=use_cot_language,
|
| 408 |
+
use_constrained_decoding=True,
|
| 409 |
+
)
|
| 410 |
+
# seed string to list
|
| 411 |
+
if isinstance(seed, str) and seed.strip():
|
| 412 |
+
if "," in seed:
|
| 413 |
+
seed_list = [int(s.strip()) for s in seed.split(",")]
|
| 414 |
+
else:
|
| 415 |
+
seed_list = [int(seed.strip())]
|
| 416 |
+
else:
|
| 417 |
+
seed_list = None
|
| 418 |
+
gen_config = GenerationConfig(
|
| 419 |
+
batch_size=batch_size_input,
|
| 420 |
+
allow_lm_batch=allow_lm_batch,
|
| 421 |
+
use_random_seed=random_seed_checkbox,
|
| 422 |
+
seeds=seed_list,
|
| 423 |
+
lm_batch_chunk_size=lm_batch_chunk_size,
|
| 424 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 425 |
+
audio_format=audio_format,
|
| 426 |
+
)
|
| 427 |
+
result = generate_music(
|
| 428 |
+
dit_handler,
|
| 429 |
+
llm_handler,
|
| 430 |
+
params=gen_params,
|
| 431 |
+
config=gen_config,
|
| 432 |
+
progress=progress,
|
| 433 |
)
|
| 434 |
|
| 435 |
+
audio_outputs = [None] * 8
|
| 436 |
+
all_audio_paths = []
|
| 437 |
+
final_codes_list = [""] * 8
|
| 438 |
+
final_scores_list = [""] * 8
|
| 439 |
+
|
| 440 |
+
# Build generation_info from result data
|
| 441 |
+
status_message = result.status_message
|
| 442 |
+
seed_value_for_ui = result.extra_outputs.get("seed_value", "")
|
| 443 |
+
lm_generated_metadata = result.extra_outputs.get("lm_metadata", {})
|
| 444 |
+
time_costs = result.extra_outputs.get("time_costs", {}).copy()
|
| 445 |
+
|
| 446 |
+
# Initialize post-processing timing
|
| 447 |
+
audio_conversion_start_time = time_module.time()
|
| 448 |
+
total_auto_score_time = 0.0
|
| 449 |
+
|
| 450 |
+
align_score_1 = ""
|
| 451 |
+
align_text_1 = ""
|
| 452 |
+
align_plot_1 = None
|
| 453 |
+
align_score_2 = ""
|
| 454 |
+
align_text_2 = ""
|
| 455 |
+
align_plot_2 = None
|
| 456 |
+
updated_audio_codes = text2music_audio_code_string if not think_checkbox else ""
|
| 457 |
+
|
| 458 |
+
if not result.success:
|
| 459 |
+
# Build generation_info string for error case
|
| 460 |
+
generation_info = _build_generation_info(
|
| 461 |
+
lm_metadata=lm_generated_metadata,
|
| 462 |
+
time_costs=time_costs,
|
| 463 |
+
seed_value=seed_value_for_ui,
|
| 464 |
+
inference_steps=inference_steps,
|
| 465 |
+
num_audios=0,
|
| 466 |
+
)
|
| 467 |
+
yield (None,) * 8 + (None, generation_info, result.status_message) + (gr.skip(),) * 25
|
| 468 |
+
return
|
| 469 |
+
|
| 470 |
+
audios = result.audios
|
| 471 |
+
progress(0.99, "Converting audio to mp3...")
|
| 472 |
+
for i in range(8):
|
| 473 |
+
if i < len(audios):
|
| 474 |
+
key = audios[i]["key"]
|
| 475 |
+
audio_tensor = audios[i]["tensor"]
|
| 476 |
+
sample_rate = audios[i]["sample_rate"]
|
| 477 |
+
audio_params = audios[i]["params"]
|
| 478 |
+
temp_dir = tempfile.mkdtemp(f"acestep_gradio_results/")
|
| 479 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 480 |
+
json_path = os.path.join(temp_dir, f"{key}.json")
|
| 481 |
+
audio_path = os.path.join(temp_dir, f"{key}.{audio_format}")
|
| 482 |
+
save_audio(audio_data=audio_tensor, output_path=audio_path, sample_rate=sample_rate, format=audio_format, channels_first=True)
|
| 483 |
+
audio_outputs[i] = audio_path
|
| 484 |
+
all_audio_paths.append(audio_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
+
code_str = audio_params.get("audio_codes", "")
|
| 487 |
+
final_codes_list[i] = code_str
|
| 488 |
|
| 489 |
+
scores_ui_updates = [gr.skip()] * 8
|
| 490 |
+
score_str = "Done!"
|
| 491 |
+
if auto_score:
|
| 492 |
+
auto_score_start = time_module.time()
|
| 493 |
+
score_str = calculate_score_handler(llm_handler, code_str, captions, lyrics, lm_generated_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale)
|
| 494 |
+
auto_score_end = time_module.time()
|
| 495 |
+
total_auto_score_time += (auto_score_end - auto_score_start)
|
| 496 |
+
scores_ui_updates[i] = score_str
|
| 497 |
+
final_scores_list[i] = score_str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
|
| 499 |
+
status_message = f"Encoding & Ready: {i+1}/{len(audios)}"
|
| 500 |
+
current_audio_updates = [gr.skip()] * 8
|
| 501 |
+
current_audio_updates[i] = audio_path
|
| 502 |
+
|
| 503 |
+
audio_codes_ui_updates = [gr.skip()] * 8
|
| 504 |
+
audio_codes_ui_updates[i] = code_str
|
| 505 |
+
yield (
|
| 506 |
+
current_audio_updates[0], current_audio_updates[1], current_audio_updates[2], current_audio_updates[3],
|
| 507 |
+
current_audio_updates[4], current_audio_updates[5], current_audio_updates[6], current_audio_updates[7],
|
| 508 |
+
all_audio_paths, # Real-time update of Batch File list
|
| 509 |
+
generation_info,
|
| 510 |
+
status_message,
|
| 511 |
+
seed_value_for_ui,
|
| 512 |
+
# Align plot placeholders (assume no need to update in real time)
|
| 513 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
| 514 |
+
# Scores
|
| 515 |
+
scores_ui_updates[0], scores_ui_updates[1], scores_ui_updates[2], scores_ui_updates[3], scores_ui_updates[4], scores_ui_updates[5], scores_ui_updates[6], scores_ui_updates[7],
|
| 516 |
+
updated_audio_codes,
|
| 517 |
+
# Codes
|
| 518 |
+
audio_codes_ui_updates[0], audio_codes_ui_updates[1], audio_codes_ui_updates[2], audio_codes_ui_updates[3],
|
| 519 |
+
audio_codes_ui_updates[4], audio_codes_ui_updates[5], audio_codes_ui_updates[6], audio_codes_ui_updates[7],
|
| 520 |
+
lm_generated_metadata,
|
| 521 |
+
is_format_caption,
|
| 522 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
else:
|
| 524 |
+
# If i exceeds the generated count (e.g., batch=2, i=2..7), do not yield
|
| 525 |
+
pass
|
| 526 |
+
time_module.sleep(0.1)
|
| 527 |
+
|
| 528 |
+
# Record audio conversion time
|
| 529 |
+
audio_conversion_end_time = time_module.time()
|
| 530 |
+
audio_conversion_time = audio_conversion_end_time - audio_conversion_start_time
|
| 531 |
+
|
| 532 |
+
# Add post-processing times to time_costs
|
| 533 |
+
if audio_conversion_time > 0:
|
| 534 |
+
time_costs['audio_conversion_time'] = audio_conversion_time
|
| 535 |
+
if total_auto_score_time > 0:
|
| 536 |
+
time_costs['auto_score_time'] = total_auto_score_time
|
| 537 |
+
|
| 538 |
+
# Update pipeline total time to include post-processing
|
| 539 |
+
if 'pipeline_total_time' in time_costs:
|
| 540 |
+
time_costs['pipeline_total_time'] += audio_conversion_time + total_auto_score_time
|
| 541 |
+
|
| 542 |
+
# Rebuild generation_info with complete timing information
|
| 543 |
+
generation_info = _build_generation_info(
|
| 544 |
+
lm_metadata=lm_generated_metadata,
|
| 545 |
+
time_costs=time_costs,
|
| 546 |
+
seed_value=seed_value_for_ui,
|
| 547 |
+
inference_steps=inference_steps,
|
| 548 |
+
num_audios=len(result.audios),
|
| 549 |
+
)
|
|
|
|
|
|
|
| 550 |
|
| 551 |
+
yield (
|
| 552 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(), # Audio 1-4: SKIP
|
| 553 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(), # Audio 5-8: SKIP
|
| 554 |
+
all_audio_paths,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
generation_info,
|
| 556 |
+
"Generation Complete",
|
| 557 |
seed_value_for_ui,
|
| 558 |
+
align_score_1, align_text_1, align_plot_1, align_score_2, align_text_2, align_plot_2,
|
| 559 |
+
final_scores_list[0], final_scores_list[1], final_scores_list[2], final_scores_list[3],
|
| 560 |
+
final_scores_list[4], final_scores_list[5], final_scores_list[6], final_scores_list[7],
|
| 561 |
+
updated_audio_codes,
|
| 562 |
+
final_codes_list[0], final_codes_list[1], final_codes_list[2], final_codes_list[3],
|
| 563 |
+
final_codes_list[4], final_codes_list[5], final_codes_list[6], final_codes_list[7],
|
| 564 |
+
lm_generated_metadata,
|
| 565 |
+
is_format_caption,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
)
|
| 567 |
|
| 568 |
|
| 569 |
+
|
| 570 |
def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale):
|
| 571 |
"""
|
| 572 |
Calculate PMI-based quality score for generated audio.
|
|
|
|
| 709 |
if stored_allow_lm_batch and isinstance(stored_codes, list):
|
| 710 |
# Batch mode: use specific sample's codes
|
| 711 |
if 0 <= sample_idx - 1 < len(stored_codes):
|
| 712 |
+
code_item = stored_codes[sample_idx - 1]
|
| 713 |
+
# Ensure it's a string (handle cases where dict was mistakenly stored)
|
| 714 |
+
audio_codes_str = code_item if isinstance(code_item, str) else ""
|
| 715 |
else:
|
| 716 |
# Single mode: all samples use same codes
|
| 717 |
audio_codes_str = stored_codes if isinstance(stored_codes, str) else ""
|
|
|
|
| 823 |
Wrapper for generate_with_progress that adds batch queue management
|
| 824 |
"""
|
| 825 |
# Call the original generation function
|
| 826 |
+
generator = generate_with_progress(
|
| 827 |
dit_handler, llm_handler,
|
| 828 |
captions, lyrics, bpm, key_scale, time_signature, vocal_language,
|
| 829 |
inference_steps, guidance_scale, random_seed_checkbox, seed,
|
|
|
|
| 840 |
lm_batch_chunk_size,
|
| 841 |
progress
|
| 842 |
)
|
| 843 |
+
final_result_from_inner = None
|
| 844 |
+
for partial_result in generator:
|
| 845 |
+
final_result_from_inner = partial_result
|
| 846 |
+
# current_batch_index, total_batches, batch_queue, next_params,
|
| 847 |
+
# batch_indicator_text, prev_btn, next_btn, next_status, restore_btn
|
| 848 |
+
yield partial_result + (
|
| 849 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
| 850 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
| 851 |
+
)
|
| 852 |
+
result = final_result_from_inner
|
| 853 |
+
all_audio_paths = result[8]
|
| 854 |
+
|
| 855 |
+
if all_audio_paths is None:
|
| 856 |
+
|
| 857 |
+
yield result + (
|
| 858 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
| 859 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
| 860 |
+
)
|
| 861 |
+
return
|
| 862 |
+
|
| 863 |
+
# Extract results from generation (ä½æēØ result äøę 访é®)
|
| 864 |
generation_info = result[9]
|
| 865 |
seed_value_for_ui = result[11]
|
| 866 |
+
lm_generated_metadata = result[35] # Fixed: lm_metadata is at index 35, not 34
|
| 867 |
|
| 868 |
# Extract codes
|
| 869 |
generated_codes_single = result[26]
|
| 870 |
generated_codes_batch = [result[27], result[28], result[29], result[30], result[31], result[32], result[33], result[34]]
|
| 871 |
+
|
| 872 |
# Determine which codes to store based on mode
|
| 873 |
if allow_lm_batch and batch_size_input >= 2:
|
| 874 |
codes_to_store = generated_codes_batch[:int(batch_size_input)]
|
| 875 |
else:
|
| 876 |
codes_to_store = generated_codes_single
|
| 877 |
+
|
| 878 |
# Save parameters for history
|
| 879 |
saved_params = {
|
| 880 |
"captions": captions,
|
|
|
|
| 920 |
}
|
| 921 |
|
| 922 |
# Next batch parameters (with cleared codes & random seed)
|
| 923 |
+
# Next batch parameters
|
| 924 |
next_params = saved_params.copy()
|
| 925 |
next_params["text2music_audio_code_string"] = ""
|
| 926 |
next_params["random_seed_checkbox"] = True
|
|
|
|
| 953 |
next_batch_status_text = ""
|
| 954 |
if autogen_checkbox:
|
| 955 |
next_batch_status_text = t("messages.autogen_enabled")
|
| 956 |
+
|
| 957 |
+
# 4. Yield final result (includes Batch UI updates)
|
| 958 |
+
# The result here is already a tuple structure
|
| 959 |
+
yield result + (
|
| 960 |
current_batch_index,
|
| 961 |
total_batches,
|
| 962 |
batch_queue,
|
|
|
|
| 1072 |
params.setdefault("complete_track_classes", [])
|
| 1073 |
|
| 1074 |
# Call generate_with_progress with the saved parameters
|
| 1075 |
+
# Note: generate_with_progress is a generator, need to iterate through it
|
| 1076 |
+
generator = generate_with_progress(
|
| 1077 |
dit_handler,
|
| 1078 |
llm_handler,
|
| 1079 |
captions=params.get("captions"),
|
|
|
|
| 1118 |
progress=progress
|
| 1119 |
)
|
| 1120 |
|
| 1121 |
+
# Consume generator to get final result (similar to generate_with_batch_management)
|
| 1122 |
+
final_result = None
|
| 1123 |
+
for partial_result in generator:
|
| 1124 |
+
final_result = partial_result
|
| 1125 |
+
|
| 1126 |
+
# Extract results from final_result
|
| 1127 |
+
all_audio_paths = final_result[8] # generated_audio_batch
|
| 1128 |
+
generation_info = final_result[9]
|
| 1129 |
+
seed_value_for_ui = final_result[11]
|
| 1130 |
+
lm_generated_metadata = final_result[35] # Fixed: lm_metadata is at index 35, not 34
|
| 1131 |
|
| 1132 |
# Extract codes
|
| 1133 |
+
generated_codes_single = final_result[26]
|
| 1134 |
+
generated_codes_batch = [final_result[27], final_result[28], final_result[29], final_result[30], final_result[31], final_result[32], final_result[33], final_result[34]]
|
| 1135 |
|
| 1136 |
# Determine which codes to store
|
| 1137 |
batch_size = params.get("batch_size_input", 2)
|
acestep/gradio_ui/interfaces/result.py
CHANGED
|
@@ -28,7 +28,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 28 |
generated_audio_1 = gr.Audio(
|
| 29 |
label=t("results.generated_music", n=1),
|
| 30 |
type="filepath",
|
| 31 |
-
interactive=False
|
|
|
|
| 32 |
)
|
| 33 |
with gr.Row(equal_height=True):
|
| 34 |
send_to_src_btn_1 = gr.Button(
|
|
@@ -58,7 +59,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 58 |
generated_audio_2 = gr.Audio(
|
| 59 |
label=t("results.generated_music", n=2),
|
| 60 |
type="filepath",
|
| 61 |
-
interactive=False
|
|
|
|
| 62 |
)
|
| 63 |
with gr.Row(equal_height=True):
|
| 64 |
send_to_src_btn_2 = gr.Button(
|
|
@@ -88,7 +90,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 88 |
generated_audio_3 = gr.Audio(
|
| 89 |
label=t("results.generated_music", n=3),
|
| 90 |
type="filepath",
|
| 91 |
-
interactive=False
|
|
|
|
| 92 |
)
|
| 93 |
with gr.Row(equal_height=True):
|
| 94 |
send_to_src_btn_3 = gr.Button(
|
|
@@ -118,7 +121,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 118 |
generated_audio_4 = gr.Audio(
|
| 119 |
label=t("results.generated_music", n=4),
|
| 120 |
type="filepath",
|
| 121 |
-
interactive=False
|
|
|
|
| 122 |
)
|
| 123 |
with gr.Row(equal_height=True):
|
| 124 |
send_to_src_btn_4 = gr.Button(
|
|
@@ -151,7 +155,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 151 |
generated_audio_5 = gr.Audio(
|
| 152 |
label=t("results.generated_music", n=5),
|
| 153 |
type="filepath",
|
| 154 |
-
interactive=False
|
|
|
|
| 155 |
)
|
| 156 |
with gr.Row(equal_height=True):
|
| 157 |
send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
|
@@ -166,7 +171,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 166 |
generated_audio_6 = gr.Audio(
|
| 167 |
label=t("results.generated_music", n=6),
|
| 168 |
type="filepath",
|
| 169 |
-
interactive=False
|
|
|
|
| 170 |
)
|
| 171 |
with gr.Row(equal_height=True):
|
| 172 |
send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
|
@@ -181,7 +187,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 181 |
generated_audio_7 = gr.Audio(
|
| 182 |
label=t("results.generated_music", n=7),
|
| 183 |
type="filepath",
|
| 184 |
-
interactive=False
|
|
|
|
| 185 |
)
|
| 186 |
with gr.Row(equal_height=True):
|
| 187 |
send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
|
@@ -196,7 +203,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 196 |
generated_audio_8 = gr.Audio(
|
| 197 |
label=t("results.generated_music", n=8),
|
| 198 |
type="filepath",
|
| 199 |
-
interactive=False
|
|
|
|
| 200 |
)
|
| 201 |
with gr.Row(equal_height=True):
|
| 202 |
send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
|
|
|
| 28 |
generated_audio_1 = gr.Audio(
|
| 29 |
label=t("results.generated_music", n=1),
|
| 30 |
type="filepath",
|
| 31 |
+
interactive=False,
|
| 32 |
+
show_download_button=False
|
| 33 |
)
|
| 34 |
with gr.Row(equal_height=True):
|
| 35 |
send_to_src_btn_1 = gr.Button(
|
|
|
|
| 59 |
generated_audio_2 = gr.Audio(
|
| 60 |
label=t("results.generated_music", n=2),
|
| 61 |
type="filepath",
|
| 62 |
+
interactive=False,
|
| 63 |
+
show_download_button=False
|
| 64 |
)
|
| 65 |
with gr.Row(equal_height=True):
|
| 66 |
send_to_src_btn_2 = gr.Button(
|
|
|
|
| 90 |
generated_audio_3 = gr.Audio(
|
| 91 |
label=t("results.generated_music", n=3),
|
| 92 |
type="filepath",
|
| 93 |
+
interactive=False,
|
| 94 |
+
show_download_button=False
|
| 95 |
)
|
| 96 |
with gr.Row(equal_height=True):
|
| 97 |
send_to_src_btn_3 = gr.Button(
|
|
|
|
| 121 |
generated_audio_4 = gr.Audio(
|
| 122 |
label=t("results.generated_music", n=4),
|
| 123 |
type="filepath",
|
| 124 |
+
interactive=False,
|
| 125 |
+
show_download_button=False
|
| 126 |
)
|
| 127 |
with gr.Row(equal_height=True):
|
| 128 |
send_to_src_btn_4 = gr.Button(
|
|
|
|
| 155 |
generated_audio_5 = gr.Audio(
|
| 156 |
label=t("results.generated_music", n=5),
|
| 157 |
type="filepath",
|
| 158 |
+
interactive=False,
|
| 159 |
+
show_download_button=False
|
| 160 |
)
|
| 161 |
with gr.Row(equal_height=True):
|
| 162 |
send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
|
|
|
| 171 |
generated_audio_6 = gr.Audio(
|
| 172 |
label=t("results.generated_music", n=6),
|
| 173 |
type="filepath",
|
| 174 |
+
interactive=False,
|
| 175 |
+
show_download_button=False
|
| 176 |
)
|
| 177 |
with gr.Row(equal_height=True):
|
| 178 |
send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
|
|
|
| 187 |
generated_audio_7 = gr.Audio(
|
| 188 |
label=t("results.generated_music", n=7),
|
| 189 |
type="filepath",
|
| 190 |
+
interactive=False,
|
| 191 |
+
show_download_button=False
|
| 192 |
)
|
| 193 |
with gr.Row(equal_height=True):
|
| 194 |
send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
|
|
|
| 203 |
generated_audio_8 = gr.Audio(
|
| 204 |
label=t("results.generated_music", n=8),
|
| 205 |
type="filepath",
|
| 206 |
+
interactive=False,
|
| 207 |
+
show_download_button=False
|
| 208 |
)
|
| 209 |
with gr.Row(equal_height=True):
|
| 210 |
send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
acestep/handler.py
CHANGED
|
@@ -2077,7 +2077,6 @@ class AceStepHandler:
|
|
| 2077 |
if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
|
| 2078 |
return {
|
| 2079 |
"audios": [],
|
| 2080 |
-
"generation_info": "",
|
| 2081 |
"status_message": "ā Model not fully initialized. Please initialize all components first.",
|
| 2082 |
"extra_outputs": {},
|
| 2083 |
"success": False,
|
|
@@ -2101,7 +2100,7 @@ class AceStepHandler:
|
|
| 2101 |
|
| 2102 |
logger.info("[generate_music] Starting generation...")
|
| 2103 |
if progress:
|
| 2104 |
-
progress(0.
|
| 2105 |
logger.info("[generate_music] Preparing inputs...")
|
| 2106 |
|
| 2107 |
# Reset offload cost
|
|
@@ -2123,8 +2122,6 @@ class AceStepHandler:
|
|
| 2123 |
repainting_end = None
|
| 2124 |
|
| 2125 |
try:
|
| 2126 |
-
progress(0.1, desc="Preparing inputs...")
|
| 2127 |
-
|
| 2128 |
# 1. Process reference audio
|
| 2129 |
refer_audios = None
|
| 2130 |
if reference_audio is not None:
|
|
@@ -2176,7 +2173,7 @@ class AceStepHandler:
|
|
| 2176 |
can_use_repainting
|
| 2177 |
)
|
| 2178 |
|
| 2179 |
-
progress(0.
|
| 2180 |
|
| 2181 |
# Prepare audio_code_hints - use if audio_code_string is provided
|
| 2182 |
# This works for both text2music (auto-switched to cover) and cover tasks
|
|
@@ -2245,7 +2242,7 @@ class AceStepHandler:
|
|
| 2245 |
|
| 2246 |
logger.info("[generate_music] VAE decode completed. Preparing audio tensors...")
|
| 2247 |
if progress:
|
| 2248 |
-
progress(0.
|
| 2249 |
|
| 2250 |
# Prepare audio tensors (no file I/O here, no UUID generation)
|
| 2251 |
# pred_wavs is already [batch, channels, samples] format
|
|
@@ -2257,23 +2254,6 @@ class AceStepHandler:
|
|
| 2257 |
audio_tensor = pred_wavs[i].cpu().float()
|
| 2258 |
audio_tensors.append(audio_tensor)
|
| 2259 |
|
| 2260 |
-
# Format time costs if available
|
| 2261 |
-
time_costs_str = ""
|
| 2262 |
-
if time_costs:
|
| 2263 |
-
if isinstance(time_costs, dict):
|
| 2264 |
-
time_costs_str = "\n\n**ā±ļø Time Costs:**\n"
|
| 2265 |
-
for key, value in time_costs.items():
|
| 2266 |
-
# Format key: encoder_time_cost -> Encoder
|
| 2267 |
-
formatted_key = key.replace("_time_cost", "").replace("_", " ").title()
|
| 2268 |
-
time_costs_str += f" - {formatted_key}: {value:.2f}s\n"
|
| 2269 |
-
elif isinstance(time_costs, (int, float)):
|
| 2270 |
-
time_costs_str = f"\n\n**ā±ļø Time Cost:** {time_costs:.2f}s"
|
| 2271 |
-
|
| 2272 |
-
generation_info = f"""**šµ Generation Complete**
|
| 2273 |
-
|
| 2274 |
-
**Seeds:** {seed_value_for_ui}
|
| 2275 |
-
**Steps:** {inference_steps}
|
| 2276 |
-
**Audio Count:** {len(audio_tensors)} audio(s){time_costs_str}"""
|
| 2277 |
status_message = f"ā
Generation completed successfully!"
|
| 2278 |
logger.info(f"[generate_music] Done! Generated {len(audio_tensors)} audio tensors.")
|
| 2279 |
|
|
@@ -2307,7 +2287,6 @@ class AceStepHandler:
|
|
| 2307 |
|
| 2308 |
return {
|
| 2309 |
"audios": audios,
|
| 2310 |
-
"generation_info": generation_info,
|
| 2311 |
"status_message": status_message,
|
| 2312 |
"extra_outputs": extra_outputs,
|
| 2313 |
"success": True,
|
|
@@ -2319,7 +2298,6 @@ class AceStepHandler:
|
|
| 2319 |
logger.exception("[generate_music] Generation failed")
|
| 2320 |
return {
|
| 2321 |
"audios": [],
|
| 2322 |
-
"generation_info": "",
|
| 2323 |
"status_message": error_msg,
|
| 2324 |
"extra_outputs": {},
|
| 2325 |
"success": False,
|
|
|
|
| 2077 |
if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
|
| 2078 |
return {
|
| 2079 |
"audios": [],
|
|
|
|
| 2080 |
"status_message": "ā Model not fully initialized. Please initialize all components first.",
|
| 2081 |
"extra_outputs": {},
|
| 2082 |
"success": False,
|
|
|
|
| 2100 |
|
| 2101 |
logger.info("[generate_music] Starting generation...")
|
| 2102 |
if progress:
|
| 2103 |
+
progress(0.51, desc="Preparing inputs...")
|
| 2104 |
logger.info("[generate_music] Preparing inputs...")
|
| 2105 |
|
| 2106 |
# Reset offload cost
|
|
|
|
| 2122 |
repainting_end = None
|
| 2123 |
|
| 2124 |
try:
|
|
|
|
|
|
|
| 2125 |
# 1. Process reference audio
|
| 2126 |
refer_audios = None
|
| 2127 |
if reference_audio is not None:
|
|
|
|
| 2173 |
can_use_repainting
|
| 2174 |
)
|
| 2175 |
|
| 2176 |
+
progress(0.52, desc=f"Generating music (batch size: {actual_batch_size})...")
|
| 2177 |
|
| 2178 |
# Prepare audio_code_hints - use if audio_code_string is provided
|
| 2179 |
# This works for both text2music (auto-switched to cover) and cover tasks
|
|
|
|
| 2242 |
|
| 2243 |
logger.info("[generate_music] VAE decode completed. Preparing audio tensors...")
|
| 2244 |
if progress:
|
| 2245 |
+
progress(0.99, desc="Preparing audio data...")
|
| 2246 |
|
| 2247 |
# Prepare audio tensors (no file I/O here, no UUID generation)
|
| 2248 |
# pred_wavs is already [batch, channels, samples] format
|
|
|
|
| 2254 |
audio_tensor = pred_wavs[i].cpu().float()
|
| 2255 |
audio_tensors.append(audio_tensor)
|
| 2256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2257 |
status_message = f"ā
Generation completed successfully!"
|
| 2258 |
logger.info(f"[generate_music] Done! Generated {len(audio_tensors)} audio tensors.")
|
| 2259 |
|
|
|
|
| 2287 |
|
| 2288 |
return {
|
| 2289 |
"audios": audios,
|
|
|
|
| 2290 |
"status_message": status_message,
|
| 2291 |
"extra_outputs": extra_outputs,
|
| 2292 |
"success": True,
|
|
|
|
| 2298 |
logger.exception("[generate_music] Generation failed")
|
| 2299 |
return {
|
| 2300 |
"audios": [],
|
|
|
|
| 2301 |
"status_message": error_msg,
|
| 2302 |
"extra_outputs": {},
|
| 2303 |
"success": False,
|
acestep/inference.py
CHANGED
|
@@ -67,19 +67,19 @@ class GenerationParams:
|
|
| 67 |
# Required Inputs
|
| 68 |
task_type: str = "text2music"
|
| 69 |
instruction: str = "Fill the audio semantic mask based on the given conditions:"
|
| 70 |
-
|
| 71 |
# Audio Uploads
|
| 72 |
reference_audio: Optional[str] = None
|
| 73 |
src_audio: Optional[str] = None
|
| 74 |
-
|
| 75 |
# LM Codes Hints
|
| 76 |
audio_codes: str = ""
|
| 77 |
-
|
| 78 |
# Text Inputs
|
| 79 |
caption: str = ""
|
| 80 |
lyrics: str = ""
|
| 81 |
instrumental: bool = False
|
| 82 |
-
|
| 83 |
# Metadata
|
| 84 |
vocal_language: str = "unknown"
|
| 85 |
bpm: Optional[int] = None
|
|
@@ -98,7 +98,7 @@ class GenerationParams:
|
|
| 98 |
repainting_start: float = 0.0
|
| 99 |
repainting_end: float = -1
|
| 100 |
audio_cover_strength: float = 1.0
|
| 101 |
-
|
| 102 |
# 5Hz Language Model Parameters
|
| 103 |
thinking: bool = True
|
| 104 |
lm_temperature: float = 0.85
|
|
@@ -108,8 +108,18 @@ class GenerationParams:
|
|
| 108 |
lm_negative_prompt: str = "NO USER INPUT"
|
| 109 |
use_cot_metas: bool = True
|
| 110 |
use_cot_caption: bool = True
|
|
|
|
| 111 |
use_cot_language: bool = True
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
def to_dict(self) -> Dict[str, Any]:
|
| 114 |
"""Convert config to dictionary for JSON serialization."""
|
| 115 |
return asdict(self)
|
|
@@ -123,25 +133,27 @@ class GenerationConfig:
|
|
| 123 |
batch_size: Number of audio samples to generate
|
| 124 |
allow_lm_batch: Whether to allow batch processing in LM
|
| 125 |
use_random_seed: Whether to use random seed
|
| 126 |
-
|
| 127 |
- None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
|
| 128 |
- List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
|
| 129 |
- int: Single seed value (will be converted to list and padded)
|
| 130 |
lm_batch_chunk_size: Batch chunk size for LM processing
|
| 131 |
-
is_format_caption: Whether to format caption
|
| 132 |
constrained_decoding_debug: Whether to enable constrained decoding debug
|
| 133 |
audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
|
| 134 |
"""
|
| 135 |
batch_size: int = 2
|
| 136 |
allow_lm_batch: bool = False
|
| 137 |
use_random_seed: bool = True
|
| 138 |
-
|
| 139 |
lm_batch_chunk_size: int = 8
|
| 140 |
-
is_format_caption: bool = False
|
| 141 |
-
use_constrained_decoding: bool = True
|
| 142 |
constrained_decoding_debug: bool = False
|
| 143 |
audio_format: str = "flac" # Default to FLAC for fast saving
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
@dataclass
|
| 146 |
class GenerationResult:
|
| 147 |
"""Result of music generation.
|
|
@@ -149,34 +161,80 @@ class GenerationResult:
|
|
| 149 |
Attributes:
|
| 150 |
# Audio Outputs
|
| 151 |
audios: List of audio dictionaries with paths, keys, params
|
| 152 |
-
generation_info: Markdown-formatted generation information
|
| 153 |
status_message: Status message from generation
|
| 154 |
extra_outputs: Extra outputs from generation
|
| 155 |
success: Whether generation completed successfully
|
| 156 |
error: Error message if generation failed
|
| 157 |
"""
|
| 158 |
-
|
| 159 |
# Audio Outputs
|
| 160 |
audios: List[Dict[str, Any]] = field(default_factory=list)
|
| 161 |
# Generation Information
|
| 162 |
-
generation_info: str = ""
|
| 163 |
status_message: str = ""
|
| 164 |
extra_outputs: Dict[str, Any] = field(default_factory=dict)
|
| 165 |
# Success Status
|
| 166 |
success: bool = True
|
| 167 |
error: Optional[str] = None
|
| 168 |
-
|
| 169 |
def to_dict(self) -> Dict[str, Any]:
|
| 170 |
"""Convert result to dictionary for JSON serialization."""
|
| 171 |
return asdict(self)
|
| 172 |
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
def generate_music(
|
| 175 |
dit_handler,
|
| 176 |
llm_handler,
|
| 177 |
params: GenerationParams,
|
| 178 |
config: GenerationConfig,
|
| 179 |
save_dir: Optional[str] = None,
|
|
|
|
| 180 |
) -> GenerationResult:
|
| 181 |
"""Generate music using ACE-Step model with optional LM reasoning.
|
| 182 |
|
|
@@ -194,24 +252,31 @@ def generate_music(
|
|
| 194 |
audio_code_string_to_use = params.audio_codes
|
| 195 |
lm_generated_metadata = None
|
| 196 |
lm_generated_audio_codes_list = []
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
# Extract mutable copies of metadata (will be updated by LM if needed)
|
| 199 |
bpm = params.bpm
|
| 200 |
key_scale = params.keyscale
|
| 201 |
time_signature = params.timesignature
|
| 202 |
audio_duration = params.duration
|
| 203 |
-
|
|
|
|
|
|
|
| 204 |
# Determine if we need to generate audio codes
|
| 205 |
# If user has provided audio_codes, we don't need to generate them
|
| 206 |
# Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
|
| 207 |
user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
|
| 208 |
-
|
| 209 |
# Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
|
| 210 |
# For now, we use "llm_dit" if batch mode or if user hasn't provided codes
|
| 211 |
# Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
|
| 212 |
# Note: This logic can be refined based on specific requirements
|
| 213 |
need_audio_codes = not user_provided_audio_codes
|
| 214 |
-
|
| 215 |
# Determine if we should use chunk-based LM generation (always use chunks for consistency)
|
| 216 |
# Determine actual batch size for chunk processing
|
| 217 |
actual_batch_size = config.batch_size if config.batch_size is not None else 1
|
|
@@ -219,80 +284,75 @@ def generate_music(
|
|
| 219 |
# Prepare seeds for batch generation
|
| 220 |
# Use config.seed if provided, otherwise fallback to params.seed
|
| 221 |
# Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
|
| 222 |
-
seed_for_generation =
|
| 223 |
-
if config.
|
| 224 |
-
if isinstance(config.
|
| 225 |
# Convert List[int] to comma-separated string
|
| 226 |
-
seed_for_generation = ",".join(str(s) for s in config.
|
| 227 |
-
|
| 228 |
-
# Single int seed
|
| 229 |
-
seed_for_generation = config.seed
|
| 230 |
-
|
| 231 |
# Use dit_handler.prepare_seeds to handle seed list generation and padding
|
| 232 |
# This will handle all the logic: padding with random seeds if needed, etc.
|
| 233 |
-
actual_seed_list, _ = dit_handler.prepare_seeds(
|
| 234 |
-
actual_batch_size, seed_for_generation, config.use_random_seed
|
| 235 |
-
)
|
| 236 |
|
| 237 |
# LM-based Chain-of-Thought reasoning
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
| 243 |
# Build user_metadata from user-provided values
|
| 244 |
user_metadata = {}
|
| 245 |
if bpm is not None:
|
| 246 |
try:
|
| 247 |
bpm_value = float(bpm)
|
| 248 |
if bpm_value > 0:
|
| 249 |
-
user_metadata['bpm'] =
|
| 250 |
except (ValueError, TypeError):
|
| 251 |
pass
|
| 252 |
-
|
| 253 |
if key_scale and key_scale.strip():
|
| 254 |
key_scale_clean = key_scale.strip()
|
| 255 |
if key_scale_clean.lower() not in ["n/a", ""]:
|
| 256 |
user_metadata['keyscale'] = key_scale_clean
|
| 257 |
-
|
| 258 |
if time_signature and time_signature.strip():
|
| 259 |
time_sig_clean = time_signature.strip()
|
| 260 |
if time_sig_clean.lower() not in ["n/a", ""]:
|
| 261 |
user_metadata['timesignature'] = time_sig_clean
|
| 262 |
-
|
| 263 |
if audio_duration is not None:
|
| 264 |
try:
|
| 265 |
duration_value = float(audio_duration)
|
| 266 |
if duration_value > 0:
|
| 267 |
-
user_metadata['duration'] =
|
| 268 |
except (ValueError, TypeError):
|
| 269 |
pass
|
| 270 |
-
|
| 271 |
user_metadata_to_pass = user_metadata if user_metadata else None
|
| 272 |
-
|
| 273 |
# Determine infer_type based on whether we need audio codes
|
| 274 |
# - "llm_dit": generates both metas and audio codes (two-phase internally)
|
| 275 |
# - "dit": generates only metas (single phase)
|
| 276 |
infer_type = "llm_dit" if need_audio_codes else "dit"
|
| 277 |
-
|
| 278 |
# Use chunk size from config, or default to batch_size if not set
|
| 279 |
max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
|
| 280 |
num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
|
| 281 |
-
|
| 282 |
all_metadata_list = []
|
| 283 |
all_audio_codes_list = []
|
| 284 |
-
|
| 285 |
for chunk_idx in range(num_chunks):
|
| 286 |
chunk_start = chunk_idx * max_inference_batch_size
|
| 287 |
chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
|
| 288 |
chunk_size = chunk_end - chunk_start
|
| 289 |
chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
|
| 290 |
-
|
| 291 |
-
logger.info(
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
)
|
| 295 |
-
|
| 296 |
# Use the determined infer_type
|
| 297 |
# - "llm_dit" will internally run two phases (metas + codes)
|
| 298 |
# - "dit" will only run phase 1 (metas only)
|
|
@@ -308,25 +368,54 @@ def generate_music(
|
|
| 308 |
user_metadata=user_metadata_to_pass,
|
| 309 |
use_cot_caption=params.use_cot_caption,
|
| 310 |
use_cot_language=params.use_cot_language,
|
| 311 |
-
|
| 312 |
-
use_constrained_decoding=
|
| 313 |
constrained_decoding_debug=config.constrained_decoding_debug,
|
| 314 |
batch_size=chunk_size,
|
| 315 |
seeds=chunk_seeds,
|
|
|
|
| 316 |
)
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
if chunk_size > 1:
|
| 319 |
-
metadata_list
|
|
|
|
| 320 |
all_metadata_list.extend(metadata_list)
|
| 321 |
all_audio_codes_list.extend(audio_codes_list)
|
| 322 |
else:
|
| 323 |
-
metadata
|
|
|
|
| 324 |
all_metadata_list.append(metadata)
|
| 325 |
all_audio_codes_list.append(audio_codes)
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
|
| 328 |
lm_generated_audio_codes_list = all_audio_codes_list
|
| 329 |
-
|
| 330 |
# Set audio_code_string_to_use based on infer_type
|
| 331 |
if infer_type == "llm_dit":
|
| 332 |
# If batch mode, use list; otherwise use single string
|
|
@@ -337,23 +426,48 @@ def generate_music(
|
|
| 337 |
else:
|
| 338 |
# For "dit" mode, keep user-provided codes or empty
|
| 339 |
audio_code_string_to_use = params.audio_codes
|
| 340 |
-
|
| 341 |
# Update metadata from LM if not provided by user
|
| 342 |
if lm_generated_metadata:
|
| 343 |
-
bpm, key_scale, time_signature, audio_duration = _update_metadata_from_lm(
|
| 344 |
-
lm_generated_metadata,
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
-
|
| 348 |
# Phase 2: DiT music generation
|
| 349 |
# Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
|
| 350 |
result = dit_handler.generate_music(
|
| 351 |
-
captions=
|
| 352 |
-
lyrics=
|
| 353 |
bpm=bpm,
|
| 354 |
key_scale=key_scale,
|
| 355 |
time_signature=time_signature,
|
| 356 |
-
vocal_language=
|
| 357 |
inference_steps=params.inference_steps,
|
| 358 |
guidance_scale=params.guidance_scale,
|
| 359 |
use_random_seed=config.use_random_seed,
|
|
@@ -371,110 +485,80 @@ def generate_music(
|
|
| 371 |
use_adg=params.use_adg,
|
| 372 |
cfg_interval_start=params.cfg_interval_start,
|
| 373 |
cfg_interval_end=params.cfg_interval_end,
|
|
|
|
| 374 |
)
|
| 375 |
-
|
| 376 |
# Check if generation failed
|
| 377 |
if not result.get("success", False):
|
| 378 |
return GenerationResult(
|
| 379 |
audios=[],
|
| 380 |
-
generation_info=result.get("generation_info", ""),
|
| 381 |
status_message=result.get("status_message", ""),
|
| 382 |
extra_outputs={},
|
| 383 |
success=False,
|
| 384 |
error=result.get("error"),
|
| 385 |
)
|
| 386 |
-
|
| 387 |
# Extract results from dit_handler.generate_music dict
|
| 388 |
dit_audios = result.get("audios", [])
|
| 389 |
-
generation_info = result.get("generation_info", "")
|
| 390 |
status_message = result.get("status_message", "")
|
| 391 |
dit_extra_outputs = result.get("extra_outputs", {})
|
| 392 |
-
|
| 393 |
-
# Append LM metadata to generation info
|
| 394 |
-
if lm_generated_metadata:
|
| 395 |
-
generation_info = _append_lm_metadata_to_info(generation_info, lm_generated_metadata)
|
| 396 |
-
|
| 397 |
# Use the seed list already prepared above (from config.seed or params.seed fallback)
|
| 398 |
# actual_seed_list was computed earlier using dit_handler.prepare_seeds
|
| 399 |
seed_list = actual_seed_list
|
| 400 |
-
|
| 401 |
# Get base params dictionary
|
| 402 |
base_params_dict = params.to_dict()
|
| 403 |
-
|
| 404 |
# Save audio files using AudioSaver (format from config)
|
| 405 |
audio_format = config.audio_format if config.audio_format else "flac"
|
| 406 |
audio_saver = AudioSaver(default_format=audio_format)
|
| 407 |
-
|
| 408 |
# Use handler's temp_dir for saving files
|
| 409 |
if save_dir is not None:
|
| 410 |
os.makedirs(save_dir, exist_ok=True)
|
| 411 |
-
|
| 412 |
# Build audios list for GenerationResult with params and save files
|
| 413 |
# Audio saving and UUID generation handled here, outside of handler
|
| 414 |
audios = []
|
| 415 |
for idx, dit_audio in enumerate(dit_audios):
|
| 416 |
# Create a copy of params dict for this audio
|
| 417 |
audio_params = base_params_dict.copy()
|
| 418 |
-
|
| 419 |
# Update audio-specific values
|
| 420 |
audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
|
| 421 |
-
|
| 422 |
# Add audio codes if batch mode
|
| 423 |
if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
|
| 424 |
audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
|
| 425 |
-
|
| 426 |
# Get audio tensor and metadata
|
| 427 |
audio_tensor = dit_audio.get("tensor")
|
| 428 |
sample_rate = dit_audio.get("sample_rate", 48000)
|
| 429 |
-
|
| 430 |
# Generate UUID for this audio (moved from handler)
|
| 431 |
batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
|
| 432 |
-
audio_code_str = lm_generated_audio_codes_list[idx] if (
|
|
|
|
| 433 |
if isinstance(audio_code_str, list):
|
| 434 |
audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
|
| 435 |
-
|
| 436 |
-
audio_key = generate_uuid_from_params(
|
| 437 |
-
|
| 438 |
-
lyrics=params.lyrics,
|
| 439 |
-
bpm=bpm,
|
| 440 |
-
key_scale=key_scale,
|
| 441 |
-
time_signature=time_signature,
|
| 442 |
-
vocal_language=params.vocal_language,
|
| 443 |
-
inference_steps=params.inference_steps,
|
| 444 |
-
guidance_scale=params.guidance_scale,
|
| 445 |
-
seed=batch_seed,
|
| 446 |
-
audio_duration=audio_duration,
|
| 447 |
-
audio_code_string=audio_code_str,
|
| 448 |
-
repainting_start=params.repainting_start,
|
| 449 |
-
repainting_end=params.repainting_end,
|
| 450 |
-
instruction=params.instruction,
|
| 451 |
-
audio_cover_strength=params.audio_cover_strength,
|
| 452 |
-
task_type=params.task_type,
|
| 453 |
-
use_adg=params.use_adg,
|
| 454 |
-
cfg_interval_start=params.cfg_interval_start,
|
| 455 |
-
cfg_interval_end=params.cfg_interval_end,
|
| 456 |
-
audio_format=audio_format,
|
| 457 |
-
reference_audio=params.reference_audio,
|
| 458 |
-
src_audio=params.src_audio,
|
| 459 |
-
batch_index=idx,
|
| 460 |
-
)
|
| 461 |
-
|
| 462 |
# Save audio file (handled outside handler)
|
| 463 |
audio_path = None
|
| 464 |
if audio_tensor is not None and save_dir is not None:
|
| 465 |
try:
|
| 466 |
audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
|
| 467 |
-
audio_path = audio_saver.save_audio(
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
channels_first=True
|
| 473 |
-
)
|
| 474 |
except Exception as e:
|
| 475 |
logger.error(f"[generate_music] Failed to save audio file: {e}")
|
| 476 |
audio_path = "" # Fallback to empty path
|
| 477 |
-
|
| 478 |
audio_dict = {
|
| 479 |
"path": audio_path or "", # File path (saved here, not in handler)
|
| 480 |
"tensor": audio_tensor, # Audio tensor [channels, samples], CPU, float32
|
|
@@ -482,259 +566,55 @@ def generate_music(
|
|
| 482 |
"sample_rate": sample_rate,
|
| 483 |
"params": audio_params,
|
| 484 |
}
|
| 485 |
-
|
| 486 |
audios.append(audio_dict)
|
| 487 |
-
|
| 488 |
# Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
|
| 489 |
extra_outputs = dit_extra_outputs.copy()
|
| 490 |
extra_outputs["lm_metadata"] = lm_generated_metadata
|
| 491 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
# Create and return GenerationResult
|
| 493 |
return GenerationResult(
|
| 494 |
audios=audios,
|
| 495 |
-
generation_info=generation_info,
|
| 496 |
status_message=status_message,
|
| 497 |
extra_outputs=extra_outputs,
|
| 498 |
success=True,
|
| 499 |
error=None,
|
| 500 |
)
|
| 501 |
-
|
| 502 |
except Exception as e:
|
| 503 |
logger.exception("Music generation failed")
|
| 504 |
return GenerationResult(
|
| 505 |
audios=[],
|
| 506 |
-
generation_info=f"ā Generation failed: {str(e)}",
|
| 507 |
status_message=f"Error: {str(e)}",
|
| 508 |
extra_outputs={},
|
| 509 |
success=False,
|
| 510 |
error=str(e),
|
| 511 |
)
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
def _update_metadata_from_lm(
|
| 515 |
-
metadata: Dict[str, Any],
|
| 516 |
-
bpm: Optional[int],
|
| 517 |
-
key_scale: str,
|
| 518 |
-
time_signature: str,
|
| 519 |
-
audio_duration: Optional[float],
|
| 520 |
-
) -> Tuple[Optional[int], str, str, Optional[float]]:
|
| 521 |
-
"""Update metadata fields from LM output if not provided by user."""
|
| 522 |
-
|
| 523 |
-
if bpm is None and metadata.get('bpm'):
|
| 524 |
-
bpm_value = metadata.get('bpm')
|
| 525 |
-
if bpm_value not in ["N/A", ""]:
|
| 526 |
-
try:
|
| 527 |
-
bpm = int(bpm_value)
|
| 528 |
-
except (ValueError, TypeError):
|
| 529 |
-
pass
|
| 530 |
-
|
| 531 |
-
if not key_scale and metadata.get('keyscale'):
|
| 532 |
-
key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
|
| 533 |
-
if key_scale_value != "N/A":
|
| 534 |
-
key_scale = key_scale_value
|
| 535 |
-
|
| 536 |
-
if not time_signature and metadata.get('timesignature'):
|
| 537 |
-
time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
|
| 538 |
-
if time_signature_value != "N/A":
|
| 539 |
-
time_signature = time_signature_value
|
| 540 |
-
|
| 541 |
-
if audio_duration is None or audio_duration <= 0:
|
| 542 |
-
audio_duration_value = metadata.get('duration', -1)
|
| 543 |
-
if audio_duration_value not in ["N/A", ""]:
|
| 544 |
-
try:
|
| 545 |
-
audio_duration = float(audio_duration_value)
|
| 546 |
-
except (ValueError, TypeError):
|
| 547 |
-
pass
|
| 548 |
-
|
| 549 |
-
return bpm, key_scale, time_signature, audio_duration
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
def _append_lm_metadata_to_info(generation_info: str, metadata: Dict[str, Any]) -> str:
|
| 553 |
-
"""Append LM-generated metadata to generation info string."""
|
| 554 |
-
|
| 555 |
-
metadata_lines = []
|
| 556 |
-
if metadata.get('bpm'):
|
| 557 |
-
metadata_lines.append(f"- **BPM:** {metadata['bpm']}")
|
| 558 |
-
if metadata.get('caption'):
|
| 559 |
-
metadata_lines.append(f"- **Refined Caption:** {metadata['caption']}")
|
| 560 |
-
if metadata.get('duration'):
|
| 561 |
-
metadata_lines.append(f"- **Duration:** {metadata['duration']} seconds")
|
| 562 |
-
if metadata.get('keyscale'):
|
| 563 |
-
metadata_lines.append(f"- **Key Scale:** {metadata['keyscale']}")
|
| 564 |
-
if metadata.get('language'):
|
| 565 |
-
metadata_lines.append(f"- **Language:** {metadata['language']}")
|
| 566 |
-
if metadata.get('timesignature'):
|
| 567 |
-
metadata_lines.append(f"- **Time Signature:** {metadata['timesignature']}")
|
| 568 |
-
|
| 569 |
-
if metadata_lines:
|
| 570 |
-
metadata_section = "\n\n**š¤ LM-Generated Metadata:**\n" + "\n\n".join(metadata_lines)
|
| 571 |
-
return metadata_section + "\n\n" + generation_info
|
| 572 |
-
|
| 573 |
-
return generation_info
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
# ============================================================================
|
| 577 |
-
# LEGACY GRADIO UI COMPATIBILITY LAYER
|
| 578 |
-
# ============================================================================
|
| 579 |
-
|
| 580 |
-
def generate_for_gradio(
|
| 581 |
-
dit_handler,
|
| 582 |
-
llm_handler,
|
| 583 |
-
captions,
|
| 584 |
-
lyrics,
|
| 585 |
-
bpm,
|
| 586 |
-
key_scale,
|
| 587 |
-
time_signature,
|
| 588 |
-
vocal_language,
|
| 589 |
-
inference_steps,
|
| 590 |
-
guidance_scale,
|
| 591 |
-
random_seed_checkbox,
|
| 592 |
-
seed,
|
| 593 |
-
reference_audio,
|
| 594 |
-
audio_duration,
|
| 595 |
-
batch_size_input,
|
| 596 |
-
src_audio,
|
| 597 |
-
text2music_audio_code_string,
|
| 598 |
-
repainting_start,
|
| 599 |
-
repainting_end,
|
| 600 |
-
instruction_display_gen,
|
| 601 |
-
audio_cover_strength,
|
| 602 |
-
task_type,
|
| 603 |
-
use_adg,
|
| 604 |
-
cfg_interval_start,
|
| 605 |
-
cfg_interval_end,
|
| 606 |
-
audio_format,
|
| 607 |
-
lm_temperature,
|
| 608 |
-
think_checkbox,
|
| 609 |
-
lm_cfg_scale,
|
| 610 |
-
lm_top_k,
|
| 611 |
-
lm_top_p,
|
| 612 |
-
lm_negative_prompt,
|
| 613 |
-
use_cot_metas,
|
| 614 |
-
use_cot_caption,
|
| 615 |
-
use_cot_language,
|
| 616 |
-
is_format_caption,
|
| 617 |
-
constrained_decoding_debug,
|
| 618 |
-
allow_lm_batch,
|
| 619 |
-
lm_batch_chunk_size,
|
| 620 |
-
):
|
| 621 |
-
"""Legacy Gradio UI compatibility wrapper.
|
| 622 |
-
|
| 623 |
-
This function maintains backward compatibility with the Gradio UI.
|
| 624 |
-
For new integrations, use generate_music() with GenerationConfig instead.
|
| 625 |
-
|
| 626 |
-
Returns:
|
| 627 |
-
Tuple with 28 elements for Gradio UI component updates
|
| 628 |
-
"""
|
| 629 |
-
|
| 630 |
-
# Convert legacy parameters to GenerationParams and GenerationConfig
|
| 631 |
-
params = GenerationParams(
|
| 632 |
-
caption=captions,
|
| 633 |
-
lyrics=lyrics,
|
| 634 |
-
bpm=bpm,
|
| 635 |
-
keyscale=key_scale,
|
| 636 |
-
timesignature=time_signature,
|
| 637 |
-
vocal_language=vocal_language,
|
| 638 |
-
audio_codes=text2music_audio_code_string,
|
| 639 |
-
duration=audio_duration,
|
| 640 |
-
inference_steps=inference_steps,
|
| 641 |
-
guidance_scale=guidance_scale,
|
| 642 |
-
seed=seed,
|
| 643 |
-
use_adg=use_adg,
|
| 644 |
-
cfg_interval_start=cfg_interval_start,
|
| 645 |
-
cfg_interval_end=cfg_interval_end,
|
| 646 |
-
audio_format=audio_format,
|
| 647 |
-
task_type=task_type,
|
| 648 |
-
reference_audio=reference_audio,
|
| 649 |
-
src_audio=src_audio,
|
| 650 |
-
repainting_start=repainting_start,
|
| 651 |
-
repainting_end=repainting_end,
|
| 652 |
-
audio_cover_strength=audio_cover_strength,
|
| 653 |
-
instruction=instruction_display_gen,
|
| 654 |
-
thinking=think_checkbox,
|
| 655 |
-
lm_temperature=lm_temperature,
|
| 656 |
-
lm_cfg_scale=lm_cfg_scale,
|
| 657 |
-
lm_top_k=lm_top_k,
|
| 658 |
-
lm_top_p=lm_top_p,
|
| 659 |
-
lm_negative_prompt=lm_negative_prompt,
|
| 660 |
-
use_cot_metas=use_cot_metas,
|
| 661 |
-
use_cot_caption=use_cot_caption,
|
| 662 |
-
use_cot_language=use_cot_language,
|
| 663 |
-
)
|
| 664 |
-
|
| 665 |
-
config = GenerationConfig(batch_size=1)
|
| 666 |
-
config.batch_size = batch_size_input
|
| 667 |
-
config.use_random_seed = random_seed_checkbox
|
| 668 |
-
config.allow_lm_batch = allow_lm_batch
|
| 669 |
-
config.lm_batch_chunk_size = lm_batch_chunk_size
|
| 670 |
-
config.is_format_caption = is_format_caption
|
| 671 |
-
config.constrained_decoding_debug = constrained_decoding_debug
|
| 672 |
-
|
| 673 |
-
# Call new API
|
| 674 |
-
result = generate_music(dit_handler, llm_handler, params, config)
|
| 675 |
-
|
| 676 |
-
# Extract audio paths from result.audios
|
| 677 |
-
audio_paths = [audio["path"] for audio in result.audios]
|
| 678 |
-
|
| 679 |
-
# Extract extra outputs
|
| 680 |
-
extra_outputs = result.extra_outputs
|
| 681 |
-
seed_value = extra_outputs.get("seed_value", "")
|
| 682 |
-
lm_metadata = extra_outputs.get("lm_metadata", None)
|
| 683 |
-
|
| 684 |
-
# Legacy alignment fields (no longer used, set to empty/None)
|
| 685 |
-
align_score_1 = ""
|
| 686 |
-
align_text_1 = ""
|
| 687 |
-
align_plot_1 = None
|
| 688 |
-
align_score_2 = ""
|
| 689 |
-
align_text_2 = ""
|
| 690 |
-
align_plot_2 = None
|
| 691 |
-
|
| 692 |
-
# Determine which codes to update in UI
|
| 693 |
-
if config.allow_lm_batch and lm_metadata:
|
| 694 |
-
# Batch mode: extract codes from metadata if available
|
| 695 |
-
lm_codes_list = lm_metadata.get('audio_codes_list', [])
|
| 696 |
-
updated_audio_codes = lm_codes_list[0] if lm_codes_list else text2music_audio_code_string
|
| 697 |
-
codes_outputs = (lm_codes_list + [""] * 8)[:8]
|
| 698 |
-
else:
|
| 699 |
-
# Single mode
|
| 700 |
-
lm_codes = lm_metadata.get('audio_codes', '') if lm_metadata else ''
|
| 701 |
-
updated_audio_codes = lm_codes if lm_codes else text2music_audio_code_string
|
| 702 |
-
codes_outputs = [""] * 8
|
| 703 |
-
|
| 704 |
-
# Prepare audio outputs (up to 8)
|
| 705 |
-
audio_outputs = (audio_paths + [None] * 8)[:8]
|
| 706 |
-
|
| 707 |
-
# Return tuple for Gradio UI (28 elements)
|
| 708 |
-
return (
|
| 709 |
-
audio_outputs[0], # generated_audio_1
|
| 710 |
-
audio_outputs[1], # generated_audio_2
|
| 711 |
-
audio_outputs[2], # generated_audio_3
|
| 712 |
-
audio_outputs[3], # generated_audio_4
|
| 713 |
-
audio_outputs[4], # generated_audio_5
|
| 714 |
-
audio_outputs[5], # generated_audio_6
|
| 715 |
-
audio_outputs[6], # generated_audio_7
|
| 716 |
-
audio_outputs[7], # generated_audio_8
|
| 717 |
-
audio_paths, # generated_audio_batch
|
| 718 |
-
result.generation_info,
|
| 719 |
-
result.status_message,
|
| 720 |
-
seed_value,
|
| 721 |
-
align_score_1,
|
| 722 |
-
align_text_1,
|
| 723 |
-
align_plot_1,
|
| 724 |
-
align_score_2,
|
| 725 |
-
align_text_2,
|
| 726 |
-
align_plot_2,
|
| 727 |
-
updated_audio_codes, # Update main audio codes in UI
|
| 728 |
-
codes_outputs[0], # text2music_audio_code_string_1
|
| 729 |
-
codes_outputs[1], # text2music_audio_code_string_2
|
| 730 |
-
codes_outputs[2], # text2music_audio_code_string_3
|
| 731 |
-
codes_outputs[3], # text2music_audio_code_string_4
|
| 732 |
-
codes_outputs[4], # text2music_audio_code_string_5
|
| 733 |
-
codes_outputs[5], # text2music_audio_code_string_6
|
| 734 |
-
codes_outputs[6], # text2music_audio_code_string_7
|
| 735 |
-
codes_outputs[7], # text2music_audio_code_string_8
|
| 736 |
-
lm_metadata, # Store metadata for "Send to src audio" buttons
|
| 737 |
-
is_format_caption, # Keep is_format_caption unchanged
|
| 738 |
-
)
|
| 739 |
-
|
| 740 |
-
|
|
|
|
| 67 |
# Required Inputs
|
| 68 |
task_type: str = "text2music"
|
| 69 |
instruction: str = "Fill the audio semantic mask based on the given conditions:"
|
| 70 |
+
|
| 71 |
# Audio Uploads
|
| 72 |
reference_audio: Optional[str] = None
|
| 73 |
src_audio: Optional[str] = None
|
| 74 |
+
|
| 75 |
# LM Codes Hints
|
| 76 |
audio_codes: str = ""
|
| 77 |
+
|
| 78 |
# Text Inputs
|
| 79 |
caption: str = ""
|
| 80 |
lyrics: str = ""
|
| 81 |
instrumental: bool = False
|
| 82 |
+
|
| 83 |
# Metadata
|
| 84 |
vocal_language: str = "unknown"
|
| 85 |
bpm: Optional[int] = None
|
|
|
|
| 98 |
repainting_start: float = 0.0
|
| 99 |
repainting_end: float = -1
|
| 100 |
audio_cover_strength: float = 1.0
|
| 101 |
+
|
| 102 |
# 5Hz Language Model Parameters
|
| 103 |
thinking: bool = True
|
| 104 |
lm_temperature: float = 0.85
|
|
|
|
| 108 |
lm_negative_prompt: str = "NO USER INPUT"
|
| 109 |
use_cot_metas: bool = True
|
| 110 |
use_cot_caption: bool = True
|
| 111 |
+
use_cot_lyrics: bool = False # TODO: not used yet
|
| 112 |
use_cot_language: bool = True
|
| 113 |
+
use_constrained_decoding: bool = True
|
| 114 |
+
|
| 115 |
+
cot_bpm: Optional[int] = None
|
| 116 |
+
cot_keyscale: str = ""
|
| 117 |
+
cot_timesignature: str = ""
|
| 118 |
+
cot_duration: Optional[float] = None
|
| 119 |
+
cot_vocal_language: str = "unknown"
|
| 120 |
+
cot_caption: str = ""
|
| 121 |
+
cot_lyrics: str = ""
|
| 122 |
+
|
| 123 |
def to_dict(self) -> Dict[str, Any]:
|
| 124 |
"""Convert config to dictionary for JSON serialization."""
|
| 125 |
return asdict(self)
|
|
|
|
| 133 |
batch_size: Number of audio samples to generate
|
| 134 |
allow_lm_batch: Whether to allow batch processing in LM
|
| 135 |
use_random_seed: Whether to use random seed
|
| 136 |
+
seeds: Seed(s) for batch generation. Can be:
|
| 137 |
- None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
|
| 138 |
- List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
|
| 139 |
- int: Single seed value (will be converted to list and padded)
|
| 140 |
lm_batch_chunk_size: Batch chunk size for LM processing
|
|
|
|
| 141 |
constrained_decoding_debug: Whether to enable constrained decoding debug
|
| 142 |
audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
|
| 143 |
"""
|
| 144 |
batch_size: int = 2
|
| 145 |
allow_lm_batch: bool = False
|
| 146 |
use_random_seed: bool = True
|
| 147 |
+
seeds: Optional[List[int]] = None
|
| 148 |
lm_batch_chunk_size: int = 8
|
|
|
|
|
|
|
| 149 |
constrained_decoding_debug: bool = False
|
| 150 |
audio_format: str = "flac" # Default to FLAC for fast saving
|
| 151 |
|
| 152 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 153 |
+
"""Convert config to dictionary for JSON serialization."""
|
| 154 |
+
return asdict(self)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
@dataclass
|
| 158 |
class GenerationResult:
|
| 159 |
"""Result of music generation.
|
|
|
|
| 161 |
Attributes:
|
| 162 |
# Audio Outputs
|
| 163 |
audios: List of audio dictionaries with paths, keys, params
|
|
|
|
| 164 |
status_message: Status message from generation
|
| 165 |
extra_outputs: Extra outputs from generation
|
| 166 |
success: Whether generation completed successfully
|
| 167 |
error: Error message if generation failed
|
| 168 |
"""
|
| 169 |
+
|
| 170 |
# Audio Outputs
|
| 171 |
audios: List[Dict[str, Any]] = field(default_factory=list)
|
| 172 |
# Generation Information
|
|
|
|
| 173 |
status_message: str = ""
|
| 174 |
extra_outputs: Dict[str, Any] = field(default_factory=dict)
|
| 175 |
# Success Status
|
| 176 |
success: bool = True
|
| 177 |
error: Optional[str] = None
|
| 178 |
+
|
| 179 |
def to_dict(self) -> Dict[str, Any]:
|
| 180 |
"""Convert result to dictionary for JSON serialization."""
|
| 181 |
return asdict(self)
|
| 182 |
|
| 183 |
|
| 184 |
+
def _update_metadata_from_lm(
|
| 185 |
+
metadata: Dict[str, Any],
|
| 186 |
+
bpm: Optional[int],
|
| 187 |
+
key_scale: str,
|
| 188 |
+
time_signature: str,
|
| 189 |
+
audio_duration: Optional[float],
|
| 190 |
+
vocal_language: str,
|
| 191 |
+
caption: str,
|
| 192 |
+
lyrics: str,
|
| 193 |
+
) -> Tuple[Optional[int], str, str, Optional[float]]:
|
| 194 |
+
"""Update metadata fields from LM output if not provided by user."""
|
| 195 |
+
|
| 196 |
+
if bpm is None and metadata.get('bpm'):
|
| 197 |
+
bpm_value = metadata.get('bpm')
|
| 198 |
+
if bpm_value not in ["N/A", ""]:
|
| 199 |
+
try:
|
| 200 |
+
bpm = int(bpm_value)
|
| 201 |
+
except (ValueError, TypeError):
|
| 202 |
+
pass
|
| 203 |
+
|
| 204 |
+
if not key_scale and metadata.get('keyscale'):
|
| 205 |
+
key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
|
| 206 |
+
if key_scale_value != "N/A":
|
| 207 |
+
key_scale = key_scale_value
|
| 208 |
+
|
| 209 |
+
if not time_signature and metadata.get('timesignature'):
|
| 210 |
+
time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
|
| 211 |
+
if time_signature_value != "N/A":
|
| 212 |
+
time_signature = time_signature_value
|
| 213 |
+
|
| 214 |
+
if audio_duration is None or audio_duration <= 0:
|
| 215 |
+
audio_duration_value = metadata.get('duration', -1)
|
| 216 |
+
if audio_duration_value not in ["N/A", ""]:
|
| 217 |
+
try:
|
| 218 |
+
audio_duration = float(audio_duration_value)
|
| 219 |
+
except (ValueError, TypeError):
|
| 220 |
+
pass
|
| 221 |
+
|
| 222 |
+
if not vocal_language and metadata.get('vocal_language'):
|
| 223 |
+
vocal_language = metadata.get('vocal_language')
|
| 224 |
+
if not caption and metadata.get('caption'):
|
| 225 |
+
caption = metadata.get('caption')
|
| 226 |
+
if not lyrics and metadata.get('lyrics'):
|
| 227 |
+
lyrics = metadata.get('lyrics')
|
| 228 |
+
return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
|
| 229 |
+
|
| 230 |
+
|
| 231 |
def generate_music(
|
| 232 |
dit_handler,
|
| 233 |
llm_handler,
|
| 234 |
params: GenerationParams,
|
| 235 |
config: GenerationConfig,
|
| 236 |
save_dir: Optional[str] = None,
|
| 237 |
+
progress=None,
|
| 238 |
) -> GenerationResult:
|
| 239 |
"""Generate music using ACE-Step model with optional LM reasoning.
|
| 240 |
|
|
|
|
| 252 |
audio_code_string_to_use = params.audio_codes
|
| 253 |
lm_generated_metadata = None
|
| 254 |
lm_generated_audio_codes_list = []
|
| 255 |
+
lm_total_time_costs = {
|
| 256 |
+
"phase1_time": 0.0,
|
| 257 |
+
"phase2_time": 0.0,
|
| 258 |
+
"total_time": 0.0,
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
# Extract mutable copies of metadata (will be updated by LM if needed)
|
| 262 |
bpm = params.bpm
|
| 263 |
key_scale = params.keyscale
|
| 264 |
time_signature = params.timesignature
|
| 265 |
audio_duration = params.duration
|
| 266 |
+
dit_input_caption = params.caption
|
| 267 |
+
dit_input_vocal_language = params.vocal_language
|
| 268 |
+
dit_input_lyrics = params.lyrics
|
| 269 |
# Determine if we need to generate audio codes
|
| 270 |
# If user has provided audio_codes, we don't need to generate them
|
| 271 |
# Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
|
| 272 |
user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
|
| 273 |
+
|
| 274 |
# Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
|
| 275 |
# For now, we use "llm_dit" if batch mode or if user hasn't provided codes
|
| 276 |
# Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
|
| 277 |
# Note: This logic can be refined based on specific requirements
|
| 278 |
need_audio_codes = not user_provided_audio_codes
|
| 279 |
+
|
| 280 |
# Determine if we should use chunk-based LM generation (always use chunks for consistency)
|
| 281 |
# Determine actual batch size for chunk processing
|
| 282 |
actual_batch_size = config.batch_size if config.batch_size is not None else 1
|
|
|
|
| 284 |
# Prepare seeds for batch generation
|
| 285 |
# Use config.seed if provided, otherwise fallback to params.seed
|
| 286 |
# Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
|
| 287 |
+
seed_for_generation = ""
|
| 288 |
+
if config.seeds is not None and len(config.seeds) > 0:
|
| 289 |
+
if isinstance(config.seeds, list):
|
| 290 |
# Convert List[int] to comma-separated string
|
| 291 |
+
seed_for_generation = ",".join(str(s) for s in config.seeds)
|
| 292 |
+
|
|
|
|
|
|
|
|
|
|
| 293 |
# Use dit_handler.prepare_seeds to handle seed list generation and padding
|
| 294 |
# This will handle all the logic: padding with random seeds if needed, etc.
|
| 295 |
+
actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
|
|
|
|
|
|
|
| 296 |
|
| 297 |
# LM-based Chain-of-Thought reasoning
|
| 298 |
+
use_lm = params.thinking and llm_handler.llm_initialized
|
| 299 |
+
lm_status = []
|
| 300 |
+
if use_lm:
|
| 301 |
+
# Convert sampling parameters - handle None values safely
|
| 302 |
+
top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
|
| 303 |
+
top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p
|
| 304 |
+
|
| 305 |
# Build user_metadata from user-provided values
|
| 306 |
user_metadata = {}
|
| 307 |
if bpm is not None:
|
| 308 |
try:
|
| 309 |
bpm_value = float(bpm)
|
| 310 |
if bpm_value > 0:
|
| 311 |
+
user_metadata['bpm'] = int(bpm_value)
|
| 312 |
except (ValueError, TypeError):
|
| 313 |
pass
|
| 314 |
+
|
| 315 |
if key_scale and key_scale.strip():
|
| 316 |
key_scale_clean = key_scale.strip()
|
| 317 |
if key_scale_clean.lower() not in ["n/a", ""]:
|
| 318 |
user_metadata['keyscale'] = key_scale_clean
|
| 319 |
+
|
| 320 |
if time_signature and time_signature.strip():
|
| 321 |
time_sig_clean = time_signature.strip()
|
| 322 |
if time_sig_clean.lower() not in ["n/a", ""]:
|
| 323 |
user_metadata['timesignature'] = time_sig_clean
|
| 324 |
+
|
| 325 |
if audio_duration is not None:
|
| 326 |
try:
|
| 327 |
duration_value = float(audio_duration)
|
| 328 |
if duration_value > 0:
|
| 329 |
+
user_metadata['duration'] = int(duration_value)
|
| 330 |
except (ValueError, TypeError):
|
| 331 |
pass
|
| 332 |
+
|
| 333 |
user_metadata_to_pass = user_metadata if user_metadata else None
|
| 334 |
+
|
| 335 |
# Determine infer_type based on whether we need audio codes
|
| 336 |
# - "llm_dit": generates both metas and audio codes (two-phase internally)
|
| 337 |
# - "dit": generates only metas (single phase)
|
| 338 |
infer_type = "llm_dit" if need_audio_codes else "dit"
|
| 339 |
+
|
| 340 |
# Use chunk size from config, or default to batch_size if not set
|
| 341 |
max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
|
| 342 |
num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
|
| 343 |
+
|
| 344 |
all_metadata_list = []
|
| 345 |
all_audio_codes_list = []
|
| 346 |
+
|
| 347 |
for chunk_idx in range(num_chunks):
|
| 348 |
chunk_start = chunk_idx * max_inference_batch_size
|
| 349 |
chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
|
| 350 |
chunk_size = chunk_end - chunk_start
|
| 351 |
chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
|
| 352 |
+
|
| 353 |
+
logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
|
| 354 |
+
f"(size: {chunk_size}, seeds: {chunk_seeds})")
|
| 355 |
+
|
|
|
|
|
|
|
| 356 |
# Use the determined infer_type
|
| 357 |
# - "llm_dit" will internally run two phases (metas + codes)
|
| 358 |
# - "dit" will only run phase 1 (metas only)
|
|
|
|
| 368 |
user_metadata=user_metadata_to_pass,
|
| 369 |
use_cot_caption=params.use_cot_caption,
|
| 370 |
use_cot_language=params.use_cot_language,
|
| 371 |
+
use_cot_metas=params.use_cot_metas,
|
| 372 |
+
use_constrained_decoding=params.use_constrained_decoding,
|
| 373 |
constrained_decoding_debug=config.constrained_decoding_debug,
|
| 374 |
batch_size=chunk_size,
|
| 375 |
seeds=chunk_seeds,
|
| 376 |
+
progress=progress,
|
| 377 |
)
|
| 378 |
+
|
| 379 |
+
# Check if LM generation failed
|
| 380 |
+
if not result.get("success", False):
|
| 381 |
+
error_msg = result.get("error", "Unknown LM error")
|
| 382 |
+
lm_status.append(f"ā LM Error: {error_msg}")
|
| 383 |
+
# Return early with error
|
| 384 |
+
return GenerationResult(
|
| 385 |
+
audios=[],
|
| 386 |
+
status_message=f"ā LM generation failed: {error_msg}",
|
| 387 |
+
extra_outputs={},
|
| 388 |
+
success=False,
|
| 389 |
+
error=error_msg,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
# Extract metadata and audio_codes from result dict
|
| 393 |
if chunk_size > 1:
|
| 394 |
+
metadata_list = result.get("metadata", [])
|
| 395 |
+
audio_codes_list = result.get("audio_codes", [])
|
| 396 |
all_metadata_list.extend(metadata_list)
|
| 397 |
all_audio_codes_list.extend(audio_codes_list)
|
| 398 |
else:
|
| 399 |
+
metadata = result.get("metadata", {})
|
| 400 |
+
audio_codes = result.get("audio_codes", "")
|
| 401 |
all_metadata_list.append(metadata)
|
| 402 |
all_audio_codes_list.append(audio_codes)
|
| 403 |
+
|
| 404 |
+
# Collect time costs from LM extra_outputs
|
| 405 |
+
lm_extra = result.get("extra_outputs", {})
|
| 406 |
+
lm_chunk_time_costs = lm_extra.get("time_costs", {})
|
| 407 |
+
if lm_chunk_time_costs:
|
| 408 |
+
# Accumulate time costs from all chunks
|
| 409 |
+
for key in ["phase1_time", "phase2_time", "total_time"]:
|
| 410 |
+
if key in lm_chunk_time_costs:
|
| 411 |
+
lm_total_time_costs[key] += lm_chunk_time_costs[key]
|
| 412 |
+
|
| 413 |
+
time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()])
|
| 414 |
+
lm_status.append(f"ā
LM chunk {chunk_idx+1}: {time_str}")
|
| 415 |
+
|
| 416 |
lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
|
| 417 |
lm_generated_audio_codes_list = all_audio_codes_list
|
| 418 |
+
|
| 419 |
# Set audio_code_string_to_use based on infer_type
|
| 420 |
if infer_type == "llm_dit":
|
| 421 |
# If batch mode, use list; otherwise use single string
|
|
|
|
| 426 |
else:
|
| 427 |
# For "dit" mode, keep user-provided codes or empty
|
| 428 |
audio_code_string_to_use = params.audio_codes
|
| 429 |
+
|
| 430 |
# Update metadata from LM if not provided by user
|
| 431 |
if lm_generated_metadata:
|
| 432 |
+
bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm(
|
| 433 |
+
metadata=lm_generated_metadata,
|
| 434 |
+
bpm=bpm,
|
| 435 |
+
key_scale=key_scale,
|
| 436 |
+
time_signature=time_signature,
|
| 437 |
+
audio_duration=audio_duration,
|
| 438 |
+
vocal_language=dit_input_vocal_language,
|
| 439 |
+
caption=dit_input_caption,
|
| 440 |
+
lyrics=dit_input_lyrics)
|
| 441 |
+
if not params.bpm:
|
| 442 |
+
params.cot_bpm = bpm
|
| 443 |
+
if not params.keyscale:
|
| 444 |
+
params.cot_keyscale = key_scale
|
| 445 |
+
if not params.timesignature:
|
| 446 |
+
params.cot_timesignature = time_signature
|
| 447 |
+
if not params.duration:
|
| 448 |
+
params.cot_duration = audio_duration
|
| 449 |
+
if not params.vocal_language:
|
| 450 |
+
params.cot_vocal_language = vocal_language
|
| 451 |
+
if not params.caption:
|
| 452 |
+
params.cot_caption = caption
|
| 453 |
+
if not params.lyrics:
|
| 454 |
+
params.cot_lyrics = lyrics
|
| 455 |
+
|
| 456 |
+
# set cot caption and language if needed
|
| 457 |
+
if params.use_cot_caption:
|
| 458 |
+
dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption)
|
| 459 |
+
if params.use_cot_language:
|
| 460 |
+
dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language)
|
| 461 |
|
|
|
|
| 462 |
# Phase 2: DiT music generation
|
| 463 |
# Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
|
| 464 |
result = dit_handler.generate_music(
|
| 465 |
+
captions=dit_input_caption,
|
| 466 |
+
lyrics=dit_input_lyrics,
|
| 467 |
bpm=bpm,
|
| 468 |
key_scale=key_scale,
|
| 469 |
time_signature=time_signature,
|
| 470 |
+
vocal_language=dit_input_vocal_language,
|
| 471 |
inference_steps=params.inference_steps,
|
| 472 |
guidance_scale=params.guidance_scale,
|
| 473 |
use_random_seed=config.use_random_seed,
|
|
|
|
| 485 |
use_adg=params.use_adg,
|
| 486 |
cfg_interval_start=params.cfg_interval_start,
|
| 487 |
cfg_interval_end=params.cfg_interval_end,
|
| 488 |
+
progress=progress,
|
| 489 |
)
|
| 490 |
+
|
| 491 |
# Check if generation failed
|
| 492 |
if not result.get("success", False):
|
| 493 |
return GenerationResult(
|
| 494 |
audios=[],
|
|
|
|
| 495 |
status_message=result.get("status_message", ""),
|
| 496 |
extra_outputs={},
|
| 497 |
success=False,
|
| 498 |
error=result.get("error"),
|
| 499 |
)
|
| 500 |
+
|
| 501 |
# Extract results from dit_handler.generate_music dict
|
| 502 |
dit_audios = result.get("audios", [])
|
|
|
|
| 503 |
status_message = result.get("status_message", "")
|
| 504 |
dit_extra_outputs = result.get("extra_outputs", {})
|
| 505 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
# Use the seed list already prepared above (from config.seed or params.seed fallback)
|
| 507 |
# actual_seed_list was computed earlier using dit_handler.prepare_seeds
|
| 508 |
seed_list = actual_seed_list
|
| 509 |
+
|
| 510 |
# Get base params dictionary
|
| 511 |
base_params_dict = params.to_dict()
|
| 512 |
+
|
| 513 |
# Save audio files using AudioSaver (format from config)
|
| 514 |
audio_format = config.audio_format if config.audio_format else "flac"
|
| 515 |
audio_saver = AudioSaver(default_format=audio_format)
|
| 516 |
+
|
| 517 |
# Use handler's temp_dir for saving files
|
| 518 |
if save_dir is not None:
|
| 519 |
os.makedirs(save_dir, exist_ok=True)
|
| 520 |
+
|
| 521 |
# Build audios list for GenerationResult with params and save files
|
| 522 |
# Audio saving and UUID generation handled here, outside of handler
|
| 523 |
audios = []
|
| 524 |
for idx, dit_audio in enumerate(dit_audios):
|
| 525 |
# Create a copy of params dict for this audio
|
| 526 |
audio_params = base_params_dict.copy()
|
| 527 |
+
|
| 528 |
# Update audio-specific values
|
| 529 |
audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
|
| 530 |
+
|
| 531 |
# Add audio codes if batch mode
|
| 532 |
if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
|
| 533 |
audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
|
| 534 |
+
|
| 535 |
# Get audio tensor and metadata
|
| 536 |
audio_tensor = dit_audio.get("tensor")
|
| 537 |
sample_rate = dit_audio.get("sample_rate", 48000)
|
| 538 |
+
|
| 539 |
# Generate UUID for this audio (moved from handler)
|
| 540 |
batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
|
| 541 |
+
audio_code_str = lm_generated_audio_codes_list[idx] if (
|
| 542 |
+
lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
|
| 543 |
if isinstance(audio_code_str, list):
|
| 544 |
audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
|
| 545 |
+
|
| 546 |
+
audio_key = generate_uuid_from_params(audio_params)
|
| 547 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
# Save audio file (handled outside handler)
|
| 549 |
audio_path = None
|
| 550 |
if audio_tensor is not None and save_dir is not None:
|
| 551 |
try:
|
| 552 |
audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
|
| 553 |
+
audio_path = audio_saver.save_audio(audio_tensor,
|
| 554 |
+
audio_file,
|
| 555 |
+
sample_rate=sample_rate,
|
| 556 |
+
format=audio_format,
|
| 557 |
+
channels_first=True)
|
|
|
|
|
|
|
| 558 |
except Exception as e:
|
| 559 |
logger.error(f"[generate_music] Failed to save audio file: {e}")
|
| 560 |
audio_path = "" # Fallback to empty path
|
| 561 |
+
|
| 562 |
audio_dict = {
|
| 563 |
"path": audio_path or "", # File path (saved here, not in handler)
|
| 564 |
"tensor": audio_tensor, # Audio tensor [channels, samples], CPU, float32
|
|
|
|
| 566 |
"sample_rate": sample_rate,
|
| 567 |
"params": audio_params,
|
| 568 |
}
|
| 569 |
+
|
| 570 |
audios.append(audio_dict)
|
| 571 |
+
|
| 572 |
# Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
|
| 573 |
extra_outputs = dit_extra_outputs.copy()
|
| 574 |
extra_outputs["lm_metadata"] = lm_generated_metadata
|
| 575 |
+
|
| 576 |
+
# Merge time_costs from both LM and DiT into a unified dictionary
|
| 577 |
+
unified_time_costs = {}
|
| 578 |
+
|
| 579 |
+
# Add LM time costs (if LM was used)
|
| 580 |
+
if use_lm and lm_total_time_costs:
|
| 581 |
+
for key, value in lm_total_time_costs.items():
|
| 582 |
+
unified_time_costs[f"lm_{key}"] = value
|
| 583 |
+
|
| 584 |
+
# Add DiT time costs (if available)
|
| 585 |
+
dit_time_costs = dit_extra_outputs.get("time_costs", {})
|
| 586 |
+
if dit_time_costs:
|
| 587 |
+
for key, value in dit_time_costs.items():
|
| 588 |
+
unified_time_costs[f"dit_{key}"] = value
|
| 589 |
+
|
| 590 |
+
# Calculate total pipeline time
|
| 591 |
+
if unified_time_costs:
|
| 592 |
+
lm_total = unified_time_costs.get("lm_total_time", 0.0)
|
| 593 |
+
dit_total = unified_time_costs.get("dit_total_time_cost", 0.0)
|
| 594 |
+
unified_time_costs["pipeline_total_time"] = lm_total + dit_total
|
| 595 |
+
|
| 596 |
+
# Update extra_outputs with unified time_costs
|
| 597 |
+
extra_outputs["time_costs"] = unified_time_costs
|
| 598 |
+
|
| 599 |
+
if lm_status:
|
| 600 |
+
status_message = "\n".join(lm_status) + "\n" + status_message
|
| 601 |
+
else:
|
| 602 |
+
status_message = status_message
|
| 603 |
# Create and return GenerationResult
|
| 604 |
return GenerationResult(
|
| 605 |
audios=audios,
|
|
|
|
| 606 |
status_message=status_message,
|
| 607 |
extra_outputs=extra_outputs,
|
| 608 |
success=True,
|
| 609 |
error=None,
|
| 610 |
)
|
| 611 |
+
|
| 612 |
except Exception as e:
|
| 613 |
logger.exception("Music generation failed")
|
| 614 |
return GenerationResult(
|
| 615 |
audios=[],
|
|
|
|
| 616 |
status_message=f"Error: {str(e)}",
|
| 617 |
extra_outputs={},
|
| 618 |
success=False,
|
| 619 |
error=str(e),
|
| 620 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
acestep/llm_inference.py
CHANGED
|
@@ -5,6 +5,7 @@ Handles all LM-related operations including initialization and generation
|
|
| 5 |
import os
|
| 6 |
import traceback
|
| 7 |
import time
|
|
|
|
| 8 |
from typing import Optional, Dict, Any, Tuple, List, Union
|
| 9 |
from contextlib import contextmanager
|
| 10 |
|
|
@@ -309,6 +310,7 @@ class LLMHandler:
|
|
| 309 |
|
| 310 |
logger.info("loading 5Hz LM tokenizer...")
|
| 311 |
start_time = time.time()
|
|
|
|
| 312 |
llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True)
|
| 313 |
logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
|
| 314 |
self.llm_tokenizer = llm_tokenizer
|
|
@@ -796,12 +798,13 @@ class LLMHandler:
|
|
| 796 |
constrained_decoding_debug: bool = False,
|
| 797 |
target_duration: Optional[float] = None,
|
| 798 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
|
|
|
| 799 |
use_cot_caption: bool = True,
|
| 800 |
use_cot_language: bool = True,
|
| 801 |
-
is_format_caption: bool = False,
|
| 802 |
batch_size: Optional[int] = None,
|
| 803 |
seeds: Optional[List[int]] = None,
|
| 804 |
-
|
|
|
|
| 805 |
"""Two-phase LM generation: CoT generation followed by audio codes generation.
|
| 806 |
|
| 807 |
- infer_type='dit': Phase 1 only - generate CoT and return metas (no audio codes)
|
|
@@ -817,20 +820,30 @@ class LLMHandler:
|
|
| 817 |
batch_size: Optional batch size for batch generation. If None or 1, returns single result.
|
| 818 |
If > 1, returns batch results (lists).
|
| 819 |
seeds: Optional list of seeds for batch generation (for reproducibility).
|
| 820 |
-
Only used when batch_size > 1.
|
| 821 |
|
| 822 |
Returns:
|
| 823 |
-
|
| 824 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 825 |
"""
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
|
|
|
| 829 |
infer_type = (infer_type or "").strip().lower()
|
| 830 |
if infer_type not in {"dit", "llm_dit"}:
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 834 |
|
| 835 |
# Determine if batch mode
|
| 836 |
is_batch = batch_size and batch_size > 1
|
|
@@ -854,7 +867,8 @@ class LLMHandler:
|
|
| 854 |
|
| 855 |
# ========== PHASE 1: CoT Generation ==========
|
| 856 |
# Skip CoT if all metadata are user-provided OR caption is already formatted
|
| 857 |
-
|
|
|
|
| 858 |
if is_batch:
|
| 859 |
logger.info("Batch Phase 1: Generating CoT metadata (once for all items)...")
|
| 860 |
else:
|
|
@@ -893,9 +907,13 @@ class LLMHandler:
|
|
| 893 |
phase1_time = time.time() - phase1_start
|
| 894 |
|
| 895 |
if not cot_output_text:
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 899 |
|
| 900 |
# Parse metadata from CoT output
|
| 901 |
metadata, _ = self.parse_lm_output(cot_output_text)
|
|
@@ -915,11 +933,31 @@ class LLMHandler:
|
|
| 915 |
if infer_type == "dit":
|
| 916 |
if is_batch:
|
| 917 |
metadata_list = [metadata.copy() for _ in range(actual_batch_size)]
|
| 918 |
-
|
| 919 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 920 |
else:
|
| 921 |
-
|
| 922 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 923 |
|
| 924 |
# ========== PHASE 2: Audio Codes Generation ==========
|
| 925 |
if is_batch:
|
|
@@ -935,6 +973,7 @@ class LLMHandler:
|
|
| 935 |
formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
|
| 936 |
logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}")
|
| 937 |
|
|
|
|
| 938 |
if is_batch:
|
| 939 |
# Batch mode: generate codes for all items
|
| 940 |
formatted_prompts = [formatted_prompt_with_cot] * actual_batch_size
|
|
@@ -978,9 +1017,21 @@ class LLMHandler:
|
|
| 978 |
seeds=seeds,
|
| 979 |
)
|
| 980 |
except Exception as e:
|
| 981 |
-
error_msg = f"
|
| 982 |
logger.error(error_msg)
|
| 983 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 984 |
|
| 985 |
# Parse audio codes from each output
|
| 986 |
audio_codes_list = []
|
|
@@ -996,8 +1047,22 @@ class LLMHandler:
|
|
| 996 |
codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list]
|
| 997 |
logger.info(f"Batch Phase 2 completed in {phase2_time:.2f}s. Generated codes: {codes_counts}")
|
| 998 |
|
| 999 |
-
|
| 1000 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1001 |
else:
|
| 1002 |
# Single mode: generate codes for one item
|
| 1003 |
codes_output_text, status = self.generate_from_formatted_prompt(
|
|
@@ -1025,7 +1090,20 @@ class LLMHandler:
|
|
| 1025 |
)
|
| 1026 |
|
| 1027 |
if not codes_output_text:
|
| 1028 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1029 |
|
| 1030 |
phase2_time = time.time() - phase2_start
|
| 1031 |
|
|
@@ -1035,8 +1113,21 @@ class LLMHandler:
|
|
| 1035 |
codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
|
| 1036 |
logger.info(f"Phase 2 completed in {phase2_time:.2f}s. Generated {codes_count} audio codes")
|
| 1037 |
|
| 1038 |
-
|
| 1039 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1040 |
|
| 1041 |
def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str:
|
| 1042 |
"""
|
|
|
|
| 5 |
import os
|
| 6 |
import traceback
|
| 7 |
import time
|
| 8 |
+
import random
|
| 9 |
from typing import Optional, Dict, Any, Tuple, List, Union
|
| 10 |
from contextlib import contextmanager
|
| 11 |
|
|
|
|
| 310 |
|
| 311 |
logger.info("loading 5Hz LM tokenizer...")
|
| 312 |
start_time = time.time()
|
| 313 |
+
# TODO: load tokenizer too slow, not found solution yet
|
| 314 |
llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True)
|
| 315 |
logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
|
| 316 |
self.llm_tokenizer = llm_tokenizer
|
|
|
|
| 798 |
constrained_decoding_debug: bool = False,
|
| 799 |
target_duration: Optional[float] = None,
|
| 800 |
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 801 |
+
use_cot_metas: bool = True,
|
| 802 |
use_cot_caption: bool = True,
|
| 803 |
use_cot_language: bool = True,
|
|
|
|
| 804 |
batch_size: Optional[int] = None,
|
| 805 |
seeds: Optional[List[int]] = None,
|
| 806 |
+
progress=None,
|
| 807 |
+
) -> Dict[str, Any]:
|
| 808 |
"""Two-phase LM generation: CoT generation followed by audio codes generation.
|
| 809 |
|
| 810 |
- infer_type='dit': Phase 1 only - generate CoT and return metas (no audio codes)
|
|
|
|
| 820 |
batch_size: Optional batch size for batch generation. If None or 1, returns single result.
|
| 821 |
If > 1, returns batch results (lists).
|
| 822 |
seeds: Optional list of seeds for batch generation (for reproducibility).
|
| 823 |
+
Only used when batch_size > 1. TODO: not used yet
|
| 824 |
|
| 825 |
Returns:
|
| 826 |
+
Dictionary containing:
|
| 827 |
+
- metadata: Dict or List[Dict] - Generated metadata
|
| 828 |
+
- audio_codes: str or List[str] - Generated audio codes
|
| 829 |
+
- success: bool - Whether generation succeeded
|
| 830 |
+
- error: Optional[str] - Error message if failed
|
| 831 |
+
- extra_outputs: Dict with time_costs and other info
|
| 832 |
"""
|
| 833 |
+
if progress is None:
|
| 834 |
+
def progress(*args, **kwargs):
|
| 835 |
+
pass
|
| 836 |
+
|
| 837 |
infer_type = (infer_type or "").strip().lower()
|
| 838 |
if infer_type not in {"dit", "llm_dit"}:
|
| 839 |
+
error_msg = f"invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
|
| 840 |
+
return {
|
| 841 |
+
"metadata": [] if (batch_size and batch_size > 1) else {},
|
| 842 |
+
"audio_codes": [] if (batch_size and batch_size > 1) else "",
|
| 843 |
+
"success": False,
|
| 844 |
+
"error": error_msg,
|
| 845 |
+
"extra_outputs": {"time_costs": {}},
|
| 846 |
+
}
|
| 847 |
|
| 848 |
# Determine if batch mode
|
| 849 |
is_batch = batch_size and batch_size > 1
|
|
|
|
| 867 |
|
| 868 |
# ========== PHASE 1: CoT Generation ==========
|
| 869 |
# Skip CoT if all metadata are user-provided OR caption is already formatted
|
| 870 |
+
progress(0.1, f"Phase 1: Generating CoT metadata (once for all items)...")
|
| 871 |
+
if not has_all_metas and use_cot_metas:
|
| 872 |
if is_batch:
|
| 873 |
logger.info("Batch Phase 1: Generating CoT metadata (once for all items)...")
|
| 874 |
else:
|
|
|
|
| 907 |
phase1_time = time.time() - phase1_start
|
| 908 |
|
| 909 |
if not cot_output_text:
|
| 910 |
+
return {
|
| 911 |
+
"metadata": [] if is_batch else {},
|
| 912 |
+
"audio_codes": [] if is_batch else "",
|
| 913 |
+
"success": False,
|
| 914 |
+
"error": status,
|
| 915 |
+
"extra_outputs": {"time_costs": {"phase1_time": phase1_time}},
|
| 916 |
+
}
|
| 917 |
|
| 918 |
# Parse metadata from CoT output
|
| 919 |
metadata, _ = self.parse_lm_output(cot_output_text)
|
|
|
|
| 933 |
if infer_type == "dit":
|
| 934 |
if is_batch:
|
| 935 |
metadata_list = [metadata.copy() for _ in range(actual_batch_size)]
|
| 936 |
+
return {
|
| 937 |
+
"metadata": metadata_list,
|
| 938 |
+
"audio_codes": [""] * actual_batch_size,
|
| 939 |
+
"success": True,
|
| 940 |
+
"error": None,
|
| 941 |
+
"extra_outputs": {
|
| 942 |
+
"time_costs": {
|
| 943 |
+
"phase1_time": phase1_time,
|
| 944 |
+
"total_time": phase1_time,
|
| 945 |
+
}
|
| 946 |
+
},
|
| 947 |
+
}
|
| 948 |
else:
|
| 949 |
+
return {
|
| 950 |
+
"metadata": metadata,
|
| 951 |
+
"audio_codes": "",
|
| 952 |
+
"success": True,
|
| 953 |
+
"error": None,
|
| 954 |
+
"extra_outputs": {
|
| 955 |
+
"time_costs": {
|
| 956 |
+
"phase1_time": phase1_time,
|
| 957 |
+
"total_time": phase1_time,
|
| 958 |
+
}
|
| 959 |
+
},
|
| 960 |
+
}
|
| 961 |
|
| 962 |
# ========== PHASE 2: Audio Codes Generation ==========
|
| 963 |
if is_batch:
|
|
|
|
| 973 |
formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
|
| 974 |
logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}")
|
| 975 |
|
| 976 |
+
progress(0.5, f"Phase 2: Generating audio codes for {actual_batch_size} items...")
|
| 977 |
if is_batch:
|
| 978 |
# Batch mode: generate codes for all items
|
| 979 |
formatted_prompts = [formatted_prompt_with_cot] * actual_batch_size
|
|
|
|
| 1017 |
seeds=seeds,
|
| 1018 |
)
|
| 1019 |
except Exception as e:
|
| 1020 |
+
error_msg = f"Error in batch codes generation: {str(e)}"
|
| 1021 |
logger.error(error_msg)
|
| 1022 |
+
return {
|
| 1023 |
+
"metadata": [],
|
| 1024 |
+
"audio_codes": [],
|
| 1025 |
+
"success": False,
|
| 1026 |
+
"error": error_msg,
|
| 1027 |
+
"extra_outputs": {
|
| 1028 |
+
"time_costs": {
|
| 1029 |
+
"phase1_time": phase1_time,
|
| 1030 |
+
"phase2_time": 0.0,
|
| 1031 |
+
"total_time": phase1_time,
|
| 1032 |
+
}
|
| 1033 |
+
},
|
| 1034 |
+
}
|
| 1035 |
|
| 1036 |
# Parse audio codes from each output
|
| 1037 |
audio_codes_list = []
|
|
|
|
| 1047 |
codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list]
|
| 1048 |
logger.info(f"Batch Phase 2 completed in {phase2_time:.2f}s. Generated codes: {codes_counts}")
|
| 1049 |
|
| 1050 |
+
total_time = phase1_time + phase2_time
|
| 1051 |
+
return {
|
| 1052 |
+
"metadata": metadata_list,
|
| 1053 |
+
"audio_codes": audio_codes_list,
|
| 1054 |
+
"success": True,
|
| 1055 |
+
"error": None,
|
| 1056 |
+
"extra_outputs": {
|
| 1057 |
+
"time_costs": {
|
| 1058 |
+
"phase1_time": phase1_time,
|
| 1059 |
+
"phase2_time": phase2_time,
|
| 1060 |
+
"total_time": total_time,
|
| 1061 |
+
},
|
| 1062 |
+
"codes_counts": codes_counts,
|
| 1063 |
+
"total_codes": sum(codes_counts),
|
| 1064 |
+
},
|
| 1065 |
+
}
|
| 1066 |
else:
|
| 1067 |
# Single mode: generate codes for one item
|
| 1068 |
codes_output_text, status = self.generate_from_formatted_prompt(
|
|
|
|
| 1090 |
)
|
| 1091 |
|
| 1092 |
if not codes_output_text:
|
| 1093 |
+
total_time = phase1_time + phase2_time
|
| 1094 |
+
return {
|
| 1095 |
+
"metadata": metadata,
|
| 1096 |
+
"audio_codes": "",
|
| 1097 |
+
"success": False,
|
| 1098 |
+
"error": status,
|
| 1099 |
+
"extra_outputs": {
|
| 1100 |
+
"time_costs": {
|
| 1101 |
+
"phase1_time": phase1_time,
|
| 1102 |
+
"phase2_time": phase2_time,
|
| 1103 |
+
"total_time": total_time,
|
| 1104 |
+
}
|
| 1105 |
+
},
|
| 1106 |
+
}
|
| 1107 |
|
| 1108 |
phase2_time = time.time() - phase2_start
|
| 1109 |
|
|
|
|
| 1113 |
codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
|
| 1114 |
logger.info(f"Phase 2 completed in {phase2_time:.2f}s. Generated {codes_count} audio codes")
|
| 1115 |
|
| 1116 |
+
total_time = phase1_time + phase2_time
|
| 1117 |
+
return {
|
| 1118 |
+
"metadata": metadata,
|
| 1119 |
+
"audio_codes": audio_codes,
|
| 1120 |
+
"success": True,
|
| 1121 |
+
"error": None,
|
| 1122 |
+
"extra_outputs": {
|
| 1123 |
+
"time_costs": {
|
| 1124 |
+
"phase1_time": phase1_time,
|
| 1125 |
+
"phase2_time": phase2_time,
|
| 1126 |
+
"total_time": total_time,
|
| 1127 |
+
},
|
| 1128 |
+
"codes_count": codes_count,
|
| 1129 |
+
},
|
| 1130 |
+
}
|
| 1131 |
|
| 1132 |
def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str:
|
| 1133 |
"""
|
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py
CHANGED
|
@@ -93,6 +93,8 @@ class ModelRunner:
|
|
| 93 |
def _allocate_sample_buffers(self):
|
| 94 |
"""Pre-allocate reusable buffers for sampling to avoid repeated tensor creation."""
|
| 95 |
max_bs = self.config.max_num_seqs
|
|
|
|
|
|
|
| 96 |
|
| 97 |
# Pre-allocate pinned memory buffers on CPU for fast transfer
|
| 98 |
# Must explicitly specify device="cpu" since default device may be "cuda"
|
|
@@ -107,6 +109,19 @@ class ModelRunner:
|
|
| 107 |
self._cpu_positions = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 108 |
self._cpu_slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 109 |
self._cpu_context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
def exit(self):
|
| 112 |
if self.world_size > 1:
|
|
@@ -227,7 +242,7 @@ class ModelRunner:
|
|
| 227 |
if i != seq.num_blocks - 1:
|
| 228 |
end = start + self.block_size
|
| 229 |
else:
|
| 230 |
-
end = start + seq.last_block_num_tokens
|
| 231 |
slot_mapping.extend(list(range(start, end)))
|
| 232 |
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
| 233 |
block_tables = self.prepare_block_tables(seqs)
|
|
@@ -269,19 +284,28 @@ class ModelRunner:
|
|
| 269 |
target_seqs = seqs
|
| 270 |
|
| 271 |
# Fill pre-allocated CPU buffers
|
|
|
|
|
|
|
|
|
|
| 272 |
for i, seq in enumerate(target_seqs):
|
| 273 |
self._cpu_temperatures[i] = seq.temperature
|
| 274 |
self._cpu_cfg_scales[i] = seq.cfg_scale
|
| 275 |
self._cpu_top_ks[i] = seq.top_k if seq.top_k is not None else 0
|
|
|
|
|
|
|
| 276 |
self._cpu_top_ps[i] = seq.top_p if seq.top_p is not None else 1.0
|
|
|
|
|
|
|
| 277 |
self._cpu_repetition_penalties[i] = seq.repetition_penalty if seq.repetition_penalty is not None else 1.0
|
|
|
|
|
|
|
| 278 |
|
| 279 |
# Transfer to GPU using sliced views (single batched transfer)
|
| 280 |
temperatures = self._cpu_temperatures[:num_seqs].cuda(non_blocking=True)
|
| 281 |
cfg_scales = self._cpu_cfg_scales[:num_seqs].cuda(non_blocking=True)
|
| 282 |
-
top_ks = self._cpu_top_ks[:num_seqs].cuda(non_blocking=True)
|
| 283 |
-
top_ps = self._cpu_top_ps[:num_seqs].cuda(non_blocking=True)
|
| 284 |
-
repetition_penalties = self._cpu_repetition_penalties[:num_seqs].cuda(non_blocking=True)
|
| 285 |
|
| 286 |
return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
|
| 287 |
|
|
@@ -309,27 +333,15 @@ class ModelRunner:
|
|
| 309 |
[cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 310 |
where uncond_seqi is the paired unconditional sequence of cond_seqi."""
|
| 311 |
# Check if this is a CFG batch (contains paired conditional and unconditional sequences)
|
| 312 |
-
is_cfg_batch =
|
| 313 |
-
if len(seqs) > 0:
|
| 314 |
-
# CFG batch if first sequence has cfg_scale > 1.0 and paired_seq
|
| 315 |
-
if seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None:
|
| 316 |
-
is_cfg_batch = True
|
| 317 |
-
# Verify batch structure: first half conditional, second half unconditional
|
| 318 |
-
num_cond = len(seqs) // 2
|
| 319 |
-
for i in range(num_cond):
|
| 320 |
-
if seqs[i].is_unconditional or seqs[i + num_cond].is_unconditional == False:
|
| 321 |
-
is_cfg_batch = False
|
| 322 |
-
break
|
| 323 |
-
|
| 324 |
if is_cfg_batch:
|
| 325 |
# CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 326 |
num_cond = len(seqs) // 2
|
| 327 |
cond_seqs = seqs[:num_cond]
|
| 328 |
-
uncond_seqs = seqs[num_cond:]
|
| 329 |
|
| 330 |
# Prepare inputs for both conditional and unconditional (they're already in the batch)
|
| 331 |
-
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
|
| 332 |
-
else self.prepare_decode(seqs))
|
| 333 |
sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
|
| 334 |
if sample_params is not None:
|
| 335 |
temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
|
|
@@ -380,7 +392,7 @@ class ModelRunner:
|
|
| 380 |
logits_cfg[i:i+1] = seq.logits_processor(seq_input_ids, logits_cfg[i:i+1])
|
| 381 |
|
| 382 |
# Prepare input_ids for sampler (for repetition penalty, though we already applied it)
|
| 383 |
-
cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
|
| 384 |
|
| 385 |
# Sample from CFG logits
|
| 386 |
token_ids_cfg = self.sampler(
|
|
@@ -389,7 +401,7 @@ class ModelRunner:
|
|
| 389 |
top_ks=top_ks if top_ks is not None else None,
|
| 390 |
top_ps=top_ps if top_ps is not None else None,
|
| 391 |
repetition_penalties=None, # Already applied above
|
| 392 |
-
input_ids=cond_input_ids,
|
| 393 |
).tolist()
|
| 394 |
|
| 395 |
# Update logits processor state after sampling
|
|
@@ -448,7 +460,7 @@ class ModelRunner:
|
|
| 448 |
logits[i] = processed[0]
|
| 449 |
|
| 450 |
# Prepare input_ids for sampler
|
| 451 |
-
seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
|
| 452 |
|
| 453 |
token_ids = self.sampler(
|
| 454 |
logits,
|
|
@@ -456,7 +468,7 @@ class ModelRunner:
|
|
| 456 |
top_ks=top_ks if top_ks is not None else None,
|
| 457 |
top_ps=top_ps if top_ps is not None else None,
|
| 458 |
repetition_penalties=None, # Already applied above
|
| 459 |
-
input_ids=seq_input_ids,
|
| 460 |
).tolist()
|
| 461 |
|
| 462 |
# Update logits processor state after sampling
|
|
|
|
| 93 |
def _allocate_sample_buffers(self):
|
| 94 |
"""Pre-allocate reusable buffers for sampling to avoid repeated tensor creation."""
|
| 95 |
max_bs = self.config.max_num_seqs
|
| 96 |
+
max_tokens = self.config.max_num_batched_tokens
|
| 97 |
+
max_num_blocks = (self.config.max_model_len + self.block_size - 1) // self.block_size
|
| 98 |
|
| 99 |
# Pre-allocate pinned memory buffers on CPU for fast transfer
|
| 100 |
# Must explicitly specify device="cpu" since default device may be "cuda"
|
|
|
|
| 109 |
self._cpu_positions = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 110 |
self._cpu_slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 111 |
self._cpu_context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 112 |
+
|
| 113 |
+
# Pre-allocate prefill buffers on CPU with pinned memory (optimization to avoid repeated tensor creation)
|
| 114 |
+
self._cpu_prefill_input_ids = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 115 |
+
self._cpu_prefill_positions = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 116 |
+
self._cpu_prefill_cu_seqlens = torch.zeros(max_bs + 1, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 117 |
+
self._cpu_prefill_slot_mapping = torch.zeros(max_tokens, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 118 |
+
|
| 119 |
+
# Pre-allocate block tables buffer (shared by both decode and prefill)
|
| 120 |
+
self._cpu_block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 121 |
+
|
| 122 |
+
# Pre-allocate buffer for sequence token IDs (used in logits processor and sampler)
|
| 123 |
+
# Max length is max_model_len since sequences can be that long
|
| 124 |
+
self._seq_token_ids_buffer = torch.zeros(max_bs, self.config.max_model_len, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 125 |
|
| 126 |
def exit(self):
|
| 127 |
if self.world_size > 1:
|
|
|
|
| 242 |
if i != seq.num_blocks - 1:
|
| 243 |
end = start + self.block_size
|
| 244 |
else:
|
| 245 |
+
end = start + seq.last_block_num_tokens
|
| 246 |
slot_mapping.extend(list(range(start, end)))
|
| 247 |
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
| 248 |
block_tables = self.prepare_block_tables(seqs)
|
|
|
|
| 284 |
target_seqs = seqs
|
| 285 |
|
| 286 |
# Fill pre-allocated CPU buffers
|
| 287 |
+
top_ks_is_zero = True
|
| 288 |
+
top_ps_is_one = True
|
| 289 |
+
repetition_penalties_is_one = True
|
| 290 |
for i, seq in enumerate(target_seqs):
|
| 291 |
self._cpu_temperatures[i] = seq.temperature
|
| 292 |
self._cpu_cfg_scales[i] = seq.cfg_scale
|
| 293 |
self._cpu_top_ks[i] = seq.top_k if seq.top_k is not None else 0
|
| 294 |
+
if seq.top_k is not None and seq.top_k > 0:
|
| 295 |
+
top_ks_is_zero = False
|
| 296 |
self._cpu_top_ps[i] = seq.top_p if seq.top_p is not None else 1.0
|
| 297 |
+
if seq.top_p is not None and seq.top_p == 1.0:
|
| 298 |
+
top_ps_is_one = False
|
| 299 |
self._cpu_repetition_penalties[i] = seq.repetition_penalty if seq.repetition_penalty is not None else 1.0
|
| 300 |
+
if seq.repetition_penalty is not None and seq.repetition_penalty == 1.0:
|
| 301 |
+
repetition_penalties_is_one = False
|
| 302 |
|
| 303 |
# Transfer to GPU using sliced views (single batched transfer)
|
| 304 |
temperatures = self._cpu_temperatures[:num_seqs].cuda(non_blocking=True)
|
| 305 |
cfg_scales = self._cpu_cfg_scales[:num_seqs].cuda(non_blocking=True)
|
| 306 |
+
top_ks = self._cpu_top_ks[:num_seqs].cuda(non_blocking=True) if not top_ks_is_zero else None
|
| 307 |
+
top_ps = self._cpu_top_ps[:num_seqs].cuda(non_blocking=True) if not top_ps_is_one else None
|
| 308 |
+
repetition_penalties = self._cpu_repetition_penalties[:num_seqs].cuda(non_blocking=True) if not repetition_penalties_is_one else None
|
| 309 |
|
| 310 |
return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
|
| 311 |
|
|
|
|
| 333 |
[cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 334 |
where uncond_seqi is the paired unconditional sequence of cond_seqi."""
|
| 335 |
# Check if this is a CFG batch (contains paired conditional and unconditional sequences)
|
| 336 |
+
is_cfg_batch = seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
if is_cfg_batch:
|
| 338 |
# CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 339 |
num_cond = len(seqs) // 2
|
| 340 |
cond_seqs = seqs[:num_cond]
|
| 341 |
+
# uncond_seqs = seqs[num_cond:]
|
| 342 |
|
| 343 |
# Prepare inputs for both conditional and unconditional (they're already in the batch)
|
| 344 |
+
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs))
|
|
|
|
| 345 |
sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
|
| 346 |
if sample_params is not None:
|
| 347 |
temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
|
|
|
|
| 392 |
logits_cfg[i:i+1] = seq.logits_processor(seq_input_ids, logits_cfg[i:i+1])
|
| 393 |
|
| 394 |
# Prepare input_ids for sampler (for repetition penalty, though we already applied it)
|
| 395 |
+
# cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
|
| 396 |
|
| 397 |
# Sample from CFG logits
|
| 398 |
token_ids_cfg = self.sampler(
|
|
|
|
| 401 |
top_ks=top_ks if top_ks is not None else None,
|
| 402 |
top_ps=top_ps if top_ps is not None else None,
|
| 403 |
repetition_penalties=None, # Already applied above
|
| 404 |
+
# input_ids=cond_input_ids,
|
| 405 |
).tolist()
|
| 406 |
|
| 407 |
# Update logits processor state after sampling
|
|
|
|
| 460 |
logits[i] = processed[0]
|
| 461 |
|
| 462 |
# Prepare input_ids for sampler
|
| 463 |
+
# seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
|
| 464 |
|
| 465 |
token_ids = self.sampler(
|
| 466 |
logits,
|
|
|
|
| 468 |
top_ks=top_ks if top_ks is not None else None,
|
| 469 |
top_ps=top_ps if top_ps is not None else None,
|
| 470 |
repetition_penalties=None, # Already applied above
|
| 471 |
+
# input_ids=seq_input_ids,
|
| 472 |
).tolist()
|
| 473 |
|
| 474 |
# Update logits processor state after sampling
|
acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py
CHANGED
|
@@ -85,6 +85,7 @@ class Sampler(nn.Module):
|
|
| 85 |
def __init__(self):
|
| 86 |
super().__init__()
|
| 87 |
|
|
|
|
| 88 |
def forward(
|
| 89 |
self,
|
| 90 |
logits: torch.Tensor,
|
|
@@ -102,27 +103,12 @@ class Sampler(nn.Module):
|
|
| 102 |
"""
|
| 103 |
# Apply temperature
|
| 104 |
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
logits = apply_top_k_top_p(
|
| 115 |
-
logits,
|
| 116 |
-
top_ks if need_topk else None,
|
| 117 |
-
top_ps if need_topp else None,
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
# Sample using compiled function
|
| 121 |
-
return self._sample(logits)
|
| 122 |
-
|
| 123 |
-
@torch.compile(dynamic=True)
|
| 124 |
-
def _sample(self, logits: torch.Tensor) -> torch.Tensor:
|
| 125 |
-
"""Compiled sampling kernel - no graph breaks here."""
|
| 126 |
-
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
| 127 |
-
q = torch.empty_like(probs).exponential_()
|
| 128 |
-
return probs.div(q).argmax(dim=-1)
|
|
|
|
| 85 |
def __init__(self):
|
| 86 |
super().__init__()
|
| 87 |
|
| 88 |
+
@torch.compile
|
| 89 |
def forward(
|
| 90 |
self,
|
| 91 |
logits: torch.Tensor,
|
|
|
|
| 103 |
"""
|
| 104 |
# Apply temperature
|
| 105 |
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
| 106 |
+
|
| 107 |
+
logits = apply_top_k_top_p(
|
| 108 |
+
logits,
|
| 109 |
+
top_ks,
|
| 110 |
+
top_ps,
|
| 111 |
+
)
|
| 112 |
+
probs = torch.softmax(logits, dim=-1)
|
| 113 |
+
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
|
| 114 |
+
return sample_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
profile_inference.py
CHANGED
|
@@ -1,223 +1,682 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
Usage:
|
| 6 |
-
python profile_inference.py
|
| 7 |
-
python profile_inference.py --warmup
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
-
import cProfile
|
| 11 |
-
import pstats
|
| 12 |
-
import io
|
| 13 |
import time
|
| 14 |
import argparse
|
| 15 |
import sys
|
| 16 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# Add project root to path
|
| 19 |
project_root = os.path.abspath(os.path.dirname(__file__))
|
| 20 |
if project_root not in sys.path:
|
| 21 |
sys.path.insert(0, project_root)
|
| 22 |
|
|
|
|
| 23 |
from acestep.inference import generate_music, GenerationParams, GenerationConfig
|
| 24 |
from acestep.handler import AceStepHandler
|
| 25 |
from acestep.llm_inference import LLMHandler
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def profile_with_cprofile(dit_handler, llm_handler, params, config, warmup=False):
|
| 31 |
-
"""Profile using Python's built-in cProfile.
|
| 32 |
-
|
| 33 |
-
Args:
|
| 34 |
-
warmup: If True, run once for warmup before profiling (default: False)
|
| 35 |
-
"""
|
| 36 |
-
print("=" * 80)
|
| 37 |
-
print("Profiling with cProfile")
|
| 38 |
-
print("=" * 80)
|
| 39 |
-
|
| 40 |
-
# Warmup run (to exclude PyTorch compilation overhead)
|
| 41 |
-
if warmup:
|
| 42 |
-
print("\n[Warmup] Running first generation to warm up (PyTorch compilation, etc.)...")
|
| 43 |
-
warmup_start = time.time()
|
| 44 |
-
params.use_cot_metas = False
|
| 45 |
-
config.is_format_caption = True
|
| 46 |
-
config.use_constrained_decoding = False
|
| 47 |
-
warmup_result = generate_music(dit_handler, llm_handler, params, config, save_dir="./")
|
| 48 |
-
warmup_time = time.time() - warmup_start
|
| 49 |
-
print(f"[Warmup] Completed in {warmup_time:.2f}s")
|
| 50 |
-
if not warmup_result.success:
|
| 51 |
-
print(f"[Warmup] ā Warmup generation failed: {warmup_result.error}")
|
| 52 |
-
return warmup_result
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
ps.print_stats(30)
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
with open(output_file, 'w') as f:
|
| 84 |
-
# Create a new Stats object with file as stream
|
| 85 |
-
ps_file = pstats.Stats(profiler, stream=f)
|
| 86 |
-
ps_file.sort_stats('cumulative')
|
| 87 |
-
ps_file.print_stats()
|
| 88 |
-
print(f"\nDetailed profile saved to: {output_file}")
|
| 89 |
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
)
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
type=str,
|
| 116 |
-
default="acestep-5Hz-lm-0.6B-v3",
|
| 117 |
-
help="LM model path"
|
| 118 |
)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
type=str,
|
| 122 |
-
default="vllm",
|
| 123 |
-
help="LM backend"
|
| 124 |
)
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
args = parser.parse_args()
|
| 132 |
|
| 133 |
-
# Initialize
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
dit_handler = AceStepHandler()
|
| 136 |
llm_handler = LLMHandler()
|
| 137 |
|
| 138 |
-
|
| 139 |
-
print(" - Initializing DiT model...")
|
| 140 |
status_dit, success_dit = dit_handler.initialize_service(
|
| 141 |
project_root=project_root,
|
| 142 |
config_path=args.config_path,
|
| 143 |
device=args.device,
|
|
|
|
| 144 |
)
|
| 145 |
if not success_dit:
|
| 146 |
-
print(f" ā
|
| 147 |
sys.exit(1)
|
| 148 |
-
print("
|
| 149 |
-
|
| 150 |
-
# Initialize LLM
|
| 151 |
-
print(" - Initializing LLM model...")
|
| 152 |
-
status_llm, success_llm = llm_handler.initialize(
|
| 153 |
-
checkpoint_dir=args.checkpoint_dir,
|
| 154 |
-
lm_model_path=args.lm_model,
|
| 155 |
-
backend=args.lm_backend,
|
| 156 |
-
device=args.device,
|
| 157 |
-
)
|
| 158 |
-
if success_llm:
|
| 159 |
-
print(" ā LM model initialized")
|
| 160 |
-
else:
|
| 161 |
-
print(f" ā LM initialization failed: {status_llm}")
|
| 162 |
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
bpm=data.get('bpm'),
|
| 178 |
-
keyscale=data.get('keyscale', ''),
|
| 179 |
-
timesignature=time_sig,
|
| 180 |
-
vocal_language=data.get('language', 'unknown'),
|
| 181 |
-
duration=data.get('duration'),
|
| 182 |
-
thinking=data.get('think', False),
|
| 183 |
-
inference_steps=data.get('inference_steps', 8),
|
| 184 |
-
seed=42,
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
config = GenerationConfig()
|
| 188 |
-
config.batch_size = data.get('batch_size', 1)
|
| 189 |
-
|
| 190 |
-
return params, config
|
| 191 |
-
|
| 192 |
-
except Exception as e:
|
| 193 |
-
print(f" ā Failed to load example file: {e}")
|
| 194 |
-
return None, None
|
| 195 |
-
|
| 196 |
-
# Load production example (same as acestep/inference.py)
|
| 197 |
-
example_file = os.path.join(project_root, "examples", "text2music", "example_05.json")
|
| 198 |
|
|
|
|
|
|
|
| 199 |
if not os.path.exists(example_file):
|
| 200 |
-
print(f"\n
|
| 201 |
-
print(" Please ensure the examples directory exists.")
|
| 202 |
sys.exit(1)
|
| 203 |
|
| 204 |
-
print(f"\n
|
| 205 |
params, config = load_example_config(example_file)
|
| 206 |
|
| 207 |
if not params or not config:
|
| 208 |
-
print("
|
| 209 |
sys.exit(1)
|
| 210 |
|
| 211 |
-
print("
|
| 212 |
-
print("
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
|
| 221 |
if __name__ == "__main__":
|
| 222 |
main()
|
| 223 |
-
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Enhanced profiling script for ACE-Step inference with deep LLM analysis
|
| 4 |
+
|
| 5 |
+
This script helps diagnose why LLM generation is slow by tracking:
|
| 6 |
+
1. Total tokens generated vs expected throughput (200 tokens/sec baseline)
|
| 7 |
+
2. Per-iteration timing to detect compilation overhead or slow operations
|
| 8 |
+
3. Constrained decoding overhead
|
| 9 |
+
4. CFG overhead (2x forward passes)
|
| 10 |
+
5. Model forward time vs sampling/processing time
|
| 11 |
|
| 12 |
Usage:
|
| 13 |
+
python profile_inference.py # Standard profiling with warmup
|
| 14 |
+
python profile_inference.py --no-warmup # Profile first run (includes compilation)
|
| 15 |
+
python profile_inference.py --llm-debug # Deep LLM performance debugging
|
| 16 |
+
python profile_inference.py --detailed # Add cProfile function-level analysis
|
| 17 |
+
|
| 18 |
+
Inference mode options:
|
| 19 |
+
python profile_inference.py --thinking # Enable CoT for code generation
|
| 20 |
+
python profile_inference.py --use-constrained-decoding # Use FSM constrained decoding
|
| 21 |
+
python profile_inference.py --use-cot-metas # Enable LM to generate metadata via CoT
|
| 22 |
"""
|
| 23 |
|
|
|
|
|
|
|
|
|
|
| 24 |
import time
|
| 25 |
import argparse
|
| 26 |
import sys
|
| 27 |
import os
|
| 28 |
+
from contextlib import contextmanager
|
| 29 |
+
from collections import defaultdict
|
| 30 |
+
import json
|
| 31 |
+
from typing import Tuple, Dict, Any, List
|
| 32 |
+
from functools import wraps
|
| 33 |
|
| 34 |
# Add project root to path
|
| 35 |
project_root = os.path.abspath(os.path.dirname(__file__))
|
| 36 |
if project_root not in sys.path:
|
| 37 |
sys.path.insert(0, project_root)
|
| 38 |
|
| 39 |
+
import torch
|
| 40 |
from acestep.inference import generate_music, GenerationParams, GenerationConfig
|
| 41 |
from acestep.handler import AceStepHandler
|
| 42 |
from acestep.llm_inference import LLMHandler
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class PreciseTimer:
|
| 46 |
+
"""High-precision timer with CUDA synchronization for accurate GPU timing"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
def __init__(self, device="cuda"):
|
| 49 |
+
self.device = device
|
| 50 |
+
self.timings = defaultdict(list)
|
| 51 |
+
self.enabled = True
|
| 52 |
+
|
| 53 |
+
def sync(self):
|
| 54 |
+
"""Synchronize CUDA operations for accurate timing"""
|
| 55 |
+
if self.enabled and self.device.startswith("cuda") and torch.cuda.is_available():
|
| 56 |
+
torch.cuda.synchronize()
|
| 57 |
|
| 58 |
+
@contextmanager
|
| 59 |
+
def time(self, name: str):
|
| 60 |
+
"""Time a code section with CUDA synchronization"""
|
| 61 |
+
if not self.enabled:
|
| 62 |
+
yield
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
self.sync()
|
| 66 |
+
start = time.perf_counter()
|
| 67 |
+
try:
|
| 68 |
+
yield
|
| 69 |
+
finally:
|
| 70 |
+
self.sync()
|
| 71 |
+
elapsed = time.perf_counter() - start
|
| 72 |
+
self.timings[name].append(elapsed)
|
| 73 |
+
|
| 74 |
+
def get_total(self, name: str) -> float:
|
| 75 |
+
"""Get total accumulated time for a section"""
|
| 76 |
+
return sum(self.timings.get(name, []))
|
| 77 |
|
| 78 |
+
def get_mean(self, name: str) -> float:
|
| 79 |
+
"""Get mean time per call for a section"""
|
| 80 |
+
times = self.timings.get(name, [])
|
| 81 |
+
return sum(times) / len(times) if times else 0.0
|
| 82 |
|
| 83 |
+
def get_count(self, name: str) -> int:
|
| 84 |
+
"""Get number of calls for a section"""
|
| 85 |
+
return len(self.timings.get(name, []))
|
|
|
|
| 86 |
|
| 87 |
+
def get_all(self, name: str) -> List[float]:
|
| 88 |
+
"""Get all timing samples for a section"""
|
| 89 |
+
return self.timings.get(name, [])
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class LLMDebugger:
|
| 93 |
+
"""Track detailed LLM performance metrics to diagnose slow generation"""
|
| 94 |
|
| 95 |
+
def __init__(self):
|
| 96 |
+
self.reset()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
+
def reset(self):
|
| 99 |
+
"""Reset all metrics"""
|
| 100 |
+
self.total_tokens = 0
|
| 101 |
+
self.generation_start = None
|
| 102 |
+
self.generation_end = None
|
| 103 |
+
self.output_text = ""
|
| 104 |
+
self.prompt_length = 0
|
| 105 |
+
|
| 106 |
+
def start(self, prompt_length: int = 0):
|
| 107 |
+
"""Mark generation start"""
|
| 108 |
+
self.generation_start = time.perf_counter()
|
| 109 |
+
self.prompt_length = prompt_length
|
| 110 |
+
|
| 111 |
+
def end(self, output_text: str = ""):
|
| 112 |
+
"""Mark generation end and store output"""
|
| 113 |
+
self.generation_end = time.perf_counter()
|
| 114 |
+
self.output_text = output_text
|
| 115 |
+
|
| 116 |
+
def set_token_count(self, count: int):
|
| 117 |
+
"""Set total token count"""
|
| 118 |
+
self.total_tokens = count
|
| 119 |
+
|
| 120 |
+
def get_throughput(self) -> float:
|
| 121 |
+
"""Calculate actual tokens per second"""
|
| 122 |
+
if self.generation_start and self.generation_end and self.total_tokens > 0:
|
| 123 |
+
total_time = self.generation_end - self.generation_start
|
| 124 |
+
if total_time > 0:
|
| 125 |
+
return self.total_tokens / total_time
|
| 126 |
+
return 0.0
|
| 127 |
+
|
| 128 |
+
def print_analysis(self):
|
| 129 |
+
"""Print detailed LLM performance analysis"""
|
| 130 |
+
if not self.generation_start or not self.generation_end:
|
| 131 |
+
return
|
| 132 |
+
|
| 133 |
+
print("\n" + "=" * 100)
|
| 134 |
+
print("š LLM PERFORMANCE DEEP DIVE")
|
| 135 |
+
print("=" * 100)
|
| 136 |
+
|
| 137 |
+
total_time = self.generation_end - self.generation_start
|
| 138 |
+
throughput = self.get_throughput()
|
| 139 |
+
|
| 140 |
+
# Basic metrics table
|
| 141 |
+
print(f"\n{'Metric':<40} {'Value':<20} {'Notes'}")
|
| 142 |
+
print("-" * 100)
|
| 143 |
+
print(f"{'Total Tokens Generated:':<40} {self.total_tokens:<20} (new tokens only)")
|
| 144 |
+
print(f"{'Prompt Length (estimate):':<40} {self.prompt_length:<20} (input tokens)")
|
| 145 |
+
print(f"{'Total Generation Time:':<40} {total_time:<20.3f} seconds")
|
| 146 |
+
print(f"{'Measured Throughput:':<40} {throughput:<20.1f} tokens/sec")
|
| 147 |
+
print(f"{'Expected Throughput:':<40} {'200':<20} tokens/sec (baseline)")
|
| 148 |
+
|
| 149 |
+
# Calculate performance gap
|
| 150 |
+
if throughput > 0:
|
| 151 |
+
slowdown = 200.0 / throughput
|
| 152 |
+
efficiency = (throughput / 200.0) * 100
|
| 153 |
+
print(f"{'Performance vs Baseline:':<40} {efficiency:<20.1f}% of expected")
|
| 154 |
+
print(f"{'Slowdown Factor:':<40} {slowdown:<20.2f}x slower")
|
| 155 |
+
|
| 156 |
+
# Analyze generated output
|
| 157 |
+
if self.output_text:
|
| 158 |
+
print(f"\n{'Output Analysis:':<40}")
|
| 159 |
+
print(f"{' Output length:':<40} {len(self.output_text):<20} characters")
|
| 160 |
+
|
| 161 |
+
# Count audio codes
|
| 162 |
+
import re
|
| 163 |
+
code_pattern = r'<\|audio_code_\d+\|>'
|
| 164 |
+
codes = re.findall(code_pattern, self.output_text)
|
| 165 |
+
if codes:
|
| 166 |
+
print(f"{' Audio codes generated:':<40} {len(codes):<20} codes")
|
| 167 |
+
print(f"{' Expected audio duration:':<40} {f'~{len(codes)/5:.1f}s':<20} (5 codes per second)")
|
| 168 |
+
if total_time > 0:
|
| 169 |
+
print(f"{' Time per audio code:':<40} {f'{total_time/len(codes)*1000:.1f}ms':<20}")
|
| 170 |
+
|
| 171 |
+
# Check for CoT section
|
| 172 |
+
if '<think>' in self.output_text and '</think>' in self.output_text:
|
| 173 |
+
cot_start = self.output_text.find('<think>')
|
| 174 |
+
cot_end = self.output_text.find('</think>') + 8
|
| 175 |
+
cot_section = self.output_text[cot_start:cot_end]
|
| 176 |
+
cot_token_est = len(cot_section) // 4
|
| 177 |
+
print(f"{' CoT section tokens (estimate):':<40} {f'~{cot_token_est}':<20}")
|
| 178 |
+
|
| 179 |
+
# Diagnostic guidance
|
| 180 |
+
print("\n" + "=" * 100)
|
| 181 |
+
print("š§ DIAGNOSTIC GUIDANCE")
|
| 182 |
+
print("=" * 100)
|
| 183 |
+
|
| 184 |
+
if throughput < 50:
|
| 185 |
+
print("\nā ļø CRITICAL: Throughput is extremely low (<50 tokens/sec)")
|
| 186 |
+
print("\nThis is ~4x slower than expected. Likely causes:")
|
| 187 |
+
print(" 1. ā Constrained decoding FSM overhead")
|
| 188 |
+
print(" ā Each token triggers FSM state machine validation")
|
| 189 |
+
print(" ā Try: set use_constrained_decoding=False in config")
|
| 190 |
+
print(" 2. ā CFG with double forward passes")
|
| 191 |
+
print(" ā cfg_scale > 1.0 means running model twice per token")
|
| 192 |
+
print(" ā Check: params.lm_cfg_scale value")
|
| 193 |
+
print(" 3. ā Running in eager mode without compilation")
|
| 194 |
+
print(" ā PyTorch should compile kernels after warmup")
|
| 195 |
+
print(" ā Check: torch._dynamo.config settings")
|
| 196 |
+
|
| 197 |
+
elif throughput < 100:
|
| 198 |
+
print("\nā ļø WARNING: Throughput is low (50-100 tokens/sec)")
|
| 199 |
+
print("\nLikely causes:")
|
| 200 |
+
print(" 1. Constrained decoding overhead (~30-50% slowdown expected)")
|
| 201 |
+
print(" 2. CFG enabled (2x compute per token if cfg_scale > 1.0)")
|
| 202 |
+
print(" 3. Small model or inefficient GPU utilization")
|
| 203 |
+
|
| 204 |
+
elif throughput < 150:
|
| 205 |
+
print("\nā ļø Throughput is below baseline but acceptable (100-150 tokens/sec)")
|
| 206 |
+
print("\nMinor overhead from:")
|
| 207 |
+
print(" - Constrained decoding: ~20-30% overhead")
|
| 208 |
+
print(" - Profiling instrumentation: ~5-10% overhead")
|
| 209 |
+
|
| 210 |
+
else:
|
| 211 |
+
print(f"\nā Throughput is good ({throughput:.1f} tokens/sec)")
|
| 212 |
+
print(" Performance is within acceptable range")
|
| 213 |
|
| 214 |
|
| 215 |
+
# Global instances
|
| 216 |
+
timer = None
|
| 217 |
+
llm_debugger = None
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def wrap_method_with_timing(obj, method_name: str, timing_key: str):
|
| 221 |
+
"""Wrap a method with timing instrumentation"""
|
| 222 |
+
original_method = getattr(obj, method_name)
|
| 223 |
+
|
| 224 |
+
@wraps(original_method)
|
| 225 |
+
def timed_wrapper(*args, **kwargs):
|
| 226 |
+
with timer.time(timing_key):
|
| 227 |
+
return original_method(*args, **kwargs)
|
| 228 |
+
|
| 229 |
+
setattr(obj, method_name, timed_wrapper)
|
| 230 |
+
return original_method
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def wrap_llm_with_debug_tracking(llm_handler):
|
| 234 |
+
"""Wrap LLM generation with detailed performance tracking"""
|
| 235 |
+
original_method = llm_handler.generate_with_stop_condition
|
| 236 |
+
|
| 237 |
+
@wraps(original_method)
|
| 238 |
+
def debug_wrapper(*args, **kwargs):
|
| 239 |
+
# Estimate prompt length
|
| 240 |
+
caption = kwargs.get('caption', args[0] if len(args) > 0 else "")
|
| 241 |
+
lyrics = kwargs.get('lyrics', args[1] if len(args) > 1 else "")
|
| 242 |
+
prompt_estimate = len(caption) + len(lyrics)
|
| 243 |
+
prompt_tokens_estimate = prompt_estimate // 4
|
| 244 |
+
|
| 245 |
+
# Start tracking
|
| 246 |
+
llm_debugger.reset()
|
| 247 |
+
llm_debugger.start(prompt_length=prompt_tokens_estimate)
|
| 248 |
+
|
| 249 |
+
# Call original with timing
|
| 250 |
+
with timer.time('llm_inference'):
|
| 251 |
+
result = original_method(*args, **kwargs)
|
| 252 |
+
|
| 253 |
+
# Extract and analyze output
|
| 254 |
+
output_text = ""
|
| 255 |
+
if isinstance(result, tuple) and len(result) >= 2:
|
| 256 |
+
if isinstance(result[1], list):
|
| 257 |
+
# Batch mode
|
| 258 |
+
output_text = "".join(result[1])
|
| 259 |
+
else:
|
| 260 |
+
# Single mode
|
| 261 |
+
cot_output = ""
|
| 262 |
+
if isinstance(result[0], dict):
|
| 263 |
+
for v in result[0].values():
|
| 264 |
+
if isinstance(v, str):
|
| 265 |
+
cot_output += v
|
| 266 |
+
output_text = cot_output + str(result[1])
|
| 267 |
+
|
| 268 |
+
# Count tokens
|
| 269 |
+
import re
|
| 270 |
+
code_pattern = r'<\|audio_code_\d+\|>'
|
| 271 |
+
codes = re.findall(code_pattern, output_text)
|
| 272 |
+
remaining_text = re.sub(code_pattern, '', output_text)
|
| 273 |
+
cot_tokens_estimate = len(remaining_text) // 4
|
| 274 |
+
total_tokens = len(codes) + cot_tokens_estimate
|
| 275 |
+
|
| 276 |
+
llm_debugger.set_token_count(total_tokens)
|
| 277 |
+
llm_debugger.end(output_text)
|
| 278 |
+
|
| 279 |
+
return result
|
| 280 |
+
|
| 281 |
+
llm_handler.generate_with_stop_condition = debug_wrapper
|
| 282 |
+
return original_method
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def instrument_handlers(dit_handler, llm_handler, enable_llm_debug=False):
|
| 286 |
+
"""Add timing instrumentation to handler methods"""
|
| 287 |
+
originals = {}
|
| 288 |
+
|
| 289 |
+
# Instrument LLM
|
| 290 |
+
if llm_handler and llm_handler.llm_initialized:
|
| 291 |
+
if enable_llm_debug:
|
| 292 |
+
originals['llm_generate'] = wrap_llm_with_debug_tracking(llm_handler)
|
| 293 |
+
else:
|
| 294 |
+
originals['llm_generate'] = wrap_method_with_timing(
|
| 295 |
+
llm_handler, 'generate_with_stop_condition', 'llm_inference'
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# Instrument DiT handler
|
| 299 |
+
originals['dit_prepare'] = wrap_method_with_timing(
|
| 300 |
+
dit_handler, 'prepare_batch_data', 'prepare_batch_data'
|
| 301 |
)
|
| 302 |
+
originals['dit_generate'] = wrap_method_with_timing(
|
| 303 |
+
dit_handler, 'service_generate', 'dit_inference'
|
|
|
|
|
|
|
|
|
|
| 304 |
)
|
| 305 |
+
originals['dit_decode'] = wrap_method_with_timing(
|
| 306 |
+
dit_handler, 'tiled_decode', 'vae_decode'
|
|
|
|
|
|
|
|
|
|
| 307 |
)
|
| 308 |
+
|
| 309 |
+
return originals
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def restore_handlers(dit_handler, llm_handler, originals):
|
| 313 |
+
"""Restore original handler methods after profiling"""
|
| 314 |
+
if llm_handler and 'llm_generate' in originals:
|
| 315 |
+
llm_handler.generate_with_stop_condition = originals['llm_generate']
|
| 316 |
+
|
| 317 |
+
dit_handler.prepare_batch_data = originals['dit_prepare']
|
| 318 |
+
dit_handler.service_generate = originals['dit_generate']
|
| 319 |
+
dit_handler.tiled_decode = originals['dit_decode']
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def print_profiling_results(total_time: float, show_llm_debug: bool = False):
|
| 323 |
+
"""Print comprehensive profiling results with performance insights"""
|
| 324 |
+
print("\n" + "=" * 100)
|
| 325 |
+
print("šÆ PROFILING RESULTS")
|
| 326 |
+
print("=" * 100)
|
| 327 |
+
|
| 328 |
+
# Define timing categories
|
| 329 |
+
model_sections = {
|
| 330 |
+
'llm_inference': 'LLM Inference (5Hz Language Model)',
|
| 331 |
+
'dit_inference': 'DiT Inference (Diffusion Transformer)',
|
| 332 |
+
'vae_decode': 'VAE Decode (Audio Decoder)',
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
non_model_sections = {
|
| 336 |
+
'prepare_batch_data': 'Prepare Batch Data (embedding, formatting)',
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
# Calculate totals
|
| 340 |
+
model_time = sum(timer.get_total(k) for k in model_sections.keys())
|
| 341 |
+
non_model_time = sum(timer.get_total(k) for k in non_model_sections.keys())
|
| 342 |
+
other_time = total_time - model_time - non_model_time
|
| 343 |
+
|
| 344 |
+
# Print summary table
|
| 345 |
+
print(f"\n{'CATEGORY':<50} {'TIME (s)':<12} {'%':<8} {'CALLS':<8}")
|
| 346 |
+
print("-" * 100)
|
| 347 |
+
|
| 348 |
+
# Model time breakdown
|
| 349 |
+
print(f"\n{'š¤ MODEL TIME (Total)':<50} {model_time:<12.3f} {100*model_time/total_time:>6.1f}% {'':<8}")
|
| 350 |
+
for key, desc in model_sections.items():
|
| 351 |
+
t = timer.get_total(key)
|
| 352 |
+
c = timer.get_count(key)
|
| 353 |
+
if c > 0:
|
| 354 |
+
mean = timer.get_mean(key)
|
| 355 |
+
pct = 100 * t / total_time
|
| 356 |
+
print(f" {'āā ' + desc:<48} {t:<12.3f} {pct:>6.1f}% {c:<8} (avg: {mean:.3f}s)")
|
| 357 |
+
|
| 358 |
+
# Non-model time breakdown
|
| 359 |
+
print(f"\n{'āļø NON-MODEL TIME (Total)':<50} {non_model_time:<12.3f} {100*non_model_time/total_time:>6.1f}% {'':<8}")
|
| 360 |
+
for key, desc in non_model_sections.items():
|
| 361 |
+
t = timer.get_total(key)
|
| 362 |
+
c = timer.get_count(key)
|
| 363 |
+
if c > 0:
|
| 364 |
+
mean = timer.get_mean(key)
|
| 365 |
+
pct = 100 * t / total_time
|
| 366 |
+
print(f" {'āā ' + desc:<48} {t:<12.3f} {pct:>6.1f}% {c:<8} (avg: {mean:.3f}s)")
|
| 367 |
+
|
| 368 |
+
# Other time
|
| 369 |
+
if other_time > 0.01:
|
| 370 |
+
pct = 100 * other_time / total_time
|
| 371 |
+
print(f"\n{'š¦ OTHER TIME (I/O, overhead, audio save)':<50} {other_time:<12.3f} {pct:>6.1f}% {'':<8}")
|
| 372 |
+
|
| 373 |
+
print(f"\n{'š TOTAL TIME':<50} {total_time:<12.3f} {'100.0%':>6} {'':<8}")
|
| 374 |
+
|
| 375 |
+
# Show LLM detailed analysis if enabled
|
| 376 |
+
if show_llm_debug:
|
| 377 |
+
llm_debugger.print_analysis()
|
| 378 |
+
|
| 379 |
+
# Performance insights
|
| 380 |
+
print("\n" + "=" * 100)
|
| 381 |
+
print("š” PERFORMANCE INSIGHTS")
|
| 382 |
+
print("=" * 100)
|
| 383 |
+
|
| 384 |
+
llm_t = timer.get_total('llm_inference')
|
| 385 |
+
dit_t = timer.get_total('dit_inference')
|
| 386 |
+
vae_t = timer.get_total('vae_decode')
|
| 387 |
+
prep_t = timer.get_total('prepare_batch_data')
|
| 388 |
+
|
| 389 |
+
# Model time insights
|
| 390 |
+
if model_time > 0:
|
| 391 |
+
print(f"\nā Model operations: {model_time:.3f}s ({100*model_time/total_time:.1f}% of total)")
|
| 392 |
+
|
| 393 |
+
if llm_t > 0:
|
| 394 |
+
print(f" - LLM: {llm_t:.3f}s ({100*llm_t/model_time:.1f}% of model time)")
|
| 395 |
+
if dit_t > 0:
|
| 396 |
+
print(f" - DiT: {dit_t:.3f}s ({100*dit_t/model_time:.1f}% of model time)")
|
| 397 |
+
if vae_t > 0:
|
| 398 |
+
print(f" - VAE: {vae_t:.3f}s ({100*vae_t/model_time:.1f}% of model time)")
|
| 399 |
+
|
| 400 |
+
# LLM bottleneck analysis
|
| 401 |
+
if llm_t > dit_t and llm_t > 5.0:
|
| 402 |
+
print(f"\nā ļø LLM IS THE BOTTLENECK: {llm_t:.3f}s ({100*llm_t/total_time:.1f}% of total)")
|
| 403 |
+
print(f"\n Possible causes:")
|
| 404 |
+
print(f" 1. Generating too many tokens ā use --llm-debug to verify")
|
| 405 |
+
print(f" 2. Constrained decoding overhead ā FSM validation per token")
|
| 406 |
+
print(f" 3. CFG overhead ā cfg_scale > 1.0 = 2x forward passes")
|
| 407 |
+
print(f" 4. First-token latency ā warmup should help")
|
| 408 |
+
print(f" 5. KV cache inefficiency ā should be ~5-10ms/token")
|
| 409 |
+
|
| 410 |
+
# Non-model insights
|
| 411 |
+
if non_model_time / total_time > 0.1:
|
| 412 |
+
print(f"\nā ļø Non-model operations: {non_model_time:.3f}s ({100*non_model_time/total_time:.1f}%)")
|
| 413 |
+
if prep_t > 0.1:
|
| 414 |
+
print(f" - Batch preparation: {prep_t:.3f}s")
|
| 415 |
+
|
| 416 |
+
# I/O overhead
|
| 417 |
+
if other_time / total_time > 0.2:
|
| 418 |
+
print(f"\nā ļø Overhead/I/O: {other_time:.3f}s ({100*other_time/total_time:.1f}%)")
|
| 419 |
+
|
| 420 |
+
# Recommendations
|
| 421 |
+
print("\n" + "=" * 100)
|
| 422 |
+
print("š OPTIMIZATION RECOMMENDATIONS")
|
| 423 |
+
print("=" * 100)
|
| 424 |
+
|
| 425 |
+
if llm_t > dit_t * 2:
|
| 426 |
+
print("\nšÆ Priority: Optimize LLM")
|
| 427 |
+
print(" 1. Run: python profile_inference.py --llm-debug")
|
| 428 |
+
print(" ā Shows exact token count and throughput")
|
| 429 |
+
print(" 2. Check constrained decoding overhead")
|
| 430 |
+
print(" 3. Check CFG scaling (lm_cfg_scale parameter)")
|
| 431 |
+
print(" 4. Profile nanovllm engine step() timing")
|
| 432 |
+
print(" 5. Compare vllm vs transformers backends")
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def run_profiled_generation(dit_handler, llm_handler, params, config,
|
| 436 |
+
enable_cprofile=False, enable_llm_debug=False):
|
| 437 |
+
"""Execute generation with full profiling instrumentation"""
|
| 438 |
+
# Instrument handlers
|
| 439 |
+
originals = instrument_handlers(dit_handler, llm_handler, enable_llm_debug)
|
| 440 |
+
|
| 441 |
+
try:
|
| 442 |
+
print("\n[Profiling] Starting generation...")
|
| 443 |
+
timer.sync()
|
| 444 |
+
total_start = time.perf_counter()
|
| 445 |
+
|
| 446 |
+
# Optional cProfile
|
| 447 |
+
prof = None
|
| 448 |
+
if enable_cprofile:
|
| 449 |
+
import cProfile
|
| 450 |
+
prof = cProfile.Profile()
|
| 451 |
+
prof.enable()
|
| 452 |
+
|
| 453 |
+
# Run generation
|
| 454 |
+
result = generate_music(dit_handler, llm_handler, params, config, save_dir="./")
|
| 455 |
+
|
| 456 |
+
# Stop timing
|
| 457 |
+
timer.sync()
|
| 458 |
+
total_time = time.perf_counter() - total_start
|
| 459 |
+
|
| 460 |
+
# Save cProfile if enabled
|
| 461 |
+
if enable_cprofile and prof:
|
| 462 |
+
prof.disable()
|
| 463 |
+
|
| 464 |
+
import pstats
|
| 465 |
+
import io
|
| 466 |
+
|
| 467 |
+
output_file = "profile_cprofile_detailed.txt"
|
| 468 |
+
with open(output_file, 'w') as f:
|
| 469 |
+
ps = pstats.Stats(prof, stream=f)
|
| 470 |
+
ps.sort_stats('cumulative')
|
| 471 |
+
ps.print_stats(100)
|
| 472 |
+
|
| 473 |
+
# Print top functions
|
| 474 |
+
print("\n" + "=" * 100)
|
| 475 |
+
print("š TOP 20 FUNCTIONS BY CUMULATIVE TIME (cProfile)")
|
| 476 |
+
print("=" * 100)
|
| 477 |
+
s = io.StringIO()
|
| 478 |
+
ps = pstats.Stats(prof, stream=s)
|
| 479 |
+
ps.sort_stats('cumulative')
|
| 480 |
+
ps.print_stats(20)
|
| 481 |
+
print(s.getvalue())
|
| 482 |
+
|
| 483 |
+
print(f"\nFull report: {output_file}")
|
| 484 |
+
|
| 485 |
+
# Print results
|
| 486 |
+
print_profiling_results(total_time, show_llm_debug=enable_llm_debug)
|
| 487 |
+
|
| 488 |
+
return result, total_time
|
| 489 |
+
|
| 490 |
+
finally:
|
| 491 |
+
restore_handlers(dit_handler, llm_handler, originals)
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def load_example_config(example_file: str) -> Tuple[GenerationParams, GenerationConfig]:
|
| 495 |
+
"""Load configuration from example JSON file"""
|
| 496 |
+
try:
|
| 497 |
+
with open(example_file, 'r', encoding='utf-8') as f:
|
| 498 |
+
data = json.load(f)
|
| 499 |
+
|
| 500 |
+
params = GenerationParams(
|
| 501 |
+
caption=data.get('caption', ''),
|
| 502 |
+
lyrics=data.get('lyrics', ''),
|
| 503 |
+
bpm=data.get('bpm'),
|
| 504 |
+
keyscale=data.get('keyscale', ''),
|
| 505 |
+
timesignature=data.get('timesignature', ''),
|
| 506 |
+
vocal_language=data.get('language', 'unknown'),
|
| 507 |
+
duration=data.get('duration'),
|
| 508 |
+
thinking=data.get('think', False),
|
| 509 |
+
inference_steps=data.get('inference_steps', 8),
|
| 510 |
+
seed=data.get('seed', 42),
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
config = GenerationConfig(batch_size=data.get('batch_size', 1), seeds=[42])
|
| 514 |
+
|
| 515 |
+
return params, config
|
| 516 |
+
|
| 517 |
+
except Exception as e:
|
| 518 |
+
print(f" ā Failed to load: {e}")
|
| 519 |
+
return None, None
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def main():
|
| 523 |
+
global timer, llm_debugger
|
| 524 |
+
|
| 525 |
+
parser = argparse.ArgumentParser(
|
| 526 |
+
description="Profile ACE-Step inference with LLM debugging"
|
| 527 |
)
|
| 528 |
+
parser.add_argument("--checkpoint-dir", type=str, default="./checkpoints")
|
| 529 |
+
parser.add_argument("--config-path", type=str, default="acestep-v15-turbo-rl")
|
| 530 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 531 |
+
parser.add_argument("--lm-model", type=str, default="acestep-5Hz-lm-0.6B-v3")
|
| 532 |
+
parser.add_argument("--lm-backend", type=str, default="vllm")
|
| 533 |
+
parser.add_argument("--no-warmup", action="store_true")
|
| 534 |
+
parser.add_argument("--detailed", action="store_true")
|
| 535 |
+
parser.add_argument("--llm-debug", action="store_true",
|
| 536 |
+
help="Enable deep LLM debugging (token count, throughput)")
|
| 537 |
+
parser.add_argument("--example", type=str, default="example_05.json")
|
| 538 |
+
|
| 539 |
+
# Inference mode parameters
|
| 540 |
+
parser.add_argument("--thinking", action="store_true",
|
| 541 |
+
help="Enable CoT reasoning for LM to generate audio codes")
|
| 542 |
+
parser.add_argument("--use-constrained-decoding", action="store_true",
|
| 543 |
+
help="Use FSM-based constrained decoding for meta generation")
|
| 544 |
+
parser.add_argument("--use-cot-metas", action="store_true",
|
| 545 |
+
help="Enable LLM to generate music metadata via CoT reasoning")
|
| 546 |
|
| 547 |
args = parser.parse_args()
|
| 548 |
|
| 549 |
+
# Initialize
|
| 550 |
+
timer = PreciseTimer(device=args.device)
|
| 551 |
+
llm_debugger = LLMDebugger()
|
| 552 |
+
|
| 553 |
+
print("=" * 100)
|
| 554 |
+
print("šµ ACE-Step Inference Profiler (LLM Performance Analysis)")
|
| 555 |
+
print("=" * 100)
|
| 556 |
+
print(f"\nConfiguration:")
|
| 557 |
+
print(f" Device: {args.device}")
|
| 558 |
+
print(f" LLM Backend: {args.lm_backend}")
|
| 559 |
+
print(f" LLM Debug: {'Enabled' if args.llm_debug else 'Disabled'}")
|
| 560 |
+
print(f" Warmup: {'Disabled' if args.no_warmup else 'Enabled'}")
|
| 561 |
+
print(f"\nInference Mode:")
|
| 562 |
+
print(f" Thinking (CoT): {'Enabled' if args.thinking else 'Disabled'}")
|
| 563 |
+
print(f" Constrained Decoding: {'Enabled' if args.use_constrained_decoding else 'Disabled'}")
|
| 564 |
+
print(f" Use CoT for Metas: {'Enabled' if args.use_cot_metas else 'Disabled'}")
|
| 565 |
+
|
| 566 |
+
# Initialize models
|
| 567 |
+
print(f"\nInitializing models...")
|
| 568 |
+
|
| 569 |
dit_handler = AceStepHandler()
|
| 570 |
llm_handler = LLMHandler()
|
| 571 |
|
| 572 |
+
print(" š¹ Initializing DiT...")
|
|
|
|
| 573 |
status_dit, success_dit = dit_handler.initialize_service(
|
| 574 |
project_root=project_root,
|
| 575 |
config_path=args.config_path,
|
| 576 |
device=args.device,
|
| 577 |
+
use_flash_attention=True,
|
| 578 |
)
|
| 579 |
if not success_dit:
|
| 580 |
+
print(f" ā Failed: {status_dit}")
|
| 581 |
sys.exit(1)
|
| 582 |
+
print(f" ā DiT ready")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
|
| 584 |
+
print(" š§ Initializing LLM...")
|
| 585 |
+
if args.thinking or args.use_cot_metas:
|
| 586 |
+
status_llm, success_llm = llm_handler.initialize(
|
| 587 |
+
checkpoint_dir=args.checkpoint_dir,
|
| 588 |
+
lm_model_path=args.lm_model,
|
| 589 |
+
backend=args.lm_backend,
|
| 590 |
+
device=args.device,
|
| 591 |
+
)
|
| 592 |
+
if success_llm:
|
| 593 |
+
print(f" ā LLM ready ({args.lm_backend})")
|
| 594 |
+
else:
|
| 595 |
+
print(f" ā Failed: {status_llm}")
|
| 596 |
+
else:
|
| 597 |
+
print(f" ā LLM not initialized (thinking or use_cot_metas is disabled)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 598 |
|
| 599 |
+
# Load example
|
| 600 |
+
example_file = os.path.join(project_root, "examples", "text2music", args.example)
|
| 601 |
if not os.path.exists(example_file):
|
| 602 |
+
print(f"\nā Not found: {example_file}")
|
|
|
|
| 603 |
sys.exit(1)
|
| 604 |
|
| 605 |
+
print(f"\nš Loading: {args.example}")
|
| 606 |
params, config = load_example_config(example_file)
|
| 607 |
|
| 608 |
if not params or not config:
|
| 609 |
+
print("ā Failed to load config")
|
| 610 |
sys.exit(1)
|
| 611 |
|
| 612 |
+
print(f" Caption: {params.caption[:60]}...")
|
| 613 |
+
print(f" Batch: {config.batch_size}, Steps: {params.inference_steps}, LLM: {params.thinking}")
|
| 614 |
+
|
| 615 |
+
# Warmup
|
| 616 |
+
if not args.no_warmup:
|
| 617 |
+
print("\n" + "=" * 100)
|
| 618 |
+
print("š„ WARMUP RUN")
|
| 619 |
+
print("=" * 100)
|
| 620 |
+
|
| 621 |
+
warmup_params = GenerationParams(
|
| 622 |
+
caption=params.caption,
|
| 623 |
+
lyrics=params.lyrics,
|
| 624 |
+
bpm=params.bpm,
|
| 625 |
+
keyscale=params.keyscale,
|
| 626 |
+
timesignature=params.timesignature,
|
| 627 |
+
vocal_language=params.vocal_language,
|
| 628 |
+
duration=params.duration,
|
| 629 |
+
thinking=args.thinking,
|
| 630 |
+
use_cot_metas=args.use_cot_metas,
|
| 631 |
+
inference_steps=params.inference_steps,
|
| 632 |
+
seed=params.seed,
|
| 633 |
+
)
|
| 634 |
+
warmup_config = GenerationConfig(batch_size=1, seeds=[42])
|
| 635 |
+
warmup_config.use_constrained_decoding = args.use_constrained_decoding
|
| 636 |
+
|
| 637 |
+
warmup_start = time.perf_counter()
|
| 638 |
+
warmup_result = generate_music(dit_handler, llm_handler, warmup_params, warmup_config, save_dir="./")
|
| 639 |
+
warmup_time = time.perf_counter() - warmup_start
|
| 640 |
+
|
| 641 |
+
print(f"\nā Warmup: {warmup_time:.2f}s")
|
| 642 |
+
if not warmup_result.success:
|
| 643 |
+
print(f"ā ļø Warning: {warmup_result.error}")
|
| 644 |
+
|
| 645 |
+
# Reset
|
| 646 |
+
timer = PreciseTimer(device=args.device)
|
| 647 |
+
llm_debugger = LLMDebugger()
|
| 648 |
+
|
| 649 |
+
# Profiling run
|
| 650 |
+
print("\n" + "=" * 100)
|
| 651 |
+
print("ā±ļø PROFILING RUN")
|
| 652 |
+
print("=" * 100)
|
| 653 |
|
| 654 |
+
# Apply inference mode settings
|
| 655 |
+
config.use_constrained_decoding = args.use_constrained_decoding
|
| 656 |
+
# Override thinking and use_cot_metas parameters if specified via CLI
|
| 657 |
+
if args.thinking:
|
| 658 |
+
params.thinking = True
|
| 659 |
+
if args.use_cot_metas:
|
| 660 |
+
params.use_cot_metas = True
|
| 661 |
|
| 662 |
+
result, total_time = run_profiled_generation(
|
| 663 |
+
dit_handler, llm_handler, params, config,
|
| 664 |
+
enable_cprofile=args.detailed,
|
| 665 |
+
enable_llm_debug=args.llm_debug
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
if not result.success:
|
| 669 |
+
print(f"\nā Failed: {result.error}")
|
| 670 |
+
sys.exit(1)
|
| 671 |
+
|
| 672 |
+
print(f"\nā
Success! Generated {len(result.audios)} audio file(s)")
|
| 673 |
+
|
| 674 |
+
# Final tips
|
| 675 |
+
if args.detailed:
|
| 676 |
+
print("\nš” Check profile_cprofile_detailed.txt for function-level analysis")
|
| 677 |
+
elif not args.llm_debug:
|
| 678 |
+
print("\nš” Run with --llm-debug to see LLM token count and throughput analysis")
|
| 679 |
|
| 680 |
|
| 681 |
if __name__ == "__main__":
|
| 682 |
main()
|
|
|