ChuxiJ commited on
Commit
03f73c6
Ā·
1 Parent(s): 11860f1

fix bugs and test profile

Browse files
acestep/audio_utils.py CHANGED
@@ -98,7 +98,6 @@ class AudioSaver:
98
  channels_first=True,
99
  backend='ffmpeg',
100
  compression=config,
101
- buffer_size=65536
102
  )
103
  elif format in ["flac", "wav"]:
104
  # FLAC and WAV use soundfile backend (fastest)
@@ -107,8 +106,7 @@ class AudioSaver:
107
  audio_tensor,
108
  sample_rate,
109
  channels_first=True,
110
- backend='soundfile',
111
- buffer_size=65536
112
  )
113
  else:
114
  # Other formats use default backend
@@ -117,7 +115,6 @@ class AudioSaver:
117
  audio_tensor,
118
  sample_rate,
119
  channels_first=True,
120
- buffer_size=65536
121
  )
122
 
123
  logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
@@ -247,87 +244,17 @@ def get_audio_file_hash(audio_file) -> str:
247
  return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
248
 
249
 
250
- def generate_uuid_from_params(
251
- captions: str,
252
- lyrics: str,
253
- bpm: Optional[int],
254
- key_scale: str,
255
- time_signature: str,
256
- vocal_language: str,
257
- inference_steps: int,
258
- guidance_scale: float,
259
- seed: Union[str, float, int],
260
- audio_duration: Optional[float],
261
- audio_code_string: Union[str, List[str]],
262
- repainting_start: float,
263
- repainting_end: Optional[float],
264
- instruction: str,
265
- audio_cover_strength: float,
266
- task_type: str,
267
- use_adg: bool,
268
- cfg_interval_start: float,
269
- cfg_interval_end: float,
270
- audio_format: str,
271
- reference_audio=None,
272
- src_audio=None,
273
- batch_index: int = 0,
274
- ) -> str:
275
  """
276
  Generate deterministic UUID from generation parameters.
277
  Same parameters will always generate the same UUID.
278
 
279
  Args:
280
- captions: Music caption
281
- lyrics: Lyrics text
282
- bpm: BPM value
283
- key_scale: Musical key and scale
284
- time_signature: Time signature
285
- vocal_language: Vocal language code
286
- inference_steps: Number of inference steps
287
- guidance_scale: Guidance scale
288
- seed: Random seed
289
- audio_duration: Audio duration in seconds
290
- audio_code_string: Audio code string or list
291
- repainting_start: Repainting start time
292
- repainting_end: Repainting end time
293
- instruction: Task instruction
294
- audio_cover_strength: Audio cover strength
295
- task_type: Task type
296
- use_adg: Whether to use ADG
297
- cfg_interval_start: CFG interval start
298
- cfg_interval_end: CFG interval end
299
- audio_format: Audio format
300
- reference_audio: Reference audio file path
301
- src_audio: Source audio file path
302
- batch_index: Index in batch (for audio_code_string list access)
303
 
304
  Returns:
305
  UUID string
306
  """
307
- params_dict = {
308
- "captions": captions or "",
309
- "lyrics": lyrics or "",
310
- "bpm": bpm,
311
- "key_scale": key_scale or "",
312
- "time_signature": time_signature or "",
313
- "vocal_language": vocal_language or "",
314
- "inference_steps": inference_steps,
315
- "guidance_scale": guidance_scale,
316
- "seed": seed,
317
- "audio_duration": audio_duration,
318
- "audio_code_string": audio_code_string if isinstance(audio_code_string, str) else (audio_code_string[batch_index] if isinstance(audio_code_string, list) and batch_index < len(audio_code_string) else ""),
319
- "repainting_start": repainting_start,
320
- "repainting_end": repainting_end,
321
- "instruction": instruction or "",
322
- "audio_cover_strength": audio_cover_strength,
323
- "task_type": task_type or "",
324
- "use_adg": use_adg,
325
- "cfg_interval_start": cfg_interval_start,
326
- "cfg_interval_end": cfg_interval_end,
327
- "audio_format": audio_format or "",
328
- "reference_audio_hash": get_audio_file_hash(reference_audio),
329
- "src_audio_hash": get_audio_file_hash(src_audio),
330
- }
331
 
332
  params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
333
  hash_obj = hashlib.sha256(params_json.encode('utf-8'))
 
98
  channels_first=True,
99
  backend='ffmpeg',
100
  compression=config,
 
101
  )
102
  elif format in ["flac", "wav"]:
103
  # FLAC and WAV use soundfile backend (fastest)
 
106
  audio_tensor,
107
  sample_rate,
108
  channels_first=True,
109
+ backend='ffmpeg',
 
110
  )
111
  else:
112
  # Other formats use default backend
 
115
  audio_tensor,
116
  sample_rate,
117
  channels_first=True,
 
118
  )
119
 
120
  logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
 
244
  return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
245
 
246
 
247
+ def generate_uuid_from_params(params_dict) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  """
249
  Generate deterministic UUID from generation parameters.
250
  Same parameters will always generate the same UUID.
251
 
252
  Args:
253
+ params_dict: Dictionary of parameters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  Returns:
256
  UUID string
257
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
260
  hash_obj = hashlib.sha256(params_json.encode('utf-8'))
acestep/gradio_ui/events/__init__.py CHANGED
@@ -331,10 +331,11 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
331
  ],
332
  outputs=[results_section[f"score_display_{btn_idx}"], results_section["batch_queue"]]
333
  )
334
-
 
335
  # ========== Generation Handler ==========
336
  generation_section["generate_btn"].click(
337
- fn=lambda *args: res_h.generate_with_batch_management(dit_handler, llm_handler, *args),
338
  inputs=[
339
  generation_section["captions"],
340
  generation_section["lyrics"],
 
331
  ],
332
  outputs=[results_section[f"score_display_{btn_idx}"], results_section["batch_queue"]]
333
  )
334
+ def generation_wrapper(*args):
335
+ yield from res_h.generate_with_batch_management(dit_handler, llm_handler, *args)
336
  # ========== Generation Handler ==========
337
  generation_section["generate_btn"].click(
338
+ fn=generation_wrapper,
339
  inputs=[
340
  generation_section["captions"],
341
  generation_section["lyrics"],
acestep/gradio_ui/events/results_handlers.py CHANGED
@@ -10,9 +10,123 @@ import tempfile
10
  import shutil
11
  import zipfile
12
  import time as time_module
 
13
  import gradio as gr
14
  from loguru import logger
15
  from acestep.gradio_ui.i18n import t
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  def store_batch_in_queue(
@@ -254,383 +368,205 @@ def generate_with_progress(
254
  auto_score,
255
  score_scale,
256
  lm_batch_chunk_size,
257
- progress=gr.Progress(track_tqdm=True)
258
  ):
259
  """Generate audio with progress tracking"""
260
- # If think is enabled (llm_dit mode) and use_cot_metas is True, generate audio codes using LM first
261
- audio_code_string_to_use = text2music_audio_code_string
262
- lm_generated_metadata = None # Store LM-generated metadata for display
263
- lm_generated_audio_codes = None # Store LM-generated audio codes for display
264
- lm_generated_audio_codes_list = [] # Store list of audio codes for batch processing
265
-
266
- # Determine if we should use batch LM generation
267
- should_use_lm_batch = (
268
- think_checkbox and
269
- llm_handler.llm_initialized and
270
- use_cot_metas and
271
- allow_lm_batch and
272
- batch_size_input >= 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  )
274
 
275
- if think_checkbox and llm_handler.llm_initialized and use_cot_metas:
276
- # Convert top_k: 0 means None (disabled)
277
- top_k_value = None if lm_top_k == 0 else int(lm_top_k)
278
- # Convert top_p: 1.0 means None (disabled)
279
- top_p_value = None if lm_top_p >= 1.0 else lm_top_p
280
-
281
- # Build user_metadata from user-provided values (only include non-empty values)
282
- user_metadata = {}
283
- # Handle bpm: gr.Number can be None, int, float, or string
284
- if bpm is not None:
285
- try:
286
- bpm_value = float(bpm)
287
- if bpm_value > 0:
288
- user_metadata['bpm'] = str(int(bpm_value))
289
- except (ValueError, TypeError):
290
- # If bpm is not a valid number, skip it
291
- pass
292
- if key_scale and key_scale.strip():
293
- key_scale_clean = key_scale.strip()
294
- if key_scale_clean.lower() not in ["n/a", ""]:
295
- user_metadata['keyscale'] = key_scale_clean
296
- if time_signature and time_signature.strip():
297
- time_sig_clean = time_signature.strip()
298
- if time_sig_clean.lower() not in ["n/a", ""]:
299
- user_metadata['timesignature'] = time_sig_clean
300
- if audio_duration is not None:
301
- try:
302
- duration_value = float(audio_duration)
303
- if duration_value > 0:
304
- user_metadata['duration'] = str(int(duration_value))
305
- except (ValueError, TypeError):
306
- # If audio_duration is not a valid number, skip it
307
- pass
308
-
309
- # Only pass user_metadata if user provided any values, otherwise let LM generate
310
- user_metadata_to_pass = user_metadata if user_metadata else None
311
-
312
- if should_use_lm_batch:
313
- # BATCH LM GENERATION
314
- logger.info(f"Using LM batch generation for {batch_size_input} items...")
315
-
316
- # Prepare seeds for batch items
317
- actual_seed_list, _ = dit_handler.prepare_seeds(batch_size_input, seed, random_seed_checkbox)
318
-
319
- # Split batch into chunks (GPU memory constraint)
320
- max_inference_batch_size = int(lm_batch_chunk_size)
321
- num_chunks = math.ceil(batch_size_input / max_inference_batch_size)
322
-
323
- all_metadata_list = []
324
- all_audio_codes_list = []
325
-
326
- for chunk_idx in range(num_chunks):
327
- chunk_start = chunk_idx * max_inference_batch_size
328
- chunk_end = min(chunk_start + max_inference_batch_size, batch_size_input)
329
- chunk_size = chunk_end - chunk_start
330
- chunk_seeds = actual_seed_list[chunk_start:chunk_end]
331
-
332
- logger.info(f"Generating LM batch chunk {chunk_idx+1}/{num_chunks} (size: {chunk_size}, seeds: {chunk_seeds})...")
333
-
334
- # Generate batch
335
- metadata_list, audio_codes_list, status = llm_handler.generate_with_stop_condition(
336
- caption=captions or "",
337
- lyrics=lyrics or "",
338
- infer_type="llm_dit",
339
- temperature=lm_temperature,
340
- cfg_scale=lm_cfg_scale,
341
- negative_prompt=lm_negative_prompt,
342
- top_k=top_k_value,
343
- top_p=top_p_value,
344
- user_metadata=user_metadata_to_pass,
345
- use_cot_caption=use_cot_caption,
346
- use_cot_language=use_cot_language,
347
- is_format_caption=is_format_caption,
348
- constrained_decoding_debug=constrained_decoding_debug,
349
- batch_size=chunk_size,
350
- seeds=chunk_seeds,
351
- )
352
-
353
- all_metadata_list.extend(metadata_list)
354
- all_audio_codes_list.extend(audio_codes_list)
355
-
356
- # Use first metadata as representative (all are same)
357
- lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
358
-
359
- # Store audio codes list for later use
360
- lm_generated_audio_codes_list = all_audio_codes_list
361
 
362
- # Prepare audio codes for DiT (list of codes, one per batch item)
363
- audio_code_string_to_use = all_audio_codes_list
364
 
365
- # Update metadata fields from LM if not provided by user
366
- if lm_generated_metadata:
367
- if bpm is None and lm_generated_metadata.get('bpm'):
368
- bpm_value = lm_generated_metadata.get('bpm')
369
- if bpm_value != "N/A" and bpm_value != "":
370
- try:
371
- bpm = int(bpm_value)
372
- except:
373
- pass
374
- if not key_scale and lm_generated_metadata.get('keyscale'):
375
- key_scale_value = lm_generated_metadata.get('keyscale', lm_generated_metadata.get('key_scale', ""))
376
- if key_scale_value != "N/A":
377
- key_scale = key_scale_value
378
- if not time_signature and lm_generated_metadata.get('timesignature'):
379
- time_signature_value = lm_generated_metadata.get('timesignature', lm_generated_metadata.get('time_signature', ""))
380
- if time_signature_value != "N/A":
381
- time_signature = time_signature_value
382
- if audio_duration is None or audio_duration <= 0:
383
- audio_duration_value = lm_generated_metadata.get('duration', -1)
384
- if audio_duration_value != "N/A" and audio_duration_value != "":
385
- try:
386
- audio_duration = float(audio_duration_value)
387
- except:
388
- pass
389
- else:
390
- # SEQUENTIAL LM GENERATION (current behavior, when allow_lm_batch is False)
391
- # Phase 1: Generate CoT metadata
392
- phase1_start = time_module.time()
393
- metadata, _, status = llm_handler.generate_with_stop_condition(
394
- caption=captions or "",
395
- lyrics=lyrics or "",
396
- infer_type="dit", # Only generate metadata in Phase 1
397
- temperature=lm_temperature,
398
- cfg_scale=lm_cfg_scale,
399
- negative_prompt=lm_negative_prompt,
400
- top_k=top_k_value,
401
- top_p=top_p_value,
402
- user_metadata=user_metadata_to_pass,
403
- use_cot_caption=use_cot_caption,
404
- use_cot_language=use_cot_language,
405
- is_format_caption=is_format_caption,
406
- constrained_decoding_debug=constrained_decoding_debug,
407
- )
408
- lm_phase1_time = time_module.time() - phase1_start
409
- logger.info(f"LM Phase 1 (CoT) completed in {lm_phase1_time:.2f}s")
410
 
411
- # Phase 2: Generate audio codes
412
- phase2_start = time_module.time()
413
- metadata, audio_codes, status = llm_handler.generate_with_stop_condition(
414
- caption=captions or "",
415
- lyrics=lyrics or "",
416
- infer_type="llm_dit", # Generate both metadata and codes
417
- temperature=lm_temperature,
418
- cfg_scale=lm_cfg_scale,
419
- negative_prompt=lm_negative_prompt,
420
- top_k=top_k_value,
421
- top_p=top_p_value,
422
- user_metadata=user_metadata_to_pass,
423
- use_cot_caption=use_cot_caption,
424
- use_cot_language=use_cot_language,
425
- is_format_caption=is_format_caption,
426
- constrained_decoding_debug=constrained_decoding_debug,
 
 
 
 
 
 
 
427
  )
428
- lm_phase2_time = time_module.time() - phase2_start
429
- logger.info(f"LM Phase 2 (Codes) completed in {lm_phase2_time:.2f}s")
430
-
431
- # Store LM-generated metadata and audio codes for display
432
- lm_generated_metadata = metadata
433
- if audio_codes:
434
- audio_code_string_to_use = audio_codes
435
- lm_generated_audio_codes = audio_codes
436
- # Update metadata fields only if they are empty/None (user didn't provide them)
437
- if bpm is None and metadata.get('bpm'):
438
- bpm_value = metadata.get('bpm')
439
- if bpm_value != "N/A" and bpm_value != "":
440
- try:
441
- bpm = int(bpm_value)
442
- except:
443
- pass
444
- if not key_scale and metadata.get('keyscale'):
445
- key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
446
- if key_scale_value != "N/A":
447
- key_scale = key_scale_value
448
- if not time_signature and metadata.get('timesignature'):
449
- time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
450
- if time_signature_value != "N/A":
451
- time_signature = time_signature_value
452
- if audio_duration is None or audio_duration <= 0:
453
- audio_duration_value = metadata.get('duration', -1)
454
- if audio_duration_value != "N/A" and audio_duration_value != "":
455
- try:
456
- audio_duration = float(audio_duration_value)
457
- except:
458
- pass
459
-
460
- # Call generate_music and get results
461
- result = dit_handler.generate_music(
462
- captions=captions, lyrics=lyrics, bpm=bpm, key_scale=key_scale,
463
- time_signature=time_signature, vocal_language=vocal_language,
464
- inference_steps=inference_steps, guidance_scale=guidance_scale,
465
- use_random_seed=random_seed_checkbox, seed=seed,
466
- reference_audio=reference_audio, audio_duration=audio_duration,
467
- batch_size=batch_size_input, src_audio=src_audio,
468
- audio_code_string=audio_code_string_to_use,
469
- repainting_start=repainting_start, repainting_end=repainting_end,
470
- instruction=instruction_display_gen, audio_cover_strength=audio_cover_strength,
471
- task_type=task_type, use_adg=use_adg,
472
- cfg_interval_start=cfg_interval_start, cfg_interval_end=cfg_interval_end,
473
- audio_format=audio_format, lm_temperature=lm_temperature,
474
- progress=progress
475
- )
476
-
477
- # Extract results from new dict structure
478
- if not isinstance(result, dict):
479
- # Fallback for old tuple format (should not happen)
480
- first_audio, second_audio, all_audio_paths, generation_info, status_message, seed_value_for_ui, \
481
- align_score_1, align_text_1, align_plot_1, align_score_2, align_text_2, align_plot_2 = result
482
- else:
483
- audios = result.get("audios", [])
484
- all_audio_paths = [audio.get("path") for audio in audios]
485
- first_audio = all_audio_paths[0] if len(all_audio_paths) > 0 else None
486
- second_audio = all_audio_paths[1] if len(all_audio_paths) > 1 else None
487
- generation_info = result.get("generation_info", "")
488
- status_message = result.get("status_message", "")
489
- seed_value_for_ui = result.get("extra_outputs", {}).get("seed_value", "")
490
- # Legacy alignment fields (no longer used)
491
- align_score_1 = ""
492
- align_text_1 = ""
493
- align_plot_1 = None
494
- align_score_2 = ""
495
- align_text_2 = ""
496
- align_plot_2 = None
497
-
498
- # Extract LM timing from status if available and prepend to generation_info
499
- if status:
500
- import re
501
- # Try to extract timing info from status using regex
502
- # Expected format: "Phase1: X.XXs" and "Phase2: X.XXs"
503
- phase1_match = re.search(r'Phase1:\s*([\d.]+)s', status)
504
- phase2_match = re.search(r'Phase2:\s*([\d.]+)s', status)
505
-
506
- if phase1_match or phase2_match:
507
- lm_timing_section = "\n\n**šŸ¤– LM Timing:**\n"
508
- lm_total = 0.0
509
- if phase1_match:
510
- phase1_time = float(phase1_match.group(1))
511
- lm_timing_section += f" - Phase 1 (CoT Metadata): {phase1_time:.2f}s\n"
512
- lm_total += phase1_time
513
- if phase2_match:
514
- phase2_time = float(phase2_match.group(1))
515
- lm_timing_section += f" - Phase 2 (Audio Codes): {phase2_time:.2f}s\n"
516
- lm_total += phase2_time
517
- if lm_total > 0:
518
- lm_timing_section += f" - Total LM Time: {lm_total:.2f}s\n"
519
- generation_info = lm_timing_section + "\n" + generation_info
520
-
521
- # Append LM-generated metadata to generation_info if available
522
- if lm_generated_metadata:
523
- metadata_lines = []
524
- if lm_generated_metadata.get('bpm'):
525
- metadata_lines.append(f"- **BPM:** {lm_generated_metadata['bpm']}")
526
- if lm_generated_metadata.get('caption'):
527
- metadata_lines.append(f"- **User Query Rewritten Caption:** {lm_generated_metadata['caption']}")
528
- if lm_generated_metadata.get('duration'):
529
- metadata_lines.append(f"- **Duration:** {lm_generated_metadata['duration']} seconds")
530
- if lm_generated_metadata.get('keyscale'):
531
- metadata_lines.append(f"- **KeyScale:** {lm_generated_metadata['keyscale']}")
532
- if lm_generated_metadata.get('language'):
533
- metadata_lines.append(f"- **Language:** {lm_generated_metadata['language']}")
534
- if lm_generated_metadata.get('timesignature'):
535
- metadata_lines.append(f"- **Time Signature:** {lm_generated_metadata['timesignature']}")
536
-
537
- if metadata_lines:
538
- metadata_section = "\n\n**šŸ¤– LM-Generated Metadata:**\n" + "\n\n".join(metadata_lines)
539
- generation_info = metadata_section + "\n\n" + generation_info
540
-
541
- # Update audio codes in UI if LM generated them
542
- codes_outputs = [""] * 8 # Codes for 8 components
543
- if should_use_lm_batch and lm_generated_audio_codes_list:
544
- # Batch mode: update individual codes inputs
545
- for idx in range(min(len(lm_generated_audio_codes_list), 8)):
546
- codes_outputs[idx] = lm_generated_audio_codes_list[idx]
547
- # For single codes input, show first one
548
- updated_audio_codes = lm_generated_audio_codes_list[0] if lm_generated_audio_codes_list else text2music_audio_code_string
549
- else:
550
- # Single mode: update main codes input
551
- updated_audio_codes = lm_generated_audio_codes if lm_generated_audio_codes else text2music_audio_code_string
552
-
553
- # AUTO-SCORING
554
- score_displays = [""] * 8 # Scores for 8 components
555
- if auto_score and all_audio_paths:
556
- logger.info(f"Auto-scoring enabled, calculating quality scores for {batch_size_input} generated audios...")
557
-
558
- # Determine which audio codes to use for scoring
559
- if should_use_lm_batch and lm_generated_audio_codes_list:
560
- codes_list = lm_generated_audio_codes_list
561
- elif audio_code_string_to_use and isinstance(audio_code_string_to_use, list):
562
- codes_list = audio_code_string_to_use
563
  else:
564
- # Single code string, replicate for all audios
565
- codes_list = [audio_code_string_to_use] * len(all_audio_paths)
566
-
567
- # Calculate scores only for actually generated audios (up to batch_size_input)
568
- # Don't score beyond the actual batch size to avoid duplicates
569
- actual_audios_to_score = min(len(all_audio_paths), int(batch_size_input))
570
- for idx in range(actual_audios_to_score):
571
- if idx < len(codes_list) and codes_list[idx]:
572
- try:
573
- score_display = calculate_score_handler(
574
- llm_handler,
575
- codes_list[idx],
576
- captions,
577
- lyrics,
578
- lm_generated_metadata,
579
- bpm, key_scale, time_signature, audio_duration, vocal_language,
580
- score_scale
581
- )
582
- score_displays[idx] = score_display
583
- logger.info(f"Auto-scored audio {idx+1}")
584
- except Exception as e:
585
- logger.error(f"Auto-scoring failed for audio {idx+1}: {e}")
586
- score_displays[idx] = f"āŒ Auto-scoring failed: {str(e)}"
587
-
588
- # Prepare audio outputs (up to 8)
589
- audio_outputs = [None] * 8
590
- for idx in range(min(len(all_audio_paths), 8)):
591
- audio_outputs[idx] = all_audio_paths[idx]
592
 
593
- return (
594
- audio_outputs[0], # generated_audio_1
595
- audio_outputs[1], # generated_audio_2
596
- audio_outputs[2], # generated_audio_3
597
- audio_outputs[3], # generated_audio_4
598
- audio_outputs[4], # generated_audio_5
599
- audio_outputs[5], # generated_audio_6
600
- audio_outputs[6], # generated_audio_7
601
- audio_outputs[7], # generated_audio_8
602
- all_audio_paths, # generated_audio_batch
603
  generation_info,
604
- status_message,
605
  seed_value_for_ui,
606
- align_score_1,
607
- align_text_1,
608
- align_plot_1,
609
- align_score_2,
610
- align_text_2,
611
- align_plot_2,
612
- score_displays[0], # score_display_1
613
- score_displays[1], # score_display_2
614
- score_displays[2], # score_display_3
615
- score_displays[3], # score_display_4
616
- score_displays[4], # score_display_5
617
- score_displays[5], # score_display_6
618
- score_displays[6], # score_display_7
619
- score_displays[7], # score_display_8
620
- updated_audio_codes, # Update main audio codes in UI
621
- codes_outputs[0], # text2music_audio_code_string_1
622
- codes_outputs[1], # text2music_audio_code_string_2
623
- codes_outputs[2], # text2music_audio_code_string_3
624
- codes_outputs[3], # text2music_audio_code_string_4
625
- codes_outputs[4], # text2music_audio_code_string_5
626
- codes_outputs[5], # text2music_audio_code_string_6
627
- codes_outputs[6], # text2music_audio_code_string_7
628
- codes_outputs[7], # text2music_audio_code_string_8
629
- lm_generated_metadata, # Store metadata for "Send to src audio" buttons
630
- is_format_caption, # Keep is_format_caption unchanged
631
  )
632
 
633
 
 
634
  def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale):
635
  """
