ChuxiJ commited on
Commit
85c5902
·
1 Parent(s): 06446b3

refact understand_music

Browse files
acestep/gradio_ui/events/generation_handlers.py CHANGED
@@ -13,6 +13,7 @@ from acestep.constants import (
13
  TASK_TYPES_BASE,
14
  )
15
  from acestep.gradio_ui.i18n import t
 
16
 
17
 
18
  def load_metadata(file_obj):
@@ -206,6 +207,9 @@ def load_random_example(task_type: str):
206
  def sample_example_smart(llm_handler, task_type: str, constrained_decoding_debug: bool = False):
207
  """Smart sample function that uses LM if initialized, otherwise falls back to examples
208
 
 
 
 
209
  Args:
210
  llm_handler: LLM handler instance
211
  task_type: The task type (e.g., "text2music")
@@ -216,50 +220,28 @@ def sample_example_smart(llm_handler, task_type: str, constrained_decoding_debug
216
  """
217
  # Check if LM is initialized
218
  if llm_handler.llm_initialized:
219
- # Use LM to generate example
220
  try:
221
- # Generate example using LM with empty input (NO USER INPUT)
222
- metadata, status = llm_handler.understand_audio_from_codes(
223
- audio_codes="NO USER INPUT",
224
- use_constrained_decoding=True,
225
  temperature=0.85,
 
226
  constrained_decoding_debug=constrained_decoding_debug,
227
  )
228
 
229
- if metadata:
230
- caption_value = metadata.get('caption', '')
231
- lyrics_value = metadata.get('lyrics', '')
232
- think_value = True # Always enable think when using LM-generated examples
233
-
234
- # Extract optional metadata fields
235
- bpm_value = None
236
- if 'bpm' in metadata and metadata['bpm'] not in [None, "N/A", ""]:
237
- try:
238
- bpm_value = int(metadata['bpm'])
239
- except (ValueError, TypeError):
240
- pass
241
-
242
- duration_value = None
243
- if 'duration' in metadata and metadata['duration'] not in [None, "N/A", ""]:
244
- try:
245
- duration_value = float(metadata['duration'])
246
- except (ValueError, TypeError):
247
- pass
248
-
249
- keyscale_value = metadata.get('keyscale', '')
250
- if keyscale_value in [None, "N/A"]:
251
- keyscale_value = ''
252
-
253
- language_value = metadata.get('language', '')
254
- if language_value in [None, "N/A"]:
255
- language_value = ''
256
-
257
- timesignature_value = metadata.get('timesignature', '')
258
- if timesignature_value in [None, "N/A"]:
259
- timesignature_value = ''
260
-
261
  gr.Info(t("messages.lm_generated"))
262
- return caption_value, lyrics_value, think_value, bpm_value, duration_value, keyscale_value, language_value, timesignature_value
 
 
 
 
 
 
 
 
 
263
  else:
264
  gr.Warning(t("messages.lm_fallback"))
265
  return load_random_example(task_type)
@@ -437,58 +419,40 @@ def transcribe_audio_codes(llm_handler, audio_code_string, constrained_decoding_
437
  Transcribe audio codes to metadata using LLM understanding.
438
  If audio_code_string is empty, generate a sample example instead.
439
 
 
 
440
  Args:
441
  llm_handler: LLM handler instance
442
  audio_code_string: String containing audio codes (or empty for example generation)
443
  constrained_decoding_debug: Whether to enable debug logging for constrained decoding
444
 
445
  Returns:
446
- Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature)
447
  """
448
- if not llm_handler.llm_initialized:
449
- return t("messages.lm_not_initialized"), "", "", None, None, "", "", ""
450
-
451
- # If codes are empty, this becomes a "generate example" task
452
- # Use "NO USER INPUT" as the input to generate a sample
453
- if not audio_code_string or not audio_code_string.strip():
454
- audio_code_string = "NO USER INPUT"
455
-
456
- # Call LLM understanding
457
- metadata, status = llm_handler.understand_audio_from_codes(
458
  audio_codes=audio_code_string,
459
  use_constrained_decoding=True,
460
  constrained_decoding_debug=constrained_decoding_debug,
461
  )
462
 
463
- # Extract fields for UI update
464
- caption = metadata.get('caption', '')
465
- lyrics = metadata.get('lyrics', '')
466
- bpm = metadata.get('bpm')
467
- duration = metadata.get('duration')
468
- keyscale = metadata.get('keyscale', '')
469
- language = metadata.get('language', '')
470
- timesignature = metadata.get('timesignature', '')
471
-
472
- # Convert to appropriate types
473
- try:
474
- bpm = int(bpm) if bpm and bpm != 'N/A' else None
475
- except:
476
- bpm = None
477
-
478
- try:
479
- duration = float(duration) if duration and duration != 'N/A' else None
480
- except:
481
- duration = None
482
 
483
  return (
484
- status,
485
- caption,
486
- lyrics,
487
- bpm,
488
- duration,
489
- keyscale,
490
- language,
491
- timesignature,
492
  True # Set is_format_caption to True (from Transcribe/LM understanding)
493
  )
494
 
 
13
  TASK_TYPES_BASE,
14
  )
15
  from acestep.gradio_ui.i18n import t
16
+ from acestep.inference import understand_music
17
 
18
 
19
  def load_metadata(file_obj):
 
207
  def sample_example_smart(llm_handler, task_type: str, constrained_decoding_debug: bool = False):
208
  """Smart sample function that uses LM if initialized, otherwise falls back to examples
