Gong Junmin commited on
Commit
06446b3
·
unverified ·
2 Parent(s): da41c7b bdc442a

Merge pull request #4 from ace-step/fix_cover_repaint

Browse files
acestep/gradio_ui/events/results_handlers.py CHANGED
@@ -21,87 +21,149 @@ from acestep.audio_utils import save_audio
21
 
22
  def parse_lrc_to_subtitles(lrc_text: str, total_duration: Optional[float] = None) -> List[Dict[str, Any]]:
23
  """
24
- Parse LRC lyrics text to Gradio subtitles format.
25
 
26
- LRC format: [MM:SS.ss]Lyric text or [MM:SS.ss][MM:SS.ss]Lyric text (with end time)
27
- Gradio subtitles format: [{"text": str, "timestamp": [start, end]}]
28
 
29
  Args:
30
  lrc_text: LRC format lyrics string
31
- total_duration: Total audio duration in seconds (used for last line's end time)
32
 
33
  Returns:
34
- List of subtitle dictionaries for Gradio Audio component
35
  """
36
  if not lrc_text or not lrc_text.strip():
37
  return []
38
 
39
- subtitles = []
40
- lines = lrc_text.strip().split('\n')
41
-
42
  # Regex patterns for LRC timestamps
43
- # Pattern 1: [MM:SS.ss] or [MM:SS.sss] - standard LRC with start time only
44
- # Pattern 2: [MM:SS.ss][MM:SS.ss] - LRC with both start and end time
45
- # Support both 2-digit (centiseconds) and 3-digit (milliseconds) formats
46
  timestamp_pattern = r'\[(\d{2}):(\d{2})\.(\d{2,3})\]'
47
 
48
- parsed_lines = []
 
49
 
50
  for line in lines:
51
  line = line.strip()
52
  if not line:
53
  continue
54
 
55
- # Find all timestamps in the line
56
  timestamps = re.findall(timestamp_pattern, line)
57
  if not timestamps:
58
  continue
59
 
60
- # Remove timestamps from text to get the lyric content
61
  text = re.sub(timestamp_pattern, '', line).strip()
 
 
 
62
  if not text:
63
  continue
64
-
65
- # Parse first timestamp as start time
66
- # Handle both 2-digit (centiseconds, /100) and 3-digit (milliseconds, /1000) formats
67
  start_minutes, start_seconds, start_centiseconds = timestamps[0]
68
  cs = int(start_centiseconds)
 
69
  start_time = int(start_minutes) * 60 + int(start_seconds) + (cs / 100.0 if len(start_centiseconds) == 2 else cs / 1000.0)
70
 
71
- # If there's a second timestamp, use it as end time
72
  end_time = None
73
  if len(timestamps) >= 2:
74
  end_minutes, end_seconds, end_centiseconds = timestamps[1]
75
  cs_end = int(end_centiseconds)
76
  end_time = int(end_minutes) * 60 + int(end_seconds) + (cs_end / 100.0 if len(end_centiseconds) == 2 else cs_end / 1000.0)
77
-
78
- parsed_lines.append({
79
  'start': start_time,
80
- 'end': end_time,
81
  'text': text
82
  })
83
 
84
  # Sort by start time
85
- parsed_lines.sort(key=lambda x: x['start'])
86
-
87
- # Fill in missing end times using next line's start time
88
- for i, line_data in enumerate(parsed_lines):
89
- if line_data['end'] is None:
90
- if i + 1 < len(parsed_lines):
91
- # Use next line's start time as end time
92
- line_data['end'] = parsed_lines[i + 1]['start']
93
- elif total_duration is not None:
94
- # Use total duration for last line
95
- line_data['end'] = total_duration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  else:
97
- # Default: add 5 seconds if no duration info
98
- line_data['end'] = line_data['start'] + 5.0
 
 
 
99
 
 
 
 
 
100
  subtitles.append({
101
- 'text': line_data['text'],
102
- 'timestamp': [line_data['start'], line_data['end']]
103
  })
104
-
105
  return subtitles
106
 
107
 
@@ -351,82 +413,35 @@ def update_navigation_buttons(current_batch, total_batches):
351
  return can_go_previous, can_go_next