636
  Calculate PMI-based quality score for generated audio.
@@ -773,7 +709,9 @@ def calculate_score_handler_with_selection(llm_handler, sample_idx, score_scale,
773
  if stored_allow_lm_batch and isinstance(stored_codes, list):
774
  # Batch mode: use specific sample's codes
775
  if 0 <= sample_idx - 1 < len(stored_codes):
776
- audio_codes_str = stored_codes[sample_idx - 1]
 
 
777
  else:
778
  # Single mode: all samples use same codes
779
  audio_codes_str = stored_codes if isinstance(stored_codes, str) else ""
@@ -885,7 +823,7 @@ def generate_with_batch_management(
885
  Wrapper for generate_with_progress that adds batch queue management
886
  """
887
  # Call the original generation function
888
- result = generate_with_progress(
889
  dit_handler, llm_handler,
890
  captions, lyrics, bpm, key_scale, time_signature, vocal_language,
891
  inference_steps, guidance_scale, random_seed_checkbox, seed,
@@ -902,23 +840,41 @@ def generate_with_batch_management(
902
  lm_batch_chunk_size,
903
  progress
904
  )
905
-
906
- # Extract results from generation
907
- all_audio_paths = result[8] # generated_audio_batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
908
  generation_info = result[9]
909
  seed_value_for_ui = result[11]
910
- lm_generated_metadata = result[34] # Index 34 is lm_metadata_state
911
 
912
  # Extract codes
913
  generated_codes_single = result[26]
914
  generated_codes_batch = [result[27], result[28], result[29], result[30], result[31], result[32], result[33], result[34]]
915
-
916
  # Determine which codes to store based on mode
917
  if allow_lm_batch and batch_size_input >= 2:
918
  codes_to_store = generated_codes_batch[:int(batch_size_input)]
919
  else:
920
  codes_to_store = generated_codes_single
921
-
922
  # Save parameters for history
923
  saved_params = {
924
  "captions": captions,
@@ -964,6 +920,7 @@ def generate_with_batch_management(
964
  }
965
 
966
  # Next batch parameters (with cleared codes & random seed)
 
967
  next_params = saved_params.copy()
968
  next_params["text2music_audio_code_string"] = ""
969
  next_params["random_seed_checkbox"] = True
@@ -996,9 +953,10 @@ def generate_with_batch_management(
996
  next_batch_status_text = ""
997
  if autogen_checkbox:
998
  next_batch_status_text = t("messages.autogen_enabled")
999
-
1000
- # Return original results plus batch management state updates
1001
- return result + (
 
1002
  current_batch_index,
1003
  total_batches,
1004
  batch_queue,
@@ -1114,7 +1072,8 @@ def generate_next_batch_background(
1114
  params.setdefault("complete_track_classes", [])
1115
 
1116
  # Call generate_with_progress with the saved parameters
1117
- result = generate_with_progress(
 
1118
  dit_handler,
1119
  llm_handler,
1120
  captions=params.get("captions"),
@@ -1159,15 +1118,20 @@ def generate_next_batch_background(
1159
  progress=progress
1160
  )
1161
 
1162
- # Extract results
1163
- all_audio_paths = result[8] # generated_audio_batch
1164
- generation_info = result[9]
1165
- seed_value_for_ui = result[11]
1166
- lm_generated_metadata = result[34] # Index 34 is lm_metadata_state
 
 
 
 
 
1167
 
1168
  # Extract codes
1169
- generated_codes_single = result[26]
1170
- generated_codes_batch = [result[27], result[28], result[29], result[30], result[31], result[32], result[33], result[34]]
1171
 
1172
  # Determine which codes to store
1173
  batch_size = params.get("batch_size_input", 2)
 
10
  import shutil
11
  import zipfile
12
  import time as time_module
13
+ from typing import Dict, Any, Optional
14
  import gradio as gr
15
  from loguru import logger
16
  from acestep.gradio_ui.i18n import t
17
+ from acestep.inference import generate_music, GenerationParams, GenerationConfig
18
+ from acestep.audio_utils import save_audio
19
+
20
+
21
+ def _build_generation_info(
22
+ lm_metadata: Optional[Dict[str, Any]],
23
+ time_costs: Dict[str, float],
24
+ seed_value: str,
25
+ inference_steps: int,
26
+ num_audios: int,
27
+ ) -> str:
28
+ """Build generation info string from result data.
29
+
30
+ Args:
31
+ lm_metadata: LM-generated metadata dictionary
32
+ time_costs: Unified time costs dictionary
33
+ seed_value: Seed value string
34
+ inference_steps: Number of inference steps
35
+ num_audios: Number of generated audios
36
+
37
+ Returns:
38
+ Formatted generation info string
39
+ """
40
+ info_parts = []
41
+
42
+ # Part 1: LM-generated metadata (if available)
43
+ if lm_metadata:
44
+ metadata_lines = []
45
+ if lm_metadata.get('bpm'):
46
+ metadata_lines.append(f"- **BPM:** {lm_metadata['bpm']}")
47
+ if lm_metadata.get('caption'):
48
+ metadata_lines.append(f"- **Refined Caption:** {lm_metadata['caption']}")
49
+ if lm_metadata.get('lyrics'):
50
+ metadata_lines.append(f"- **Refined Lyrics:** {lm_metadata['lyrics']}")
51
+ if lm_metadata.get('duration'):
52
+ metadata_lines.append(f"- **Duration:** {lm_metadata['duration']} seconds")
53
+ if lm_metadata.get('keyscale'):
54
+ metadata_lines.append(f"- **Key Scale:** {lm_metadata['keyscale']}")
55
+ if lm_metadata.get('language'):
56
+ metadata_lines.append(f"- **Language:** {lm_metadata['language']}")
57
+ if lm_metadata.get('timesignature'):
58
+ metadata_lines.append(f"- **Time Signature:** {lm_metadata['timesignature']}")
59
+
60
+ if metadata_lines:
61
+ metadata_section = "**šŸ¤– LM-Generated Metadata:**\n" + "\n".join(metadata_lines)
62
+ info_parts.append(metadata_section)
63
+
64
+ # Part 2: Time costs (formatted and beautified)
65
+ if time_costs:
66
+ time_lines = []
67
+
68
+ # LM time costs
69
+ lm_phase1 = time_costs.get('lm_phase1_time', 0.0)
70
+ lm_phase2 = time_costs.get('lm_phase2_time', 0.0)
71
+ lm_total = time_costs.get('lm_total_time', 0.0)
72
+
73
+ if lm_total > 0:
74
+ time_lines.append("**🧠 LM Time:**")
75
+ if lm_phase1 > 0:
76
+ time_lines.append(f" - Phase 1 (CoT): {lm_phase1:.2f}s")
77
+ if lm_phase2 > 0:
78
+ time_lines.append(f" - Phase 2 (Codes): {lm_phase2:.2f}s")
79
+ time_lines.append(f" - Total: {lm_total:.2f}s")
80
+
81
+ # DiT time costs
82
+ dit_encoder = time_costs.get('dit_encoder_time_cost', 0.0)
83
+ dit_model = time_costs.get('dit_model_time_cost', 0.0)
84
+ dit_vae_decode = time_costs.get('dit_vae_decode_time_cost', 0.0)
85
+ dit_offload = time_costs.get('dit_offload_time_cost', 0.0)
86
+ dit_total = time_costs.get('dit_total_time_cost', 0.0)
87
+ if dit_total > 0:
88
+ time_lines.append("\n**šŸŽµ DiT Time:**")
89
+ if dit_encoder > 0:
90
+ time_lines.append(f" - Encoder: {dit_encoder:.2f}s")
91
+ if dit_model > 0:
92
+ time_lines.append(f" - Model: {dit_model:.2f}s")
93
+ if dit_vae_decode > 0:
94
+ time_lines.append(f" - VAE Decode: {dit_vae_decode:.2f}s")
95
+ if dit_offload > 0:
96
+ time_lines.append(f" - Offload: {dit_offload:.2f}s")
97
+ time_lines.append(f" - Total: {dit_total:.2f}s")
98
+
99
+ # Post-processing time costs
100
+ audio_conversion_time = time_costs.get('audio_conversion_time', 0.0)
101
+ auto_score_time = time_costs.get('auto_score_time', 0.0)
102
+
103
+ if audio_conversion_time > 0 or auto_score_time > 0:
104
+ time_lines.append("\n**šŸ”§ Post-processing Time:**")
105
+ if audio_conversion_time > 0:
106
+ time_lines.append(f" - Audio Conversion: {audio_conversion_time:.2f}s")
107
+ if auto_score_time > 0:
108
+ time_lines.append(f" - Auto Score: {auto_score_time:.2f}s")
109
+
110
+ # Pipeline total
111
+ pipeline_total = time_costs.get('pipeline_total_time', 0.0)
112
+ if pipeline_total > 0:
113
+ time_lines.append(f"\n**ā±ļø Pipeline Total: {pipeline_total:.2f}s**")
114
+
115
+ if time_lines:
116
+ time_section = "\n".join(time_lines)
117
+ info_parts.append(time_section)
118
+
119
+ # Part 3: Generation summary
120
+ summary_lines = [
121
+ "**šŸŽµ Generation Complete**",
122
+ f" - **Seeds:** {seed_value}",
123
+ f" - **Steps:** {inference_steps}",
124
+ f" - **Audio Count:** {num_audios} audio(s)",
125
+ ]
126
+ info_parts.append("\n".join(summary_lines))
127
+
128
+ # Combine all parts
129
+ return "\n\n".join(info_parts)
130
 
131
 
132
  def store_batch_in_queue(
 
368
  auto_score,
369
  score_scale,
370
  lm_batch_chunk_size,
371
+ progress=gr.Progress(track_tqdm=True),
372
  ):
373
  """Generate audio with progress tracking"""
374
+
375
+ # step 1: prepare inputs
376
+ # generate_music, GenerationParams, GenerationConfig
377
+ gen_params = GenerationParams(
378
+ task_type=task_type,
379
+ instruction=instruction_display_gen,
380
+ reference_audio=reference_audio,
381
+ src_audio=src_audio,
382
+ audio_codes=text2music_audio_code_string if not think_checkbox else "",
383
+ caption=captions or "",
384
+ lyrics=lyrics or "",
385
+ instrumental=False,
386
+ vocal_language=vocal_language,
387
+ bpm=bpm,
388
+ keyscale=key_scale,
389
+ timesignature=time_signature,
390
+ duration=audio_duration,
391
+ inference_steps=inference_steps,
392
+ guidance_scale=guidance_scale,
393
+ use_adg=use_adg,
394
+ cfg_interval_start=cfg_interval_start,
395
+ cfg_interval_end=cfg_interval_end,
396
+ repainting_start=repainting_start,
397
+ repainting_end=repainting_end,
398
+ audio_cover_strength=audio_cover_strength,
399
+ thinking=think_checkbox,
400
+ lm_temperature=lm_temperature,
401
+ lm_cfg_scale=lm_cfg_scale,
402
+ lm_top_k=lm_top_k,
403
+ lm_top_p=lm_top_p,
404
+ lm_negative_prompt=lm_negative_prompt,
405
+ use_cot_metas=use_cot_metas,
406
+ use_cot_caption=use_cot_caption,
407
+ use_cot_language=use_cot_language,
408
+ use_constrained_decoding=True,
409
+ )
410
+ # seed string to list
411
+ if isinstance(seed, str) and seed.strip():
412
+ if "," in seed:
413
+ seed_list = [int(s.strip()) for s in seed.split(",")]
414
+ else:
415
+ seed_list = [int(seed.strip())]
416
+ else:
417
+ seed_list = None
418
+ gen_config = GenerationConfig(
419
+ batch_size=batch_size_input,
420
+ allow_lm_batch=allow_lm_batch,
421
+ use_random_seed=random_seed_checkbox,
422
+ seeds=seed_list,
423
+ lm_batch_chunk_size=lm_batch_chunk_size,
424
+ constrained_decoding_debug=constrained_decoding_debug,
425
+ audio_format=audio_format,
426
+ )
427
+ result = generate_music(
428
+ dit_handler,
429
+ llm_handler,
430
+ params=gen_params,
431
+ config=gen_config,
432
+ progress=progress,
433
  )
434
 
435
+ audio_outputs = [None] * 8
436
+ all_audio_paths = []
437
+ final_codes_list = [""] * 8
438
+ final_scores_list = [""] * 8
439
+
440
+ # Build generation_info from result data
441
+ status_message = result.status_message
442
+ seed_value_for_ui = result.extra_outputs.get("seed_value", "")
443
+ lm_generated_metadata = result.extra_outputs.get("lm_metadata", {})
444
+ time_costs = result.extra_outputs.get("time_costs", {}).copy()
445
+
446
+ # Initialize post-processing timing
447
+ audio_conversion_start_time = time_module.time()
448
+ total_auto_score_time = 0.0
449
+
450
+ align_score_1 = ""
451
+ align_text_1 = ""
452
+ align_plot_1 = None
453
+ align_score_2 = ""
454
+ align_text_2 = ""
455
+ align_plot_2 = None
456
+ updated_audio_codes = text2music_audio_code_string if not think_checkbox else ""
457
+
458
+ if not result.success:
459
+ # Build generation_info string for error case
460
+ generation_info = _build_generation_info(
461
+ lm_metadata=lm_generated_metadata,
462
+ time_costs=time_costs,
463
+ seed_value=seed_value_for_ui,
464
+ inference_steps=inference_steps,
465
+ num_audios=0,
466
+ )
467
+ yield (None,) * 8 + (None, generation_info, result.status_message) + (gr.skip(),) * 25
468
+ return
469
+
470
+ audios = result.audios
471
+ progress(0.99, "Converting audio to mp3...")
472
+ for i in range(8):
473
+ if i < len(audios):
474
+ key = audios[i]["key"]
475
+ audio_tensor = audios[i]["tensor"]
476
+ sample_rate = audios[i]["sample_rate"]
477
+ audio_params = audios[i]["params"]
478
+ temp_dir = tempfile.mkdtemp(f"acestep_gradio_results/")
479
+ os.makedirs(temp_dir, exist_ok=True)
480
+ json_path = os.path.join(temp_dir, f"{key}.json")
481
+ audio_path = os.path.join(temp_dir, f"{key}.{audio_format}")
482
+ save_audio(audio_data=audio_tensor, output_path=audio_path, sample_rate=sample_rate, format=audio_format, channels_first=True)
483
+ audio_outputs[i] = audio_path
484
+ all_audio_paths.append(audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
+ code_str = audio_params.get("audio_codes", "")
487
+ final_codes_list[i] = code_str
488
 
489
+ scores_ui_updates = [gr.skip()] * 8
490
+ score_str = "Done!"
491
+ if auto_score:
492
+ auto_score_start = time_module.time()
493
+ score_str = calculate_score_handler(llm_handler, code_str, captions, lyrics, lm_generated_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale)
494
+ auto_score_end = time_module.time()
495
+ total_auto_score_time += (auto_score_end - auto_score_start)
496
+ scores_ui_updates[i] = score_str
497
+ final_scores_list[i] = score_str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
 
499
+ status_message = f"Encoding & Ready: {i+1}/{len(audios)}"
500
+ current_audio_updates = [gr.skip()] * 8
501
+ current_audio_updates[i] = audio_path
502
+
503
+ audio_codes_ui_updates = [gr.skip()] * 8
504
+ audio_codes_ui_updates[i] = code_str
505
+ yield (
506
+ current_audio_updates[0], current_audio_updates[1], current_audio_updates[2], current_audio_updates[3],
507
+ current_audio_updates[4], current_audio_updates[5], current_audio_updates[6], current_audio_updates[7],
508
+ all_audio_paths, # Real-time update of Batch File list
509
+ generation_info,
510
+ status_message,
511
+ seed_value_for_ui,
512
+ # Align plot placeholders (assume no need to update in real time)
513
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
514
+ # Scores
515
+ scores_ui_updates[0], scores_ui_updates[1], scores_ui_updates[2], scores_ui_updates[3], scores_ui_updates[4], scores_ui_updates[5], scores_ui_updates[6], scores_ui_updates[7],
516
+ updated_audio_codes,
517
+ # Codes
518
+ audio_codes_ui_updates[0], audio_codes_ui_updates[1], audio_codes_ui_updates[2], audio_codes_ui_updates[3],
519
+ audio_codes_ui_updates[4], audio_codes_ui_updates[5], audio_codes_ui_updates[6], audio_codes_ui_updates[7],
520
+ lm_generated_metadata,
521
+ is_format_caption,
522
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  else:
524
+ # If i exceeds the generated count (e.g., batch=2, i=2..7), do not yield
525
+ pass
526
+ time_module.sleep(0.1)
527
+
528
+ # Record audio conversion time
529
+ audio_conversion_end_time = time_module.time()
530
+ audio_conversion_time = audio_conversion_end_time - audio_conversion_start_time
531
+
532
+ # Add post-processing times to time_costs
533
+ if audio_conversion_time > 0:
534
+ time_costs['audio_conversion_time'] = audio_conversion_time
535
+ if total_auto_score_time > 0:
536
+ time_costs['auto_score_time'] = total_auto_score_time
537
+
538
+ # Update pipeline total time to include post-processing
539
+ if 'pipeline_total_time' in time_costs:
540
+ time_costs['pipeline_total_time'] += audio_conversion_time + total_auto_score_time
541
+
542
+ # Rebuild generation_info with complete timing information
543
+ generation_info = _build_generation_info(
544
+ lm_metadata=lm_generated_metadata,
545
+ time_costs=time_costs,
546
+ seed_value=seed_value_for_ui,
547
+ inference_steps=inference_steps,
548
+ num_audios=len(result.audios),
549
+ )
 
 
550
 
551
+ yield (
552
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(), # Audio 1-4: SKIP
553
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(), # Audio 5-8: SKIP
554
+ all_audio_paths,
 
 
 
 
 
 
555
  generation_info,
556
+ "Generation Complete",
557
  seed_value_for_ui,
558
+ align_score_1, align_text_1, align_plot_1, align_score_2, align_text_2, align_plot_2,
559
+ final_scores_list[0], final_scores_list[1], final_scores_list[2], final_scores_list[3],
560
+ final_scores_list[4], final_scores_list[5], final_scores_list[6], final_scores_list[7],
561
+ updated_audio_codes,
562
+ final_codes_list[0], final_codes_list[1], final_codes_list[2], final_codes_list[3],
563
+ final_codes_list[4], final_codes_list[5], final_codes_list[6], final_codes_list[7],
564
+ lm_generated_metadata,
565
+ is_format_caption,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  )
567
 
568
 
569
+
570
  def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale):
571
  """
572
  Calculate PMI-based quality score for generated audio.
 
709
  if stored_allow_lm_batch and isinstance(stored_codes, list):
710
  # Batch mode: use specific sample's codes
711
  if 0 <= sample_idx - 1 < len(stored_codes):
712
+ code_item = stored_codes[sample_idx - 1]
713
+ # Ensure it's a string (handle cases where dict was mistakenly stored)
714
+ audio_codes_str = code_item if isinstance(code_item, str) else ""
715
  else:
716
  # Single mode: all samples use same codes
717
  audio_codes_str = stored_codes if isinstance(stored_codes, str) else ""
 
823
  Wrapper for generate_with_progress that adds batch queue management
824
  """
825
  # Call the original generation function
826
+ generator = generate_with_progress(
827
  dit_handler, llm_handler,
828
  captions, lyrics, bpm, key_scale, time_signature, vocal_language,
829
  inference_steps, guidance_scale, random_seed_checkbox, seed,
 
840
  lm_batch_chunk_size,
841
  progress
842
  )
843
+ final_result_from_inner = None
844
+ for partial_result in generator:
845
+ final_result_from_inner = partial_result
846
+ # current_batch_index, total_batches, batch_queue, next_params,
847
+ # batch_indicator_text, prev_btn, next_btn, next_status, restore_btn
848
+ yield partial_result + (
849
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(),
850
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
851
+ )
852
+ result = final_result_from_inner
853
+ all_audio_paths = result[8]
854
+
855
+ if all_audio_paths is None:
856
+
857
+ yield result + (
858
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(),
859
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
860
+ )
861
+ return
862
+
863
+ # Extract results from generation (使用 result 下标访问)
864
  generation_info = result[9]
865
  seed_value_for_ui = result[11]
866
+ lm_generated_metadata = result[35] # Fixed: lm_metadata is at index 35, not 34
867
 
868
  # Extract codes
869
  generated_codes_single = result[26]
870
  generated_codes_batch = [result[27], result[28], result[29], result[30], result[31], result[32], result[33], result[34]]
871
+
872
  # Determine which codes to store based on mode
873
  if allow_lm_batch and batch_size_input >= 2:
874
  codes_to_store = generated_codes_batch[:int(batch_size_input)]
875
  else:
876
  codes_to_store = generated_codes_single
877
+
878
  # Save parameters for history
879
  saved_params = {
880
  "captions": captions,
 
920
  }
921
 
922
  # Next batch parameters (with cleared codes & random seed)
923
+ # Next batch parameters
924
  next_params = saved_params.copy()
925
  next_params["text2music_audio_code_string"] = ""
926
  next_params["random_seed_checkbox"] = True
 
953
  next_batch_status_text = ""
954
  if autogen_checkbox:
955
  next_batch_status_text = t("messages.autogen_enabled")
956
+
957
+ # 4. Yield final result (includes Batch UI updates)
958
+ # The result here is already a tuple structure
959
+ yield result + (
960
  current_batch_index,
961
  total_batches,
962
  batch_queue,
 
1072
  params.setdefault("complete_track_classes", [])
1073
 
1074
  # Call generate_with_progress with the saved parameters
1075
+ # Note: generate_with_progress is a generator, need to iterate through it
1076
+ generator = generate_with_progress(
1077
  dit_handler,
1078
  llm_handler,
1079
  captions=params.get("captions"),
 
1118
  progress=progress
1119
  )
1120
 
1121
+ # Consume generator to get final result (similar to generate_with_batch_management)
1122
+ final_result = None
1123
+ for partial_result in generator:
1124
+ final_result = partial_result
1125
+
1126
+ # Extract results from final_result
1127
+ all_audio_paths = final_result[8] # generated_audio_batch
1128
+ generation_info = final_result[9]
1129
+ seed_value_for_ui = final_result[11]
1130
+ lm_generated_metadata = final_result[35] # Fixed: lm_metadata is at index 35, not 34
1131
 
1132
  # Extract codes
1133
+ generated_codes_single = final_result[26]
1134
+ generated_codes_batch = [final_result[27], final_result[28], final_result[29], final_result[30], final_result[31], final_result[32], final_result[33], final_result[34]]
1135
 
1136
  # Determine which codes to store
1137
  batch_size = params.get("batch_size_input", 2)
acestep/gradio_ui/interfaces/result.py CHANGED
@@ -28,7 +28,8 @@ def create_results_section(dit_handler) -> dict:
28
  generated_audio_1 = gr.Audio(
29
  label=t("results.generated_music", n=1),
30
  type="filepath",
31
- interactive=False
 
32
  )
33
  with gr.Row(equal_height=True):
34
  send_to_src_btn_1 = gr.Button(
@@ -58,7 +59,8 @@ def create_results_section(dit_handler) -> dict:
58
  generated_audio_2 = gr.Audio(
59
  label=t("results.generated_music", n=2),
60
  type="filepath",
61
- interactive=False
 
62
  )
63
  with gr.Row(equal_height=True):
64
  send_to_src_btn_2 = gr.Button(
@@ -88,7 +90,8 @@ def create_results_section(dit_handler) -> dict:
88
  generated_audio_3 = gr.Audio(
89
  label=t("results.generated_music", n=3),
90
  type="filepath",
91
- interactive=False
 
92
  )
93
  with gr.Row(equal_height=True):
94
  send_to_src_btn_3 = gr.Button(
@@ -118,7 +121,8 @@ def create_results_section(dit_handler) -> dict:
118
  generated_audio_4 = gr.Audio(
119
  label=t("results.generated_music", n=4),
120
  type="filepath",
121
- interactive=False
 
122
  )
123
  with gr.Row(equal_height=True):
124
  send_to_src_btn_4 = gr.Button(
@@ -151,7 +155,8 @@ def create_results_section(dit_handler) -> dict:
151
  generated_audio_5 = gr.Audio(
152
  label=t("results.generated_music", n=5),
153
  type="filepath",
154
- interactive=False
 
155
  )
156
  with gr.Row(equal_height=True):
157
  send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
@@ -166,7 +171,8 @@ def create_results_section(dit_handler) -> dict:
166
  generated_audio_6 = gr.Audio(
167
  label=t("results.generated_music", n=6),
168
  type="filepath",
169
- interactive=False
 
170
  )
171
  with gr.Row(equal_height=True):
172
  send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
@@ -181,7 +187,8 @@ def create_results_section(dit_handler) -> dict:
181
  generated_audio_7 = gr.Audio(
182
  label=t("results.generated_music", n=7),
183
  type="filepath",
184
- interactive=False
 
185
  )
186
  with gr.Row(equal_height=True):
187
  send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
@@ -196,7 +203,8 @@ def create_results_section(dit_handler) -> dict:
196
  generated_audio_8 = gr.Audio(
197
  label=t("results.generated_music", n=8),
198
  type="filepath",
199
- interactive=False
 
200
  )
201
  with gr.Row(equal_height=True):
202
  send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
 
28
  generated_audio_1 = gr.Audio(
29
  label=t("results.generated_music", n=1),
30
  type="filepath",
31
+ interactive=False,
32
+ show_download_button=False
33
  )
34
  with gr.Row(equal_height=True):
35
  send_to_src_btn_1 = gr.Button(
 
59
  generated_audio_2 = gr.Audio(
60
  label=t("results.generated_music", n=2),
61
  type="filepath",
62
+ interactive=False,
63
+ show_download_button=False
64
  )
65
  with gr.Row(equal_height=True):
66
  send_to_src_btn_2 = gr.Button(
 
90
  generated_audio_3 = gr.Audio(
91
  label=t("results.generated_music", n=3),
92
  type="filepath",
93
+ interactive=False,
94
+ show_download_button=False
95
  )
96
  with gr.Row(equal_height=True):
97
  send_to_src_btn_3 = gr.Button(
 
121
  generated_audio_4 = gr.Audio(
122
  label=t("results.generated_music", n=4),
123
  type="filepath",
124
+ interactive=False,
125
+ show_download_button=False
126
  )
127
  with gr.Row(equal_height=True):
128
  send_to_src_btn_4 = gr.Button(
 
155
  generated_audio_5 = gr.Audio(
156
  label=t("results.generated_music", n=5),
157
  type="filepath",
158
+ interactive=False,
159
+ show_download_button=False
160
  )
161
  with gr.Row(equal_height=True):
162
  send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
 
171
  generated_audio_6 = gr.Audio(
172
  label=t("results.generated_music", n=6),
173
  type="filepath",
174
+ interactive=False,
175
+ show_download_button=False
176
  )
177
  with gr.Row(equal_height=True):
178
  send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
 
187
  generated_audio_7 = gr.Audio(
188
  label=t("results.generated_music", n=7),
189
  type="filepath",
190
+ interactive=False,
191
+ show_download_button=False
192
  )
193
  with gr.Row(equal_height=True):
194
  send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
 
203
  generated_audio_8 = gr.Audio(
204
  label=t("results.generated_music", n=8),
205
  type="filepath",
206
+ interactive=False,
207
+ show_download_button=False
208
  )
209
  with gr.Row(equal_height=True):
210
  send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
acestep/handler.py CHANGED
@@ -2077,7 +2077,6 @@ class AceStepHandler:
2077
  if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
2078
  return {
2079
  "audios": [],
2080
- "generation_info": "",
2081
  "status_message": "āŒ Model not fully initialized. Please initialize all components first.",
2082
  "extra_outputs": {},
2083
  "success": False,
@@ -2101,7 +2100,7 @@ class AceStepHandler:
2101
 
2102
  logger.info("[generate_music] Starting generation...")
2103
  if progress:
2104
- progress(0.05, desc="Preparing inputs...")
2105
  logger.info("[generate_music] Preparing inputs...")
2106
 
2107
  # Reset offload cost
@@ -2123,8 +2122,6 @@ class AceStepHandler:
2123
  repainting_end = None
2124
 
2125
  try:
2126
- progress(0.1, desc="Preparing inputs...")
2127
-
2128
  # 1. Process reference audio
2129
  refer_audios = None
2130
  if reference_audio is not None:
@@ -2176,7 +2173,7 @@ class AceStepHandler:
2176
  can_use_repainting
2177
  )
2178
 
2179
- progress(0.3, desc=f"Generating music (batch size: {actual_batch_size})...")
2180
 
2181
  # Prepare audio_code_hints - use if audio_code_string is provided
2182
  # This works for both text2music (auto-switched to cover) and cover tasks
@@ -2245,7 +2242,7 @@ class AceStepHandler:
2245
 
2246
  logger.info("[generate_music] VAE decode completed. Preparing audio tensors...")
2247
  if progress:
2248
- progress(0.9, desc="Preparing audio data...")
2249
 
2250
  # Prepare audio tensors (no file I/O here, no UUID generation)
2251
  # pred_wavs is already [batch, channels, samples] format
@@ -2257,23 +2254,6 @@ class AceStepHandler:
2257
  audio_tensor = pred_wavs[i].cpu().float()
2258
  audio_tensors.append(audio_tensor)
2259
 
2260
- # Format time costs if available
2261
- time_costs_str = ""
2262
- if time_costs:
2263
- if isinstance(time_costs, dict):
2264
- time_costs_str = "\n\n**ā±ļø Time Costs:**\n"
2265
- for key, value in time_costs.items():
2266
- # Format key: encoder_time_cost -> Encoder
2267
- formatted_key = key.replace("_time_cost", "").replace("_", " ").title()
2268
- time_costs_str += f" - {formatted_key}: {value:.2f}s\n"
2269
- elif isinstance(time_costs, (int, float)):
2270
- time_costs_str = f"\n\n**ā±ļø Time Cost:** {time_costs:.2f}s"
2271
-
2272
- generation_info = f"""**šŸŽµ Generation Complete**
2273
-
2274
- **Seeds:** {seed_value_for_ui}
2275
- **Steps:** {inference_steps}
2276
- **Audio Count:** {len(audio_tensors)} audio(s){time_costs_str}"""
2277
  status_message = f"āœ… Generation completed successfully!"
2278
  logger.info(f"[generate_music] Done! Generated {len(audio_tensors)} audio tensors.")
2279
 
@@ -2307,7 +2287,6 @@ class AceStepHandler:
2307
 
2308
  return {
2309
  "audios": audios,
2310
- "generation_info": generation_info,
2311
  "status_message": status_message,
2312
  "extra_outputs": extra_outputs,
2313
  "success": True,
@@ -2319,7 +2298,6 @@ class AceStepHandler:
2319
  logger.exception("[generate_music] Generation failed")
2320
  return {
2321
  "audios": [],
2322
- "generation_info": "",
2323
  "status_message": error_msg,
2324
  "extra_outputs": {},
2325
  "success": False,
 
2077
  if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
2078
  return {
2079
  "audios": [],
 
2080
  "status_message": "āŒ Model not fully initialized. Please initialize all components first.",
2081
  "extra_outputs": {},
2082
  "success": False,
 
2100
 
2101
  logger.info("[generate_music] Starting generation...")
2102
  if progress:
2103
+ progress(0.51, desc="Preparing inputs...")
2104
  logger.info("[generate_music] Preparing inputs...")
2105
 
2106
  # Reset offload cost
 
2122
  repainting_end = None
2123
 
2124
  try:
 
 
2125
  # 1. Process reference audio
2126
  refer_audios = None
2127
  if reference_audio is not None:
 
2173
  can_use_repainting
2174
  )
2175
 
2176
+ progress(0.52, desc=f"Generating music (batch size: {actual_batch_size})...")
2177
 
2178
  # Prepare audio_code_hints - use if audio_code_string is provided
2179
  # This works for both text2music (auto-switched to cover) and cover tasks
 
2242
 
2243
  logger.info("[generate_music] VAE decode completed. Preparing audio tensors...")
2244
  if progress:
2245
+ progress(0.99, desc="Preparing audio data...")
2246
 
2247
  # Prepare audio tensors (no file I/O here, no UUID generation)
2248
  # pred_wavs is already [batch, channels, samples] format
 
2254
  audio_tensor = pred_wavs[i].cpu().float()
2255
  audio_tensors.append(audio_tensor)
2256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2257
  status_message = f"āœ… Generation completed successfully!"
2258
  logger.info(f"[generate_music] Done! Generated {len(audio_tensors)} audio tensors.")
2259
 
 
2287
 
2288
  return {
2289
  "audios": audios,
 
2290
  "status_message": status_message,
2291
  "extra_outputs": extra_outputs,
2292
  "success": True,
 
2298
  logger.exception("[generate_music] Generation failed")
2299
  return {
2300
  "audios": [],
 
2301
  "status_message": error_msg,
2302
  "extra_outputs": {},
2303
  "success": False,
acestep/inference.py CHANGED
@@ -67,19 +67,19 @@ class GenerationParams:
67
  # Required Inputs
68
  task_type: str = "text2music"
69
  instruction: str = "Fill the audio semantic mask based on the given conditions:"
70
-
71
  # Audio Uploads
72
  reference_audio: Optional[str] = None
73
  src_audio: Optional[str] = None
74
-
75
  # LM Codes Hints
76
  audio_codes: str = ""
77
-
78
  # Text Inputs
79
  caption: str = ""
80
  lyrics: str = ""
81
  instrumental: bool = False
82
-
83
  # Metadata
84
  vocal_language: str = "unknown"
85
  bpm: Optional[int] = None
@@ -98,7 +98,7 @@ class GenerationParams:
98
  repainting_start: float = 0.0
99
  repainting_end: float = -1
100
  audio_cover_strength: float = 1.0
101
-
102
  # 5Hz Language Model Parameters
103
  thinking: bool = True
104
  lm_temperature: float = 0.85
@@ -108,8 +108,18 @@ class GenerationParams:
108
  lm_negative_prompt: str = "NO USER INPUT"
109
  use_cot_metas: bool = True
110
  use_cot_caption: bool = True
 
111
  use_cot_language: bool = True
112
-
 
 
 
 
 
 
 
 
 
113
  def to_dict(self) -> Dict[str, Any]:
114
  """Convert config to dictionary for JSON serialization."""
115
  return asdict(self)
@@ -123,25 +133,27 @@ class GenerationConfig:
123
  batch_size: Number of audio samples to generate
124
  allow_lm_batch: Whether to allow batch processing in LM
125
  use_random_seed: Whether to use random seed
126
- seed: Seed(s) for batch generation. Can be:
127
  - None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
128
  - List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
129
  - int: Single seed value (will be converted to list and padded)
130
  lm_batch_chunk_size: Batch chunk size for LM processing
131
- is_format_caption: Whether to format caption
132
  constrained_decoding_debug: Whether to enable constrained decoding debug
133
  audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
134
  """
135
  batch_size: int = 2
136
  allow_lm_batch: bool = False
137
  use_random_seed: bool = True
138
- seed: Optional[Union[int, List[int]]] = None
139
  lm_batch_chunk_size: int = 8
140
- is_format_caption: bool = False
141
- use_constrained_decoding: bool = True
142
  constrained_decoding_debug: bool = False
143
  audio_format: str = "flac" # Default to FLAC for fast saving
144
 
 
 
 
 
 
145
  @dataclass
146
  class GenerationResult:
147
  """Result of music generation.
@@ -149,34 +161,80 @@ class GenerationResult:
149
  Attributes:
150
  # Audio Outputs
151
  audios: List of audio dictionaries with paths, keys, params
152
- generation_info: Markdown-formatted generation information
153
  status_message: Status message from generation
154
  extra_outputs: Extra outputs from generation
155
  success: Whether generation completed successfully
156
  error: Error message if generation failed
157
  """
158
-
159
  # Audio Outputs
160
  audios: List[Dict[str, Any]] = field(default_factory=list)
161
  # Generation Information
162
- generation_info: str = ""
163
  status_message: str = ""
164
  extra_outputs: Dict[str, Any] = field(default_factory=dict)
165
  # Success Status
166
  success: bool = True
167
  error: Optional[str] = None
168
-
169
  def to_dict(self) -> Dict[str, Any]:
170
  """Convert result to dictionary for JSON serialization."""
171
  return asdict(self)
172
 
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  def generate_music(
175
  dit_handler,
176
  llm_handler,
177
  params: GenerationParams,
178
  config: GenerationConfig,
179
  save_dir: Optional[str] = None,
 
180
  ) -> GenerationResult:
181
  """Generate music using ACE-Step model with optional LM reasoning.
182
 
@@ -194,24 +252,31 @@ def generate_music(
194
  audio_code_string_to_use = params.audio_codes
195
  lm_generated_metadata = None
196
  lm_generated_audio_codes_list = []
197
-
 
 
 
 
 
198
  # Extract mutable copies of metadata (will be updated by LM if needed)
199
  bpm = params.bpm
200
  key_scale = params.keyscale
201
  time_signature = params.timesignature
202
  audio_duration = params.duration
203
-
 
 
204
  # Determine if we need to generate audio codes
205
  # If user has provided audio_codes, we don't need to generate them
206
  # Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
207
  user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
208
-
209
  # Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
210
  # For now, we use "llm_dit" if batch mode or if user hasn't provided codes
211
  # Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
212
  # Note: This logic can be refined based on specific requirements
213
  need_audio_codes = not user_provided_audio_codes
214
-
215
  # Determine if we should use chunk-based LM generation (always use chunks for consistency)
216
  # Determine actual batch size for chunk processing
217
  actual_batch_size = config.batch_size if config.batch_size is not None else 1
@@ -219,80 +284,75 @@ def generate_music(
219
  # Prepare seeds for batch generation
220
  # Use config.seed if provided, otherwise fallback to params.seed
221
  # Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
222
- seed_for_generation = params.seed # Default fallback
223
- if config.seed is not None:
224
- if isinstance(config.seed, list):
225
  # Convert List[int] to comma-separated string
226
- seed_for_generation = ",".join(str(s) for s in config.seed)
227
- elif isinstance(config.seed, int):
228
- # Single int seed
229
- seed_for_generation = config.seed
230
-
231
  # Use dit_handler.prepare_seeds to handle seed list generation and padding
232
  # This will handle all the logic: padding with random seeds if needed, etc.
233
- actual_seed_list, _ = dit_handler.prepare_seeds(
234
- actual_batch_size, seed_for_generation, config.use_random_seed
235
- )
236
 
237
  # LM-based Chain-of-Thought reasoning
238
- if params.thinking and llm_handler.llm_initialized and params.use_cot_metas:
239
- # Convert sampling parameters
240
- top_k_value = None if params.lm_top_k == 0 else int(params.lm_top_k)
241
- top_p_value = None if params.lm_top_p >= 1.0 else params.lm_top_p
242
-
 
 
243
  # Build user_metadata from user-provided values
244
  user_metadata = {}
245
  if bpm is not None:
246
  try:
247
  bpm_value = float(bpm)
248
  if bpm_value > 0:
249
- user_metadata['bpm'] = str(int(bpm_value))
250
  except (ValueError, TypeError):
251
  pass
252
-
253
  if key_scale and key_scale.strip():
254
  key_scale_clean = key_scale.strip()
255
  if key_scale_clean.lower() not in ["n/a", ""]:
256
  user_metadata['keyscale'] = key_scale_clean
257
-
258
  if time_signature and time_signature.strip():
259
  time_sig_clean = time_signature.strip()
260
  if time_sig_clean.lower() not in ["n/a", ""]:
261
  user_metadata['timesignature'] = time_sig_clean
262
-
263
  if audio_duration is not None:
264
  try:
265
  duration_value = float(audio_duration)
266
  if duration_value > 0:
267
- user_metadata['duration'] = str(int(duration_value))
268
  except (ValueError, TypeError):
269
  pass
270
-
271
  user_metadata_to_pass = user_metadata if user_metadata else None
272
-
273
  # Determine infer_type based on whether we need audio codes
274
  # - "llm_dit": generates both metas and audio codes (two-phase internally)
275
  # - "dit": generates only metas (single phase)
276
  infer_type = "llm_dit" if need_audio_codes else "dit"
277
-
278
  # Use chunk size from config, or default to batch_size if not set
279
  max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
280
  num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
281
-
282
  all_metadata_list = []
283
  all_audio_codes_list = []
284
-
285
  for chunk_idx in range(num_chunks):
286
  chunk_start = chunk_idx * max_inference_batch_size
287
  chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
288
  chunk_size = chunk_end - chunk_start
289
  chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
290
-
291
- logger.info(
292
- f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
293
- f"(size: {chunk_size}, seeds: {chunk_seeds})"
294
- )
295
-
296
  # Use the determined infer_type
297
  # - "llm_dit" will internally run two phases (metas + codes)
298
  # - "dit" will only run phase 1 (metas only)
@@ -308,25 +368,54 @@ def generate_music(
308
  user_metadata=user_metadata_to_pass,
309
  use_cot_caption=params.use_cot_caption,
310
  use_cot_language=params.use_cot_language,
311
- is_format_caption=config.is_format_caption,
312
- use_constrained_decoding=config.use_constrained_decoding,
313
  constrained_decoding_debug=config.constrained_decoding_debug,
314
  batch_size=chunk_size,
315
  seeds=chunk_seeds,
 
316
  )
317
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  if chunk_size > 1:
319
- metadata_list, audio_codes_list, status = result
 
320
  all_metadata_list.extend(metadata_list)
321
  all_audio_codes_list.extend(audio_codes_list)
322
  else:
323
- metadata, audio_codes, status = result
 
324
  all_metadata_list.append(metadata)
325
  all_audio_codes_list.append(audio_codes)
326
-
 
 
 
 
 
 
 
 
 
 
 
 
327
  lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
328
  lm_generated_audio_codes_list = all_audio_codes_list
329
-
330
  # Set audio_code_string_to_use based on infer_type
331
  if infer_type == "llm_dit":
332
  # If batch mode, use list; otherwise use single string
@@ -337,23 +426,48 @@ def generate_music(
337
  else:
338
  # For "dit" mode, keep user-provided codes or empty
339
  audio_code_string_to_use = params.audio_codes
340
-
341
  # Update metadata from LM if not provided by user
342
  if lm_generated_metadata:
343
- bpm, key_scale, time_signature, audio_duration = _update_metadata_from_lm(
344
- lm_generated_metadata, bpm, key_scale, time_signature, audio_duration
345
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
-
348
  # Phase 2: DiT music generation
349
  # Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
350
  result = dit_handler.generate_music(
351
- captions=params.caption,
352
- lyrics=params.lyrics,
353
  bpm=bpm,
354
  key_scale=key_scale,
355
  time_signature=time_signature,
356
- vocal_language=params.vocal_language,
357
  inference_steps=params.inference_steps,
358
  guidance_scale=params.guidance_scale,
359
  use_random_seed=config.use_random_seed,
@@ -371,110 +485,80 @@ def generate_music(
371
  use_adg=params.use_adg,
372
  cfg_interval_start=params.cfg_interval_start,
373
  cfg_interval_end=params.cfg_interval_end,
 
374
  )
375
-
376
  # Check if generation failed
377
  if not result.get("success", False):
378
  return GenerationResult(
379
  audios=[],
380
- generation_info=result.get("generation_info", ""),
381
  status_message=result.get("status_message", ""),
382
  extra_outputs={},
383
  success=False,
384
  error=result.get("error"),
385
  )
386
-
387
  # Extract results from dit_handler.generate_music dict
388
  dit_audios = result.get("audios", [])
389
- generation_info = result.get("generation_info", "")
390
  status_message = result.get("status_message", "")
391
  dit_extra_outputs = result.get("extra_outputs", {})
392
-
393
- # Append LM metadata to generation info
394
- if lm_generated_metadata:
395
- generation_info = _append_lm_metadata_to_info(generation_info, lm_generated_metadata)
396
-
397
  # Use the seed list already prepared above (from config.seed or params.seed fallback)
398
  # actual_seed_list was computed earlier using dit_handler.prepare_seeds
399
  seed_list = actual_seed_list
400
-
401
  # Get base params dictionary
402
  base_params_dict = params.to_dict()
403
-
404
  # Save audio files using AudioSaver (format from config)
405
  audio_format = config.audio_format if config.audio_format else "flac"
406
  audio_saver = AudioSaver(default_format=audio_format)
407
-
408
  # Use handler's temp_dir for saving files
409
  if save_dir is not None:
410
  os.makedirs(save_dir, exist_ok=True)
411
-
412
  # Build audios list for GenerationResult with params and save files
413
  # Audio saving and UUID generation handled here, outside of handler
414
  audios = []
415
  for idx, dit_audio in enumerate(dit_audios):
416
  # Create a copy of params dict for this audio
417
  audio_params = base_params_dict.copy()
418
-
419
  # Update audio-specific values
420
  audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
421
-
422
  # Add audio codes if batch mode
423
  if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
424
  audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
425
-
426
  # Get audio tensor and metadata
427
  audio_tensor = dit_audio.get("tensor")
428
  sample_rate = dit_audio.get("sample_rate", 48000)
429
-
430
  # Generate UUID for this audio (moved from handler)
431
  batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
432
- audio_code_str = lm_generated_audio_codes_list[idx] if (lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
 
433
  if isinstance(audio_code_str, list):
434
  audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
435
-
436
- audio_key = generate_uuid_from_params(
437
- captions=params.caption,
438
- lyrics=params.lyrics,
439
- bpm=bpm,
440
- key_scale=key_scale,
441
- time_signature=time_signature,
442
- vocal_language=params.vocal_language,
443
- inference_steps=params.inference_steps,
444
- guidance_scale=params.guidance_scale,
445
- seed=batch_seed,
446
- audio_duration=audio_duration,
447
- audio_code_string=audio_code_str,
448
- repainting_start=params.repainting_start,
449
- repainting_end=params.repainting_end,
450
- instruction=params.instruction,
451
- audio_cover_strength=params.audio_cover_strength,
452
- task_type=params.task_type,
453
- use_adg=params.use_adg,
454
- cfg_interval_start=params.cfg_interval_start,
455
- cfg_interval_end=params.cfg_interval_end,
456
- audio_format=audio_format,
457
- reference_audio=params.reference_audio,
458
- src_audio=params.src_audio,
459
- batch_index=idx,
460
- )
461
-
462
  # Save audio file (handled outside handler)
463
  audio_path = None
464
  if audio_tensor is not None and save_dir is not None:
465
  try:
466
  audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
467
- audio_path = audio_saver.save_audio(
468
- audio_tensor,
469
- audio_file,
470
- sample_rate=sample_rate,
471
- format=audio_format,
472
- channels_first=True
473
- )
474
  except Exception as e:
475
  logger.error(f"[generate_music] Failed to save audio file: {e}")
476
  audio_path = "" # Fallback to empty path
477
-
478
  audio_dict = {
479
  "path": audio_path or "", # File path (saved here, not in handler)
480
  "tensor": audio_tensor, # Audio tensor [channels, samples], CPU, float32
@@ -482,259 +566,55 @@ def generate_music(
482
  "sample_rate": sample_rate,
483
  "params": audio_params,
484
  }
485
-
486
  audios.append(audio_dict)
487
-
488
  # Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
489
  extra_outputs = dit_extra_outputs.copy()
490
  extra_outputs["lm_metadata"] = lm_generated_metadata
491
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
  # Create and return GenerationResult
493
  return GenerationResult(
494
  audios=audios,
495
- generation_info=generation_info,
496
  status_message=status_message,
497
  extra_outputs=extra_outputs,
498
  success=True,
499
  error=None,
500
  )
501
-
502
  except Exception as e:
503
  logger.exception("Music generation failed")
504
  return GenerationResult(
505
  audios=[],
506
- generation_info=f"āŒ Generation failed: {str(e)}",
507
  status_message=f"Error: {str(e)}",
508
  extra_outputs={},
509
  success=False,
510
  error=str(e),
511
  )
512
-
513
-
514
- def _update_metadata_from_lm(
515
- metadata: Dict[str, Any],
516
- bpm: Optional[int],
517
- key_scale: str,
518
- time_signature: str,
519
- audio_duration: Optional[float],
520
- ) -> Tuple[Optional[int], str, str, Optional[float]]:
521
- """Update metadata fields from LM output if not provided by user."""
522
-
523
- if bpm is None and metadata.get('bpm'):
524
- bpm_value = metadata.get('bpm')
525
- if bpm_value not in ["N/A", ""]:
526
- try:
527
- bpm = int(bpm_value)
528
- except (ValueError, TypeError):
529
- pass
530
-
531
- if not key_scale and metadata.get('keyscale'):
532
- key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
533
- if key_scale_value != "N/A":
534
- key_scale = key_scale_value
535
-
536
- if not time_signature and metadata.get('timesignature'):
537
- time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
538
- if time_signature_value != "N/A":
539
- time_signature = time_signature_value
540
-
541
- if audio_duration is None or audio_duration <= 0:
542
- audio_duration_value = metadata.get('duration', -1)
543
- if audio_duration_value not in ["N/A", ""]:
544
- try:
545
- audio_duration = float(audio_duration_value)
546
- except (ValueError, TypeError):
547
- pass
548
-
549
- return bpm, key_scale, time_signature, audio_duration
550
-
551
-
552
- def _append_lm_metadata_to_info(generation_info: str, metadata: Dict[str, Any]) -> str:
553
- """Append LM-generated metadata to generation info string."""
554
-
555
- metadata_lines = []
556
- if metadata.get('bpm'):
557
- metadata_lines.append(f"- **BPM:** {metadata['bpm']}")
558
- if metadata.get('caption'):
559
- metadata_lines.append(f"- **Refined Caption:** {metadata['caption']}")
560
- if metadata.get('duration'):
561
- metadata_lines.append(f"- **Duration:** {metadata['duration']} seconds")
562
- if metadata.get('keyscale'):
563
- metadata_lines.append(f"- **Key Scale:** {metadata['keyscale']}")
564
- if metadata.get('language'):
565
- metadata_lines.append(f"- **Language:** {metadata['language']}")
566
- if metadata.get('timesignature'):
567
- metadata_lines.append(f"- **Time Signature:** {metadata['timesignature']}")
568
-
569
- if metadata_lines:
570
- metadata_section = "\n\n**šŸ¤– LM-Generated Metadata:**\n" + "\n\n".join(metadata_lines)
571
- return metadata_section + "\n\n" + generation_info
572
-
573
- return generation_info
574
-
575
-
576
- # ============================================================================
577
- # LEGACY GRADIO UI COMPATIBILITY LAYER
578
- # ============================================================================
579
-
580
- def generate_for_gradio(
581
- dit_handler,
582
- llm_handler,
583
- captions,
584
- lyrics,
585
- bpm,
586
- key_scale,
587
- time_signature,
588
- vocal_language,
589
- inference_steps,
590
- guidance_scale,
591
- random_seed_checkbox,
592
- seed,
593
- reference_audio,
594
- audio_duration,
595
- batch_size_input,
596
- src_audio,
597
- text2music_audio_code_string,
598
- repainting_start,
599
- repainting_end,
600
- instruction_display_gen,
601
- audio_cover_strength,
602
- task_type,
603
- use_adg,
604
- cfg_interval_start,
605
- cfg_interval_end,
606
- audio_format,
607
- lm_temperature,
608
- think_checkbox,
609
- lm_cfg_scale,
610
- lm_top_k,
611
- lm_top_p,
612
- lm_negative_prompt,
613
- use_cot_metas,
614
- use_cot_caption,
615
- use_cot_language,
616
- is_format_caption,
617
- constrained_decoding_debug,
618
- allow_lm_batch,
619
- lm_batch_chunk_size,
620
- ):
621
- """Legacy Gradio UI compatibility wrapper.
622
-
623
- This function maintains backward compatibility with the Gradio UI.
624
- For new integrations, use generate_music() with GenerationConfig instead.
625
-
626
- Returns:
627
- Tuple with 28 elements for Gradio UI component updates
628
- """
629
-
630
- # Convert legacy parameters to GenerationParams and GenerationConfig
631
- params = GenerationParams(
632
- caption=captions,
633
- lyrics=lyrics,
634
- bpm=bpm,
635
- keyscale=key_scale,
636
- timesignature=time_signature,
637
- vocal_language=vocal_language,
638
- audio_codes=text2music_audio_code_string,
639
- duration=audio_duration,
640
- inference_steps=inference_steps,
641
- guidance_scale=guidance_scale,
642
- seed=seed,
643
- use_adg=use_adg,
644
- cfg_interval_start=cfg_interval_start,
645
- cfg_interval_end=cfg_interval_end,
646
- audio_format=audio_format,
647
- task_type=task_type,
648
- reference_audio=reference_audio,
649
- src_audio=src_audio,
650
- repainting_start=repainting_start,
651
- repainting_end=repainting_end,
652
- audio_cover_strength=audio_cover_strength,
653
- instruction=instruction_display_gen,
654
- thinking=think_checkbox,
655
- lm_temperature=lm_temperature,
656
- lm_cfg_scale=lm_cfg_scale,
657
- lm_top_k=lm_top_k,
658
- lm_top_p=lm_top_p,
659
- lm_negative_prompt=lm_negative_prompt,
660
- use_cot_metas=use_cot_metas,
661
- use_cot_caption=use_cot_caption,
662
- use_cot_language=use_cot_language,
663
- )
664
-
665
- config = GenerationConfig(batch_size=1)
666
- config.batch_size = batch_size_input
667
- config.use_random_seed = random_seed_checkbox
668
- config.allow_lm_batch = allow_lm_batch
669
- config.lm_batch_chunk_size = lm_batch_chunk_size
670
- config.is_format_caption = is_format_caption
671
- config.constrained_decoding_debug = constrained_decoding_debug
672
-
673
- # Call new API
674
- result = generate_music(dit_handler, llm_handler, params, config)
675
-
676
- # Extract audio paths from result.audios
677
- audio_paths = [audio["path"] for audio in result.audios]
678
-
679
- # Extract extra outputs
680
- extra_outputs = result.extra_outputs
681
- seed_value = extra_outputs.get("seed_value", "")
682
- lm_metadata = extra_outputs.get("lm_metadata", None)
683
-
684
- # Legacy alignment fields (no longer used, set to empty/None)
685
- align_score_1 = ""
686
- align_text_1 = ""
687
- align_plot_1 = None
688
- align_score_2 = ""
689
- align_text_2 = ""
690
- align_plot_2 = None
691
-
692
- # Determine which codes to update in UI
693
- if config.allow_lm_batch and lm_metadata:
694
- # Batch mode: extract codes from metadata if available
695
- lm_codes_list = lm_metadata.get('audio_codes_list', [])
696
- updated_audio_codes = lm_codes_list[0] if lm_codes_list else text2music_audio_code_string
697
- codes_outputs = (lm_codes_list + [""] * 8)[:8]
698
- else:
699
- # Single mode
700
- lm_codes = lm_metadata.get('audio_codes', '') if lm_metadata else ''
701
- updated_audio_codes = lm_codes if lm_codes else text2music_audio_code_string
702
- codes_outputs = [""] * 8
703
-
704
- # Prepare audio outputs (up to 8)
705
- audio_outputs = (audio_paths + [None] * 8)[:8]
706
-
707
- # Return tuple for Gradio UI (28 elements)
708
- return (
709
- audio_outputs[0], # generated_audio_1
710
- audio_outputs[1], # generated_audio_2
711
- audio_outputs[2], # generated_audio_3
712
- audio_outputs[3], # generated_audio_4
713
- audio_outputs[4], # generated_audio_5
714
- audio_outputs[5], # generated_audio_6
715
- audio_outputs[6], # generated_audio_7
716
- audio_outputs[7], # generated_audio_8
717
- audio_paths, # generated_audio_batch
718
- result.generation_info,
719
- result.status_message,
720
- seed_value,
721
- align_score_1,
722
- align_text_1,
723
- align_plot_1,
724
- align_score_2,
725
- align_text_2,
726
- align_plot_2,
727
- updated_audio_codes, # Update main audio codes in UI
728
- codes_outputs[0], # text2music_audio_code_string_1
729
- codes_outputs[1], # text2music_audio_code_string_2
730
- codes_outputs[2], # text2music_audio_code_string_3
731
- codes_outputs[3], # text2music_audio_code_string_4
732
- codes_outputs[4], # text2music_audio_code_string_5
733
- codes_outputs[5], # text2music_audio_code_string_6
734
- codes_outputs[6], # text2music_audio_code_string_7
735
- codes_outputs[7], # text2music_audio_code_string_8
736
- lm_metadata, # Store metadata for "Send to src audio" buttons
737
- is_format_caption, # Keep is_format_caption unchanged
738
- )
739
-
740
-
 
67
  # Required Inputs
68
  task_type: str = "text2music"
69
  instruction: str = "Fill the audio semantic mask based on the given conditions:"
70
+
71
  # Audio Uploads
72
  reference_audio: Optional[str] = None
73
  src_audio: Optional[str] = None
74
+
75
  # LM Codes Hints
76
  audio_codes: str = ""
77
+
78
  # Text Inputs
79
  caption: str = ""
80
  lyrics: str = ""
81
  instrumental: bool = False
82
+
83
  # Metadata
84
  vocal_language: str = "unknown"
85
  bpm: Optional[int] = None
 
98
  repainting_start: float = 0.0
99
  repainting_end: float = -1
100
  audio_cover_strength: float = 1.0
101
+
102
  # 5Hz Language Model Parameters
103
  thinking: bool = True
104
  lm_temperature: float = 0.85
 
108
  lm_negative_prompt: str = "NO USER INPUT"
109
  use_cot_metas: bool = True
110
  use_cot_caption: bool = True
111
+ use_cot_lyrics: bool = False # TODO: not used yet
112
  use_cot_language: bool = True
113
+ use_constrained_decoding: bool = True
114
+
115
+ cot_bpm: Optional[int] = None
116
+ cot_keyscale: str = ""
117
+ cot_timesignature: str = ""
118
+ cot_duration: Optional[float] = None
119
+ cot_vocal_language: str = "unknown"
120
+ cot_caption: str = ""
121
+ cot_lyrics: str = ""
122
+
123
  def to_dict(self) -> Dict[str, Any]:
124
  """Convert config to dictionary for JSON serialization."""
125
  return asdict(self)
 
133
  batch_size: Number of audio samples to generate
134
  allow_lm_batch: Whether to allow batch processing in LM
135
  use_random_seed: Whether to use random seed
136
+ seeds: Seed(s) for batch generation. Can be:
137
  - None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
138
  - List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
139
  - int: Single seed value (will be converted to list and padded)
140
  lm_batch_chunk_size: Batch chunk size for LM processing
 
141
  constrained_decoding_debug: Whether to enable constrained decoding debug
142
  audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
143
  """
144
  batch_size: int = 2
145
  allow_lm_batch: bool = False
146
  use_random_seed: bool = True
147
+ seeds: Optional[List[int]] = None
148
  lm_batch_chunk_size: int = 8
 
 
149
  constrained_decoding_debug: bool = False
150
  audio_format: str = "flac" # Default to FLAC for fast saving
151
 
152
+ def to_dict(self) -> Dict[str, Any]:
153
+ """Convert config to dictionary for JSON serialization."""
154
+ return asdict(self)
155
+
156
+
157
  @dataclass
158
  class GenerationResult:
159
  """Result of music generation.
 
161
  Attributes:
162
  # Audio Outputs
163
  audios: List of audio dictionaries with paths, keys, params
 
164
  status_message: Status message from generation
165
  extra_outputs: Extra outputs from generation
166
  success: Whether generation completed successfully
167
  error: Error message if generation failed
168
  """
169
+
170
  # Audio Outputs
171
  audios: List[Dict[str, Any]] = field(default_factory=list)
172
  # Generation Information
 
173
  status_message: str = ""
174
  extra_outputs: Dict[str, Any] = field(default_factory=dict)
175
  # Success Status
176
  success: bool = True
177
  error: Optional[str] = None
178
+
179
  def to_dict(self) -> Dict[str, Any]:
180
  """Convert result to dictionary for JSON serialization."""
181
  return asdict(self)
182
 
183
 
184
+ def _update_metadata_from_lm(
185
+ metadata: Dict[str, Any],
186
+ bpm: Optional[int],
187
+ key_scale: str,
188
+ time_signature: str,
189
+ audio_duration: Optional[float],
190
+ vocal_language: str,
191
+ caption: str,
192
+ lyrics: str,
193
+ ) -> Tuple[Optional[int], str, str, Optional[float]]:
194
+ """Update metadata fields from LM output if not provided by user."""
195
+
196
+ if bpm is None and metadata.get('bpm'):
197
+ bpm_value = metadata.get('bpm')
198
+ if bpm_value not in ["N/A", ""]:
199
+ try:
200
+ bpm = int(bpm_value)
201
+ except (ValueError, TypeError):
202
+ pass
203
+
204
+ if not key_scale and metadata.get('keyscale'):
205
+ key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
206
+ if key_scale_value != "N/A":
207
+ key_scale = key_scale_value
208
+
209
+ if not time_signature and metadata.get('timesignature'):
210
+ time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
211
+ if time_signature_value != "N/A":
212
+ time_signature = time_signature_value
213
+
214
+ if audio_duration is None or audio_duration <= 0:
215
+ audio_duration_value = metadata.get('duration', -1)
216
+ if audio_duration_value not in ["N/A", ""]:
217
+ try:
218
+ audio_duration = float(audio_duration_value)
219
+ except (ValueError, TypeError):
220
+ pass
221
+
222
+ if not vocal_language and metadata.get('vocal_language'):
223
+ vocal_language = metadata.get('vocal_language')
224
+ if not caption and metadata.get('caption'):
225
+ caption = metadata.get('caption')
226
+ if not lyrics and metadata.get('lyrics'):
227
+ lyrics = metadata.get('lyrics')
228
+ return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
229
+
230
+
231
  def generate_music(
232
  dit_handler,
233
  llm_handler,
234
  params: GenerationParams,
235
  config: GenerationConfig,
236
  save_dir: Optional[str] = None,
237
+ progress=None,
238
  ) -> GenerationResult:
239
  """Generate music using ACE-Step model with optional LM reasoning.
240
 
 
252
  audio_code_string_to_use = params.audio_codes
253
  lm_generated_metadata = None
254
  lm_generated_audio_codes_list = []
255
+ lm_total_time_costs = {
256
+ "phase1_time": 0.0,
257
+ "phase2_time": 0.0,
258
+ "total_time": 0.0,
259
+ }
260
+
261
  # Extract mutable copies of metadata (will be updated by LM if needed)
262
  bpm = params.bpm
263
  key_scale = params.keyscale
264
  time_signature = params.timesignature
265
  audio_duration = params.duration
266
+ dit_input_caption = params.caption
267
+ dit_input_vocal_language = params.vocal_language
268
+ dit_input_lyrics = params.lyrics
269
  # Determine if we need to generate audio codes
270
  # If user has provided audio_codes, we don't need to generate them
271
  # Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
272
  user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
273
+
274
  # Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
275
  # For now, we use "llm_dit" if batch mode or if user hasn't provided codes
276
  # Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
277
  # Note: This logic can be refined based on specific requirements
278
  need_audio_codes = not user_provided_audio_codes
279
+
280
  # Determine if we should use chunk-based LM generation (always use chunks for consistency)
281
  # Determine actual batch size for chunk processing
282
  actual_batch_size = config.batch_size if config.batch_size is not None else 1
 
284
  # Prepare seeds for batch generation
285
  # Use config.seed if provided, otherwise fallback to params.seed
286
  # Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
287
+ seed_for_generation = ""
288
+ if config.seeds is not None and len(config.seeds) > 0:
289
+ if isinstance(config.seeds, list):
290
  # Convert List[int] to comma-separated string
291
+ seed_for_generation = ",".join(str(s) for s in config.seeds)
292
+
 
 
 
293
  # Use dit_handler.prepare_seeds to handle seed list generation and padding
294
  # This will handle all the logic: padding with random seeds if needed, etc.
295
+ actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
 
 
296
 
297
  # LM-based Chain-of-Thought reasoning
298
+ use_lm = params.thinking and llm_handler.llm_initialized
299
+ lm_status = []
300
+ if use_lm:
301
+ # Convert sampling parameters - handle None values safely
302
+ top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
303
+ top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p
304
+
305
  # Build user_metadata from user-provided values
306
  user_metadata = {}
307
  if bpm is not None:
308
  try:
309
  bpm_value = float(bpm)
310
  if bpm_value > 0:
311
+ user_metadata['bpm'] = int(bpm_value)
312
  except (ValueError, TypeError):
313
  pass
314
+
315
  if key_scale and key_scale.strip():
316
  key_scale_clean = key_scale.strip()
317
  if key_scale_clean.lower() not in ["n/a", ""]:
318
  user_metadata['keyscale'] = key_scale_clean
319
+
320
  if time_signature and time_signature.strip():
321
  time_sig_clean = time_signature.strip()
322
  if time_sig_clean.lower() not in ["n/a", ""]:
323
  user_metadata['timesignature'] = time_sig_clean
324
+
325
  if audio_duration is not None:
326
  try:
327
  duration_value = float(audio_duration)
328
  if duration_value > 0:
329
+ user_metadata['duration'] = int(duration_value)
330
  except (ValueError, TypeError):
331
  pass
332
+
333
  user_metadata_to_pass = user_metadata if user_metadata else None
334
+
335
  # Determine infer_type based on whether we need audio codes
336
  # - "llm_dit": generates both metas and audio codes (two-phase internally)
337
  # - "dit": generates only metas (single phase)
338
  infer_type = "llm_dit" if need_audio_codes else "dit"
339
+
340
  # Use chunk size from config, or default to batch_size if not set
341
  max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
342
  num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
343
+
344
  all_metadata_list = []
345
  all_audio_codes_list = []
346
+
347
  for chunk_idx in range(num_chunks):
348
  chunk_start = chunk_idx * max_inference_batch_size
349
  chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
350
  chunk_size = chunk_end - chunk_start
351
  chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
352
+
353
+ logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
354
+ f"(size: {chunk_size}, seeds: {chunk_seeds})")
355
+
 
 
356
  # Use the determined infer_type
357
  # - "llm_dit" will internally run two phases (metas + codes)
358
  # - "dit" will only run phase 1 (metas only)
 
368
  user_metadata=user_metadata_to_pass,
369
  use_cot_caption=params.use_cot_caption,
370
  use_cot_language=params.use_cot_language,
371
+ use_cot_metas=params.use_cot_metas,
372
+ use_constrained_decoding=params.use_constrained_decoding,
373
  constrained_decoding_debug=config.constrained_decoding_debug,
374
  batch_size=chunk_size,
375
  seeds=chunk_seeds,
376
+ progress=progress,
377
  )
378
+
379
+ # Check if LM generation failed
380
+ if not result.get("success", False):
381
+ error_msg = result.get("error", "Unknown LM error")
382
+ lm_status.append(f"āŒ LM Error: {error_msg}")
383
+ # Return early with error
384
+ return GenerationResult(
385
+ audios=[],
386
+ status_message=f"āŒ LM generation failed: {error_msg}",
387
+ extra_outputs={},
388
+ success=False,
389
+ error=error_msg,
390
+ )
391
+
392
+ # Extract metadata and audio_codes from result dict
393
  if chunk_size > 1:
394
+ metadata_list = result.get("metadata", [])
395
+ audio_codes_list = result.get("audio_codes", [])
396
  all_metadata_list.extend(metadata_list)
397
  all_audio_codes_list.extend(audio_codes_list)
398
  else:
399
+ metadata = result.get("metadata", {})
400
+ audio_codes = result.get("audio_codes", "")
401
  all_metadata_list.append(metadata)
402
  all_audio_codes_list.append(audio_codes)
403
+
404
+ # Collect time costs from LM extra_outputs
405
+ lm_extra = result.get("extra_outputs", {})
406
+ lm_chunk_time_costs = lm_extra.get("time_costs", {})
407
+ if lm_chunk_time_costs:
408
+ # Accumulate time costs from all chunks
409
+ for key in ["phase1_time", "phase2_time", "total_time"]:
410
+ if key in lm_chunk_time_costs:
411
+ lm_total_time_costs[key] += lm_chunk_time_costs[key]
412
+
413
+ time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()])
414
+ lm_status.append(f"āœ… LM chunk {chunk_idx+1}: {time_str}")
415
+
416
  lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
417
  lm_generated_audio_codes_list = all_audio_codes_list
418
+
419
  # Set audio_code_string_to_use based on infer_type
420
  if infer_type == "llm_dit":
421
  # If batch mode, use list; otherwise use single string
 
426
  else:
427
  # For "dit" mode, keep user-provided codes or empty
428
  audio_code_string_to_use = params.audio_codes
429
+
430
  # Update metadata from LM if not provided by user
431
  if lm_generated_metadata:
432
+ bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm(
433
+ metadata=lm_generated_metadata,
434
+ bpm=bpm,
435
+ key_scale=key_scale,
436
+ time_signature=time_signature,
437
+ audio_duration=audio_duration,
438
+ vocal_language=dit_input_vocal_language,
439
+ caption=dit_input_caption,
440
+ lyrics=dit_input_lyrics)
441
+ if not params.bpm:
442
+ params.cot_bpm = bpm
443
+ if not params.keyscale:
444
+ params.cot_keyscale = key_scale
445
+ if not params.timesignature:
446
+ params.cot_timesignature = time_signature
447
+ if not params.duration:
448
+ params.cot_duration = audio_duration
449
+ if not params.vocal_language:
450
+ params.cot_vocal_language = vocal_language
451
+ if not params.caption:
452
+ params.cot_caption = caption
453
+ if not params.lyrics:
454
+ params.cot_lyrics = lyrics
455
+
456
+ # set cot caption and language if needed
457
+ if params.use_cot_caption:
458
+ dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption)
459
+ if params.use_cot_language:
460
+ dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language)
461
 
 
462
  # Phase 2: DiT music generation
463
  # Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
464
  result = dit_handler.generate_music(
465
+ captions=dit_input_caption,
466
+ lyrics=dit_input_lyrics,
467
  bpm=bpm,
468
  key_scale=key_scale,
469
  time_signature=time_signature,
470
+ vocal_language=dit_input_vocal_language,
471
  inference_steps=params.inference_steps,
472
  guidance_scale=params.guidance_scale,
473
  use_random_seed=config.use_random_seed,
 
485
  use_adg=params.use_adg,
486
  cfg_interval_start=params.cfg_interval_start,
487
  cfg_interval_end=params.cfg_interval_end,
488
+ progress=progress,
489
  )
490
+
491
  # Check if generation failed
492
  if not result.get("success", False):
493
  return GenerationResult(
494
  audios=[],
 
495
  status_message=result.get("status_message", ""),
496
  extra_outputs={},
497
  success=False,
498
  error=result.get("error"),
499
  )
500
+
501
  # Extract results from dit_handler.generate_music dict
502
  dit_audios = result.get("audios", [])
 
503
  status_message = result.get("status_message", "")
504
  dit_extra_outputs = result.get("extra_outputs", {})
505
+
 
 
 
 
506
  # Use the seed list already prepared above (from config.seed or params.seed fallback)
507
  # actual_seed_list was computed earlier using dit_handler.prepare_seeds
508
  seed_list = actual_seed_list
509
+
510
  # Get base params dictionary
511
  base_params_dict = params.to_dict()
512
+
513
  # Save audio files using AudioSaver (format from config)
514
  audio_format = config.audio_format if config.audio_format else "flac"
515
  audio_saver = AudioSaver(default_format=audio_format)
516
+
517
  # Use handler's temp_dir for saving files
518
  if save_dir is not None:
519
  os.makedirs(save_dir, exist_ok=True)
520
+
521
  # Build audios list for GenerationResult with params and save files
522
  # Audio saving and UUID generation handled here, outside of handler
523
  audios = []
524
  for idx, dit_audio in enumerate(dit_audios):
525
  # Create a copy of params dict for this audio
526
  audio_params = base_params_dict.copy()
527
+
528
  # Update audio-specific values
529
  audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
530
+
531
  # Add audio codes if batch mode
532
  if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
533
  audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
534
+
535
  # Get audio tensor and metadata
536
  audio_tensor = dit_audio.get("tensor")
537
  sample_rate = dit_audio.get("sample_rate", 48000)
538
+
539
  # Generate UUID for this audio (moved from handler)
540
  batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
541
+ audio_code_str = lm_generated_audio_codes_list[idx] if (
542
+ lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
543
  if isinstance(audio_code_str, list):
544
  audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
545
+
546
+ audio_key = generate_uuid_from_params(audio_params)
547
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
  # Save audio file (handled outside handler)
549
  audio_path = None
550
  if audio_tensor is not None and save_dir is not None:
551
  try:
552
  audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
553
+ audio_path = audio_saver.save_audio(audio_tensor,
554
+ audio_file,
555
+ sample_rate=sample_rate,
556
+ format=audio_format,
557
+ channels_first=True)
 
 
558
  except Exception as e:
559
  logger.error(f"[generate_music] Failed to save audio file: {e}")
560
  audio_path = "" # Fallback to empty path
561
+
562
  audio_dict = {
563
  "path": audio_path or "", # File path (saved here, not in handler)
564
  "tensor": audio_tensor, # Audio tensor [channels, samples], CPU, float32
 
566
  "sample_rate": sample_rate,
567
  "params": audio_params,
568
  }
569
+
570
  audios.append(audio_dict)
571
+
572
  # Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
573
  extra_outputs = dit_extra_outputs.copy()
574
  extra_outputs["lm_metadata"] = lm_generated_metadata
575
+
576
+ # Merge time_costs from both LM and DiT into a unified dictionary
577
+ unified_time_costs = {}
578
+
579
+ # Add LM time costs (if LM was used)
580
+ if use_lm and lm_total_time_costs:
581
+ for key, value in lm_total_time_costs.items():
582
+ unified_time_costs[f"lm_{key}"] = value
583
+
584
+ # Add DiT time costs (if available)
585
+ dit_time_costs = dit_extra_outputs.get("time_costs", {})
586
+ if dit_time_costs:
587
+ for key, value in dit_time_costs.items():
588
+ unified_time_costs[f"dit_{key}"] = value
589
+
590
+ # Calculate total pipeline time
591
+ if unified_time_costs:
592
+ lm_total = unified_time_costs.get("lm_total_time", 0.0)
593
+ dit_total = unified_time_costs.get("dit_total_time_cost", 0.0)
594
+ unified_time_costs["pipeline_total_time"] = lm_total + dit_total
595
+
596
+ # Update extra_outputs with unified time_costs
597
+ extra_outputs["time_costs"] = unified_time_costs
598
+
599
+ if lm_status:
600
+ status_message = "\n".join(lm_status) + "\n" + status_message
601
+ else:
602
+ status_message = status_message
603
  # Create and return GenerationResult
604
  return GenerationResult(
605
  audios=audios,
 
606
  status_message=status_message,
607
  extra_outputs=extra_outputs,
608
  success=True,
609
  error=None,
610
  )
611
+
612
  except Exception as e:
613
  logger.exception("Music generation failed")
614
  return GenerationResult(
615
  audios=[],
 
616
  status_message=f"Error: {str(e)}",
617
  extra_outputs={},
618
  success=False,
619
  error=str(e),
620
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acestep/llm_inference.py CHANGED
@@ -5,6 +5,7 @@ Handles all LM-related operations including initialization and generation
5
  import os
6
  import traceback
7
  import time
 
8
  from typing import Optional, Dict, Any, Tuple, List, Union
9
  from contextlib import contextmanager
10
 
@@ -309,6 +310,7 @@ class LLMHandler:
309
 
310
  logger.info("loading 5Hz LM tokenizer...")
311
  start_time = time.time()
 
312
  llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True)
313
  logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
314
  self.llm_tokenizer = llm_tokenizer
@@ -796,12 +798,13 @@ class LLMHandler:
796
  constrained_decoding_debug: bool = False,
797
  target_duration: Optional[float] = None,
798
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
 
799
  use_cot_caption: bool = True,
800
  use_cot_language: bool = True,
801
- is_format_caption: bool = False,
802
  batch_size: Optional[int] = None,
803
  seeds: Optional[List[int]] = None,
804
- ) -> Union[Tuple[Dict[str, Any], str, str], Tuple[List[Dict[str, Any]], List[str], str]]:
 
805
  """Two-phase LM generation: CoT generation followed by audio codes generation.
806
 
807
  - infer_type='dit': Phase 1 only - generate CoT and return metas (no audio codes)
@@ -817,20 +820,30 @@ class LLMHandler:
817
  batch_size: Optional batch size for batch generation. If None or 1, returns single result.
818
  If > 1, returns batch results (lists).
819
  seeds: Optional list of seeds for batch generation (for reproducibility).
820
- Only used when batch_size > 1.
821
 
822
  Returns:
823
- If batch_size is None or 1: (metadata, audio_codes, status_msg)
824
- If batch_size > 1: (metadata_list, audio_codes_list, status_msg)
 
 
 
 
825
  """
826
- import time
827
- import random
828
-
 
829
  infer_type = (infer_type or "").strip().lower()
830
  if infer_type not in {"dit", "llm_dit"}:
831
- if batch_size and batch_size > 1:
832
- return [], [], f"āŒ invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
833
- return {}, "", f"āŒ invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
 
 
 
 
 
834
 
835
  # Determine if batch mode
836
  is_batch = batch_size and batch_size > 1
@@ -854,7 +867,8 @@ class LLMHandler:
854
 
855
  # ========== PHASE 1: CoT Generation ==========
856
  # Skip CoT if all metadata are user-provided OR caption is already formatted
857
- if not has_all_metas and not is_format_caption:
 
858
  if is_batch:
859
  logger.info("Batch Phase 1: Generating CoT metadata (once for all items)...")
860
  else:
@@ -893,9 +907,13 @@ class LLMHandler:
893
  phase1_time = time.time() - phase1_start
894
 
895
  if not cot_output_text:
896
- if is_batch:
897
- return [], [], status
898
- return {}, "", status
 
 
 
 
899
 
900
  # Parse metadata from CoT output
901
  metadata, _ = self.parse_lm_output(cot_output_text)
@@ -915,11 +933,31 @@ class LLMHandler:
915
  if infer_type == "dit":
916
  if is_batch:
917
  metadata_list = [metadata.copy() for _ in range(actual_batch_size)]
918
- status_msg = f"āœ… Generated CoT metadata successfully (batch mode)\nFields: {', '.join(metadata.keys())}\nPhase1: {phase1_time:.2f}s"
919
- return metadata_list, [""] * actual_batch_size, status_msg
 
 
 
 
 
 
 
 
 
 
920
  else:
921
- status_msg = f"āœ… Generated CoT metadata successfully\nFields: {', '.join(metadata.keys())}\nPhase1: {phase1_time:.2f}s"
922
- return metadata, "", status_msg
 
 
 
 
 
 
 
 
 
 
923
 
924
  # ========== PHASE 2: Audio Codes Generation ==========
925
  if is_batch:
@@ -935,6 +973,7 @@ class LLMHandler:
935
  formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
936
  logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}")
937
 
 
938
  if is_batch:
939
  # Batch mode: generate codes for all items
940
  formatted_prompts = [formatted_prompt_with_cot] * actual_batch_size
@@ -978,9 +1017,21 @@ class LLMHandler:
978
  seeds=seeds,
979
  )
980
  except Exception as e:
981
- error_msg = f"āŒ Error in batch codes generation: {str(e)}"
982
  logger.error(error_msg)
983
- return [], [], error_msg
 
 
 
 
 
 
 
 
 
 
 
 
984
 
985
  # Parse audio codes from each output
986
  audio_codes_list = []
@@ -996,8 +1047,22 @@ class LLMHandler:
996
  codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list]
997
  logger.info(f"Batch Phase 2 completed in {phase2_time:.2f}s. Generated codes: {codes_counts}")
998
 
999
- status_msg = f"āœ… Batch generation completed ({actual_batch_size} items)\nPhase 1: CoT metadata\nPhase 2: {sum(codes_counts)} total codes ({codes_counts})\nPhase1: {phase1_time:.2f}s, Phase2: {phase2_time:.2f}s"
1000
- return metadata_list, audio_codes_list, status_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1001
  else:
1002
  # Single mode: generate codes for one item
1003
  codes_output_text, status = self.generate_from_formatted_prompt(
@@ -1025,7 +1090,20 @@ class LLMHandler:
1025
  )
1026
 
1027
  if not codes_output_text:
1028
- return metadata, "", status
 
 
 
 
 
 
 
 
 
 
 
 
 
1029
 
1030
  phase2_time = time.time() - phase2_start
1031
 
@@ -1035,8 +1113,21 @@ class LLMHandler:
1035
  codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
1036
  logger.info(f"Phase 2 completed in {phase2_time:.2f}s. Generated {codes_count} audio codes")
1037
 
1038
- status_msg = f"āœ… Generated successfully (2-phase)\nPhase 1: CoT metadata\nPhase 2: {codes_count} audio codes\nPhase1: {phase1_time:.2f}s, Phase2: {phase2_time:.2f}s"
1039
- return metadata, audio_codes, status_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
1040
 
1041
  def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str:
1042
  """
 
5
  import os
6
  import traceback
7
  import time
8
+ import random
9
  from typing import Optional, Dict, Any, Tuple, List, Union
10
  from contextlib import contextmanager
11
 
 
310
 
311
  logger.info("loading 5Hz LM tokenizer...")
312
  start_time = time.time()
313
+ # TODO: load tokenizer too slow, not found solution yet
314
  llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True)
315
  logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
316
  self.llm_tokenizer = llm_tokenizer
 
798
  constrained_decoding_debug: bool = False,
799
  target_duration: Optional[float] = None,
800
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
801
+ use_cot_metas: bool = True,
802
  use_cot_caption: bool = True,
803
  use_cot_language: bool = True,
 
804
  batch_size: Optional[int] = None,
805
  seeds: Optional[List[int]] = None,
806
+ progress=None,
807
+ ) -> Dict[str, Any]:
808
  """Two-phase LM generation: CoT generation followed by audio codes generation.
809
 
810
  - infer_type='dit': Phase 1 only - generate CoT and return metas (no audio codes)
 
820
  batch_size: Optional batch size for batch generation. If None or 1, returns single result.
821
  If > 1, returns batch results (lists).
822
  seeds: Optional list of seeds for batch generation (for reproducibility).
823
+ Only used when batch_size > 1. TODO: not used yet
824
 
825
  Returns:
826
+ Dictionary containing:
827
+ - metadata: Dict or List[Dict] - Generated metadata
828
+ - audio_codes: str or List[str] - Generated audio codes
829
+ - success: bool - Whether generation succeeded
830
+ - error: Optional[str] - Error message if failed
831
+ - extra_outputs: Dict with time_costs and other info
832
  """
