Spaces:
Running
on
A100
Running
on
A100
Merge pull request #4 from ace-step/fix_cover_repaint
Browse files- acestep/gradio_ui/events/results_handlers.py +118 -103
- acestep/handler.py +5 -2
- acestep/inference.py +7 -1
acestep/gradio_ui/events/results_handlers.py
CHANGED
|
@@ -21,87 +21,149 @@ from acestep.audio_utils import save_audio
|
|
| 21 |
|
| 22 |
def parse_lrc_to_subtitles(lrc_text: str, total_duration: Optional[float] = None) -> List[Dict[str, Any]]:
|
| 23 |
"""
|
| 24 |
-
Parse LRC lyrics text to Gradio subtitles format.
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
|
| 29 |
Args:
|
| 30 |
lrc_text: LRC format lyrics string
|
| 31 |
-
total_duration: Total audio duration in seconds
|
| 32 |
|
| 33 |
Returns:
|
| 34 |
-
List of subtitle dictionaries
|
| 35 |
"""
|
| 36 |
if not lrc_text or not lrc_text.strip():
|
| 37 |
return []
|
| 38 |
|
| 39 |
-
subtitles = []
|
| 40 |
-
lines = lrc_text.strip().split('\n')
|
| 41 |
-
|
| 42 |
# Regex patterns for LRC timestamps
|
| 43 |
-
# Pattern 1: [MM:SS.ss] or [MM:SS.sss] - standard LRC with start time only
|
| 44 |
-
# Pattern 2: [MM:SS.ss][MM:SS.ss] - LRC with both start and end time
|
| 45 |
-
# Support both 2-digit (centiseconds) and 3-digit (milliseconds) formats
|
| 46 |
timestamp_pattern = r'\[(\d{2}):(\d{2})\.(\d{2,3})\]'
|
| 47 |
|
| 48 |
-
|
|
|
|
| 49 |
|
| 50 |
for line in lines:
|
| 51 |
line = line.strip()
|
| 52 |
if not line:
|
| 53 |
continue
|
| 54 |
|
| 55 |
-
# Find all timestamps in the line
|
| 56 |
timestamps = re.findall(timestamp_pattern, line)
|
| 57 |
if not timestamps:
|
| 58 |
continue
|
| 59 |
|
| 60 |
-
# Remove timestamps from text to get the lyric content
|
| 61 |
text = re.sub(timestamp_pattern, '', line).strip()
|
|
|
|
|
|
|
|
|
|
| 62 |
if not text:
|
| 63 |
continue
|
| 64 |
-
|
| 65 |
-
# Parse
|
| 66 |
-
# Handle both 2-digit (centiseconds, /100) and 3-digit (milliseconds, /1000) formats
|
| 67 |
start_minutes, start_seconds, start_centiseconds = timestamps[0]
|
| 68 |
cs = int(start_centiseconds)
|
|
|
|
| 69 |
start_time = int(start_minutes) * 60 + int(start_seconds) + (cs / 100.0 if len(start_centiseconds) == 2 else cs / 1000.0)
|
| 70 |
|
| 71 |
-
#
|
| 72 |
end_time = None
|
| 73 |
if len(timestamps) >= 2:
|
| 74 |
end_minutes, end_seconds, end_centiseconds = timestamps[1]
|
| 75 |
cs_end = int(end_centiseconds)
|
| 76 |
end_time = int(end_minutes) * 60 + int(end_seconds) + (cs_end / 100.0 if len(end_centiseconds) == 2 else cs_end / 1000.0)
|
| 77 |
-
|
| 78 |
-
|
| 79 |
'start': start_time,
|
| 80 |
-
'
|
| 81 |
'text': text
|
| 82 |
})
|
| 83 |
|
| 84 |
# Sort by start time
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
else:
|
| 97 |
-
#
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
subtitles.append({
|
| 101 |
-
'text':
|
| 102 |
-
'timestamp': [
|
| 103 |
})
|
| 104 |
-
|
| 105 |
return subtitles
|
| 106 |
|
| 107 |
|
|
@@ -351,82 +413,35 @@ def update_navigation_buttons(current_batch, total_batches):
|
|
| 351 |
return can_go_previous, can_go_next
|
| 352 |
|
| 353 |
def send_audio_to_src_with_metadata(audio_file, lm_metadata):
|
| 354 |
-
"""Send generated audio file to src_audio input
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
Args:
|
| 357 |
audio_file: Audio file path
|
| 358 |
-
lm_metadata: Dictionary containing LM-generated metadata
|
| 359 |
|
| 360 |
Returns:
|
| 361 |
Tuple of (audio_file, bpm, caption, lyrics, duration, key_scale, language, time_signature, is_format_caption)
|
|
|
|
| 362 |
"""
|
| 363 |
if audio_file is None:
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
# Extract metadata fields if available
|
| 367 |
-
bpm_value = None
|
| 368 |
-
caption_value = None
|
| 369 |
-
lyrics_value = None
|
| 370 |
-
duration_value = None
|
| 371 |
-
key_scale_value = None
|
| 372 |
-
language_value = None
|
| 373 |
-
time_signature_value = None
|
| 374 |
-
|
| 375 |
-
if lm_metadata:
|
| 376 |
-
# BPM
|
| 377 |
-
if lm_metadata.get('bpm'):
|
| 378 |
-
bpm_str = lm_metadata.get('bpm')
|
| 379 |
-
if bpm_str and bpm_str != "N/A":
|
| 380 |
-
try:
|
| 381 |
-
bpm_value = int(bpm_str)
|
| 382 |
-
except (ValueError, TypeError):
|
| 383 |
-
pass
|
| 384 |
-
|
| 385 |
-
# Caption (Rewritten Caption)
|
| 386 |
-
if lm_metadata.get('caption'):
|
| 387 |
-
caption_value = lm_metadata.get('caption')
|
| 388 |
-
|
| 389 |
-
# Lyrics
|
| 390 |
-
if lm_metadata.get('lyrics'):
|
| 391 |
-
lyrics_value = lm_metadata.get('lyrics')
|
| 392 |
-
|
| 393 |
-
# Duration
|
| 394 |
-
if lm_metadata.get('duration'):
|
| 395 |
-
duration_str = lm_metadata.get('duration')
|
| 396 |
-
if duration_str and duration_str != "N/A":
|
| 397 |
-
try:
|
| 398 |
-
duration_value = float(duration_str)
|
| 399 |
-
except (ValueError, TypeError):
|
| 400 |
-
pass
|
| 401 |
-
|
| 402 |
-
# KeyScale
|
| 403 |
-
if lm_metadata.get('keyscale'):
|
| 404 |
-
key_scale_str = lm_metadata.get('keyscale')
|
| 405 |
-
if key_scale_str and key_scale_str != "N/A":
|
| 406 |
-
key_scale_value = key_scale_str
|
| 407 |
-
|
| 408 |
-
# Language
|
| 409 |
-
if lm_metadata.get('language'):
|
| 410 |
-
language_str = lm_metadata.get('language')
|
| 411 |
-
if language_str and language_str != "N/A":
|
| 412 |
-
language_value = language_str
|
| 413 |
-
|
| 414 |
-
# Time Signature
|
| 415 |
-
if lm_metadata.get('timesignature'):
|
| 416 |
-
time_sig_str = lm_metadata.get('timesignature')
|
| 417 |
-
if time_sig_str and time_sig_str != "N/A":
|
| 418 |
-
time_signature_value = time_sig_str
|
| 419 |
|
|
|
|
|
|
|
| 420 |
return (
|
| 421 |
-
audio_file,
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
)
|
| 431 |
|
| 432 |
|
|
|
|
| 21 |
|
| 22 |
def parse_lrc_to_subtitles(lrc_text: str, total_duration: Optional[float] = None) -> List[Dict[str, Any]]:
|
| 23 |
"""
|
| 24 |
+
Parse LRC lyrics text to Gradio subtitles format with SMART POST-PROCESSING.
|
| 25 |
|
| 26 |
+
Fixes the issue where lines starting very close to each other (e.g. Intro/Verse tags)
|
| 27 |
+
disappear too quickly. It merges short lines into the subsequent line.
|
| 28 |
|
| 29 |
Args:
|
| 30 |
lrc_text: LRC format lyrics string
|
| 31 |
+
total_duration: Total audio duration in seconds
|
| 32 |
|
| 33 |
Returns:
|
| 34 |
+
List of subtitle dictionaries
|
| 35 |
"""
|
| 36 |
if not lrc_text or not lrc_text.strip():
|
| 37 |
return []
|
| 38 |
|
|
|
|
|
|
|
|
|
|
| 39 |
# Regex patterns for LRC timestamps
|
|
|
|
|
|
|
|
|
|
| 40 |
timestamp_pattern = r'\[(\d{2}):(\d{2})\.(\d{2,3})\]'
|
| 41 |
|
| 42 |
+
raw_entries = []
|
| 43 |
+
lines = lrc_text.strip().split('\n')
|
| 44 |
|
| 45 |
for line in lines:
|
| 46 |
line = line.strip()
|
| 47 |
if not line:
|
| 48 |
continue
|
| 49 |
|
|
|
|
| 50 |
timestamps = re.findall(timestamp_pattern, line)
|
| 51 |
if not timestamps:
|
| 52 |
continue
|
| 53 |
|
|
|
|
| 54 |
text = re.sub(timestamp_pattern, '', line).strip()
|
| 55 |
+
# Even if text is empty, we might want to capture the timestamp to mark an end,
|
| 56 |
+
# but for subtitles, empty text usually means silence or instrumental.
|
| 57 |
+
# We keep it if it has text, or if it looks like a functional tag.
|
| 58 |
if not text:
|
| 59 |
continue
|
| 60 |
+
|
| 61 |
+
# Parse start time
|
|
|
|
| 62 |
start_minutes, start_seconds, start_centiseconds = timestamps[0]
|
| 63 |
cs = int(start_centiseconds)
|
| 64 |
+
# Handle 2-digit (1/100) vs 3-digit (1/1000)
|
| 65 |
start_time = int(start_minutes) * 60 + int(start_seconds) + (cs / 100.0 if len(start_centiseconds) == 2 else cs / 1000.0)
|
| 66 |
|
| 67 |
+
# Determine explicit end time if present (e.g. [start][end]text)
|
| 68 |
end_time = None
|
| 69 |
if len(timestamps) >= 2:
|
| 70 |
end_minutes, end_seconds, end_centiseconds = timestamps[1]
|
| 71 |
cs_end = int(end_centiseconds)
|
| 72 |
end_time = int(end_minutes) * 60 + int(end_seconds) + (cs_end / 100.0 if len(end_centiseconds) == 2 else cs_end / 1000.0)
|
| 73 |
+
|
| 74 |
+
raw_entries.append({
|
| 75 |
'start': start_time,
|
| 76 |
+
'explicit_end': end_time,
|
| 77 |
'text': text
|
| 78 |
})
|
| 79 |
|
| 80 |
# Sort by start time
|
| 81 |
+
raw_entries.sort(key=lambda x: x['start'])
|
| 82 |
+
|
| 83 |
+
if not raw_entries:
|
| 84 |
+
return []
|
| 85 |
+
|
| 86 |
+
# --- POST-PROCESSING: MERGE SHORT LINES ---
|
| 87 |
+
# Threshold: If a line displays for less than X seconds before the next line, merge them.
|
| 88 |
+
MIN_DISPLAY_DURATION = 2.0 # seconds
|
| 89 |
+
|
| 90 |
+
merged_entries = []
|
| 91 |
+
i = 0
|
| 92 |
+
while i < len(raw_entries):
|
| 93 |
+
current = raw_entries[i]
|
| 94 |
+
|
| 95 |
+
# Look ahead to see if we need to merge multiple lines
|
| 96 |
+
# We act as an accumulator
|
| 97 |
+
combined_text = current['text']
|
| 98 |
+
combined_start = current['start']
|
| 99 |
+
# Default end is strictly the explicit end, or we figure it out later
|
| 100 |
+
combined_explicit_end = current['explicit_end']
|
| 101 |
+
|
| 102 |
+
next_idx = i + 1
|
| 103 |
+
|
| 104 |
+
# While there is a next line, and the gap between current start and next start is too small
|
| 105 |
+
while next_idx < len(raw_entries):
|
| 106 |
+
next_entry = raw_entries[next_idx]
|
| 107 |
+
gap = next_entry['start'] - combined_start
|
| 108 |
+
|
| 109 |
+
# If the gap is smaller than threshold (and the next line doesn't start way later)
|
| 110 |
+
# We merge 'current' into 'next' visually by stacking text
|
| 111 |
+
if gap < MIN_DISPLAY_DURATION:
|
| 112 |
+
# Merge text
|
| 113 |
+
# If text is wrapped in brackets [], likely a tag, separate with space
|
| 114 |
+
# If regular lyrics, maybe newline? Let's use newline for clarity in subtitles.
|
| 115 |
+
combined_text += "\n" + next_entry['text']
|
| 116 |
+
|
| 117 |
+
# The explicit end becomes the next entry's explicit end (if any),
|
| 118 |
+
# effectively extending the block
|
| 119 |
+
if next_entry['explicit_end']:
|
| 120 |
+
combined_explicit_end = next_entry['explicit_end']
|
| 121 |
+
|
| 122 |
+
# Consume this next entry
|
| 123 |
+
next_idx += 1
|
| 124 |
+
else:
|
| 125 |
+
# Gap is big enough, stop merging
|
| 126 |
+
break
|
| 127 |
+
|
| 128 |
+
# Add the (potentially merged) entry
|
| 129 |
+
merged_entries.append({
|
| 130 |
+
'start': combined_start,
|
| 131 |
+
'explicit_end': combined_explicit_end,
|
| 132 |
+
'text': combined_text
|
| 133 |
+
})
|
| 134 |
+
|
| 135 |
+
# Move loop index
|
| 136 |
+
i = next_idx
|
| 137 |
+
|
| 138 |
+
# --- GENERATE FINAL SUBTITLES ---
|
| 139 |
+
subtitles = []
|
| 140 |
+
for i, entry in enumerate(merged_entries):
|
| 141 |
+
start = entry['start']
|
| 142 |
+
text = entry['text']
|
| 143 |
+
|
| 144 |
+
# Determine End Time
|
| 145 |
+
if entry['explicit_end'] is not None:
|
| 146 |
+
end = entry['explicit_end']
|
| 147 |
+
else:
|
| 148 |
+
# If no explicit end, use next line's start
|
| 149 |
+
if i + 1 < len(merged_entries):
|
| 150 |
+
end = merged_entries[i + 1]['start']
|
| 151 |
else:
|
| 152 |
+
# Last line
|
| 153 |
+
if total_duration is not None and total_duration > start:
|
| 154 |
+
end = total_duration
|
| 155 |
+
else:
|
| 156 |
+
end = start + 5.0 # Default duration for last line
|
| 157 |
|
| 158 |
+
# Final safety: Ensure end > start
|
| 159 |
+
if end <= start:
|
| 160 |
+
end = start + 3.0
|
| 161 |
+
|
| 162 |
subtitles.append({
|
| 163 |
+
'text': text,
|
| 164 |
+
'timestamp': [start, end]
|
| 165 |
})
|
| 166 |
+
|
| 167 |
return subtitles
|
| 168 |
|
| 169 |
|
|
|
|
| 413 |
return can_go_previous, can_go_next
|
| 414 |
|
| 415 |
def send_audio_to_src_with_metadata(audio_file, lm_metadata):
|
| 416 |
+
"""Send generated audio file to src_audio input WITHOUT modifying other fields
|
| 417 |
+
|
| 418 |
+
This function ONLY sets the src_audio field. All other metadata fields (caption, lyrics, etc.)
|
| 419 |
+
are preserved by returning gr.skip() to avoid overwriting user's existing inputs.
|
| 420 |
|
| 421 |
Args:
|
| 422 |
audio_file: Audio file path
|
| 423 |
+
lm_metadata: Dictionary containing LM-generated metadata (unused, kept for API compatibility)
|
| 424 |
|
| 425 |
Returns:
|
| 426 |
Tuple of (audio_file, bpm, caption, lyrics, duration, key_scale, language, time_signature, is_format_caption)
|
| 427 |
+
All values except audio_file are gr.skip() to preserve existing UI values
|
| 428 |
"""
|
| 429 |
if audio_file is None:
|
| 430 |
+
# Return all skip to not modify anything
|
| 431 |
+
return (gr.skip(),) * 9
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
|
| 433 |
+
# Only set the audio file, skip all other fields to preserve existing values
|
| 434 |
+
# This ensures user's caption, lyrics, bpm, etc. are NOT cleared
|
| 435 |
return (
|
| 436 |
+
audio_file, # src_audio - set the audio file
|
| 437 |
+
gr.skip(), # bpm - preserve existing value
|
| 438 |
+
gr.skip(), # caption - preserve existing value
|
| 439 |
+
gr.skip(), # lyrics - preserve existing value
|
| 440 |
+
gr.skip(), # duration - preserve existing value
|
| 441 |
+
gr.skip(), # key_scale - preserve existing value
|
| 442 |
+
gr.skip(), # language - preserve existing value
|
| 443 |
+
gr.skip(), # time_signature - preserve existing value
|
| 444 |
+
gr.skip(), # is_format_caption - preserve existing value
|
| 445 |
)
|
| 446 |
|
| 447 |
|
acestep/handler.py
CHANGED
|
@@ -793,8 +793,11 @@ class AceStepHandler:
|
|
| 793 |
Returns:
|
| 794 |
Latents tensor [T, D] or [batch, T, D]
|
| 795 |
"""
|
|
|
|
|
|
|
|
|
|
| 796 |
# Ensure batch dimension
|
| 797 |
-
if
|
| 798 |
audio = audio.unsqueeze(0)
|
| 799 |
|
| 800 |
# Ensure input is in VAE's dtype
|
|
@@ -811,7 +814,7 @@ class AceStepHandler:
|
|
| 811 |
latents = latents.transpose(1, 2)
|
| 812 |
|
| 813 |
# Remove batch dimension if input didn't have it
|
| 814 |
-
if
|
| 815 |
latents = latents.squeeze(0)
|
| 816 |
|
| 817 |
return latents
|
|
|
|
| 793 |
Returns:
|
| 794 |
Latents tensor [T, D] or [batch, T, D]
|
| 795 |
"""
|
| 796 |
+
# Save original dimension info BEFORE modifying audio
|
| 797 |
+
input_was_2d = (audio.dim() == 2)
|
| 798 |
+
|
| 799 |
# Ensure batch dimension
|
| 800 |
+
if input_was_2d:
|
| 801 |
audio = audio.unsqueeze(0)
|
| 802 |
|
| 803 |
# Ensure input is in VAE's dtype
|
|
|
|
| 814 |
latents = latents.transpose(1, 2)
|
| 815 |
|
| 816 |
# Remove batch dimension if input didn't have it
|
| 817 |
+
if input_was_2d:
|
| 818 |
latents = latents.squeeze(0)
|
| 819 |
|
| 820 |
return latents
|
acestep/inference.py
CHANGED
|
@@ -297,8 +297,14 @@ def generate_music(
|
|
| 297 |
actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
|
| 298 |
|
| 299 |
# LM-based Chain-of-Thought reasoning
|
| 300 |
-
|
|
|
|
|
|
|
|
|
|
| 301 |
lm_status = []
|
|
|
|
|
|
|
|
|
|
| 302 |
if use_lm:
|
| 303 |
# Convert sampling parameters - handle None values safely
|
| 304 |
top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
|
|
|
|
| 297 |
actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
|
| 298 |
|
| 299 |
# LM-based Chain-of-Thought reasoning
|
| 300 |
+
# Skip LM for cover/repaint tasks - these tasks use reference/src audio directly
|
| 301 |
+
# and don't need LM to generate audio codes
|
| 302 |
+
skip_lm_tasks = {"cover", "repaint"}
|
| 303 |
+
use_lm = params.thinking and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks
|
| 304 |
lm_status = []
|
| 305 |
+
|
| 306 |
+
if params.task_type in skip_lm_tasks:
|
| 307 |
+
logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly")
|
| 308 |
if use_lm:
|
| 309 |
# Convert sampling parameters - handle None values safely
|
| 310 |
top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
|