352
 
353
  def send_audio_to_src_with_metadata(audio_file, lm_metadata):
354
- """Send generated audio file to src_audio input and populate metadata fields
 
 
 
355
 
356
  Args:
357
  audio_file: Audio file path
358
- lm_metadata: Dictionary containing LM-generated metadata
359
 
360
  Returns:
361
  Tuple of (audio_file, bpm, caption, lyrics, duration, key_scale, language, time_signature, is_format_caption)
 
362
  """
363
  if audio_file is None:
364
- return None, None, None, None, None, None, None, None, True # Keep is_format_caption as True
365
-
366
- # Extract metadata fields if available
367
- bpm_value = None
368
- caption_value = None
369
- lyrics_value = None
370
- duration_value = None
371
- key_scale_value = None
372
- language_value = None
373
- time_signature_value = None
374
-
375
- if lm_metadata:
376
- # BPM
377
- if lm_metadata.get('bpm'):
378
- bpm_str = lm_metadata.get('bpm')
379
- if bpm_str and bpm_str != "N/A":
380
- try:
381
- bpm_value = int(bpm_str)
382
- except (ValueError, TypeError):
383
- pass
384
-
385
- # Caption (Rewritten Caption)
386
- if lm_metadata.get('caption'):
387
- caption_value = lm_metadata.get('caption')
388
-
389
- # Lyrics
390
- if lm_metadata.get('lyrics'):
391
- lyrics_value = lm_metadata.get('lyrics')
392
-
393
- # Duration
394
- if lm_metadata.get('duration'):
395
- duration_str = lm_metadata.get('duration')
396
- if duration_str and duration_str != "N/A":
397
- try:
398
- duration_value = float(duration_str)
399
- except (ValueError, TypeError):
400
- pass
401
-
402
- # KeyScale
403
- if lm_metadata.get('keyscale'):
404
- key_scale_str = lm_metadata.get('keyscale')
405
- if key_scale_str and key_scale_str != "N/A":
406
- key_scale_value = key_scale_str
407
-
408
- # Language
409
- if lm_metadata.get('language'):
410
- language_str = lm_metadata.get('language')
411
- if language_str and language_str != "N/A":
412
- language_value = language_str
413
-
414
- # Time Signature
415
- if lm_metadata.get('timesignature'):
416
- time_sig_str = lm_metadata.get('timesignature')
417
- if time_sig_str and time_sig_str != "N/A":
418
- time_signature_value = time_sig_str
419
 
 
 
420
  return (
421
- audio_file,
422
- bpm_value,
423
- caption_value,
424
- lyrics_value,
425
- duration_value,
426
- key_scale_value,
427
- language_value,
428
- time_signature_value,
429
- True # Set is_format_caption to True (from LM-generated metadata)
430
  )
431
 
432
 
 
21
 
22
  def parse_lrc_to_subtitles(lrc_text: str, total_duration: Optional[float] = None) -> List[Dict[str, Any]]:
23
  """
24
+ Parse LRC lyrics text to Gradio subtitles format with SMART POST-PROCESSING.
25
 
26
+ Fixes the issue where lines starting very close to each other (e.g. Intro/Verse tags)
27
+ disappear too quickly. It merges short lines into the subsequent line.
28
 
29
  Args:
30
  lrc_text: LRC format lyrics string
31
+ total_duration: Total audio duration in seconds
32
 
33
  Returns:
34
+ List of subtitle dictionaries
35
  """
36
  if not lrc_text or not lrc_text.strip():
37
  return []
38
 
 
 
 
39
  # Regex patterns for LRC timestamps
 
 
 
40
  timestamp_pattern = r'\[(\d{2}):(\d{2})\.(\d{2,3})\]'
41
 
42
+ raw_entries = []
43
+ lines = lrc_text.strip().split('\n')
44
 
45
  for line in lines:
46
  line = line.strip()
47
  if not line:
48
  continue
49
 
 
50
  timestamps = re.findall(timestamp_pattern, line)
51
  if not timestamps:
52
  continue
53
 
 
54
  text = re.sub(timestamp_pattern, '', line).strip()