833
+ if progress is None:
834
+ def progress(*args, **kwargs):
835
+ pass
836
+
837
  infer_type = (infer_type or "").strip().lower()
838
  if infer_type not in {"dit", "llm_dit"}:
839
+ error_msg = f"invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
840
+ return {
841
+ "metadata": [] if (batch_size and batch_size > 1) else {},
842
+ "audio_codes": [] if (batch_size and batch_size > 1) else "",
843
+ "success": False,
844
+ "error": error_msg,
845
+ "extra_outputs": {"time_costs": {}},
846
+ }
847
 
848
  # Determine if batch mode
849
  is_batch = batch_size and batch_size > 1
 
867
 
868
  # ========== PHASE 1: CoT Generation ==========
869
  # Skip CoT if all metadata are user-provided OR caption is already formatted
870
+ progress(0.1, f"Phase 1: Generating CoT metadata (once for all items)...")
871
+ if not has_all_metas and use_cot_metas:
872
  if is_batch:
873
  logger.info("Batch Phase 1: Generating CoT metadata (once for all items)...")
874
  else:
 
907
  phase1_time = time.time() - phase1_start
908
 
909
  if not cot_output_text:
910
+ return {
911
+ "metadata": [] if is_batch else {},
912
+ "audio_codes": [] if is_batch else "",
913
+ "success": False,
914
+ "error": status,
915
+ "extra_outputs": {"time_costs": {"phase1_time": phase1_time}},
916
+ }
917
 