209
 
210
+ This is a Gradio wrapper that uses the understand_music API from acestep.inference
211
+ to generate examples when LM is available.
212
+
213
  Args:
214
  llm_handler: LLM handler instance
215
  task_type: The task type (e.g., "text2music")
 
220
  """
221
  # Check if LM is initialized
222
  if llm_handler.llm_initialized:
223
+ # Use LM to generate example via understand_music API
224
  try:
225
+ result = understand_music(
226
+ llm_handler=llm_handler,
227
+ audio_codes="NO USER INPUT", # Empty input triggers example generation
 
228
  temperature=0.85,
229
+ use_constrained_decoding=True,
230
  constrained_decoding_debug=constrained_decoding_debug,
231
  )
232
 
233
+ if result.success:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  gr.Info(t("messages.lm_generated"))
235
+ return (
236
+ result.caption,
237
+ result.lyrics,
238
+ True, # Always enable think when using LM-generated examples
239
+ result.bpm,
240
+ result.duration,
241
+ result.keyscale,
242
+ result.language,
243
+ result.timesignature,
244
+ )
245
  else:
246
  gr.Warning(t("messages.lm_fallback"))
247
  return load_random_example(task_type)
 
419
  Transcribe audio codes to metadata using LLM understanding.
420
  If audio_code_string is empty, generate a sample example instead.
421
 
422
+ This is a Gradio wrapper around the understand_music API in acestep.inference.
423
+
424
  Args:
425
  llm_handler: LLM handler instance
426
  audio_code_string: String containing audio codes (or empty for example generation)
427
  constrained_decoding_debug: Whether to enable debug logging for constrained decoding
428
 
429
  Returns:
430
+ Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature, is_format_caption)
431
  """
432
+ # Call the inference API
433
+ result = understand_music(
434
+ llm_handler=llm_handler,
 
 
 
 
 
 
 
435
  audio_codes=audio_code_string,
436
  use_constrained_decoding=True,
437
  constrained_decoding_debug=constrained_decoding_debug,
438
  )
439
 
440
+ # Handle error case with localized message
441
+ if not result.success:
442
+ # Use localized error message for LLM not initialized
443
+ if result.error == "LLM not initialized":
444
+ return t("messages.lm_not_initialized"), "", "", None, None, "", "", "", False
445
+ return result.status_message, "", "", None, None, "", "", "", False
 
 
 
 
 
 
 
 
 
 
 
 
 
446
 
447
  return (
448
+ result.status_message,
449
+ result.caption,
450
+ result.lyrics,
451
+ result.bpm,
452
+ result.duration,
453
+ result.keyscale,
454
+ result.language,
455
+ result.timesignature,
456
  True # Set is_format_caption to True (from Transcribe/LM understanding)
457
  )
458
 
