Ace-Step-v1.5 / acestep /inference.py
ChuxiJ's picture
add docs and serve mode
1e0d19a
raw
history blame
48.5 kB
"""
ACE-Step Inference API Module
This module provides a standardized inference interface for music generation,
designed for third-party integration. It offers both a simplified API and
backward-compatible Gradio UI support.
"""
import math
import os
import tempfile
from typing import Optional, Union, List, Dict, Any, Tuple
from dataclasses import dataclass, field, asdict
from loguru import logger
from acestep.audio_utils import AudioSaver, generate_uuid_from_params
@dataclass
class GenerationParams:
"""Configuration for music generation parameters.
Attributes:
# Text Inputs
caption: A short text prompt describing the desired music (main prompt). < 512 characters
lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters
instrumental: If True, generate instrumental music regardless of lyrics.
# Music Metadata
bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300
keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor
timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection.
vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES
duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600
# Generation Parameters
inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model).
guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model.
seed: Integer seed for reproducibility. -1 means use random seed each time.
# Advanced DiT Parameters
use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
cfg_interval_start: Start ratio (0.0–1.0) to apply CFG.
cfg_interval_end: End ratio (0.0–1.0) to apply CFG.
shift: Timestep shift factor (default 1.0). When != 1.0, applies t = shift * t / (1 + (shift - 1) * t) to timesteps.
# Task-Specific Parameters
task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
reference_audio: Path to a reference audio file for style transfer or cover tasks.
src_audio: Path to a source audio file for audio-to-audio tasks.
audio_codes: Audio semantic codes as a string (advanced use, for code-control generation).
repainting_start: For repaint/lego tasks: start time in seconds for region to repaint.
repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end).
audio_cover_strength: Strength of reference audio/codes influence (range 0.0–1.0). set smaller (0.2) for style transfer tasks.
instruction: Optional task instruction prompt. If empty, auto-generated by system.
# 5Hz Language Model Parameters for CoT reasoning
thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes.
lm_temperature: Sampling temperature for the LLM (0.0–2.0). Higher = more creative/varied results.
lm_cfg_scale: Classifier-free guidance scale for the LLM.
lm_top_k: LLM top-k sampling (0 = disabled).
lm_top_p: LLM top-p nucleus sampling (1.0 = disabled).
lm_negative_prompt: Negative prompt to use for LLM (for control).
use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning.
use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning.
use_cot_language: Whether to let LLM detect vocal language via CoT.
"""
# Required Inputs
task_type: str = "text2music"
instruction: str = "Fill the audio semantic mask based on the given conditions:"
# Audio Uploads
reference_audio: Optional[str] = None
src_audio: Optional[str] = None
# LM Codes Hints
audio_codes: str = ""
# Text Inputs
caption: str = ""
lyrics: str = ""
instrumental: bool = False
# Metadata
vocal_language: str = "unknown"
bpm: Optional[int] = None
keyscale: str = ""
timesignature: str = ""
duration: float = -1.0
# Advanced Settings
inference_steps: int = 8
seed: int = -1
guidance_scale: float = 7.0
use_adg: bool = False
cfg_interval_start: float = 0.0
cfg_interval_end: float = 1.0
shift: float = 1.0
infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
# Custom timesteps (parsed from string like "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
# If provided, overrides inference_steps and shift
timesteps: Optional[List[float]] = None
repainting_start: float = 0.0
repainting_end: float = -1
audio_cover_strength: float = 1.0
# 5Hz Language Model Parameters
thinking: bool = True
lm_temperature: float = 0.85
lm_cfg_scale: float = 2.0
lm_top_k: int = 0
lm_top_p: float = 0.9
lm_negative_prompt: str = "NO USER INPUT"
use_cot_metas: bool = True
use_cot_caption: bool = True
use_cot_lyrics: bool = False # TODO: not used yet
use_cot_language: bool = True
use_constrained_decoding: bool = True
cot_bpm: Optional[int] = None
cot_keyscale: str = ""
cot_timesignature: str = ""
cot_duration: Optional[float] = None
cot_vocal_language: str = "unknown"
cot_caption: str = ""
cot_lyrics: str = ""
def to_dict(self) -> Dict[str, Any]:
"""Convert config to dictionary for JSON serialization."""
return asdict(self)
@dataclass
class GenerationConfig:
"""Configuration for music generation.
Attributes:
batch_size: Number of audio samples to generate
allow_lm_batch: Whether to allow batch processing in LM
use_random_seed: Whether to use random seed
seeds: Seed(s) for batch generation. Can be:
- None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
- List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
- int: Single seed value (will be converted to list and padded)
lm_batch_chunk_size: Batch chunk size for LM processing
constrained_decoding_debug: Whether to enable constrained decoding debug
audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
"""
batch_size: int = 2
allow_lm_batch: bool = False
use_random_seed: bool = True
seeds: Optional[List[int]] = None
lm_batch_chunk_size: int = 8
constrained_decoding_debug: bool = False
audio_format: str = "flac" # Default to FLAC for fast saving
def to_dict(self) -> Dict[str, Any]:
"""Convert config to dictionary for JSON serialization."""
return asdict(self)
@dataclass
class GenerationResult:
"""Result of music generation.
Attributes:
# Audio Outputs
audios: List of audio dictionaries with paths, keys, params
status_message: Status message from generation
extra_outputs: Extra outputs from generation
success: Whether generation completed successfully
error: Error message if generation failed
"""
# Audio Outputs
audios: List[Dict[str, Any]] = field(default_factory=list)
# Generation Information
status_message: str = ""
extra_outputs: Dict[str, Any] = field(default_factory=dict)
# Success Status
success: bool = True
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert result to dictionary for JSON serialization."""
return asdict(self)
@dataclass
class UnderstandResult:
"""Result of music understanding from audio codes.
Attributes:
# Metadata Fields
caption: Generated caption describing the music
lyrics: Generated or extracted lyrics
bpm: Beats per minute (None if not detected)
duration: Duration in seconds (None if not detected)
keyscale: Musical key (e.g., "C Major")
language: Vocal language code (e.g., "en", "zh")
timesignature: Time signature (e.g., "4/4")
# Status
status_message: Status message from understanding
success: Whether understanding completed successfully
error: Error message if understanding failed
"""
# Metadata Fields
caption: str = ""
lyrics: str = ""
bpm: Optional[int] = None
duration: Optional[float] = None
keyscale: str = ""
language: str = ""
timesignature: str = ""
# Status
status_message: str = ""
success: bool = True
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert result to dictionary for JSON serialization."""
return asdict(self)
def _update_metadata_from_lm(
metadata: Dict[str, Any],
bpm: Optional[int],
key_scale: str,
time_signature: str,
audio_duration: Optional[float],
vocal_language: str,
caption: str,
lyrics: str,
) -> Tuple[Optional[int], str, str, Optional[float]]:
"""Update metadata fields from LM output if not provided by user."""
if bpm is None and metadata.get('bpm'):
bpm_value = metadata.get('bpm')
if bpm_value not in ["N/A", ""]:
try:
bpm = int(bpm_value)
except (ValueError, TypeError):
pass
if not key_scale and metadata.get('keyscale'):
key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
if key_scale_value != "N/A":
key_scale = key_scale_value
if not time_signature and metadata.get('timesignature'):
time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
if time_signature_value != "N/A":
time_signature = time_signature_value
if audio_duration is None or audio_duration <= 0:
audio_duration_value = metadata.get('duration', -1)
if audio_duration_value not in ["N/A", ""]:
try:
audio_duration = float(audio_duration_value)
except (ValueError, TypeError):
pass
if not vocal_language and metadata.get('vocal_language'):
vocal_language = metadata.get('vocal_language')
if not caption and metadata.get('caption'):
caption = metadata.get('caption')
if not lyrics and metadata.get('lyrics'):
lyrics = metadata.get('lyrics')
return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
def generate_music(
dit_handler,
llm_handler,
params: GenerationParams,
config: GenerationConfig,
save_dir: Optional[str] = None,
progress=None,
) -> GenerationResult:
"""Generate music using ACE-Step model with optional LM reasoning.
Args:
dit_handler: Initialized DiT model handler (AceStepHandler instance)
llm_handler: Initialized LLM handler (LLMHandler instance)
params: Generation parameters (GenerationParams instance)
config: Generation configuration (GenerationConfig instance)
Returns:
GenerationResult with generated audio files and metadata
"""
try:
# Phase 1: LM-based metadata and code generation (if enabled)
audio_code_string_to_use = params.audio_codes
lm_generated_metadata = None
lm_generated_audio_codes_list = []
lm_total_time_costs = {
"phase1_time": 0.0,
"phase2_time": 0.0,
"total_time": 0.0,
}
# Extract mutable copies of metadata (will be updated by LM if needed)
bpm = params.bpm
key_scale = params.keyscale
time_signature = params.timesignature
audio_duration = params.duration
dit_input_caption = params.caption
dit_input_vocal_language = params.vocal_language
dit_input_lyrics = params.lyrics
# Determine if we need to generate audio codes
# If user has provided audio_codes, we don't need to generate them
# Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
# Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
# For now, we use "llm_dit" if batch mode or if user hasn't provided codes
# Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
# Note: This logic can be refined based on specific requirements
need_audio_codes = not user_provided_audio_codes
# Determine if we should use chunk-based LM generation (always use chunks for consistency)
# Determine actual batch size for chunk processing
actual_batch_size = config.batch_size if config.batch_size is not None else 1
# Prepare seeds for batch generation
# Use config.seed if provided, otherwise fallback to params.seed
# Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
seed_for_generation = ""
if config.seeds is not None and len(config.seeds) > 0:
if isinstance(config.seeds, list):
# Convert List[int] to comma-separated string
seed_for_generation = ",".join(str(s) for s in config.seeds)
# Use dit_handler.prepare_seeds to handle seed list generation and padding
# This will handle all the logic: padding with random seeds if needed, etc.
actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
# LM-based Chain-of-Thought reasoning
# Skip LM for cover/repaint tasks - these tasks use reference/src audio directly
# and don't need LM to generate audio codes
skip_lm_tasks = {"cover", "repaint"}
# Determine if we should use LLM
# LLM is needed for:
# 1. thinking=True: generate audio codes via LM
# 2. use_cot_caption=True: enhance/generate caption via CoT
# 3. use_cot_language=True: detect vocal language via CoT
# 4. use_cot_metas=True: fill missing metadata via CoT
need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas
use_lm = (params.thinking or need_lm_for_cot) and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks
lm_status = []
if params.task_type in skip_lm_tasks:
logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly")
logger.info(f"[generate_music] LLM usage decision: thinking={params.thinking}, "
f"use_cot_caption={params.use_cot_caption}, use_cot_language={params.use_cot_language}, "
f"use_cot_metas={params.use_cot_metas}, need_lm_for_cot={need_lm_for_cot}, "
f"llm_initialized={llm_handler.llm_initialized if llm_handler else False}, use_lm={use_lm}")
if use_lm:
# Convert sampling parameters - handle None values safely
top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p
# Build user_metadata from user-provided values
user_metadata = {}
if bpm is not None:
try:
bpm_value = float(bpm)
if bpm_value > 0:
user_metadata['bpm'] = int(bpm_value)
except (ValueError, TypeError):
pass
if key_scale and key_scale.strip():
key_scale_clean = key_scale.strip()
if key_scale_clean.lower() not in ["n/a", ""]:
user_metadata['keyscale'] = key_scale_clean
if time_signature and time_signature.strip():
time_sig_clean = time_signature.strip()
if time_sig_clean.lower() not in ["n/a", ""]:
user_metadata['timesignature'] = time_sig_clean
if audio_duration is not None:
try:
duration_value = float(audio_duration)
if duration_value > 0:
user_metadata['duration'] = int(duration_value)
except (ValueError, TypeError):
pass
user_metadata_to_pass = user_metadata if user_metadata else None
# Determine infer_type based on whether we need audio codes
# - "llm_dit": generates both metas and audio codes (two-phase internally)
# - "dit": generates only metas (single phase)
infer_type = "llm_dit" if need_audio_codes and params.thinking else "dit"
# Use chunk size from config, or default to batch_size if not set
max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
all_metadata_list = []
all_audio_codes_list = []
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * max_inference_batch_size
chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
chunk_size = chunk_end - chunk_start
chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
f"(size: {chunk_size}, seeds: {chunk_seeds})")
# Use the determined infer_type
# - "llm_dit" will internally run two phases (metas + codes)
# - "dit" will only run phase 1 (metas only)
result = llm_handler.generate_with_stop_condition(
caption=params.caption or "",
lyrics=params.lyrics or "",
infer_type=infer_type,
temperature=params.lm_temperature,
cfg_scale=params.lm_cfg_scale,
negative_prompt=params.lm_negative_prompt,
top_k=top_k_value,
top_p=top_p_value,
user_metadata=user_metadata_to_pass,
use_cot_caption=params.use_cot_caption,
use_cot_language=params.use_cot_language,
use_cot_metas=params.use_cot_metas,
use_constrained_decoding=params.use_constrained_decoding,
constrained_decoding_debug=config.constrained_decoding_debug,
batch_size=chunk_size,
seeds=chunk_seeds,
progress=progress,
)
# Check if LM generation failed
if not result.get("success", False):
error_msg = result.get("error", "Unknown LM error")
lm_status.append(f"❌ LM Error: {error_msg}")
# Return early with error
return GenerationResult(
audios=[],
status_message=f"❌ LM generation failed: {error_msg}",
extra_outputs={},
success=False,
error=error_msg,
)
# Extract metadata and audio_codes from result dict
if chunk_size > 1:
metadata_list = result.get("metadata", [])
audio_codes_list = result.get("audio_codes", [])
all_metadata_list.extend(metadata_list)
all_audio_codes_list.extend(audio_codes_list)
else:
metadata = result.get("metadata", {})
audio_codes = result.get("audio_codes", "")
all_metadata_list.append(metadata)
all_audio_codes_list.append(audio_codes)
# Collect time costs from LM extra_outputs
lm_extra = result.get("extra_outputs", {})
lm_chunk_time_costs = lm_extra.get("time_costs", {})
if lm_chunk_time_costs:
# Accumulate time costs from all chunks
for key in ["phase1_time", "phase2_time", "total_time"]:
if key in lm_chunk_time_costs:
lm_total_time_costs[key] += lm_chunk_time_costs[key]
time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()])
lm_status.append(f"✅ LM chunk {chunk_idx+1}: {time_str}")
lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
lm_generated_audio_codes_list = all_audio_codes_list
# Set audio_code_string_to_use based on infer_type
if infer_type == "llm_dit":
# If batch mode, use list; otherwise use single string
if actual_batch_size > 1:
audio_code_string_to_use = all_audio_codes_list
else:
audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else ""
else:
# For "dit" mode, keep user-provided codes or empty
audio_code_string_to_use = params.audio_codes
# Update metadata from LM if not provided by user
if lm_generated_metadata:
bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm(
metadata=lm_generated_metadata,
bpm=bpm,
key_scale=key_scale,
time_signature=time_signature,
audio_duration=audio_duration,
vocal_language=dit_input_vocal_language,
caption=dit_input_caption,
lyrics=dit_input_lyrics)
if not params.bpm:
params.cot_bpm = bpm
if not params.keyscale:
params.cot_keyscale = key_scale
if not params.timesignature:
params.cot_timesignature = time_signature
if not params.duration:
params.cot_duration = audio_duration
if not params.vocal_language:
params.cot_vocal_language = vocal_language
if not params.caption:
params.cot_caption = caption
if not params.lyrics:
params.cot_lyrics = lyrics
# set cot caption and language if needed
if params.use_cot_caption:
dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption)
if params.use_cot_language:
dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language)
# Phase 2: DiT music generation
# Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
result = dit_handler.generate_music(
captions=dit_input_caption,
lyrics=dit_input_lyrics,
bpm=bpm,
key_scale=key_scale,
time_signature=time_signature,
vocal_language=dit_input_vocal_language,
inference_steps=params.inference_steps,
guidance_scale=params.guidance_scale,
use_random_seed=config.use_random_seed,
seed=seed_for_generation, # Use config.seed (or params.seed fallback) instead of params.seed directly
reference_audio=params.reference_audio,
audio_duration=audio_duration,
batch_size=config.batch_size if config.batch_size is not None else 1,
src_audio=params.src_audio,
audio_code_string=audio_code_string_to_use,
repainting_start=params.repainting_start,
repainting_end=params.repainting_end,
instruction=params.instruction,
audio_cover_strength=params.audio_cover_strength,
task_type=params.task_type,
use_adg=params.use_adg,
cfg_interval_start=params.cfg_interval_start,
cfg_interval_end=params.cfg_interval_end,
shift=params.shift,
infer_method=params.infer_method,
timesteps=params.timesteps,
progress=progress,
)
# Check if generation failed
if not result.get("success", False):
return GenerationResult(
audios=[],
status_message=result.get("status_message", ""),
extra_outputs={},
success=False,
error=result.get("error"),
)
# Extract results from dit_handler.generate_music dict
dit_audios = result.get("audios", [])
status_message = result.get("status_message", "")
dit_extra_outputs = result.get("extra_outputs", {})
# Use the seed list already prepared above (from config.seed or params.seed fallback)
# actual_seed_list was computed earlier using dit_handler.prepare_seeds
seed_list = actual_seed_list
# Get base params dictionary
base_params_dict = params.to_dict()
# Save audio files using AudioSaver (format from config)
audio_format = config.audio_format if config.audio_format else "flac"
audio_saver = AudioSaver(default_format=audio_format)
# Use handler's temp_dir for saving files
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
# Build audios list for GenerationResult with params and save files
# Audio saving and UUID generation handled here, outside of handler
audios = []
for idx, dit_audio in enumerate(dit_audios):
# Create a copy of params dict for this audio
audio_params = base_params_dict.copy()
# Update audio-specific values
audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
# Add audio codes if batch mode
if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
# Get audio tensor and metadata
audio_tensor = dit_audio.get("tensor")
sample_rate = dit_audio.get("sample_rate", 48000)
# Generate UUID for this audio (moved from handler)
batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
audio_code_str = lm_generated_audio_codes_list[idx] if (
lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
if isinstance(audio_code_str, list):
audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
audio_key = generate_uuid_from_params(audio_params)
# Save audio file (handled outside handler)
audio_path = None
if audio_tensor is not None and save_dir is not None:
try:
audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
audio_path = audio_saver.save_audio(audio_tensor,
audio_file,
sample_rate=sample_rate,
format=audio_format,
channels_first=True)
except Exception as e:
logger.error(f"[generate_music] Failed to save audio file: {e}")
audio_path = "" # Fallback to empty path
audio_dict = {
"path": audio_path or "", # File path (saved here, not in handler)
"tensor": audio_tensor, # Audio tensor [channels, samples], CPU, float32
"key": audio_key,
"sample_rate": sample_rate,
"params": audio_params,
}
audios.append(audio_dict)
# Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
extra_outputs = dit_extra_outputs.copy()
extra_outputs["lm_metadata"] = lm_generated_metadata
# Merge time_costs from both LM and DiT into a unified dictionary
unified_time_costs = {}
# Add LM time costs (if LM was used)
if use_lm and lm_total_time_costs:
for key, value in lm_total_time_costs.items():
unified_time_costs[f"lm_{key}"] = value
# Add DiT time costs (if available)
dit_time_costs = dit_extra_outputs.get("time_costs", {})
if dit_time_costs:
for key, value in dit_time_costs.items():
unified_time_costs[f"dit_{key}"] = value
# Calculate total pipeline time
if unified_time_costs:
lm_total = unified_time_costs.get("lm_total_time", 0.0)
dit_total = unified_time_costs.get("dit_total_time_cost", 0.0)
unified_time_costs["pipeline_total_time"] = lm_total + dit_total
# Update extra_outputs with unified time_costs
extra_outputs["time_costs"] = unified_time_costs
if lm_status:
status_message = "\n".join(lm_status) + "\n" + status_message
else:
status_message = status_message
# Create and return GenerationResult
return GenerationResult(
audios=audios,
status_message=status_message,
extra_outputs=extra_outputs,
success=True,
error=None,
)
except Exception as e:
logger.exception("Music generation failed")
return GenerationResult(
audios=[],
status_message=f"Error: {str(e)}",
extra_outputs={},
success=False,
error=str(e),
)
def understand_music(
llm_handler,
audio_codes: str,
temperature: float = 0.85,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: float = 1.0,
use_constrained_decoding: bool = True,
constrained_decoding_debug: bool = False,
) -> UnderstandResult:
"""Understand music from audio codes using the 5Hz Language Model.
This function analyzes audio semantic codes and generates metadata about the music,
including caption, lyrics, BPM, duration, key scale, language, and time signature.
If audio_codes is empty or "NO USER INPUT", the LM will generate a sample example
instead of analyzing existing codes.
Note: cfg_scale and negative_prompt are not supported in understand mode.
Args:
llm_handler: Initialized LLM handler (LLMHandler instance)
audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...")
Use empty string or "NO USER INPUT" to generate a sample example.
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
top_k: Top-K sampling (None or 0 = disabled)
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
repetition_penalty: Repetition penalty (1.0 = no penalty)
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
Returns:
UnderstandResult with parsed metadata fields and status
Example:
>>> result = understand_music(llm_handler, audio_codes="<|audio_code_123|>...")
>>> if result.success:
... print(f"Caption: {result.caption}")
... print(f"BPM: {result.bpm}")
... print(f"Lyrics: {result.lyrics}")
"""
# Check if LLM is initialized
if not llm_handler.llm_initialized:
return UnderstandResult(
status_message="5Hz LM not initialized. Please initialize it first.",
success=False,
error="LLM not initialized",
)
# If codes are empty, use "NO USER INPUT" to generate a sample example
if not audio_codes or not audio_codes.strip():
audio_codes = "NO USER INPUT"
try:
# Call LLM understanding
metadata, status = llm_handler.understand_audio_from_codes(
audio_codes=audio_codes,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
)
# Check if LLM returned empty metadata (error case)
if not metadata:
return UnderstandResult(
status_message=status or "Failed to understand audio codes",
success=False,
error=status or "Empty metadata returned",
)
# Extract and convert fields
caption = metadata.get('caption', '')
lyrics = metadata.get('lyrics', '')
keyscale = metadata.get('keyscale', '')
language = metadata.get('language', metadata.get('vocal_language', ''))
timesignature = metadata.get('timesignature', '')
# Convert BPM to int
bpm = None
bpm_value = metadata.get('bpm')
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
try:
bpm = int(bpm_value)
except (ValueError, TypeError):
pass
# Convert duration to float
duration = None
duration_value = metadata.get('duration')
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
try:
duration = float(duration_value)
except (ValueError, TypeError):
pass
# Clean up N/A values
if keyscale == 'N/A':
keyscale = ''
if language == 'N/A':
language = ''
if timesignature == 'N/A':
timesignature = ''
return UnderstandResult(
caption=caption,
lyrics=lyrics,
bpm=bpm,
duration=duration,
keyscale=keyscale,
language=language,
timesignature=timesignature,
status_message=status,
success=True,
error=None,
)
except Exception as e:
logger.exception("Music understanding failed")
return UnderstandResult(
status_message=f"Error: {str(e)}",
success=False,
error=str(e),
)
@dataclass
class CreateSampleResult:
"""Result of creating a music sample from a natural language query.
This is used by the "Simple Mode" / "Inspiration Mode" feature where users
provide a natural language description and the LLM generates a complete
sample with caption, lyrics, and metadata.
Attributes:
# Metadata Fields
caption: Generated detailed music description/caption
lyrics: Generated lyrics (or "[Instrumental]" for instrumental music)
bpm: Beats per minute (None if not generated)
duration: Duration in seconds (None if not generated)
keyscale: Musical key (e.g., "C Major")
language: Vocal language code (e.g., "en", "zh")
timesignature: Time signature (e.g., "4")
instrumental: Whether this is an instrumental piece
# Status
status_message: Status message from sample creation
success: Whether sample creation completed successfully
error: Error message if sample creation failed
"""
# Metadata Fields
caption: str = ""
lyrics: str = ""
bpm: Optional[int] = None
duration: Optional[float] = None
keyscale: str = ""
language: str = ""
timesignature: str = ""
instrumental: bool = False
# Status
status_message: str = ""
success: bool = True
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert result to dictionary for JSON serialization."""
return asdict(self)
def create_sample(
llm_handler,
query: str,
instrumental: bool = False,
vocal_language: Optional[str] = None,
temperature: float = 0.85,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: float = 1.0,
use_constrained_decoding: bool = True,
constrained_decoding_debug: bool = False,
) -> CreateSampleResult:
"""Create a music sample from a natural language query using the 5Hz Language Model.
This is the "Simple Mode" / "Inspiration Mode" feature that takes a user's natural
language description of music and generates a complete sample including:
- Detailed caption/description
- Lyrics (unless instrumental)
- Metadata (BPM, duration, key, language, time signature)
Note: cfg_scale and negative_prompt are not supported in create_sample mode.
Args:
llm_handler: Initialized LLM handler (LLMHandler instance)
query: User's natural language music description (e.g., "a soft Bengali love song")
instrumental: Whether to generate instrumental music (no vocals)
vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh").
If provided, the model will be constrained to generate lyrics in this language.
If None or "unknown", no language constraint is applied.
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
top_k: Top-K sampling (None or 0 = disabled)
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
repetition_penalty: Repetition penalty (1.0 = no penalty)
use_constrained_decoding: Whether to use FSM-based constrained decoding
constrained_decoding_debug: Whether to enable debug logging
Returns:
CreateSampleResult with generated sample fields and status
Example:
>>> result = create_sample(llm_handler, "a soft Bengali love song for a quiet evening", vocal_language="bn")
>>> if result.success:
... print(f"Caption: {result.caption}")
... print(f"Lyrics: {result.lyrics}")
... print(f"BPM: {result.bpm}")
"""
# Check if LLM is initialized
if not llm_handler.llm_initialized:
return CreateSampleResult(
status_message="5Hz LM not initialized. Please initialize it first.",
success=False,
error="LLM not initialized",
)
try:
# Call LLM to create sample
metadata, status = llm_handler.create_sample_from_query(
query=query,
instrumental=instrumental,
vocal_language=vocal_language,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
)
# Check if LLM returned empty metadata (error case)
if not metadata:
return CreateSampleResult(
status_message=status or "Failed to create sample",
success=False,
error=status or "Empty metadata returned",
)
# Extract and convert fields
caption = metadata.get('caption', '')
lyrics = metadata.get('lyrics', '')
keyscale = metadata.get('keyscale', '')
language = metadata.get('language', metadata.get('vocal_language', ''))
timesignature = metadata.get('timesignature', '')
is_instrumental = metadata.get('instrumental', instrumental)
# Convert BPM to int
bpm = None
bpm_value = metadata.get('bpm')
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
try:
bpm = int(bpm_value)
except (ValueError, TypeError):
pass
# Convert duration to float
duration = None
duration_value = metadata.get('duration')
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
try:
duration = float(duration_value)
except (ValueError, TypeError):
pass
# Clean up N/A values
if keyscale == 'N/A':
keyscale = ''
if language == 'N/A':
language = ''
if timesignature == 'N/A':
timesignature = ''
return CreateSampleResult(
caption=caption,
lyrics=lyrics,
bpm=bpm,
duration=duration,
keyscale=keyscale,
language=language,
timesignature=timesignature,
instrumental=is_instrumental,
status_message=status,
success=True,
error=None,
)
except Exception as e:
logger.exception("Sample creation failed")
return CreateSampleResult(
status_message=f"Error: {str(e)}",
success=False,
error=str(e),
)
@dataclass
class FormatSampleResult:
"""Result of formatting user-provided caption and lyrics.
This is used by the "Format" feature where users provide caption and lyrics,
and the LLM formats them into structured music metadata and an enhanced description.
Attributes:
# Metadata Fields
caption: Enhanced/formatted music description/caption
lyrics: Formatted lyrics (may be same as input or reformatted)
bpm: Beats per minute (None if not detected)
duration: Duration in seconds (None if not detected)
keyscale: Musical key (e.g., "C Major")
language: Vocal language code (e.g., "en", "zh")
timesignature: Time signature (e.g., "4")
# Status
status_message: Status message from formatting
success: Whether formatting completed successfully
error: Error message if formatting failed
"""
# Metadata Fields
caption: str = ""
lyrics: str = ""
bpm: Optional[int] = None
duration: Optional[float] = None
keyscale: str = ""
language: str = ""
timesignature: str = ""
# Status
status_message: str = ""
success: bool = True
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert result to dictionary for JSON serialization."""
return asdict(self)
def format_sample(
llm_handler,
caption: str,
lyrics: str,
user_metadata: Optional[Dict[str, Any]] = None,
temperature: float = 0.85,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: float = 1.0,
use_constrained_decoding: bool = True,
constrained_decoding_debug: bool = False,
) -> FormatSampleResult:
"""Format user-provided caption and lyrics using the 5Hz Language Model.
This function takes user input (caption and lyrics) and generates structured
music metadata including an enhanced caption, BPM, duration, key, language,
and time signature.
If user_metadata is provided, those values will be used to constrain the
decoding, ensuring the output matches user-specified values.
Note: cfg_scale and negative_prompt are not supported in format mode.
Args:
llm_handler: Initialized LLM handler (LLMHandler instance)
caption: User's caption/description (e.g., "Latin pop, reggaeton")
lyrics: User's lyrics with structure tags
user_metadata: Optional dict with user-provided metadata to constrain decoding.
Supported keys: bpm, duration, keyscale, timesignature, language
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
top_k: Top-K sampling (None or 0 = disabled)
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
repetition_penalty: Repetition penalty (1.0 = no penalty)
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
Returns:
FormatSampleResult with formatted metadata fields and status
Example:
>>> result = format_sample(llm_handler, "Latin pop, reggaeton", "[Verse 1]\\nHola mundo...")
>>> if result.success:
... print(f"Caption: {result.caption}")
... print(f"BPM: {result.bpm}")
... print(f"Lyrics: {result.lyrics}")
"""
# Check if LLM is initialized
if not llm_handler.llm_initialized:
return FormatSampleResult(
status_message="5Hz LM not initialized. Please initialize it first.",
success=False,
error="LLM not initialized",
)
try:
# Call LLM formatting
metadata, status = llm_handler.format_sample_from_input(
caption=caption,
lyrics=lyrics,
user_metadata=user_metadata,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
)
# Check if LLM returned empty metadata (error case)
if not metadata:
return FormatSampleResult(
status_message=status or "Failed to format input",
success=False,
error=status or "Empty metadata returned",
)
# Extract and convert fields
result_caption = metadata.get('caption', '')
result_lyrics = metadata.get('lyrics', lyrics) # Fall back to input lyrics
keyscale = metadata.get('keyscale', '')
language = metadata.get('language', metadata.get('vocal_language', ''))
timesignature = metadata.get('timesignature', '')
# Convert BPM to int
bpm = None
bpm_value = metadata.get('bpm')
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
try:
bpm = int(bpm_value)
except (ValueError, TypeError):
pass
# Convert duration to float
duration = None
duration_value = metadata.get('duration')
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
try:
duration = float(duration_value)
except (ValueError, TypeError):
pass
# Clean up N/A values
if keyscale == 'N/A':
keyscale = ''
if language == 'N/A':
language = ''
if timesignature == 'N/A':
timesignature = ''
return FormatSampleResult(
caption=result_caption,
lyrics=result_lyrics,
bpm=bpm,
duration=duration,
keyscale=keyscale,
language=language,
timesignature=timesignature,
status_message=status,
success=True,
error=None,
)
except Exception as e:
logger.exception("Format sample failed")
return FormatSampleResult(
status_message=f"Error: {str(e)}",
success=False,
error=str(e),
)