918
  # Parse metadata from CoT output
919
  metadata, _ = self.parse_lm_output(cot_output_text)
 
933
  if infer_type == "dit":
934
  if is_batch:
935
  metadata_list = [metadata.copy() for _ in range(actual_batch_size)]
936
+ return {
937
+ "metadata": metadata_list,
938
+ "audio_codes": [""] * actual_batch_size,
939
+ "success": True,
940
+ "error": None,
941
+ "extra_outputs": {
942
+ "time_costs": {
943
+ "phase1_time": phase1_time,
944
+ "total_time": phase1_time,
945
+ }
946
+ },
947
+ }
948
  else:
949
+ return {
950
+ "metadata": metadata,
951
+ "audio_codes": "",
952
+ "success": True,
953
+ "error": None,
954
+ "extra_outputs": {
955
+ "time_costs": {
956
+ "phase1_time": phase1_time,
957
+ "total_time": phase1_time,
958
+ }
959
+ },
960
+ }
961
 
962
  # ========== PHASE 2: Audio Codes Generation ==========
963
  if is_batch:
 
973
  formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
974
  logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}")
975
 
976
+ progress(0.5, f"Phase 2: Generating audio codes for {actual_batch_size} items...")
977
  if is_batch:
978
  # Batch mode: generate codes for all items