acestep/inference.py CHANGED
@@ -183,6 +183,44 @@ class GenerationResult:
183
  return asdict(self)
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  def _update_metadata_from_lm(
187
  metadata: Dict[str, Any],
188
  bpm: Optional[int],
@@ -627,3 +665,135 @@ def generate_music(
627
  success=False,
628
  error=str(e),
629
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  return asdict(self)
184
 
185
 
186
+ @dataclass
187
+ class UnderstandResult:
188
+ """Result of music understanding from audio codes.
189
+
190
+ Attributes:
191
+ # Metadata Fields
192
+ caption: Generated caption describing the music
193
+ lyrics: Generated or extracted lyrics
194
+ bpm: Beats per minute (None if not detected)
195
+ duration: Duration in seconds (None if not detected)
196
+ keyscale: Musical key (e.g., "C Major")
197
+ language: Vocal language code (e.g., "en", "zh")
198
+ timesignature: Time signature (e.g., "4/4")
199
+
200
+ # Status
201
+ status_message: Status message from understanding
202
+ success: Whether understanding completed successfully
203
+ error: Error message if understanding failed
204
+ """
205
+ # Metadata Fields
206
+ caption: str = ""
207
+ lyrics: str = ""
208
+ bpm: Optional[int] = None
209
+ duration: Optional[float] = None
210
+ keyscale: str = ""
211
+ language: str = ""
212
+ timesignature: str = ""
213
+
214
+ # Status
215
+ status_message: str = ""
216
+ success: bool = True
217
+ error: Optional[str] = None
218
+
219
+ def to_dict(self) -> Dict[str, Any]:
220
+ """Convert result to dictionary for JSON serialization."""
221
+ return asdict(self)
222
+
223
+
224
  def _update_metadata_from_lm(
225
  metadata: Dict[str, Any],
226
  bpm: Optional[int],
 
665
  success=False,
666
  error=str(e),
667
  )
668
+
669
+
670
+ def understand_music(
671
+ llm_handler,
672
+ audio_codes: str,
673
+ temperature: float = 0.85,
674
+ cfg_scale: float = 1.0,
675
+ negative_prompt: str = "NO USER INPUT",
676
+ top_k: Optional[int] = None,
677
+ top_p: Optional[float] = None,
678
+ repetition_penalty: float = 1.0,
679
+ use_constrained_decoding: bool = True,
680
+ constrained_decoding_debug: bool = False,
681
+ ) -> UnderstandResult:
682
+ """Understand music from audio codes using the 5Hz Language Model.
683
+
684
+ This function analyzes audio semantic codes and generates metadata about the music,
685
+ including caption, lyrics, BPM, duration, key scale, language, and time signature.
686
+
687
+ If audio_codes is empty or "NO USER INPUT", the LM will generate a sample example
688
+ instead of analyzing existing codes.
689
+
690
+ Args:
691
+ llm_handler: Initialized LLM handler (LLMHandler instance)
692
+ audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...")
693
+ Use empty string or "NO USER INPUT" to generate a sample example.
694
+ temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
695
+ cfg_scale: Classifier-Free Guidance scale (1.0 = no CFG, >1.0 = use CFG)
696
+ negative_prompt: Negative prompt for CFG guidance
697
+ top_k: Top-K sampling (None or 0 = disabled)
698
+ top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
699
+ repetition_penalty: Repetition penalty (1.0 = no penalty)
700
+ use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
701
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
702
+
703
+ Returns:
704
+ UnderstandResult with parsed metadata fields and status
705
+
706
+ Example:
707
+ >>> result = understand_music(llm_handler, audio_codes="<|audio_code_123|>...")
708
+ >>> if result.success:
709
+ ... print(f"Caption: {result.caption}")
710
+ ... print(f"BPM: {result.bpm}")
711
+ ... print(f"Lyrics: {result.lyrics}")
712
+ """
713
+ # Check if LLM is initialized
714
+ if not llm_handler.llm_initialized:
715
+ return UnderstandResult(
716
+ status_message="5Hz LM not initialized. Please initialize it first.",
717
+ success=False,
718
+ error="LLM not initialized",
719
+ )
720
+
721
+ # If codes are empty, use "NO USER INPUT" to generate a sample example
722
+ if not audio_codes or not audio_codes.strip():
723
+ audio_codes = "NO USER INPUT"
724
+
725
+ try:
726
+ # Call LLM understanding
727
+ metadata, status = llm_handler.understand_audio_from_codes(
728
+ audio_codes=audio_codes,
729
+ temperature=temperature,
730
+ cfg_scale=cfg_scale,
731
+ negative_prompt=negative_prompt,
732
+ top_k=top_k,
733
+ top_p=top_p,
734
+ repetition_penalty=repetition_penalty,
735
+ use_constrained_decoding=use_constrained_decoding,
736
+ constrained_decoding_debug=constrained_decoding_debug,
737
+ )
738
+
739
+ # Check if LLM returned empty metadata (error case)
740
+ if not metadata:
741
+ return UnderstandResult(
742
+ status_message=status or "Failed to understand audio codes",
743
+ success=False,
744
+ error=status or "Empty metadata returned",
745
+ )
746
+
747
+ # Extract and convert fields
748
+ caption = metadata.get('caption', '')
749
+ lyrics = metadata.get('lyrics', '')
750
+ keyscale = metadata.get('keyscale', '')
751
+ language = metadata.get('language', metadata.get('vocal_language', ''))
752
+ timesignature = metadata.get('timesignature', '')
753
+
754
+ # Convert BPM to int
755
+ bpm = None
756
+ bpm_value = metadata.get('bpm')
757
+ if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
758
+ try:
759
+ bpm = int(bpm_value)
760
+ except (ValueError, TypeError):
761
+ pass
762
+
763
+ # Convert duration to float
764
+ duration = None
765
+ duration_value = metadata.get('duration')
766
+ if duration_value is not None and duration_value != 'N/A' and duration_value != '':
767
+ try:
768
+ duration = float(duration_value)
769
+ except (ValueError, TypeError):
770
+ pass
771
+
772
+ # Clean up N/A values
773
+ if keyscale == 'N/A':
774
+ keyscale = ''
775
+ if language == 'N/A':
776
+ language = ''
777
+ if timesignature == 'N/A':
778
+ timesignature = ''
779
+
780
+ return UnderstandResult(
781
+ caption=caption,
782
+ lyrics=lyrics,
783
+ bpm=bpm,
784
+ duration=duration,
785
+ keyscale=keyscale,
786
+ language=language,
787
+ timesignature=timesignature,
788
+ status_message=status,
789
+ success=True,
790
+ error=None,
791
+ )
792
+
793
+ except Exception as e:
794
+ logger.exception("Music understanding failed")
795
+ return UnderstandResult(
796
+ status_message=f"Error: {str(e)}",
797
+ success=False,
798
+ error=str(e),
799
+ )