55
+ # Even if text is empty, we might want to capture the timestamp to mark an end,
56
+ # but for subtitles, empty text usually means silence or instrumental.
57
+ # We keep it if it has text, or if it looks like a functional tag.
58
  if not text:
59
  continue
60
+
61
+ # Parse start time
 
62
  start_minutes, start_seconds, start_centiseconds = timestamps[0]
63
  cs = int(start_centiseconds)
64
+ # Handle 2-digit (1/100) vs 3-digit (1/1000)
65
  start_time = int(start_minutes) * 60 + int(start_seconds) + (cs / 100.0 if len(start_centiseconds) == 2 else cs / 1000.0)
66
 
67
+ # Determine explicit end time if present (e.g. [start][end]text)
68
  end_time = None
69
  if len(timestamps) >= 2:
70
  end_minutes, end_seconds, end_centiseconds = timestamps[1]
71
  cs_end = int(end_centiseconds)
72
  end_time = int(end_minutes) * 60 + int(end_seconds) + (cs_end / 100.0 if len(end_centiseconds) == 2 else cs_end / 1000.0)
73
+
74
+ raw_entries.append({
75
  'start': start_time,
76
+ 'explicit_end': end_time,
77
  'text': text
78
  })
79
 
80
  # Sort by start time
81
+ raw_entries.sort(key=lambda x: x['start'])
82
+
83
+ if not raw_entries:
84
+ return []
85
+
86
+ # --- POST-PROCESSING: MERGE SHORT LINES ---
87
+ # Threshold: If a line displays for less than X seconds before the next line, merge them.
88
+ MIN_DISPLAY_DURATION = 2.0 # seconds
89
+
90
+ merged_entries = []
91
+ i = 0
92
+ while i < len(raw_entries):
93
+ current = raw_entries[i]
94
+
95
+ # Look ahead to see if we need to merge multiple lines
96
+ # We act as an accumulator
97
+ combined_text = current['text']
98
+ combined_start = current['start']
99
+ # Default end is strictly the explicit end, or we figure it out later
100
+ combined_explicit_end = current['explicit_end']
101
+
102
+ next_idx = i + 1
103
+
104
+ # While there is a next line, and the gap between current start and next start is too small
105
+ while next_idx < len(raw_entries):
106
+ next_entry = raw_entries[next_idx]
107
+ gap = next_entry['start'] - combined_start
108
+
109
+ # If the gap is smaller than threshold (and the next line doesn't start way later)
110
+ # We merge 'current' into 'next' visually by stacking text
111
+ if gap < MIN_DISPLAY_DURATION:
112
+ # Merge text
113
+ # If text is wrapped in brackets [], likely a tag, separate with space
114
+ # If regular lyrics, maybe newline? Let's use newline for clarity in subtitles.
115
+ combined_text += "\n" + next_entry['text']
116
+
117
+ # The explicit end becomes the next entry's explicit end (if any),
118
+ # effectively extending the block
119
+ if next_entry['explicit_end']:
120
+ combined_explicit_end = next_entry['explicit_end']
121
+
122
+ # Consume this next entry
123
+ next_idx += 1
124
+ else:
125
+ # Gap is big enough, stop merging
126
+ break
127
+
128
+ # Add the (potentially merged) entry
129
+ merged_entries.append({
130
+ 'start': combined_start,
131
+ 'explicit_end': combined_explicit_end,
132
+ 'text': combined_text
133
+ })
134
+
135
+ # Move loop index
136
+ i = next_idx
137
+
138
+ # --- GENERATE FINAL SUBTITLES ---
139
+ subtitles = []
140
+ for i, entry in enumerate(merged_entries):
141
+ start = entry['start']
142
+ text = entry['text']
143
+
144
+ # Determine End Time
145
+ if entry['explicit_end'] is not None:
146
+ end = entry['explicit_end']
147
+ else:
148
+ # If no explicit end, use next line's start
149
+ if i + 1 < len(merged_entries):
150
+ end = merged_entries[i + 1]['start']
151
  else:
152
+ # Last line
153
+ if total_duration is not None and total_duration > start:
154
+ end = total_duration
155
+ else:
156
+ end = start + 5.0 # Default duration for last line
157
 