979
  formatted_prompts = [formatted_prompt_with_cot] * actual_batch_size
 
1017
  seeds=seeds,
1018
  )
1019
  except Exception as e:
1020
+ error_msg = f"Error in batch codes generation: {str(e)}"
1021
  logger.error(error_msg)
1022
+ return {
1023
+ "metadata": [],
1024
+ "audio_codes": [],
1025
+ "success": False,
1026
+ "error": error_msg,
1027
+ "extra_outputs": {
1028
+ "time_costs": {
1029
+ "phase1_time": phase1_time,
1030
+ "phase2_time": 0.0,
1031
+ "total_time": phase1_time,
1032
+ }
1033
+ },
1034
+ }
1035
 
1036
  # Parse audio codes from each output
1037
  audio_codes_list = []
 
1047
  codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list]
1048
  logger.info(f"Batch Phase 2 completed in {phase2_time:.2f}s. Generated codes: {codes_counts}")
1049
 
1050
+ total_time = phase1_time + phase2_time
1051
+ return {
1052
+ "metadata": metadata_list,
1053
+ "audio_codes": audio_codes_list,
1054
+ "success": True,
1055
+ "error": None,
1056
+ "extra_outputs": {
1057
+ "time_costs": {
1058
+ "phase1_time": phase1_time,
1059
+ "phase2_time": phase2_time,
1060
+ "total_time": total_time,
1061
+ },
1062
+ "codes_counts": codes_counts,
1063
+ "total_codes": sum(codes_counts),
1064
+ },
1065
+ }
1066
  else:
