Spaces:
Running
on
A100
Running
on
A100
Merge pull request #5 from ace-step/fix_transcribe_audio_codes
Browse files- acestep/gradio_ui/events/generation_handlers.py +40 -76
- acestep/inference.py +170 -0
acestep/gradio_ui/events/generation_handlers.py
CHANGED
|
@@ -13,6 +13,7 @@ from acestep.constants import (
|
|
| 13 |
TASK_TYPES_BASE,
|
| 14 |
)
|
| 15 |
from acestep.gradio_ui.i18n import t
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def load_metadata(file_obj):
|
|
@@ -206,6 +207,9 @@ def load_random_example(task_type: str):
|
|
| 206 |
def sample_example_smart(llm_handler, task_type: str, constrained_decoding_debug: bool = False):
|
| 207 |
"""Smart sample function that uses LM if initialized, otherwise falls back to examples
|
| 208 |
|
|
|
|
|
|
|
|
|
|
| 209 |
Args:
|
| 210 |
llm_handler: LLM handler instance
|
| 211 |
task_type: The task type (e.g., "text2music")
|
|
@@ -216,50 +220,28 @@ def sample_example_smart(llm_handler, task_type: str, constrained_decoding_debug
|
|
| 216 |
"""
|
| 217 |
# Check if LM is initialized
|
| 218 |
if llm_handler.llm_initialized:
|
| 219 |
-
# Use LM to generate example
|
| 220 |
try:
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
audio_codes="NO USER INPUT",
|
| 224 |
-
use_constrained_decoding=True,
|
| 225 |
temperature=0.85,
|
|
|
|
| 226 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 227 |
)
|
| 228 |
|
| 229 |
-
if
|
| 230 |
-
caption_value = metadata.get('caption', '')
|
| 231 |
-
lyrics_value = metadata.get('lyrics', '')
|
| 232 |
-
think_value = True # Always enable think when using LM-generated examples
|
| 233 |
-
|
| 234 |
-
# Extract optional metadata fields
|
| 235 |
-
bpm_value = None
|
| 236 |
-
if 'bpm' in metadata and metadata['bpm'] not in [None, "N/A", ""]:
|
| 237 |
-
try:
|
| 238 |
-
bpm_value = int(metadata['bpm'])
|
| 239 |
-
except (ValueError, TypeError):
|
| 240 |
-
pass
|
| 241 |
-
|
| 242 |
-
duration_value = None
|
| 243 |
-
if 'duration' in metadata and metadata['duration'] not in [None, "N/A", ""]:
|
| 244 |
-
try:
|
| 245 |
-
duration_value = float(metadata['duration'])
|
| 246 |
-
except (ValueError, TypeError):
|
| 247 |
-
pass
|
| 248 |
-
|
| 249 |
-
keyscale_value = metadata.get('keyscale', '')
|
| 250 |
-
if keyscale_value in [None, "N/A"]:
|
| 251 |
-
keyscale_value = ''
|
| 252 |
-
|
| 253 |
-
language_value = metadata.get('language', '')
|
| 254 |
-
if language_value in [None, "N/A"]:
|
| 255 |
-
language_value = ''
|
| 256 |
-
|
| 257 |
-
timesignature_value = metadata.get('timesignature', '')
|
| 258 |
-
if timesignature_value in [None, "N/A"]:
|
| 259 |
-
timesignature_value = ''
|
| 260 |
-
|
| 261 |
gr.Info(t("messages.lm_generated"))
|
| 262 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
else:
|
| 264 |
gr.Warning(t("messages.lm_fallback"))
|
| 265 |
return load_random_example(task_type)
|
|
@@ -437,58 +419,40 @@ def transcribe_audio_codes(llm_handler, audio_code_string, constrained_decoding_
|
|
| 437 |
Transcribe audio codes to metadata using LLM understanding.
|
| 438 |
If audio_code_string is empty, generate a sample example instead.
|
| 439 |
|
|
|
|
|
|
|
| 440 |
Args:
|
| 441 |
llm_handler: LLM handler instance
|
| 442 |
audio_code_string: String containing audio codes (or empty for example generation)
|
| 443 |
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 444 |
|
| 445 |
Returns:
|
| 446 |
-
Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature)
|
| 447 |
"""
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
# If codes are empty, this becomes a "generate example" task
|
| 452 |
-
# Use "NO USER INPUT" as the input to generate a sample
|
| 453 |
-
if not audio_code_string or not audio_code_string.strip():
|
| 454 |
-
audio_code_string = "NO USER INPUT"
|
| 455 |
-
|
| 456 |
-
# Call LLM understanding
|
| 457 |
-
metadata, status = llm_handler.understand_audio_from_codes(
|
| 458 |
audio_codes=audio_code_string,
|
| 459 |
use_constrained_decoding=True,
|
| 460 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 461 |
)
|
| 462 |
|
| 463 |
-
#
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
language = metadata.get('language', '')
|
| 470 |
-
timesignature = metadata.get('timesignature', '')
|
| 471 |
-
|
| 472 |
-
# Convert to appropriate types
|
| 473 |
-
try:
|
| 474 |
-
bpm = int(bpm) if bpm and bpm != 'N/A' else None
|
| 475 |
-
except:
|
| 476 |
-
bpm = None
|
| 477 |
-
|
| 478 |
-
try:
|
| 479 |
-
duration = float(duration) if duration and duration != 'N/A' else None
|
| 480 |
-
except:
|
| 481 |
-
duration = None
|
| 482 |
|
| 483 |
return (
|
| 484 |
-
|
| 485 |
-
caption,
|
| 486 |
-
lyrics,
|
| 487 |
-
bpm,
|
| 488 |
-
duration,
|
| 489 |
-
keyscale,
|
| 490 |
-
language,
|
| 491 |
-
timesignature,
|
| 492 |
True # Set is_format_caption to True (from Transcribe/LM understanding)
|
| 493 |
)
|
| 494 |
|
|
|
|
| 13 |
TASK_TYPES_BASE,
|
| 14 |
)
|
| 15 |
from acestep.gradio_ui.i18n import t
|
| 16 |
+
from acestep.inference import understand_music
|
| 17 |
|
| 18 |
|
| 19 |
def load_metadata(file_obj):
|
|
|
|
| 207 |
def sample_example_smart(llm_handler, task_type: str, constrained_decoding_debug: bool = False):
|
| 208 |
"""Smart sample function that uses LM if initialized, otherwise falls back to examples
|
| 209 |
|
| 210 |
+
This is a Gradio wrapper that uses the understand_music API from acestep.inference
|
| 211 |
+
to generate examples when LM is available.
|
| 212 |
+
|
| 213 |
Args:
|
| 214 |
llm_handler: LLM handler instance
|
| 215 |
task_type: The task type (e.g., "text2music")
|
|
|
|
| 220 |
"""
|
| 221 |
# Check if LM is initialized
|
| 222 |
if llm_handler.llm_initialized:
|
| 223 |
+
# Use LM to generate example via understand_music API
|
| 224 |
try:
|
| 225 |
+
result = understand_music(
|
| 226 |
+
llm_handler=llm_handler,
|
| 227 |
+
audio_codes="NO USER INPUT", # Empty input triggers example generation
|
|
|
|
| 228 |
temperature=0.85,
|
| 229 |
+
use_constrained_decoding=True,
|
| 230 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 231 |
)
|
| 232 |
|
| 233 |
+
if result.success:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
gr.Info(t("messages.lm_generated"))
|
| 235 |
+
return (
|
| 236 |
+
result.caption,
|
| 237 |
+
result.lyrics,
|
| 238 |
+
True, # Always enable think when using LM-generated examples
|
| 239 |
+
result.bpm,
|
| 240 |
+
result.duration,
|
| 241 |
+
result.keyscale,
|
| 242 |
+
result.language,
|
| 243 |
+
result.timesignature,
|
| 244 |
+
)
|
| 245 |
else:
|
| 246 |
gr.Warning(t("messages.lm_fallback"))
|
| 247 |
return load_random_example(task_type)
|
|
|
|
| 419 |
Transcribe audio codes to metadata using LLM understanding.
|
| 420 |
If audio_code_string is empty, generate a sample example instead.
|
| 421 |
|
| 422 |
+
This is a Gradio wrapper around the understand_music API in acestep.inference.
|
| 423 |
+
|
| 424 |
Args:
|
| 425 |
llm_handler: LLM handler instance
|
| 426 |
audio_code_string: String containing audio codes (or empty for example generation)
|
| 427 |
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 428 |
|
| 429 |
Returns:
|
| 430 |
+
Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature, is_format_caption)
|
| 431 |
"""
|
| 432 |
+
# Call the inference API
|
| 433 |
+
result = understand_music(
|
| 434 |
+
llm_handler=llm_handler,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
audio_codes=audio_code_string,
|
| 436 |
use_constrained_decoding=True,
|
| 437 |
constrained_decoding_debug=constrained_decoding_debug,
|
| 438 |
)
|
| 439 |
|
| 440 |
+
# Handle error case with localized message
|
| 441 |
+
if not result.success:
|
| 442 |
+
# Use localized error message for LLM not initialized
|
| 443 |
+
if result.error == "LLM not initialized":
|
| 444 |
+
return t("messages.lm_not_initialized"), "", "", None, None, "", "", "", False
|
| 445 |
+
return result.status_message, "", "", None, None, "", "", "", False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
|
| 447 |
return (
|
| 448 |
+
result.status_message,
|
| 449 |
+
result.caption,
|
| 450 |
+
result.lyrics,
|
| 451 |
+
result.bpm,
|
| 452 |
+
result.duration,
|
| 453 |
+
result.keyscale,
|
| 454 |
+
result.language,
|
| 455 |
+
result.timesignature,
|
| 456 |
True # Set is_format_caption to True (from Transcribe/LM understanding)
|
| 457 |
)
|
| 458 |
|
acestep/inference.py
CHANGED
|
@@ -183,6 +183,44 @@ class GenerationResult:
|
|
| 183 |
return asdict(self)
|
| 184 |
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
def _update_metadata_from_lm(
|
| 187 |
metadata: Dict[str, Any],
|
| 188 |
bpm: Optional[int],
|
|
@@ -627,3 +665,135 @@ def generate_music(
|
|
| 627 |
success=False,
|
| 628 |
error=str(e),
|
| 629 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
return asdict(self)
|
| 184 |
|
| 185 |
|
| 186 |
+
@dataclass
|
| 187 |
+
class UnderstandResult:
|
| 188 |
+
"""Result of music understanding from audio codes.
|
| 189 |
+
|
| 190 |
+
Attributes:
|
| 191 |
+
# Metadata Fields
|
| 192 |
+
caption: Generated caption describing the music
|
| 193 |
+
lyrics: Generated or extracted lyrics
|
| 194 |
+
bpm: Beats per minute (None if not detected)
|
| 195 |
+
duration: Duration in seconds (None if not detected)
|
| 196 |
+
keyscale: Musical key (e.g., "C Major")
|
| 197 |
+
language: Vocal language code (e.g., "en", "zh")
|
| 198 |
+
timesignature: Time signature (e.g., "4/4")
|
| 199 |
+
|
| 200 |
+
# Status
|
| 201 |
+
status_message: Status message from understanding
|
| 202 |
+
success: Whether understanding completed successfully
|
| 203 |
+
error: Error message if understanding failed
|
| 204 |
+
"""
|
| 205 |
+
# Metadata Fields
|
| 206 |
+
caption: str = ""
|
| 207 |
+
lyrics: str = ""
|
| 208 |
+
bpm: Optional[int] = None
|
| 209 |
+
duration: Optional[float] = None
|
| 210 |
+
keyscale: str = ""
|
| 211 |
+
language: str = ""
|
| 212 |
+
timesignature: str = ""
|
| 213 |
+
|
| 214 |
+
# Status
|
| 215 |
+
status_message: str = ""
|
| 216 |
+
success: bool = True
|
| 217 |
+
error: Optional[str] = None
|
| 218 |
+
|
| 219 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 220 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 221 |
+
return asdict(self)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
def _update_metadata_from_lm(
|
| 225 |
metadata: Dict[str, Any],
|
| 226 |
bpm: Optional[int],
|
|
|
|
| 665 |
success=False,
|
| 666 |
error=str(e),
|
| 667 |
)
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
def understand_music(
|
| 671 |
+
llm_handler,
|
| 672 |
+
audio_codes: str,
|
| 673 |
+
temperature: float = 0.85,
|
| 674 |
+
cfg_scale: float = 1.0,
|
| 675 |
+
negative_prompt: str = "NO USER INPUT",
|
| 676 |
+
top_k: Optional[int] = None,
|
| 677 |
+
top_p: Optional[float] = None,
|
| 678 |
+
repetition_penalty: float = 1.0,
|
| 679 |
+
use_constrained_decoding: bool = True,
|
| 680 |
+
constrained_decoding_debug: bool = False,
|
| 681 |
+
) -> UnderstandResult:
|
| 682 |
+
"""Understand music from audio codes using the 5Hz Language Model.
|
| 683 |
+
|
| 684 |
+
This function analyzes audio semantic codes and generates metadata about the music,
|
| 685 |
+
including caption, lyrics, BPM, duration, key scale, language, and time signature.
|
| 686 |
+
|
| 687 |
+
If audio_codes is empty or "NO USER INPUT", the LM will generate a sample example
|
| 688 |
+
instead of analyzing existing codes.
|
| 689 |
+
|
| 690 |
+
Args:
|
| 691 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 692 |
+
audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...")
|
| 693 |
+
Use empty string or "NO USER INPUT" to generate a sample example.
|
| 694 |
+
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
|
| 695 |
+
cfg_scale: Classifier-Free Guidance scale (1.0 = no CFG, >1.0 = use CFG)
|
| 696 |
+
negative_prompt: Negative prompt for CFG guidance
|
| 697 |
+
top_k: Top-K sampling (None or 0 = disabled)
|
| 698 |
+
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
|
| 699 |
+
repetition_penalty: Repetition penalty (1.0 = no penalty)
|
| 700 |
+
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
|
| 701 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 702 |
+
|
| 703 |
+
Returns:
|
| 704 |
+
UnderstandResult with parsed metadata fields and status
|
| 705 |
+
|
| 706 |
+
Example:
|
| 707 |
+
>>> result = understand_music(llm_handler, audio_codes="<|audio_code_123|>...")
|
| 708 |
+
>>> if result.success:
|
| 709 |
+
... print(f"Caption: {result.caption}")
|
| 710 |
+
... print(f"BPM: {result.bpm}")
|
| 711 |
+
... print(f"Lyrics: {result.lyrics}")
|
| 712 |
+
"""
|
| 713 |
+
# Check if LLM is initialized
|
| 714 |
+
if not llm_handler.llm_initialized:
|
| 715 |
+
return UnderstandResult(
|
| 716 |
+
status_message="5Hz LM not initialized. Please initialize it first.",
|
| 717 |
+
success=False,
|
| 718 |
+
error="LLM not initialized",
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
# If codes are empty, use "NO USER INPUT" to generate a sample example
|
| 722 |
+
if not audio_codes or not audio_codes.strip():
|
| 723 |
+
audio_codes = "NO USER INPUT"
|
| 724 |
+
|
| 725 |
+
try:
|
| 726 |
+
# Call LLM understanding
|
| 727 |
+
metadata, status = llm_handler.understand_audio_from_codes(
|
| 728 |
+
audio_codes=audio_codes,
|
| 729 |
+
temperature=temperature,
|
| 730 |
+
cfg_scale=cfg_scale,
|
| 731 |
+
negative_prompt=negative_prompt,
|
| 732 |
+
top_k=top_k,
|
| 733 |
+
top_p=top_p,
|
| 734 |
+
repetition_penalty=repetition_penalty,
|
| 735 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 736 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
# Check if LLM returned empty metadata (error case)
|
| 740 |
+
if not metadata:
|
| 741 |
+
return UnderstandResult(
|
| 742 |
+
status_message=status or "Failed to understand audio codes",
|
| 743 |
+
success=False,
|
| 744 |
+
error=status or "Empty metadata returned",
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
# Extract and convert fields
|
| 748 |
+
caption = metadata.get('caption', '')
|
| 749 |
+
lyrics = metadata.get('lyrics', '')
|
| 750 |
+
keyscale = metadata.get('keyscale', '')
|
| 751 |
+
language = metadata.get('language', metadata.get('vocal_language', ''))
|
| 752 |
+
timesignature = metadata.get('timesignature', '')
|
| 753 |
+
|
| 754 |
+
# Convert BPM to int
|
| 755 |
+
bpm = None
|
| 756 |
+
bpm_value = metadata.get('bpm')
|
| 757 |
+
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
|
| 758 |
+
try:
|
| 759 |
+
bpm = int(bpm_value)
|
| 760 |
+
except (ValueError, TypeError):
|
| 761 |
+
pass
|
| 762 |
+
|
| 763 |
+
# Convert duration to float
|
| 764 |
+
duration = None
|
| 765 |
+
duration_value = metadata.get('duration')
|
| 766 |
+
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
|
| 767 |
+
try:
|
| 768 |
+
duration = float(duration_value)
|
| 769 |
+
except (ValueError, TypeError):
|
| 770 |
+
pass
|
| 771 |
+
|
| 772 |
+
# Clean up N/A values
|
| 773 |
+
if keyscale == 'N/A':
|
| 774 |
+
keyscale = ''
|
| 775 |
+
if language == 'N/A':
|
| 776 |
+
language = ''
|
| 777 |
+
if timesignature == 'N/A':
|
| 778 |
+
timesignature = ''
|
| 779 |
+
|
| 780 |
+
return UnderstandResult(
|
| 781 |
+
caption=caption,
|
| 782 |
+
lyrics=lyrics,
|
| 783 |
+
bpm=bpm,
|
| 784 |
+
duration=duration,
|
| 785 |
+
keyscale=keyscale,
|
| 786 |
+
language=language,
|
| 787 |
+
timesignature=timesignature,
|
| 788 |
+
status_message=status,
|
| 789 |
+
success=True,
|
| 790 |
+
error=None,
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
except Exception as e:
|
| 794 |
+
logger.exception("Music understanding failed")
|
| 795 |
+
return UnderstandResult(
|
| 796 |
+
status_message=f"Error: {str(e)}",
|
| 797 |
+
success=False,
|
| 798 |
+
error=str(e),
|
| 799 |
+
)
|