158
+ # Final safety: Ensure end > start
159
+ if end <= start:
160
+ end = start + 3.0
161
+
162
  subtitles.append({
163
+ 'text': text,
164
+ 'timestamp': [start, end]
165
  })
166
+
167
  return subtitles
168
 
169
 
 
413
  return can_go_previous, can_go_next
414
 
415
  def send_audio_to_src_with_metadata(audio_file, lm_metadata):
416
+ """Send generated audio file to src_audio input WITHOUT modifying other fields
417
+
418
+ This function ONLY sets the src_audio field. All other metadata fields (caption, lyrics, etc.)
419
+ are preserved by returning gr.skip() to avoid overwriting user's existing inputs.
420
 
421
  Args:
422
  audio_file: Audio file path
423
+ lm_metadata: Dictionary containing LM-generated metadata (unused, kept for API compatibility)
424
 
425
  Returns:
426
  Tuple of (audio_file, bpm, caption, lyrics, duration, key_scale, language, time_signature, is_format_caption)
427
+ All values except audio_file are gr.skip() to preserve existing UI values
428
  """
429
  if audio_file is None:
430
+ # Return all skip to not modify anything
431
+ return (gr.skip(),) * 9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
 
433
+ # Only set the audio file, skip all other fields to preserve existing values
434
+ # This ensures user's caption, lyrics, bpm, etc. are NOT cleared
435
  return (
436
+ audio_file, # src_audio - set the audio file
437
+ gr.skip(), # bpm - preserve existing value
438
+ gr.skip(), # caption - preserve existing value
439
+ gr.skip(), # lyrics - preserve existing value
440
+ gr.skip(), # duration - preserve existing value
441
+ gr.skip(), # key_scale - preserve existing value
442
+ gr.skip(), # language - preserve existing value
443
+ gr.skip(), # time_signature - preserve existing value
444
+ gr.skip(), # is_format_caption - preserve existing value
445
  )
446
 
447
 
acestep/handler.py CHANGED
@@ -793,8 +793,11 @@ class AceStepHandler:
793
  Returns:
794
  Latents tensor [T, D] or [batch, T, D]
795
  """
 
 
 
796
  # Ensure batch dimension
797
- if audio.dim() == 2:
798
  audio = audio.unsqueeze(0)
799
 
800
  # Ensure input is in VAE's dtype
@@ -811,7 +814,7 @@ class AceStepHandler:
811
  latents = latents.transpose(1, 2)
812
 
813
  # Remove batch dimension if input didn't have it
814
- if audio.dim() == 2:
815
  latents = latents.squeeze(0)
816
 
817
  return latents
 
793
  Returns:
794
  Latents tensor [T, D] or [batch, T, D]
795
  """
796
+ # Save original dimension info BEFORE modifying audio
797
+ input_was_2d = (audio.dim() == 2)
798
+
799
  # Ensure batch dimension
800
+ if input_was_2d:
801
  audio = audio.unsqueeze(0)
802
 
803
  # Ensure input is in VAE's dtype
 
814
  latents = latents.transpose(1, 2)
815
 
816
  # Remove batch dimension if input didn't have it
817
+ if input_was_2d:
818
  latents = latents.squeeze(0)
819
 
820
  return latents
acestep/inference.py CHANGED
@@ -297,8 +297,14 @@ def generate_music(
297
  actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
298
 
299
  # LM-based Chain-of-Thought reasoning
300
- use_lm = params.thinking and llm_handler.llm_initialized
 
 
 
301
  lm_status = []
 
 
 
302
  if use_lm:
303
  # Convert sampling parameters - handle None values safely
304
  top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
 
297
  actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
298
 
299
  # LM-based Chain-of-Thought reasoning
300
+ # Skip LM for cover/repaint tasks - these tasks use reference/src audio directly
301
+ # and don't need LM to generate audio codes
302
+ skip_lm_tasks = {"cover", "repaint"}
303
+ use_lm = params.thinking and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks
304
  lm_status = []
305
+
306
+ if params.task_type in skip_lm_tasks:
307
+ logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly")
308
  if use_lm:
309
  # Convert sampling parameters - handle None values safely
310
  top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)