1067
  # Single mode: generate codes for one item
1068
  codes_output_text, status = self.generate_from_formatted_prompt(
 
1090
  )
1091
 
1092
  if not codes_output_text:
1093
+ total_time = phase1_time + phase2_time
1094
+ return {
1095
+ "metadata": metadata,
1096
+ "audio_codes": "",
1097
+ "success": False,
1098
+ "error": status,
1099
+ "extra_outputs": {
1100
+ "time_costs": {
1101
+ "phase1_time": phase1_time,
1102
+ "phase2_time": phase2_time,
1103
+ "total_time": total_time,
1104
+ }
1105
+ },
1106
+ }
1107
 
1108
  phase2_time = time.time() - phase2_start
1109
 
 
1113
  codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
1114
  logger.info(f"Phase 2 completed in {phase2_time:.2f}s. Generated {codes_count} audio codes")
1115
 
1116
+ total_time = phase1_time + phase2_time
1117
+ return {
1118
+ "metadata": metadata,
1119
+ "audio_codes": audio_codes,
1120
+ "success": True,
1121
+ "error": None,
1122
+ "extra_outputs": {
1123
+ "time_costs": {
1124
+ "phase1_time": phase1_time,
1125
+ "phase2_time": phase2_time,
1126
+ "total_time": total_time,
1127
+ },
1128
+ "codes_count": codes_count,
1129
+ },
1130
+ }
1131
 
1132
  def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str:
1133
  """
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py CHANGED
@@ -93,6 +93,8 @@ class ModelRunner:
93
  def _allocate_sample_buffers(self):
94
  """Pre-allocate reusable buffers for sampling to avoid repeated tensor creation."""
95
  max_bs = self.config.max_num_seqs
 
 
96
 
97
  # Pre-allocate pinned memory buffers on CPU for fast transfer
98
  # Must explicitly specify device="cpu" since default device may be "cuda"
@@ -107,6 +109,19 @@ class ModelRunner:
107
  self._cpu_positions = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
108
  self._cpu_slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
109
  self._cpu_context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  def exit(self):
112
  if self.world_size > 1:
@@ -227,7 +242,7 @@ class ModelRunner:
227
  if i != seq.num_blocks - 1:
228
  end = start + self.block_size
229
  else:
230
- end = start + seq.last_block_num_tokens
231
  slot_mapping.extend(list(range(start, end)))
232
  if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
233
  block_tables = self.prepare_block_tables(seqs)
@@ -269,19 +284,28 @@ class ModelRunner:
269
  target_seqs = seqs
270
 
271
  # Fill pre-allocated CPU buffers
 
 
 
272
  for i, seq in enumerate(target_seqs):
273
  self._cpu_temperatures[i] = seq.temperature
274
  self._cpu_cfg_scales[i] = seq.cfg_scale
275
  self._cpu_top_ks[i] = seq.top_k if seq.top_k is not None else 0
 
 
276
  self._cpu_top_ps[i] = seq.top_p if seq.top_p is not None else 1.0
 
 
277
  self._cpu_repetition_penalties[i] = seq.repetition_penalty if seq.repetition_penalty is not None else 1.0
 
 
278
 
279
  # Transfer to GPU using sliced views (single batched transfer)
280
  temperatures = self._cpu_temperatures[:num_seqs].cuda(non_blocking=True)
281
  cfg_scales = self._cpu_cfg_scales[:num_seqs].cuda(non_blocking=True)
282
- top_ks = self._cpu_top_ks[:num_seqs].cuda(non_blocking=True)
283
- top_ps = self._cpu_top_ps[:num_seqs].cuda(non_blocking=True)
284
- repetition_penalties = self._cpu_repetition_penalties[:num_seqs].cuda(non_blocking=True)
285
 
286
  return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
287
 
@@ -309,27 +333,15 @@ class ModelRunner:
309
  [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
310
  where uncond_seqi is the paired unconditional sequence of cond_seqi."""
311
  # Check if this is a CFG batch (contains paired conditional and unconditional sequences)
312
- is_cfg_batch = False
313
- if len(seqs) > 0:
314
- # CFG batch if first sequence has cfg_scale > 1.0 and paired_seq
315
- if seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None:
316
- is_cfg_batch = True
317
- # Verify batch structure: first half conditional, second half unconditional
318
- num_cond = len(seqs) // 2
319
- for i in range(num_cond):
320
- if seqs[i].is_unconditional or seqs[i + num_cond].is_unconditional == False:
321
- is_cfg_batch = False
322
- break
323
-
324
  if is_cfg_batch:
325
  # CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
326
  num_cond = len(seqs) // 2
327
  cond_seqs = seqs[:num_cond]
328
- uncond_seqs = seqs[num_cond:]
329
 
330
  # Prepare inputs for both conditional and unconditional (they're already in the batch)
331
- input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
332
- else self.prepare_decode(seqs))
333
  sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
334
  if sample_params is not None:
335
  temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
@@ -380,7 +392,7 @@ class ModelRunner:
380
  logits_cfg[i:i+1] = seq.logits_processor(seq_input_ids, logits_cfg[i:i+1])
381
 
382
  # Prepare input_ids for sampler (for repetition penalty, though we already applied it)
383
- cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
384
 
385
  # Sample from CFG logits
386
  token_ids_cfg = self.sampler(
@@ -389,7 +401,7 @@ class ModelRunner:
389
  top_ks=top_ks if top_ks is not None else None,
390
  top_ps=top_ps if top_ps is not None else None,
391
  repetition_penalties=None, # Already applied above
392
- input_ids=cond_input_ids,
393
  ).tolist()
394
 
395
  # Update logits processor state after sampling
@@ -448,7 +460,7 @@ class ModelRunner:
448
  logits[i] = processed[0]
449
 
450
  # Prepare input_ids for sampler
451
- seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
452
 
453
  token_ids = self.sampler(
454
  logits,
@@ -456,7 +468,7 @@ class ModelRunner:
456
  top_ks=top_ks if top_ks is not None else None,
457
  top_ps=top_ps if top_ps is not None else None,
458
  repetition_penalties=None, # Already applied above
459
- input_ids=seq_input_ids,
460
  ).tolist()
461
 
462
  # Update logits processor state after sampling
 
93
  def _allocate_sample_buffers(self):
94
  """Pre-allocate reusable buffers for sampling to avoid repeated tensor creation."""
95
  max_bs = self.config.max_num_seqs
96
+ max_tokens = self.config.max_num_batched_tokens
97
+ max_num_blocks = (self.config.max_model_len + self.block_size - 1) // self.block_size
98
 
99
  # Pre-allocate pinned memory buffers on CPU for fast transfer
100
  # Must explicitly specify device="cpu" since default device may be "cuda"
 
109
  self._cpu_positions = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
110
  self._cpu_slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
111
  self._cpu_context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
112
+
113
+ # Pre-allocate prefill buffers on CPU with pinned memory (optimization to avoid repeated tensor creation)
114
+ self._cpu_prefill_input_ids = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
115
+ self._cpu_prefill_positions = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
116
+ self._cpu_prefill_cu_seqlens = torch.zeros(max_bs + 1, dtype=torch.int32, device="cpu", pin_memory=True)
117
+ self._cpu_prefill_slot_mapping = torch.zeros(max_tokens, dtype=torch.int32, device="cpu", pin_memory=True)
118
+
119
+ # Pre-allocate block tables buffer (shared by both decode and prefill)
120
+ self._cpu_block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device="cpu", pin_memory=True)
121
+
122
+ # Pre-allocate buffer for sequence token IDs (used in logits processor and sampler)
123
+ # Max length is max_model_len since sequences can be that long
124
+ self._seq_token_ids_buffer = torch.zeros(max_bs, self.config.max_model_len, dtype=torch.int64, device="cpu", pin_memory=True)
125
 
126
  def exit(self):
127
  if self.world_size > 1:
 
242
  if i != seq.num_blocks - 1:
243
  end = start + self.block_size
244
  else:
245
+ end = start + seq.last_block_num_tokens
246
  slot_mapping.extend(list(range(start, end)))
247
  if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
248
  block_tables = self.prepare_block_tables(seqs)
 
284
  target_seqs = seqs
285
 
286
  # Fill pre-allocated CPU buffers
287
+ top_ks_is_zero = True
288
+ top_ps_is_one = True
289
+ repetition_penalties_is_one = True
290
  for i, seq in enumerate(target_seqs):
291
  self._cpu_temperatures[i] = seq.temperature
292
  self._cpu_cfg_scales[i] = seq.cfg_scale
293
  self._cpu_top_ks[i] = seq.top_k if seq.top_k is not None else 0
294
+ if seq.top_k is not None and seq.top_k > 0:
295
+ top_ks_is_zero = False
296
  self._cpu_top_ps[i] = seq.top_p if seq.top_p is not None else 1.0
297
+ if seq.top_p is not None and seq.top_p == 1.0:
298
+ top_ps_is_one = False
299
  self._cpu_repetition_penalties[i] = seq.repetition_penalty if seq.repetition_penalty is not None else 1.0
300
+ if seq.repetition_penalty is not None and seq.repetition_penalty == 1.0:
301
+ repetition_penalties_is_one = False
302
 
303
  # Transfer to GPU using sliced views (single batched transfer)
304
  temperatures = self._cpu_temperatures[:num_seqs].cuda(non_blocking=True)
305
  cfg_scales = self._cpu_cfg_scales[:num_seqs].cuda(non_blocking=True)
306
+ top_ks = self._cpu_top_ks[:num_seqs].cuda(non_blocking=True) if not top_ks_is_zero else None
307
+ top_ps = self._cpu_top_ps[:num_seqs].cuda(non_blocking=True) if not top_ps_is_one else None
308
+ repetition_penalties = self._cpu_repetition_penalties[:num_seqs].cuda(non_blocking=True) if not repetition_penalties_is_one else None
309
 
310
  return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
