""" ACE-Step Inference API Module This module provides a standardized inference interface for music generation, designed for third-party integration. It offers both a simplified API and backward-compatible Gradio UI support. """ import math import os import tempfile from typing import Optional, Union, List, Dict, Any, Tuple from dataclasses import dataclass, field, asdict from loguru import logger from acestep.audio_utils import AudioSaver, generate_uuid_from_params @dataclass class GenerationParams: """Configuration for music generation parameters. Attributes: # Text Inputs caption: A short text prompt describing the desired music (main prompt). < 512 characters lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters instrumental: If True, generate instrumental music regardless of lyrics. # Music Metadata bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300 keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection. vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600 # Generation Parameters inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model). guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model. seed: Integer seed for reproducibility. -1 means use random seed each time. # Advanced DiT Parameters use_adg: Whether to use Adaptive Dual Guidance (only works for base model). cfg_interval_start: Start ratio (0.0–1.0) to apply CFG. cfg_interval_end: End ratio (0.0–1.0) to apply CFG. shift: Timestep shift factor (default 1.0). When != 1.0, applies t = shift * t / (1 + (shift - 1) * t) to timesteps. # Task-Specific Parameters task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete". reference_audio: Path to a reference audio file for style transfer or cover tasks. src_audio: Path to a source audio file for audio-to-audio tasks. audio_codes: Audio semantic codes as a string (advanced use, for code-control generation). repainting_start: For repaint/lego tasks: start time in seconds for region to repaint. repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end). audio_cover_strength: Strength of reference audio/codes influence (range 0.0–1.0). set smaller (0.2) for style transfer tasks. instruction: Optional task instruction prompt. If empty, auto-generated by system. # 5Hz Language Model Parameters for CoT reasoning thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes. lm_temperature: Sampling temperature for the LLM (0.0–2.0). Higher = more creative/varied results. lm_cfg_scale: Classifier-free guidance scale for the LLM. lm_top_k: LLM top-k sampling (0 = disabled). lm_top_p: LLM top-p nucleus sampling (1.0 = disabled). lm_negative_prompt: Negative prompt to use for LLM (for control). use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning. use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning. use_cot_language: Whether to let LLM detect vocal language via CoT. """ # Required Inputs task_type: str = "text2music" instruction: str = "Fill the audio semantic mask based on the given conditions:" # Audio Uploads reference_audio: Optional[str] = None src_audio: Optional[str] = None # LM Codes Hints audio_codes: str = "" # Text Inputs caption: str = "" lyrics: str = "" instrumental: bool = False # Metadata vocal_language: str = "unknown" bpm: Optional[int] = None keyscale: str = "" timesignature: str = "" duration: float = -1.0 # Advanced Settings inference_steps: int = 8 seed: int = -1 guidance_scale: float = 7.0 use_adg: bool = False cfg_interval_start: float = 0.0 cfg_interval_end: float = 1.0 shift: float = 1.0 infer_method: str = "ode" # "ode" or "sde" - diffusion inference method # Custom timesteps (parsed from string like "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0") # If provided, overrides inference_steps and shift timesteps: Optional[List[float]] = None repainting_start: float = 0.0 repainting_end: float = -1 audio_cover_strength: float = 1.0 # 5Hz Language Model Parameters thinking: bool = True lm_temperature: float = 0.85 lm_cfg_scale: float = 2.0 lm_top_k: int = 0 lm_top_p: float = 0.9 lm_negative_prompt: str = "NO USER INPUT" use_cot_metas: bool = True use_cot_caption: bool = True use_cot_lyrics: bool = False # TODO: not used yet use_cot_language: bool = True use_constrained_decoding: bool = True cot_bpm: Optional[int] = None cot_keyscale: str = "" cot_timesignature: str = "" cot_duration: Optional[float] = None cot_vocal_language: str = "unknown" cot_caption: str = "" cot_lyrics: str = "" def to_dict(self) -> Dict[str, Any]: """Convert config to dictionary for JSON serialization.""" return asdict(self) @dataclass class GenerationConfig: """Configuration for music generation. Attributes: batch_size: Number of audio samples to generate allow_lm_batch: Whether to allow batch processing in LM use_random_seed: Whether to use random seed seeds: Seed(s) for batch generation. Can be: - None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False) - List[int]: List of seeds, will be padded with random seeds if fewer than batch_size - int: Single seed value (will be converted to list and padded) lm_batch_chunk_size: Batch chunk size for LM processing constrained_decoding_debug: Whether to enable constrained decoding debug audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac" """ batch_size: int = 2 allow_lm_batch: bool = False use_random_seed: bool = True seeds: Optional[List[int]] = None lm_batch_chunk_size: int = 8 constrained_decoding_debug: bool = False audio_format: str = "flac" # Default to FLAC for fast saving def to_dict(self) -> Dict[str, Any]: """Convert config to dictionary for JSON serialization.""" return asdict(self) @dataclass class GenerationResult: """Result of music generation. Attributes: # Audio Outputs audios: List of audio dictionaries with paths, keys, params status_message: Status message from generation extra_outputs: Extra outputs from generation success: Whether generation completed successfully error: Error message if generation failed """ # Audio Outputs audios: List[Dict[str, Any]] = field(default_factory=list) # Generation Information status_message: str = "" extra_outputs: Dict[str, Any] = field(default_factory=dict) # Success Status success: bool = True error: Optional[str] = None def to_dict(self) -> Dict[str, Any]: """Convert result to dictionary for JSON serialization.""" return asdict(self) @dataclass class UnderstandResult: """Result of music understanding from audio codes. Attributes: # Metadata Fields caption: Generated caption describing the music lyrics: Generated or extracted lyrics bpm: Beats per minute (None if not detected) duration: Duration in seconds (None if not detected) keyscale: Musical key (e.g., "C Major") language: Vocal language code (e.g., "en", "zh") timesignature: Time signature (e.g., "4/4") # Status status_message: Status message from understanding success: Whether understanding completed successfully error: Error message if understanding failed """ # Metadata Fields caption: str = "" lyrics: str = "" bpm: Optional[int] = None duration: Optional[float] = None keyscale: str = "" language: str = "" timesignature: str = "" # Status status_message: str = "" success: bool = True error: Optional[str] = None def to_dict(self) -> Dict[str, Any]: """Convert result to dictionary for JSON serialization.""" return asdict(self) def _update_metadata_from_lm( metadata: Dict[str, Any], bpm: Optional[int], key_scale: str, time_signature: str, audio_duration: Optional[float], vocal_language: str, caption: str, lyrics: str, ) -> Tuple[Optional[int], str, str, Optional[float]]: """Update metadata fields from LM output if not provided by user.""" if bpm is None and metadata.get('bpm'): bpm_value = metadata.get('bpm') if bpm_value not in ["N/A", ""]: try: bpm = int(bpm_value) except (ValueError, TypeError): pass if not key_scale and metadata.get('keyscale'): key_scale_value = metadata.get('keyscale', metadata.get('key_scale', "")) if key_scale_value != "N/A": key_scale = key_scale_value if not time_signature and metadata.get('timesignature'): time_signature_value = metadata.get('timesignature', metadata.get('time_signature', "")) if time_signature_value != "N/A": time_signature = time_signature_value if audio_duration is None or audio_duration <= 0: audio_duration_value = metadata.get('duration', -1) if audio_duration_value not in ["N/A", ""]: try: audio_duration = float(audio_duration_value) except (ValueError, TypeError): pass if not vocal_language and metadata.get('vocal_language'): vocal_language = metadata.get('vocal_language') if not caption and metadata.get('caption'): caption = metadata.get('caption') if not lyrics and metadata.get('lyrics'): lyrics = metadata.get('lyrics') return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics def generate_music( dit_handler, llm_handler, params: GenerationParams, config: GenerationConfig, save_dir: Optional[str] = None, progress=None, ) -> GenerationResult: """Generate music using ACE-Step model with optional LM reasoning. Args: dit_handler: Initialized DiT model handler (AceStepHandler instance) llm_handler: Initialized LLM handler (LLMHandler instance) params: Generation parameters (GenerationParams instance) config: Generation configuration (GenerationConfig instance) Returns: GenerationResult with generated audio files and metadata """ try: # Phase 1: LM-based metadata and code generation (if enabled) audio_code_string_to_use = params.audio_codes lm_generated_metadata = None lm_generated_audio_codes_list = [] lm_total_time_costs = { "phase1_time": 0.0, "phase2_time": 0.0, "total_time": 0.0, } # Extract mutable copies of metadata (will be updated by LM if needed) bpm = params.bpm key_scale = params.keyscale time_signature = params.timesignature audio_duration = params.duration dit_input_caption = params.caption dit_input_vocal_language = params.vocal_language dit_input_lyrics = params.lyrics # Determine if we need to generate audio codes # If user has provided audio_codes, we don't need to generate them # Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode) user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip()) # Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed # For now, we use "llm_dit" if batch mode or if user hasn't provided codes # Use "dit" if user has provided codes (only need metas) or if explicitly only need metas # Note: This logic can be refined based on specific requirements need_audio_codes = not user_provided_audio_codes # Determine if we should use chunk-based LM generation (always use chunks for consistency) # Determine actual batch size for chunk processing actual_batch_size = config.batch_size if config.batch_size is not None else 1 # Prepare seeds for batch generation # Use config.seed if provided, otherwise fallback to params.seed # Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts seed_for_generation = "" if config.seeds is not None and len(config.seeds) > 0: if isinstance(config.seeds, list): # Convert List[int] to comma-separated string seed_for_generation = ",".join(str(s) for s in config.seeds) # Use dit_handler.prepare_seeds to handle seed list generation and padding # This will handle all the logic: padding with random seeds if needed, etc. actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed) # LM-based Chain-of-Thought reasoning # Skip LM for cover/repaint tasks - these tasks use reference/src audio directly # and don't need LM to generate audio codes skip_lm_tasks = {"cover", "repaint"} # Determine if we should use LLM # LLM is needed for: # 1. thinking=True: generate audio codes via LM # 2. use_cot_caption=True: enhance/generate caption via CoT # 3. use_cot_language=True: detect vocal language via CoT # 4. use_cot_metas=True: fill missing metadata via CoT need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas use_lm = (params.thinking or need_lm_for_cot) and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks lm_status = [] if params.task_type in skip_lm_tasks: logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly") logger.info(f"[generate_music] LLM usage decision: thinking={params.thinking}, " f"use_cot_caption={params.use_cot_caption}, use_cot_language={params.use_cot_language}, " f"use_cot_metas={params.use_cot_metas}, need_lm_for_cot={need_lm_for_cot}, " f"llm_initialized={llm_handler.llm_initialized if llm_handler else False}, use_lm={use_lm}") if use_lm: # Convert sampling parameters - handle None values safely top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k) top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p # Build user_metadata from user-provided values user_metadata = {} if bpm is not None: try: bpm_value = float(bpm) if bpm_value > 0: user_metadata['bpm'] = int(bpm_value) except (ValueError, TypeError): pass if key_scale and key_scale.strip(): key_scale_clean = key_scale.strip() if key_scale_clean.lower() not in ["n/a", ""]: user_metadata['keyscale'] = key_scale_clean if time_signature and time_signature.strip(): time_sig_clean = time_signature.strip() if time_sig_clean.lower() not in ["n/a", ""]: user_metadata['timesignature'] = time_sig_clean if audio_duration is not None: try: duration_value = float(audio_duration) if duration_value > 0: user_metadata['duration'] = int(duration_value) except (ValueError, TypeError): pass user_metadata_to_pass = user_metadata if user_metadata else None # Determine infer_type based on whether we need audio codes # - "llm_dit": generates both metas and audio codes (two-phase internally) # - "dit": generates only metas (single phase) infer_type = "llm_dit" if need_audio_codes and params.thinking else "dit" # Use chunk size from config, or default to batch_size if not set max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size num_chunks = math.ceil(actual_batch_size / max_inference_batch_size) all_metadata_list = [] all_audio_codes_list = [] for chunk_idx in range(num_chunks): chunk_start = chunk_idx * max_inference_batch_size chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size) chunk_size = chunk_end - chunk_start chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) " f"(size: {chunk_size}, seeds: {chunk_seeds})") # Use the determined infer_type # - "llm_dit" will internally run two phases (metas + codes) # - "dit" will only run phase 1 (metas only) result = llm_handler.generate_with_stop_condition( caption=params.caption or "", lyrics=params.lyrics or "", infer_type=infer_type, temperature=params.lm_temperature, cfg_scale=params.lm_cfg_scale, negative_prompt=params.lm_negative_prompt, top_k=top_k_value, top_p=top_p_value, user_metadata=user_metadata_to_pass, use_cot_caption=params.use_cot_caption, use_cot_language=params.use_cot_language, use_cot_metas=params.use_cot_metas, use_constrained_decoding=params.use_constrained_decoding, constrained_decoding_debug=config.constrained_decoding_debug, batch_size=chunk_size, seeds=chunk_seeds, progress=progress, ) # Check if LM generation failed if not result.get("success", False): error_msg = result.get("error", "Unknown LM error") lm_status.append(f"❌ LM Error: {error_msg}") # Return early with error return GenerationResult( audios=[], status_message=f"❌ LM generation failed: {error_msg}", extra_outputs={}, success=False, error=error_msg, ) # Extract metadata and audio_codes from result dict if chunk_size > 1: metadata_list = result.get("metadata", []) audio_codes_list = result.get("audio_codes", []) all_metadata_list.extend(metadata_list) all_audio_codes_list.extend(audio_codes_list) else: metadata = result.get("metadata", {}) audio_codes = result.get("audio_codes", "") all_metadata_list.append(metadata) all_audio_codes_list.append(audio_codes) # Collect time costs from LM extra_outputs lm_extra = result.get("extra_outputs", {}) lm_chunk_time_costs = lm_extra.get("time_costs", {}) if lm_chunk_time_costs: # Accumulate time costs from all chunks for key in ["phase1_time", "phase2_time", "total_time"]: if key in lm_chunk_time_costs: lm_total_time_costs[key] += lm_chunk_time_costs[key] time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()]) lm_status.append(f"✅ LM chunk {chunk_idx+1}: {time_str}") lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None lm_generated_audio_codes_list = all_audio_codes_list # Set audio_code_string_to_use based on infer_type if infer_type == "llm_dit": # If batch mode, use list; otherwise use single string if actual_batch_size > 1: audio_code_string_to_use = all_audio_codes_list else: audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else "" else: # For "dit" mode, keep user-provided codes or empty audio_code_string_to_use = params.audio_codes # Update metadata from LM if not provided by user if lm_generated_metadata: bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm( metadata=lm_generated_metadata, bpm=bpm, key_scale=key_scale, time_signature=time_signature, audio_duration=audio_duration, vocal_language=dit_input_vocal_language, caption=dit_input_caption, lyrics=dit_input_lyrics) if not params.bpm: params.cot_bpm = bpm if not params.keyscale: params.cot_keyscale = key_scale if not params.timesignature: params.cot_timesignature = time_signature if not params.duration: params.cot_duration = audio_duration if not params.vocal_language: params.cot_vocal_language = vocal_language if not params.caption: params.cot_caption = caption if not params.lyrics: params.cot_lyrics = lyrics # set cot caption and language if needed if params.use_cot_caption: dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption) if params.use_cot_language: dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language) # Phase 2: DiT music generation # Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation result = dit_handler.generate_music( captions=dit_input_caption, lyrics=dit_input_lyrics, bpm=bpm, key_scale=key_scale, time_signature=time_signature, vocal_language=dit_input_vocal_language, inference_steps=params.inference_steps, guidance_scale=params.guidance_scale, use_random_seed=config.use_random_seed, seed=seed_for_generation, # Use config.seed (or params.seed fallback) instead of params.seed directly reference_audio=params.reference_audio, audio_duration=audio_duration, batch_size=config.batch_size if config.batch_size is not None else 1, src_audio=params.src_audio, audio_code_string=audio_code_string_to_use, repainting_start=params.repainting_start, repainting_end=params.repainting_end, instruction=params.instruction, audio_cover_strength=params.audio_cover_strength, task_type=params.task_type, use_adg=params.use_adg, cfg_interval_start=params.cfg_interval_start, cfg_interval_end=params.cfg_interval_end, shift=params.shift, infer_method=params.infer_method, timesteps=params.timesteps, progress=progress, ) # Check if generation failed if not result.get("success", False): return GenerationResult( audios=[], status_message=result.get("status_message", ""), extra_outputs={}, success=False, error=result.get("error"), ) # Extract results from dit_handler.generate_music dict dit_audios = result.get("audios", []) status_message = result.get("status_message", "") dit_extra_outputs = result.get("extra_outputs", {}) # Use the seed list already prepared above (from config.seed or params.seed fallback) # actual_seed_list was computed earlier using dit_handler.prepare_seeds seed_list = actual_seed_list # Get base params dictionary base_params_dict = params.to_dict() # Save audio files using AudioSaver (format from config) audio_format = config.audio_format if config.audio_format else "flac" audio_saver = AudioSaver(default_format=audio_format) # Use handler's temp_dir for saving files if save_dir is not None: os.makedirs(save_dir, exist_ok=True) # Build audios list for GenerationResult with params and save files # Audio saving and UUID generation handled here, outside of handler audios = [] for idx, dit_audio in enumerate(dit_audios): # Create a copy of params dict for this audio audio_params = base_params_dict.copy() # Update audio-specific values audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None # Add audio codes if batch mode if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list): audio_params["audio_codes"] = lm_generated_audio_codes_list[idx] # Get audio tensor and metadata audio_tensor = dit_audio.get("tensor") sample_rate = dit_audio.get("sample_rate", 48000) # Generate UUID for this audio (moved from handler) batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1 audio_code_str = lm_generated_audio_codes_list[idx] if ( lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use if isinstance(audio_code_str, list): audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else "" audio_key = generate_uuid_from_params(audio_params) # Save audio file (handled outside handler) audio_path = None if audio_tensor is not None and save_dir is not None: try: audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}") audio_path = audio_saver.save_audio(audio_tensor, audio_file, sample_rate=sample_rate, format=audio_format, channels_first=True) except Exception as e: logger.error(f"[generate_music] Failed to save audio file: {e}") audio_path = "" # Fallback to empty path audio_dict = { "path": audio_path or "", # File path (saved here, not in handler) "tensor": audio_tensor, # Audio tensor [channels, samples], CPU, float32 "key": audio_key, "sample_rate": sample_rate, "params": audio_params, } audios.append(audio_dict) # Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata extra_outputs = dit_extra_outputs.copy() extra_outputs["lm_metadata"] = lm_generated_metadata # Merge time_costs from both LM and DiT into a unified dictionary unified_time_costs = {} # Add LM time costs (if LM was used) if use_lm and lm_total_time_costs: for key, value in lm_total_time_costs.items(): unified_time_costs[f"lm_{key}"] = value # Add DiT time costs (if available) dit_time_costs = dit_extra_outputs.get("time_costs", {}) if dit_time_costs: for key, value in dit_time_costs.items(): unified_time_costs[f"dit_{key}"] = value # Calculate total pipeline time if unified_time_costs: lm_total = unified_time_costs.get("lm_total_time", 0.0) dit_total = unified_time_costs.get("dit_total_time_cost", 0.0) unified_time_costs["pipeline_total_time"] = lm_total + dit_total # Update extra_outputs with unified time_costs extra_outputs["time_costs"] = unified_time_costs if lm_status: status_message = "\n".join(lm_status) + "\n" + status_message else: status_message = status_message # Create and return GenerationResult return GenerationResult( audios=audios, status_message=status_message, extra_outputs=extra_outputs, success=True, error=None, ) except Exception as e: logger.exception("Music generation failed") return GenerationResult( audios=[], status_message=f"Error: {str(e)}", extra_outputs={}, success=False, error=str(e), ) def understand_music( llm_handler, audio_codes: str, temperature: float = 0.85, top_k: Optional[int] = None, top_p: Optional[float] = None, repetition_penalty: float = 1.0, use_constrained_decoding: bool = True, constrained_decoding_debug: bool = False, ) -> UnderstandResult: """Understand music from audio codes using the 5Hz Language Model. This function analyzes audio semantic codes and generates metadata about the music, including caption, lyrics, BPM, duration, key scale, language, and time signature. If audio_codes is empty or "NO USER INPUT", the LM will generate a sample example instead of analyzing existing codes. Note: cfg_scale and negative_prompt are not supported in understand mode. Args: llm_handler: Initialized LLM handler (LLMHandler instance) audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...") Use empty string or "NO USER INPUT" to generate a sample example. temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative. top_k: Top-K sampling (None or 0 = disabled) top_p: Top-P (nucleus) sampling (None or 1.0 = disabled) repetition_penalty: Repetition penalty (1.0 = no penalty) use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata constrained_decoding_debug: Whether to enable debug logging for constrained decoding Returns: UnderstandResult with parsed metadata fields and status Example: >>> result = understand_music(llm_handler, audio_codes="<|audio_code_123|>...") >>> if result.success: ... print(f"Caption: {result.caption}") ... print(f"BPM: {result.bpm}") ... print(f"Lyrics: {result.lyrics}") """ # Check if LLM is initialized if not llm_handler.llm_initialized: return UnderstandResult( status_message="5Hz LM not initialized. Please initialize it first.", success=False, error="LLM not initialized", ) # If codes are empty, use "NO USER INPUT" to generate a sample example if not audio_codes or not audio_codes.strip(): audio_codes = "NO USER INPUT" try: # Call LLM understanding metadata, status = llm_handler.understand_audio_from_codes( audio_codes=audio_codes, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, use_constrained_decoding=use_constrained_decoding, constrained_decoding_debug=constrained_decoding_debug, ) # Check if LLM returned empty metadata (error case) if not metadata: return UnderstandResult( status_message=status or "Failed to understand audio codes", success=False, error=status or "Empty metadata returned", ) # Extract and convert fields caption = metadata.get('caption', '') lyrics = metadata.get('lyrics', '') keyscale = metadata.get('keyscale', '') language = metadata.get('language', metadata.get('vocal_language', '')) timesignature = metadata.get('timesignature', '') # Convert BPM to int bpm = None bpm_value = metadata.get('bpm') if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '': try: bpm = int(bpm_value) except (ValueError, TypeError): pass # Convert duration to float duration = None duration_value = metadata.get('duration') if duration_value is not None and duration_value != 'N/A' and duration_value != '': try: duration = float(duration_value) except (ValueError, TypeError): pass # Clean up N/A values if keyscale == 'N/A': keyscale = '' if language == 'N/A': language = '' if timesignature == 'N/A': timesignature = '' return UnderstandResult( caption=caption, lyrics=lyrics, bpm=bpm, duration=duration, keyscale=keyscale, language=language, timesignature=timesignature, status_message=status, success=True, error=None, ) except Exception as e: logger.exception("Music understanding failed") return UnderstandResult( status_message=f"Error: {str(e)}", success=False, error=str(e), ) @dataclass class CreateSampleResult: """Result of creating a music sample from a natural language query. This is used by the "Simple Mode" / "Inspiration Mode" feature where users provide a natural language description and the LLM generates a complete sample with caption, lyrics, and metadata. Attributes: # Metadata Fields caption: Generated detailed music description/caption lyrics: Generated lyrics (or "[Instrumental]" for instrumental music) bpm: Beats per minute (None if not generated) duration: Duration in seconds (None if not generated) keyscale: Musical key (e.g., "C Major") language: Vocal language code (e.g., "en", "zh") timesignature: Time signature (e.g., "4") instrumental: Whether this is an instrumental piece # Status status_message: Status message from sample creation success: Whether sample creation completed successfully error: Error message if sample creation failed """ # Metadata Fields caption: str = "" lyrics: str = "" bpm: Optional[int] = None duration: Optional[float] = None keyscale: str = "" language: str = "" timesignature: str = "" instrumental: bool = False # Status status_message: str = "" success: bool = True error: Optional[str] = None def to_dict(self) -> Dict[str, Any]: """Convert result to dictionary for JSON serialization.""" return asdict(self) def create_sample( llm_handler, query: str, instrumental: bool = False, vocal_language: Optional[str] = None, temperature: float = 0.85, top_k: Optional[int] = None, top_p: Optional[float] = None, repetition_penalty: float = 1.0, use_constrained_decoding: bool = True, constrained_decoding_debug: bool = False, ) -> CreateSampleResult: """Create a music sample from a natural language query using the 5Hz Language Model. This is the "Simple Mode" / "Inspiration Mode" feature that takes a user's natural language description of music and generates a complete sample including: - Detailed caption/description - Lyrics (unless instrumental) - Metadata (BPM, duration, key, language, time signature) Note: cfg_scale and negative_prompt are not supported in create_sample mode. Args: llm_handler: Initialized LLM handler (LLMHandler instance) query: User's natural language music description (e.g., "a soft Bengali love song") instrumental: Whether to generate instrumental music (no vocals) vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh"). If provided, the model will be constrained to generate lyrics in this language. If None or "unknown", no language constraint is applied. temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative. top_k: Top-K sampling (None or 0 = disabled) top_p: Top-P (nucleus) sampling (None or 1.0 = disabled) repetition_penalty: Repetition penalty (1.0 = no penalty) use_constrained_decoding: Whether to use FSM-based constrained decoding constrained_decoding_debug: Whether to enable debug logging Returns: CreateSampleResult with generated sample fields and status Example: >>> result = create_sample(llm_handler, "a soft Bengali love song for a quiet evening", vocal_language="bn") >>> if result.success: ... print(f"Caption: {result.caption}") ... print(f"Lyrics: {result.lyrics}") ... print(f"BPM: {result.bpm}") """ # Check if LLM is initialized if not llm_handler.llm_initialized: return CreateSampleResult( status_message="5Hz LM not initialized. Please initialize it first.", success=False, error="LLM not initialized", ) try: # Call LLM to create sample metadata, status = llm_handler.create_sample_from_query( query=query, instrumental=instrumental, vocal_language=vocal_language, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, use_constrained_decoding=use_constrained_decoding, constrained_decoding_debug=constrained_decoding_debug, ) # Check if LLM returned empty metadata (error case) if not metadata: return CreateSampleResult( status_message=status or "Failed to create sample", success=False, error=status or "Empty metadata returned", ) # Extract and convert fields caption = metadata.get('caption', '') lyrics = metadata.get('lyrics', '') keyscale = metadata.get('keyscale', '') language = metadata.get('language', metadata.get('vocal_language', '')) timesignature = metadata.get('timesignature', '') is_instrumental = metadata.get('instrumental', instrumental) # Convert BPM to int bpm = None bpm_value = metadata.get('bpm') if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '': try: bpm = int(bpm_value) except (ValueError, TypeError): pass # Convert duration to float duration = None duration_value = metadata.get('duration') if duration_value is not None and duration_value != 'N/A' and duration_value != '': try: duration = float(duration_value) except (ValueError, TypeError): pass # Clean up N/A values if keyscale == 'N/A': keyscale = '' if language == 'N/A': language = '' if timesignature == 'N/A': timesignature = '' return CreateSampleResult( caption=caption, lyrics=lyrics, bpm=bpm, duration=duration, keyscale=keyscale, language=language, timesignature=timesignature, instrumental=is_instrumental, status_message=status, success=True, error=None, ) except Exception as e: logger.exception("Sample creation failed") return CreateSampleResult( status_message=f"Error: {str(e)}", success=False, error=str(e), ) @dataclass class FormatSampleResult: """Result of formatting user-provided caption and lyrics. This is used by the "Format" feature where users provide caption and lyrics, and the LLM formats them into structured music metadata and an enhanced description. Attributes: # Metadata Fields caption: Enhanced/formatted music description/caption lyrics: Formatted lyrics (may be same as input or reformatted) bpm: Beats per minute (None if not detected) duration: Duration in seconds (None if not detected) keyscale: Musical key (e.g., "C Major") language: Vocal language code (e.g., "en", "zh") timesignature: Time signature (e.g., "4") # Status status_message: Status message from formatting success: Whether formatting completed successfully error: Error message if formatting failed """ # Metadata Fields caption: str = "" lyrics: str = "" bpm: Optional[int] = None duration: Optional[float] = None keyscale: str = "" language: str = "" timesignature: str = "" # Status status_message: str = "" success: bool = True error: Optional[str] = None def to_dict(self) -> Dict[str, Any]: """Convert result to dictionary for JSON serialization.""" return asdict(self) def format_sample( llm_handler, caption: str, lyrics: str, user_metadata: Optional[Dict[str, Any]] = None, temperature: float = 0.85, top_k: Optional[int] = None, top_p: Optional[float] = None, repetition_penalty: float = 1.0, use_constrained_decoding: bool = True, constrained_decoding_debug: bool = False, ) -> FormatSampleResult: """Format user-provided caption and lyrics using the 5Hz Language Model. This function takes user input (caption and lyrics) and generates structured music metadata including an enhanced caption, BPM, duration, key, language, and time signature. If user_metadata is provided, those values will be used to constrain the decoding, ensuring the output matches user-specified values. Note: cfg_scale and negative_prompt are not supported in format mode. Args: llm_handler: Initialized LLM handler (LLMHandler instance) caption: User's caption/description (e.g., "Latin pop, reggaeton") lyrics: User's lyrics with structure tags user_metadata: Optional dict with user-provided metadata to constrain decoding. Supported keys: bpm, duration, keyscale, timesignature, language temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative. top_k: Top-K sampling (None or 0 = disabled) top_p: Top-P (nucleus) sampling (None or 1.0 = disabled) repetition_penalty: Repetition penalty (1.0 = no penalty) use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata constrained_decoding_debug: Whether to enable debug logging for constrained decoding Returns: FormatSampleResult with formatted metadata fields and status Example: >>> result = format_sample(llm_handler, "Latin pop, reggaeton", "[Verse 1]\\nHola mundo...") >>> if result.success: ... print(f"Caption: {result.caption}") ... print(f"BPM: {result.bpm}") ... print(f"Lyrics: {result.lyrics}") """ # Check if LLM is initialized if not llm_handler.llm_initialized: return FormatSampleResult( status_message="5Hz LM not initialized. Please initialize it first.", success=False, error="LLM not initialized", ) try: # Call LLM formatting metadata, status = llm_handler.format_sample_from_input( caption=caption, lyrics=lyrics, user_metadata=user_metadata, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, use_constrained_decoding=use_constrained_decoding, constrained_decoding_debug=constrained_decoding_debug, ) # Check if LLM returned empty metadata (error case) if not metadata: return FormatSampleResult( status_message=status or "Failed to format input", success=False, error=status or "Empty metadata returned", ) # Extract and convert fields result_caption = metadata.get('caption', '') result_lyrics = metadata.get('lyrics', lyrics) # Fall back to input lyrics keyscale = metadata.get('keyscale', '') language = metadata.get('language', metadata.get('vocal_language', '')) timesignature = metadata.get('timesignature', '') # Convert BPM to int bpm = None bpm_value = metadata.get('bpm') if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '': try: bpm = int(bpm_value) except (ValueError, TypeError): pass # Convert duration to float duration = None duration_value = metadata.get('duration') if duration_value is not None and duration_value != 'N/A' and duration_value != '': try: duration = float(duration_value) except (ValueError, TypeError): pass # Clean up N/A values if keyscale == 'N/A': keyscale = '' if language == 'N/A': language = '' if timesignature == 'N/A': timesignature = '' return FormatSampleResult( caption=result_caption, lyrics=result_lyrics, bpm=bpm, duration=duration, keyscale=keyscale, language=language, timesignature=timesignature, status_message=status, success=True, error=None, ) except Exception as e: logger.exception("Format sample failed") return FormatSampleResult( status_message=f"Error: {str(e)}", success=False, error=str(e), )