311
 
 
333
  [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
334
  where uncond_seqi is the paired unconditional sequence of cond_seqi."""
335
  # Check if this is a CFG batch (contains paired conditional and unconditional sequences)
336
+ is_cfg_batch = seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None
 
 
 
 
 
 
 
 
 
 
 
337
  if is_cfg_batch:
338
  # CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
339
  num_cond = len(seqs) // 2
340
  cond_seqs = seqs[:num_cond]
341
+ # uncond_seqs = seqs[num_cond:]
342
 
343
  # Prepare inputs for both conditional and unconditional (they're already in the batch)
344
+ input_ids, positions = (self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs))
 
345
  sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
346
  if sample_params is not None:
347
  temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
 
392
  logits_cfg[i:i+1] = seq.logits_processor(seq_input_ids, logits_cfg[i:i+1])
393
 
394
  # Prepare input_ids for sampler (for repetition penalty, though we already applied it)
395
+ # cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
396
 
397
  # Sample from CFG logits
398
  token_ids_cfg = self.sampler(
 
401
  top_ks=top_ks if top_ks is not None else None,
402
  top_ps=top_ps if top_ps is not None else None,
403
  repetition_penalties=None, # Already applied above
404
+ # input_ids=cond_input_ids,
405
  ).tolist()
406
 
407
  # Update logits processor state after sampling
 
460
  logits[i] = processed[0]
461
 
462
  # Prepare input_ids for sampler
463
+ # seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
464
 
465
  token_ids = self.sampler(
466
  logits,
 
468
  top_ks=top_ks if top_ks is not None else None,
469
  top_ps=top_ps if top_ps is not None else None,
470
  repetition_penalties=None, # Already applied above
471
+ # input_ids=seq_input_ids,
472
  ).tolist()
473
 
474
  # Update logits processor state after sampling
acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py CHANGED
@@ -85,6 +85,7 @@ class Sampler(nn.Module):
85
  def __init__(self):
86
  super().__init__()
87
 
 
88
  def forward(
89
  self,
90
  logits: torch.Tensor,
@@ -102,27 +103,12 @@ class Sampler(nn.Module):
102
  """
103
  # Apply temperature
104
  logits = logits.float().div_(temperatures.unsqueeze(dim=1))
105
-
106
- # Check conditions OUTSIDE compiled code to avoid graph breaks
107
- # These .any() calls cause CPU-GPU sync, but we do it once here
108
- # instead of inside the compiled function
109
- need_topk = top_ks is not None and bool((top_ks > 0).any()) and bool((top_ks < logits.shape[1]).any())
110
- need_topp = top_ps is not None and bool((top_ps < 1.0).any()) and bool((top_ps > 0.0).any())
111
-
112
- if need_topk or need_topp:
113
- # Apply filtering (this part is not compiled due to dynamic control flow)
114
- logits = apply_top_k_top_p(
115
- logits,
116
- top_ks if need_topk else None,
117
- top_ps if need_topp else None,
118
- )
119
-
120
- # Sample using compiled function
121
- return self._sample(logits)
122
-
123
- @torch.compile(dynamic=True)
124
- def _sample(self, logits: torch.Tensor) -> torch.Tensor:
125
- """Compiled sampling kernel - no graph breaks here."""
126
- probs = logits.softmax(dim=-1, dtype=torch.float32)
127
- q = torch.empty_like(probs).exponential_()
128
- return probs.div(q).argmax(dim=-1)
 
85
  def __init__(self):
86
  super().__init__()
87
 
88
+ @torch.compile
89
  def forward(
90
  self,
91
  logits: torch.Tensor,
 
103
  """
104
  # Apply temperature
105
  logits = logits.float().div_(temperatures.unsqueeze(dim=1))
106
+
107
+ logits = apply_top_k_top_p(
108
+ logits,
109
+ top_ks,
110
+ top_ps,
111
+ )
112
+ probs = torch.softmax(logits, dim=-1)
113
+ sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
114
+ return sample_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
profile_inference.py CHANGED
@@ -1,223 +1,682 @@
1
  #!/usr/bin/env python3
2
  """
3
- Profiling script for acestep/inference.py using cProfile
 
 
 
 
 
 
 
4
 
5
  Usage:
6
- python profile_inference.py
7
- python profile_inference.py --warmup
 
 
 
 
 
 
 
8
  """
9
 
10
- import cProfile
11
- import pstats
12
- import io
13
  import time
14
  import argparse
15
  import sys
16
  import os
 
 
 
 
 
17
 
18
  # Add project root to path
19
  project_root = os.path.abspath(os.path.dirname(__file__))
20
  if project_root not in sys.path:
21
  sys.path.insert(0, project_root)
22
 
 
23
  from acestep.inference import generate_music, GenerationParams, GenerationConfig
24
  from acestep.handler import AceStepHandler
25
  from acestep.llm_inference import LLMHandler
26
- import json
27
- from typing import Tuple
28
-
29
-
30
- def profile_with_cprofile(dit_handler, llm_handler, params, config, warmup=False):
31
- """Profile using Python's built-in cProfile.
32
-
33
- Args:
34
- warmup: If True, run once for warmup before profiling (default: False)
35
- """
36
- print("=" * 80)
37
- print("Profiling with cProfile")
38
- print("=" * 80)
39
-
40
- # Warmup run (to exclude PyTorch compilation overhead)
41
- if warmup:
42
- print("\n[Warmup] Running first generation to warm up (PyTorch compilation, etc.)...")
43
- warmup_start = time.time()
44
- params.use_cot_metas = False
45
- config.is_format_caption = True
46
- config.use_constrained_decoding = False
47
- warmup_result = generate_music(dit_handler, llm_handler, params, config, save_dir="./")
48
- warmup_time = time.time() - warmup_start
49
- print(f"[Warmup] Completed in {warmup_time:.2f}s")
50
- if not warmup_result.success:
51
- print(f"[Warmup] ⚠ Warmup generation failed: {warmup_result.error}")
52
- return warmup_result
53
 
54
- # Actual profiling run (first inference)
55
- print("\n[Profiling] Running first generation for profiling...")
56
- profiler = cProfile.Profile()
57
- profiler.enable()
 
 
 
 
 
58
 
59
- profiling_start = time.time()
60
- try:
61
- result = generate_music(dit_handler, llm_handler, params, config, save_dir="./")
62
- finally:
63
- profiler.disable()
64
- profiling_time = time.time() - profiling_start
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- # Create stats
67
- s = io.StringIO()
68
- ps = pstats.Stats(profiler, stream=s)
69
- ps.sort_stats('cumulative')
70
 
71
- print(f"\n[Profiling] Completed in {profiling_time:.2f}s")
72
- print("\nTop 30 functions by cumulative time:")
73
- print("-" * 80)
74
- ps.print_stats(30)
75
 
76
- print("\nTop 30 functions by total time:")
77
- print("-" * 80)
78
- ps.sort_stats('tottime')
79
- ps.print_stats(30)
 
 
 
80
 
81
- # Save detailed report to file
82
- output_file = "profile_cprofile.txt"
83
- with open(output_file, 'w') as f:
84
- # Create a new Stats object with file as stream
85
- ps_file = pstats.Stats(profiler, stream=f)
86
- ps_file.sort_stats('cumulative')
87
- ps_file.print_stats()
88
- print(f"\nDetailed profile saved to: {output_file}")
89
 
90
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
 
93
- def main():
94
- parser = argparse.ArgumentParser(description="Profile acestep/inference.py")
95
- parser.add_argument(
96
- "--checkpoint-dir",
97
- type=str,
98
- default="./checkpoints",
99
- help="Path to checkpoints directory"
100
- )
101
- parser.add_argument(
102
- "--config-path",
103
- type=str,
104
- default="acestep-v15-turbo-rl",
105
- help="Model config path"
106
- )
107
- parser.add_argument(
108
- "--device",
109
- type=str,
110
- default="cuda",
111
- help="Device to use (cuda/cpu)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  )
113
- parser.add_argument(
114
- "--lm-model",
115
- type=str,
116
- default="acestep-5Hz-lm-0.6B-v3",
117
- help="LM model path"
118
  )
119
- parser.add_argument(
120
- "--lm-backend",
121
- type=str,
122
- default="vllm",
123
- help="LM backend"
124
  )
125
- parser.add_argument(
126
- "--warmup",
127
- action="store_true",
128
- help="Enable warmup run before profiling (default: False, profile first run)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  args = parser.parse_args()
132
 
133
- # Initialize handlers
134
- print("Initializing handlers...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  dit_handler = AceStepHandler()
136
  llm_handler = LLMHandler()
137
 
138
- # Initialize DiT
139
- print(" - Initializing DiT model...")
140
  status_dit, success_dit = dit_handler.initialize_service(
141
  project_root=project_root,
142
  config_path=args.config_path,
143
  device=args.device,
 
144
  )
145
  if not success_dit:
146
- print(f" āŒ DiT initialization failed: {status_dit}")
147
  sys.exit(1)
148
- print(" āœ“ DiT model initialized")
149
-
150
- # Initialize LLM
151
- print(" - Initializing LLM model...")
152
- status_llm, success_llm = llm_handler.initialize(
153
- checkpoint_dir=args.checkpoint_dir,
154
- lm_model_path=args.lm_model,
155
- backend=args.lm_backend,
156
- device=args.device,
157
- )
158
- if success_llm:
159
- print(" āœ“ LM model initialized")
160
- else:
161
- print(f" ⚠ LM initialization failed: {status_llm}")
162
 
163
- # Load test parameters from example file (same as acestep/inference.py)
164
- def load_example_config(example_file: str) -> Tuple[GenerationParams, GenerationConfig]:
165
- """Load configuration from an example JSON file."""
166
- try:
167
- with open(example_file, 'r', encoding='utf-8') as f:
168
- data = json.load(f)
169
-
170
- # Convert example format to GenerationParams and GenerationConfig
171
- # Handle time signature format (example uses "4" instead of "4/4")
172
- time_sig = data.get('timesignature', '')
173
-
174
- params = GenerationParams(
175
- caption=data.get('caption', ''),
176
- lyrics=data.get('lyrics', ''),
177
- bpm=data.get('bpm'),
178
- keyscale=data.get('keyscale', ''),
179
- timesignature=time_sig,
180
- vocal_language=data.get('language', 'unknown'),
181
- duration=data.get('duration'),
182
- thinking=data.get('think', False),
183
- inference_steps=data.get('inference_steps', 8),
184
- seed=42,
185
- )
186
-
187
- config = GenerationConfig()
188
- config.batch_size = data.get('batch_size', 1)
189
-
190
- return params, config
191
-
192
- except Exception as e:
193
- print(f" ⚠ Failed to load example file: {e}")
194
- return None, None
195
-
196
- # Load production example (same as acestep/inference.py)
197
- example_file = os.path.join(project_root, "examples", "text2music", "example_05.json")
198
 
 
 
199
  if not os.path.exists(example_file):
200
- print(f"\n āŒ Example file not found: {example_file}")
201
- print(" Please ensure the examples directory exists.")
202
  sys.exit(1)
203
 
204
- print(f"\n Loading example: {os.path.basename(example_file)}")
205
  params, config = load_example_config(example_file)
206
 
207
  if not params or not config:
208
- print(" āŒ Failed to load example configuration")
209
  sys.exit(1)
210
 
211
- print("\n" + "=" * 80)
212
- print("Starting profiling...")
213
- print("=" * 80)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
- result = profile_with_cprofile(dit_handler, llm_handler, params, config, warmup=args.warmup)
 
 
 
 
 
 
216
 
217
- if result and not result.success:
218
- print(f"\n⚠ Generation failed: {result.error}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
 
221
  if __name__ == "__main__":
222
  main()
223
-
 
1
  #!/usr/bin/env python3
2
  """
3
+ Enhanced profiling script for ACE-Step inference with deep LLM analysis
4
+
5
+ This script helps diagnose why LLM generation is slow by tracking:
6
+ 1. Total tokens generated vs expected throughput (200 tokens/sec baseline)
7
+ 2. Per-iteration timing to detect compilation overhead or slow operations
8
+ 3. Constrained decoding overhead
9
+ 4. CFG overhead (2x forward passes)
10
+ 5. Model forward time vs sampling/processing time
11
 
12
  Usage:
13
+ python profile_inference.py # Standard profiling with warmup
14
+ python profile_inference.py --no-warmup # Profile first run (includes compilation)
15
+ python profile_inference.py --llm-debug # Deep LLM performance debugging
16
+ python profile_inference.py --detailed # Add cProfile function-level analysis
17
+
18
+ Inference mode options:
19
+ python profile_inference.py --thinking # Enable CoT for code generation
20
+ python profile_inference.py --use-constrained-decoding # Use FSM constrained decoding
21
+ python profile_inference.py --use-cot-metas # Enable LM to generate metadata via CoT
22
  """
23
 
 
 
 
24
  import time
25
  import argparse
26
  import sys
27
  import os
28
+ from contextlib import contextmanager
29
+ from collections import defaultdict
30
+ import json
31
+ from typing import Tuple, Dict, Any, List
32
+ from functools import wraps
33
 
34
  # Add project root to path
35
  project_root = os.path.abspath(os.path.dirname(__file__))
36
  if project_root not in sys.path:
37
  sys.path.insert(0, project_root)
38
 
39
+ import torch
40
  from acestep.inference import generate_music, GenerationParams, GenerationConfig
41
  from acestep.handler import AceStepHandler
42
  from acestep.llm_inference import LLMHandler
43
+
44
+
45
+ class PreciseTimer:
46
+ """High-precision timer with CUDA synchronization for accurate GPU timing"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ def __init__(self, device="cuda"):
49
+ self.device = device
50
+ self.timings = defaultdict(list)
51
+ self.enabled = True
52
+
53
+ def sync(self):
54
+ """Synchronize CUDA operations for accurate timing"""
55
+ if self.enabled and self.device.startswith("cuda") and torch.cuda.is_available():
56
+ torch.cuda.synchronize()
57
 
58
+ @contextmanager
59
+ def time(self, name: str):
60
+ """Time a code section with CUDA synchronization"""
61
+ if not self.enabled:
62
+ yield
63
+ return
64
+
65
+ self.sync()
66
+ start = time.perf_counter()
67
+ try:
68
+ yield
69
+ finally:
70
+ self.sync()
71
+ elapsed = time.perf_counter() - start
72
+ self.timings[name].append(elapsed)
73
+
74
+ def get_total(self, name: str) -> float:
75
+ """Get total accumulated time for a section"""
76
+ return sum(self.timings.get(name, []))
77
 
78
+ def get_mean(self, name: str) -> float:
79
+ """Get mean time per call for a section"""
80
+ times = self.timings.get(name, [])
81
+ return sum(times) / len(times) if times else 0.0
82
 
83
+ def get_count(self, name: str) -> int:
84
+ """Get number of calls for a section"""
85
+ return len(self.timings.get(name, []))
 
86
 
87
+ def get_all(self, name: str) -> List[float]:
88
+ """Get all timing samples for a section"""
89
+ return self.timings.get(name, [])
90
+
91
+
92
+ class LLMDebugger:
93
+ """Track detailed LLM performance metrics to diagnose slow generation"""
94
 
95
+ def __init__(self):
96
+ self.reset()
 
 
 
 
 
 
97
 
98
+ def reset(self):
99
+ """Reset all metrics"""
100
+ self.total_tokens = 0
101
+ self.generation_start = None
102
+ self.generation_end = None
103
+ self.output_text = ""
104
+ self.prompt_length = 0
105
+
106
+ def start(self, prompt_length: int = 0):
107
+ """Mark generation start"""
108
+ self.generation_start = time.perf_counter()
109
+ self.prompt_length = prompt_length
110
+
111
+ def end(self, output_text: str = ""):
112
+ """Mark generation end and store output"""
113
+ self.generation_end = time.perf_counter()
114
+ self.output_text = output_text
115
+
116
+ def set_token_count(self, count: int):
117
+ """Set total token count"""
118
+ self.total_tokens = count
119
+
120
+ def get_throughput(self) -> float:
121
+ """Calculate actual tokens per second"""
122
+ if self.generation_start and self.generation_end and self.total_tokens > 0:
123
+ total_time = self.generation_end - self.generation_start
124
+ if total_time > 0:
125
+ return self.total_tokens / total_time
126
+ return 0.0
127
+
128
+ def print_analysis(self):
129
+ """Print detailed LLM performance analysis"""
130
+ if not self.generation_start or not self.generation_end:
131
+ return
132
+
133
+ print("\n" + "=" * 100)
134
+ print("šŸ” LLM PERFORMANCE DEEP DIVE")
135
+ print("=" * 100)
136
+
137
+ total_time = self.generation_end - self.generation_start
138
+ throughput = self.get_throughput()
139
+
140
+ # Basic metrics table
141
+ print(f"\n{'Metric':<40} {'Value':<20} {'Notes'}")
142
+ print("-" * 100)
143
+ print(f"{'Total Tokens Generated:':<40} {self.total_tokens:<20} (new tokens only)")
144
+ print(f"{'Prompt Length (estimate):':<40} {self.prompt_length:<20} (input tokens)")
145
+ print(f"{'Total Generation Time:':<40} {total_time:<20.3f} seconds")
146
+ print(f"{'Measured Throughput:':<40} {throughput:<20.1f} tokens/sec")
147
+ print(f"{'Expected Throughput:':<40} {'200':<20} tokens/sec (baseline)")
148
+
149
+ # Calculate performance gap
150
+ if throughput > 0:
151
+ slowdown = 200.0 / throughput
152
+ efficiency = (throughput / 200.0) * 100
153
+ print(f"{'Performance vs Baseline:':<40} {efficiency:<20.1f}% of expected")
154
+ print(f"{'Slowdown Factor:':<40} {slowdown:<20.2f}x slower")
155
+
156
+ # Analyze generated output
157
+ if self.output_text:
158
+ print(f"\n{'Output Analysis:':<40}")
159
+ print(f"{' Output length:':<40} {len(self.output_text):<20} characters")
160
+
161
+ # Count audio codes
162
+ import re
163
+ code_pattern = r'<\|audio_code_\d+\|>'
164
+ codes = re.findall(code_pattern, self.output_text)
165
+ if codes:
166
+ print(f"{' Audio codes generated:':<40} {len(codes):<20} codes")
167
+ print(f"{' Expected audio duration:':<40} {f'~{len(codes)/5:.1f}s':<20} (5 codes per second)")
168
+ if total_time > 0:
169
+ print(f"{' Time per audio code:':<40} {f'{total_time/len(codes)*1000:.1f}ms':<20}")
170
+
171
+ # Check for CoT section
172
+ if '<think>' in self.output_text and '</think>' in self.output_text:
173
+ cot_start = self.output_text.find('<think>')
174
+ cot_end = self.output_text.find('</think>') + 8
175
+ cot_section = self.output_text[cot_start:cot_end]
176
+ cot_token_est = len(cot_section) // 4
177
+ print(f"{' CoT section tokens (estimate):':<40} {f'~{cot_token_est}':<20}")
178
+
179
+ # Diagnostic guidance
180
+ print("\n" + "=" * 100)
181
+ print("šŸ”§ DIAGNOSTIC GUIDANCE")
182
+ print("=" * 100)
183
+
184
+ if throughput < 50:
185
+ print("\nāš ļø CRITICAL: Throughput is extremely low (<50 tokens/sec)")
186
+ print("\nThis is ~4x slower than expected. Likely causes:")
187
+ print(" 1. ā— Constrained decoding FSM overhead")
188
+ print(" → Each token triggers FSM state machine validation")
189
+ print(" → Try: set use_constrained_decoding=False in config")
190
+ print(" 2. ā— CFG with double forward passes")
191
+ print(" → cfg_scale > 1.0 means running model twice per token")
192
+ print(" → Check: params.lm_cfg_scale value")
193
+ print(" 3. ā— Running in eager mode without compilation")
194
+ print(" → PyTorch should compile kernels after warmup")
195
+ print(" → Check: torch._dynamo.config settings")
196
+
197
+ elif throughput < 100:
198
+ print("\nāš ļø WARNING: Throughput is low (50-100 tokens/sec)")
199
+ print("\nLikely causes:")
200
+ print(" 1. Constrained decoding overhead (~30-50% slowdown expected)")
201
+ print(" 2. CFG enabled (2x compute per token if cfg_scale > 1.0)")
202
+ print(" 3. Small model or inefficient GPU utilization")
203
+
204
+ elif throughput < 150:
205
+ print("\nāš ļø Throughput is below baseline but acceptable (100-150 tokens/sec)")
206
+ print("\nMinor overhead from:")
207
+ print(" - Constrained decoding: ~20-30% overhead")
208
+ print(" - Profiling instrumentation: ~5-10% overhead")
209
+
210
+ else:
211
+ print(f"\nāœ“ Throughput is good ({throughput:.1f} tokens/sec)")
212
+ print(" Performance is within acceptable range")
213
 
214
 
215
+ # Global instances
216
+ timer = None
217
+ llm_debugger = None
218
+
219
+
220
+ def wrap_method_with_timing(obj, method_name: str, timing_key: str):
221
+ """Wrap a method with timing instrumentation"""
222
+ original_method = getattr(obj, method_name)
223
+
224
+ @wraps(original_method)
225
+ def timed_wrapper(*args, **kwargs):
226
+ with timer.time(timing_key):
227
+ return original_method(*args, **kwargs)
228
+
229
+ setattr(obj, method_name, timed_wrapper)
230
+ return original_method
231
+
232
+
233
+ def wrap_llm_with_debug_tracking(llm_handler):
234
+ """Wrap LLM generation with detailed performance tracking"""
235
+ original_method = llm_handler.generate_with_stop_condition
236
+
237
+ @wraps(original_method)
238
+ def debug_wrapper(*args, **kwargs):
239
+ # Estimate prompt length
240
+ caption = kwargs.get('caption', args[0] if len(args) > 0 else "")
241
+ lyrics = kwargs.get('lyrics', args[1] if len(args) > 1 else "")
242
+ prompt_estimate = len(caption) + len(lyrics)
243
+ prompt_tokens_estimate = prompt_estimate // 4
244
+
245
+ # Start tracking
246
+ llm_debugger.reset()
247
+ llm_debugger.start(prompt_length=prompt_tokens_estimate)
248
+
249
+ # Call original with timing
250
+ with timer.time('llm_inference'):
251
+ result = original_method(*args, **kwargs)
252
+
253
+ # Extract and analyze output
254
+ output_text = ""
255
+ if isinstance(result, tuple) and len(result) >= 2:
256
+ if isinstance(result[1], list):
257
+ # Batch mode
258
+ output_text = "".join(result[1])
259
+ else:
260
+ # Single mode
261
+ cot_output = ""
262
+ if isinstance(result[0], dict):
263
+ for v in result[0].values():
264
+ if isinstance(v, str):
265
+ cot_output += v
266
+ output_text = cot_output + str(result[1])
267
+
268
+ # Count tokens
269
+ import re
270
+ code_pattern = r'<\|audio_code_\d+\|>'
271
+ codes = re.findall(code_pattern, output_text)
272
+ remaining_text = re.sub(code_pattern, '', output_text)
273
+ cot_tokens_estimate = len(remaining_text) // 4
274
+ total_tokens = len(codes) + cot_tokens_estimate
275
+
276
+ llm_debugger.set_token_count(total_tokens)
277
+ llm_debugger.end(output_text)
278
+
279
+ return result
280
+
281
+ llm_handler.generate_with_stop_condition = debug_wrapper
282
+ return original_method
283
+
284
+
285
+ def instrument_handlers(dit_handler, llm_handler, enable_llm_debug=False):
286
+ """Add timing instrumentation to handler methods"""
287
+ originals = {}
288
+
289
+ # Instrument LLM
290
+ if llm_handler and llm_handler.llm_initialized:
291
+ if enable_llm_debug:
292
+ originals['llm_generate'] = wrap_llm_with_debug_tracking(llm_handler)
293
+ else:
294
+ originals['llm_generate'] = wrap_method_with_timing(
295
+ llm_handler, 'generate_with_stop_condition', 'llm_inference'
296
+ )
297
+
298
+ # Instrument DiT handler
299
+ originals['dit_prepare'] = wrap_method_with_timing(
300
+ dit_handler, 'prepare_batch_data', 'prepare_batch_data'
301
  )
302
+ originals['dit_generate'] = wrap_method_with_timing(
303
+ dit_handler, 'service_generate', 'dit_inference'
 
 
 
304
  )
305
+ originals['dit_decode'] = wrap_method_with_timing(
306
+ dit_handler, 'tiled_decode', 'vae_decode'
 
 
 
307
  )
308
+
309
+ return originals
310
+
311
+
312
+ def restore_handlers(dit_handler, llm_handler, originals):
313
+ """Restore original handler methods after profiling"""
314
+ if llm_handler and 'llm_generate' in originals:
315
+ llm_handler.generate_with_stop_condition = originals['llm_generate']
316
+
317
+ dit_handler.prepare_batch_data = originals['dit_prepare']
318
+ dit_handler.service_generate = originals['dit_generate']
319
+ dit_handler.tiled_decode = originals['dit_decode']
320
+
321
+
322
+ def print_profiling_results(total_time: float, show_llm_debug: bool = False):
323
+ """Print comprehensive profiling results with performance insights"""
324
+ print("\n" + "=" * 100)
325
+ print("šŸŽÆ PROFILING RESULTS")
326
+ print("=" * 100)
327
+
328
+ # Define timing categories
329
+ model_sections = {
330
+ 'llm_inference': 'LLM Inference (5Hz Language Model)',
331
+ 'dit_inference': 'DiT Inference (Diffusion Transformer)',
332
+ 'vae_decode': 'VAE Decode (Audio Decoder)',
333
+ }
334
+
335
+ non_model_sections = {
336
+ 'prepare_batch_data': 'Prepare Batch Data (embedding, formatting)',
337
+ }
338
+
339
+ # Calculate totals
340
+ model_time = sum(timer.get_total(k) for k in model_sections.keys())
341
+ non_model_time = sum(timer.get_total(k) for k in non_model_sections.keys())
342
+ other_time = total_time - model_time - non_model_time
343
+
344
+ # Print summary table
345
+ print(f"\n{'CATEGORY':<50} {'TIME (s)':<12} {'%':<8} {'CALLS':<8}")
346
+ print("-" * 100)
347
+
348
+ # Model time breakdown
349
+ print(f"\n{'šŸ¤– MODEL TIME (Total)':<50} {model_time:<12.3f} {100*model_time/total_time:>6.1f}% {'':<8}")
350
+ for key, desc in model_sections.items():
351
+ t = timer.get_total(key)
352
+ c = timer.get_count(key)
353
+ if c > 0:
354
+ mean = timer.get_mean(key)
355
+ pct = 100 * t / total_time
356
+ print(f" {'ā”œā”€ ' + desc:<48} {t:<12.3f} {pct:>6.1f}% {c:<8} (avg: {mean:.3f}s)")
357
+
358
+ # Non-model time breakdown
359
+ print(f"\n{'āš™ļø NON-MODEL TIME (Total)':<50} {non_model_time:<12.3f} {100*non_model_time/total_time:>6.1f}% {'':<8}")
360
+ for key, desc in non_model_sections.items():
361
+ t = timer.get_total(key)
362
+ c = timer.get_count(key)
363
+ if c > 0:
364
+ mean = timer.get_mean(key)
365
+ pct = 100 * t / total_time
366
+ print(f" {'ā”œā”€ ' + desc:<48} {t:<12.3f} {pct:>6.1f}% {c:<8} (avg: {mean:.3f}s)")
367
+
368
+ # Other time
369
+ if other_time > 0.01:
370
+ pct = 100 * other_time / total_time
371
+ print(f"\n{'šŸ“¦ OTHER TIME (I/O, overhead, audio save)':<50} {other_time:<12.3f} {pct:>6.1f}% {'':<8}")
372
+
373
+ print(f"\n{'šŸ“Š TOTAL TIME':<50} {total_time:<12.3f} {'100.0%':>6} {'':<8}")
374
+
375
+ # Show LLM detailed analysis if enabled
376
+ if show_llm_debug:
377
+ llm_debugger.print_analysis()
378
+
379
+ # Performance insights
380
+ print("\n" + "=" * 100)
381
+ print("šŸ’” PERFORMANCE INSIGHTS")
382
+ print("=" * 100)
383
+
384
+ llm_t = timer.get_total('llm_inference')
385
+ dit_t = timer.get_total('dit_inference')
386
+ vae_t = timer.get_total('vae_decode')
387
+ prep_t = timer.get_total('prepare_batch_data')
388
+
389
+ # Model time insights
390
+ if model_time > 0:
391
+ print(f"\nāœ“ Model operations: {model_time:.3f}s ({100*model_time/total_time:.1f}% of total)")
392
+
393
+ if llm_t > 0:
394
+ print(f" - LLM: {llm_t:.3f}s ({100*llm_t/model_time:.1f}% of model time)")
395
+ if dit_t > 0:
396
+ print(f" - DiT: {dit_t:.3f}s ({100*dit_t/model_time:.1f}% of model time)")
397
+ if vae_t > 0:
398
+ print(f" - VAE: {vae_t:.3f}s ({100*vae_t/model_time:.1f}% of model time)")
399
+
400
+ # LLM bottleneck analysis
401
+ if llm_t > dit_t and llm_t > 5.0:
402
+ print(f"\nāš ļø LLM IS THE BOTTLENECK: {llm_t:.3f}s ({100*llm_t/total_time:.1f}% of total)")
403
+ print(f"\n Possible causes:")
404
+ print(f" 1. Generating too many tokens → use --llm-debug to verify")
405
+ print(f" 2. Constrained decoding overhead → FSM validation per token")
406
+ print(f" 3. CFG overhead → cfg_scale > 1.0 = 2x forward passes")
407
+ print(f" 4. First-token latency → warmup should help")
408
+ print(f" 5. KV cache inefficiency → should be ~5-10ms/token")
409
+
410
+ # Non-model insights
411
+ if non_model_time / total_time > 0.1:
412
+ print(f"\nāš ļø Non-model operations: {non_model_time:.3f}s ({100*non_model_time/total_time:.1f}%)")
413
+ if prep_t > 0.1:
414
+ print(f" - Batch preparation: {prep_t:.3f}s")
415
+
416
+ # I/O overhead
417
+ if other_time / total_time > 0.2:
418
+ print(f"\nāš ļø Overhead/I/O: {other_time:.3f}s ({100*other_time/total_time:.1f}%)")
419
+
420
+ # Recommendations
421
+ print("\n" + "=" * 100)
422
+ print("šŸš€ OPTIMIZATION RECOMMENDATIONS")
423
+ print("=" * 100)
424
+
425
+ if llm_t > dit_t * 2:
426
+ print("\nšŸŽÆ Priority: Optimize LLM")
427
+ print(" 1. Run: python profile_inference.py --llm-debug")
428
+ print(" → Shows exact token count and throughput")
429
+ print(" 2. Check constrained decoding overhead")
430
+ print(" 3. Check CFG scaling (lm_cfg_scale parameter)")
431
+ print(" 4. Profile nanovllm engine step() timing")
432
+ print(" 5. Compare vllm vs transformers backends")
433
+
434
+
435
+ def run_profiled_generation(dit_handler, llm_handler, params, config,
436
+ enable_cprofile=False, enable_llm_debug=False):
437
+ """Execute generation with full profiling instrumentation"""
438
+ # Instrument handlers
439
+ originals = instrument_handlers(dit_handler, llm_handler, enable_llm_debug)
440
+
441
+ try:
442
+ print("\n[Profiling] Starting generation...")
443
+ timer.sync()
444
+ total_start = time.perf_counter()
445
+
446
+ # Optional cProfile
447
+ prof = None
448
+ if enable_cprofile:
449
+ import cProfile
450
+ prof = cProfile.Profile()
451
+ prof.enable()
452
+
453
+ # Run generation
454
+ result = generate_music(dit_handler, llm_handler, params, config, save_dir="./")
455
+
456
+ # Stop timing
457
+ timer.sync()
458
+ total_time = time.perf_counter() - total_start
459
+
460
+ # Save cProfile if enabled
461
+ if enable_cprofile and prof:
462
+ prof.disable()
463
+
464
+ import pstats
465
+ import io
466
+
467
+ output_file = "profile_cprofile_detailed.txt"
468
+ with open(output_file, 'w') as f:
469
+ ps = pstats.Stats(prof, stream=f)
470
+ ps.sort_stats('cumulative')
471
+ ps.print_stats(100)
472
+
473
+ # Print top functions
474
+ print("\n" + "=" * 100)
475
+ print("šŸ“Š TOP 20 FUNCTIONS BY CUMULATIVE TIME (cProfile)")
476
+ print("=" * 100)
477
+ s = io.StringIO()
478
+ ps = pstats.Stats(prof, stream=s)
479
+ ps.sort_stats('cumulative')
480
+ ps.print_stats(20)
481
+ print(s.getvalue())
482
+
483
+ print(f"\nFull report: {output_file}")
484
+
485
+ # Print results
486
+ print_profiling_results(total_time, show_llm_debug=enable_llm_debug)
487
+
488
+ return result, total_time
489
+
490
+ finally:
491
+ restore_handlers(dit_handler, llm_handler, originals)
492
+
493
+
494
+ def load_example_config(example_file: str) -> Tuple[GenerationParams, GenerationConfig]:
495
+ """Load configuration from example JSON file"""
496
+ try:
497
+ with open(example_file, 'r', encoding='utf-8') as f:
498
+ data = json.load(f)
499
+
500
+ params = GenerationParams(
501
+ caption=data.get('caption', ''),
502
+ lyrics=data.get('lyrics', ''),
503
+ bpm=data.get('bpm'),
504
+ keyscale=data.get('keyscale', ''),
505
+ timesignature=data.get('timesignature', ''),
506
+ vocal_language=data.get('language', 'unknown'),
507
+ duration=data.get('duration'),
508
+ thinking=data.get('think', False),
509
+ inference_steps=data.get('inference_steps', 8),
510
+ seed=data.get('seed', 42),
511
+ )
512
+
513
+ config = GenerationConfig(batch_size=data.get('batch_size', 1), seeds=[42])
514
+
515
+ return params, config
516
+
517
+ except Exception as e:
518
+ print(f" āŒ Failed to load: {e}")
519
+ return None, None
520
+
521
+
522
+ def main():
523
+ global timer, llm_debugger
524
+
525
+ parser = argparse.ArgumentParser(
526
+ description="Profile ACE-Step inference with LLM debugging"
527
  )
528
+ parser.add_argument("--checkpoint-dir", type=str, default="./checkpoints")
529
+ parser.add_argument("--config-path", type=str, default="acestep-v15-turbo-rl")
530
+ parser.add_argument("--device", type=str, default="cuda")
531
+ parser.add_argument("--lm-model", type=str, default="acestep-5Hz-lm-0.6B-v3")
532
+ parser.add_argument("--lm-backend", type=str, default="vllm")
533
+ parser.add_argument("--no-warmup", action="store_true")
534
+ parser.add_argument("--detailed", action="store_true")
535
+ parser.add_argument("--llm-debug", action="store_true",
536
+ help="Enable deep LLM debugging (token count, throughput)")
537
+ parser.add_argument("--example", type=str, default="example_05.json")
538
+
539
+ # Inference mode parameters
540
+ parser.add_argument("--thinking", action="store_true",
541
+ help="Enable CoT reasoning for LM to generate audio codes")
542
+ parser.add_argument("--use-constrained-decoding", action="store_true",
543
+ help="Use FSM-based constrained decoding for meta generation")
544
+ parser.add_argument("--use-cot-metas", action="store_true",
545
+ help="Enable LLM to generate music metadata via CoT reasoning")
546
 
547
  args = parser.parse_args()
548
 
549
+ # Initialize
550
+ timer = PreciseTimer(device=args.device)
551
+ llm_debugger = LLMDebugger()
552
+
553
+ print("=" * 100)
554
+ print("šŸŽµ ACE-Step Inference Profiler (LLM Performance Analysis)")
555
+ print("=" * 100)
556
+ print(f"\nConfiguration:")
557
+ print(f" Device: {args.device}")
558
+ print(f" LLM Backend: {args.lm_backend}")
559
+ print(f" LLM Debug: {'Enabled' if args.llm_debug else 'Disabled'}")
560
+ print(f" Warmup: {'Disabled' if args.no_warmup else 'Enabled'}")
561
+ print(f"\nInference Mode:")
562
+ print(f" Thinking (CoT): {'Enabled' if args.thinking else 'Disabled'}")
563
+ print(f" Constrained Decoding: {'Enabled' if args.use_constrained_decoding else 'Disabled'}")
564
+ print(f" Use CoT for Metas: {'Enabled' if args.use_cot_metas else 'Disabled'}")
565
+
566
+ # Initialize models
567
+ print(f"\nInitializing models...")
568
+
569
  dit_handler = AceStepHandler()
570
  llm_handler = LLMHandler()
571
 
572
+ print(" šŸŽ¹ Initializing DiT...")
 
573
  status_dit, success_dit = dit_handler.initialize_service(
574
  project_root=project_root,
575
  config_path=args.config_path,
576
  device=args.device,
577
+ use_flash_attention=True,
578
  )
579
  if not success_dit:
580
+ print(f" āŒ Failed: {status_dit}")
581
  sys.exit(1)
582
+ print(f" āœ“ DiT ready")
 
 
 
 
 
 
 
 
 
 
 
 
 
583
 
584
+ print(" 🧠 Initializing LLM...")
585
+ if args.thinking or args.use_cot_metas:
586
+ status_llm, success_llm = llm_handler.initialize(
587
+ checkpoint_dir=args.checkpoint_dir,
588
+ lm_model_path=args.lm_model,
589
+ backend=args.lm_backend,
590
+ device=args.device,
591
+ )
592
+ if success_llm:
593
+ print(f" āœ“ LLM ready ({args.lm_backend})")
594
+ else:
595
+ print(f" ⚠ Failed: {status_llm}")
596
+ else:
597
+ print(f" āœ“ LLM not initialized (thinking or use_cot_metas is disabled)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598
 
599
+ # Load example
600
+ example_file = os.path.join(project_root, "examples", "text2music", args.example)
601
  if not os.path.exists(example_file):
602
+ print(f"\nāŒ Not found: {example_file}")
 
603
  sys.exit(1)
604
 
605
+ print(f"\nšŸ“„ Loading: {args.example}")
606
  params, config = load_example_config(example_file)
607
 
608
  if not params or not config:
609
+ print("āŒ Failed to load config")
610
  sys.exit(1)
611
 
612
+ print(f" Caption: {params.caption[:60]}...")
613
+ print(f" Batch: {config.batch_size}, Steps: {params.inference_steps}, LLM: {params.thinking}")
614
+
615
+ # Warmup
616
+ if not args.no_warmup:
617
+ print("\n" + "=" * 100)
618
+ print("šŸ”„ WARMUP RUN")
619
+ print("=" * 100)
620
+
621
+ warmup_params = GenerationParams(
622
+ caption=params.caption,
623
+ lyrics=params.lyrics,
624
+ bpm=params.bpm,
625
+ keyscale=params.keyscale,
626
+ timesignature=params.timesignature,
627
+ vocal_language=params.vocal_language,
628
+ duration=params.duration,
629
+ thinking=args.thinking,
630
+ use_cot_metas=args.use_cot_metas,
631
+ inference_steps=params.inference_steps,
632
+ seed=params.seed,
633
+ )
634
+ warmup_config = GenerationConfig(batch_size=1, seeds=[42])
635
+ warmup_config.use_constrained_decoding = args.use_constrained_decoding
636
+
637
+ warmup_start = time.perf_counter()
638
+ warmup_result = generate_music(dit_handler, llm_handler, warmup_params, warmup_config, save_dir="./")
639
+ warmup_time = time.perf_counter() - warmup_start
640
+
641
+ print(f"\nāœ“ Warmup: {warmup_time:.2f}s")
642
+ if not warmup_result.success:
643
+ print(f"āš ļø Warning: {warmup_result.error}")
644
+
645
+ # Reset
646
+ timer = PreciseTimer(device=args.device)
647
+ llm_debugger = LLMDebugger()
648
+
649
+ # Profiling run
650
+ print("\n" + "=" * 100)
651
+ print("ā±ļø PROFILING RUN")
652
+ print("=" * 100)
653
 
654
+ # Apply inference mode settings
655
+ config.use_constrained_decoding = args.use_constrained_decoding
656
+ # Override thinking and use_cot_metas parameters if specified via CLI
657
+ if args.thinking:
658
+ params.thinking = True
659
+ if args.use_cot_metas:
660
+ params.use_cot_metas = True
661
 
662
+ result, total_time = run_profiled_generation(
663
+ dit_handler, llm_handler, params, config,
664
+ enable_cprofile=args.detailed,
665
+ enable_llm_debug=args.llm_debug
666
+ )
667
+
668
+ if not result.success:
669
+ print(f"\nāŒ Failed: {result.error}")
670
+ sys.exit(1)
671
+
672
+ print(f"\nāœ… Success! Generated {len(result.audios)} audio file(s)")
673
+
674
+ # Final tips
675
+ if args.detailed:
676
+ print("\nšŸ’” Check profile_cprofile_detailed.txt for function-level analysis")
677
+ elif not args.llm_debug:
678
+ print("\nšŸ’” Run with --llm-debug to see LLM token count and throughput analysis")
679
 
680
 
681
  if __name__ == "__main__":
682
  main()