Spaces:
Running
on
A100
Running
on
A100
add inference code and doc
Browse files- INFERENCE.md +695 -0
- acestep/gradio_ui/event.py +4 -7
- acestep/gradio_ui/events/results_handlers.py +2 -5
- acestep/handler.py +5 -2
- acestep/inference.py +928 -0
INFERENCE.md
ADDED
|
@@ -0,0 +1,695 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ACE-Step Inference API Documentation
|
| 2 |
+
|
| 3 |
+
This document provides comprehensive documentation for the ACE-Step inference API, including parameter specifications for all supported task types.
|
| 4 |
+
|
| 5 |
+
## Table of Contents
|
| 6 |
+
|
| 7 |
+
- [Quick Start](#quick-start)
|
| 8 |
+
- [API Overview](#api-overview)
|
| 9 |
+
- [Configuration Parameters](#configuration-parameters)
|
| 10 |
+
- [Task Types](#task-types)
|
| 11 |
+
- [Complete Examples](#complete-examples)
|
| 12 |
+
- [Best Practices](#best-practices)
|
| 13 |
+
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
## Quick Start
|
| 17 |
+
|
| 18 |
+
### Basic Usage
|
| 19 |
+
|
| 20 |
+
```python
|
| 21 |
+
from acestep.handler import AceStepHandler
|
| 22 |
+
from acestep.llm_inference import LLMHandler
|
| 23 |
+
from acestep.inference import GenerationConfig, generate_music
|
| 24 |
+
|
| 25 |
+
# Initialize handlers
|
| 26 |
+
dit_handler = AceStepHandler()
|
| 27 |
+
llm_handler = LLMHandler()
|
| 28 |
+
|
| 29 |
+
# Initialize services
|
| 30 |
+
dit_handler.initialize_service(
|
| 31 |
+
project_root="/path/to/project",
|
| 32 |
+
config_path="acestep-v15-turbo-rl",
|
| 33 |
+
device="cuda"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
llm_handler.initialize(
|
| 37 |
+
checkpoint_dir="/path/to/checkpoints",
|
| 38 |
+
lm_model_path="acestep-5Hz-lm-0.6B-v3",
|
| 39 |
+
backend="vllm",
|
| 40 |
+
device="cuda"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Configure generation
|
| 44 |
+
config = GenerationConfig(
|
| 45 |
+
caption="upbeat electronic dance music with heavy bass",
|
| 46 |
+
bpm=128,
|
| 47 |
+
audio_duration=30,
|
| 48 |
+
batch_size=1,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Generate music
|
| 52 |
+
result = generate_music(dit_handler, llm_handler, config)
|
| 53 |
+
|
| 54 |
+
# Access results
|
| 55 |
+
if result.success:
|
| 56 |
+
for audio_path in result.audio_paths:
|
| 57 |
+
print(f"Generated: {audio_path}")
|
| 58 |
+
else:
|
| 59 |
+
print(f"Error: {result.error}")
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
---
|
| 63 |
+
|
| 64 |
+
## API Overview
|
| 65 |
+
|
| 66 |
+
### Main Function
|
| 67 |
+
|
| 68 |
+
```python
|
| 69 |
+
def generate_music(
|
| 70 |
+
dit_handler: AceStepHandler,
|
| 71 |
+
llm_handler: LLMHandler,
|
| 72 |
+
config: GenerationConfig,
|
| 73 |
+
) -> GenerationResult
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
### Configuration Object
|
| 77 |
+
|
| 78 |
+
The `GenerationConfig` dataclass consolidates all generation parameters:
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
@dataclass
|
| 82 |
+
class GenerationConfig:
|
| 83 |
+
# Required parameters with sensible defaults
|
| 84 |
+
caption: str = ""
|
| 85 |
+
lyrics: str = ""
|
| 86 |
+
# ... (see full parameter list below)
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
### Result Object
|
| 90 |
+
|
| 91 |
+
```python
|
| 92 |
+
@dataclass
|
| 93 |
+
class GenerationResult:
|
| 94 |
+
audio_paths: List[str] # Paths to generated audio files
|
| 95 |
+
generation_info: str # Markdown-formatted info
|
| 96 |
+
status_message: str # Status message
|
| 97 |
+
seed_value: str # Seed used
|
| 98 |
+
lm_metadata: Optional[Dict] # LM-generated metadata
|
| 99 |
+
success: bool # Success flag
|
| 100 |
+
error: Optional[str] # Error message if failed
|
| 101 |
+
# ... (see full fields below)
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
---
|
| 105 |
+
|
| 106 |
+
## Configuration Parameters
|
| 107 |
+
|
| 108 |
+
### Text Inputs
|
| 109 |
+
|
| 110 |
+
| Parameter | Type | Default | Description |
|
| 111 |
+
|-----------|------|---------|-------------|
|
| 112 |
+
| `caption` | `str` | `""` | Text description of the desired music. Can be a simple prompt like "relaxing piano music" or detailed description with genre, mood, instruments, etc. |
|
| 113 |
+
| `lyrics` | `str` | `""` | Lyrics text for vocal music. Use `"[Instrumental]"` for instrumental tracks. Supports multiple languages. |
|
| 114 |
+
|
| 115 |
+
### Music Metadata
|
| 116 |
+
|
| 117 |
+
| Parameter | Type | Default | Description |
|
| 118 |
+
|-----------|------|---------|-------------|
|
| 119 |
+
| `bpm` | `Optional[int]` | `None` | Beats per minute (30-300). `None` enables auto-detection via LM. |
|
| 120 |
+
| `key_scale` | `str` | `""` | Musical key (e.g., "C Major", "Am", "F# minor"). Empty string enables auto-detection. |
|
| 121 |
+
| `time_signature` | `str` | `""` | Time signature (e.g., "4/4", "3/4", "6/8"). Empty string enables auto-detection. |
|
| 122 |
+
| `vocal_language` | `str` | `"unknown"` | Language code for vocals (ISO 639-1). Supported: `"en"`, `"zh"`, `"ja"`, `"es"`, `"fr"`, etc. Use `"unknown"` for auto-detection. |
|
| 123 |
+
| `audio_duration` | `Optional[float]` | `None` | Duration in seconds (10-600). `None` enables auto-detection based on lyrics length. |
|
| 124 |
+
|
| 125 |
+
### Generation Parameters
|
| 126 |
+
|
| 127 |
+
| Parameter | Type | Default | Description |
|
| 128 |
+
|-----------|------|---------|-------------|
|
| 129 |
+
| `inference_steps` | `int` | `8` | Number of denoising steps. Turbo model: 1-8 (recommended 8). Base model: 1-100 (recommended 32-64). Higher = better quality but slower. |
|
| 130 |
+
| `guidance_scale` | `float` | `7.0` | Classifier-free guidance scale (1.0-15.0). Higher values increase adherence to text prompt. Typical range: 5.0-9.0. |
|
| 131 |
+
| `use_random_seed` | `bool` | `True` | Whether to use random seed. `True` for different results each time, `False` for reproducible results. |
|
| 132 |
+
| `seed` | `int` | `-1` | Random seed for reproducibility. Use `-1` for random seed, or any positive integer for fixed seed. |
|
| 133 |
+
| `batch_size` | `int` | `1` | Number of samples to generate in parallel (1-8). Higher values require more GPU memory. |
|
| 134 |
+
|
| 135 |
+
### Advanced DiT Parameters
|
| 136 |
+
|
| 137 |
+
| Parameter | Type | Default | Description |
|
| 138 |
+
|-----------|------|---------|-------------|
|
| 139 |
+
| `use_adg` | `bool` | `False` | Use Adaptive Dual Guidance (base model only). Improves quality at the cost of speed. |
|
| 140 |
+
| `cfg_interval_start` | `float` | `0.0` | CFG application start ratio (0.0-1.0). Controls when to start applying classifier-free guidance. |
|
| 141 |
+
| `cfg_interval_end` | `float` | `1.0` | CFG application end ratio (0.0-1.0). Controls when to stop applying classifier-free guidance. |
|
| 142 |
+
| `audio_format` | `str` | `"mp3"` | Output audio format. Options: `"mp3"`, `"wav"`, `"flac"`. |
|
| 143 |
+
|
| 144 |
+
### Task-Specific Parameters
|
| 145 |
+
|
| 146 |
+
| Parameter | Type | Default | Description |
|
| 147 |
+
|-----------|------|---------|-------------|
|
| 148 |
+
| `task_type` | `str` | `"text2music"` | Generation task type. See [Task Types](#task-types) section for details. |
|
| 149 |
+
| `reference_audio` | `Optional[str]` | `None` | Path to reference audio file for style transfer or continuation tasks. |
|
| 150 |
+
| `src_audio` | `Optional[str]` | `None` | Path to source audio file for audio-to-audio tasks (cover, repaint, etc.). |
|
| 151 |
+
| `audio_code_string` | `Union[str, List[str]]` | `""` | Pre-extracted 5Hz audio codes. Can be single string or list for batch mode. Advanced use only. |
|
| 152 |
+
| `repainting_start` | `float` | `0.0` | Repainting start time in seconds (for repaint/lego tasks). |
|
| 153 |
+
| `repainting_end` | `float` | `-1` | Repainting end time in seconds. Use `-1` for end of audio. |
|
| 154 |
+
| `audio_cover_strength` | `float` | `1.0` | Strength of audio cover/codes influence (0.0-1.0). Higher = stronger influence from source audio. |
|
| 155 |
+
| `instruction` | `str` | `""` | Task-specific instruction prompt. Auto-generated if empty. |
|
| 156 |
+
|
| 157 |
+
### 5Hz Language Model Parameters
|
| 158 |
+
|
| 159 |
+
| Parameter | Type | Default | Description |
|
| 160 |
+
|-----------|------|---------|-------------|
|
| 161 |
+
| `use_llm_thinking` | `bool` | `False` | Enable LM-based Chain-of-Thought reasoning. When enabled, LM generates metadata and/or audio codes. |
|
| 162 |
+
| `lm_temperature` | `float` | `0.85` | LM sampling temperature (0.0-2.0). Higher = more creative/diverse, lower = more conservative. |
|
| 163 |
+
| `lm_cfg_scale` | `float` | `2.0` | LM classifier-free guidance scale (1.0-5.0). Higher = stronger adherence to prompt. |
|
| 164 |
+
| `lm_top_k` | `int` | `0` | LM top-k sampling. `0` disables top-k filtering. Typical values: 40-100. |
|
| 165 |
+
| `lm_top_p` | `float` | `0.9` | LM nucleus sampling (0.0-1.0). `1.0` disables nucleus sampling. Typical values: 0.9-0.95. |
|
| 166 |
+
| `lm_negative_prompt` | `str` | `"NO USER INPUT"` | Negative prompt for LM guidance. Helps avoid unwanted characteristics. |
|
| 167 |
+
| `use_cot_metas` | `bool` | `True` | Generate metadata using LM CoT reasoning (BPM, key, duration, etc.). |
|
| 168 |
+
| `use_cot_caption` | `bool` | `True` | Refine user caption using LM CoT reasoning. |
|
| 169 |
+
| `use_cot_language` | `bool` | `True` | Detect vocal language using LM CoT reasoning. |
|
| 170 |
+
| `is_format_caption` | `bool` | `False` | Whether caption is already formatted/refined (skip LM refinement). |
|
| 171 |
+
| `constrained_decoding_debug` | `bool` | `False` | Enable debug logging for constrained decoding. |
|
| 172 |
+
|
| 173 |
+
### Batch LM Generation
|
| 174 |
+
|
| 175 |
+
| Parameter | Type | Default | Description |
|
| 176 |
+
|-----------|------|---------|-------------|
|
| 177 |
+
| `allow_lm_batch` | `bool` | `False` | Allow batch LM code generation. Faster when `batch_size >= 2` and `use_llm_thinking=True`. |
|
| 178 |
+
| `lm_batch_chunk_size` | `int` | `4` | Maximum batch size per LM inference chunk (GPU memory constraint). |
|
| 179 |
+
|
| 180 |
+
---
|
| 181 |
+
|
| 182 |
+
## Task Types
|
| 183 |
+
|
| 184 |
+
ACE-Step supports 6 different generation task types, each optimized for specific use cases.
|
| 185 |
+
|
| 186 |
+
### 1. Text2Music (Default)
|
| 187 |
+
|
| 188 |
+
**Purpose**: Generate music from text descriptions and optional metadata.
|
| 189 |
+
|
| 190 |
+
**Key Parameters**:
|
| 191 |
+
```python
|
| 192 |
+
config = GenerationConfig(
|
| 193 |
+
task_type="text2music",
|
| 194 |
+
caption="energetic rock music with electric guitar",
|
| 195 |
+
lyrics="[Instrumental]", # or actual lyrics
|
| 196 |
+
bpm=140,
|
| 197 |
+
audio_duration=30,
|
| 198 |
+
)
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
**Required**:
|
| 202 |
+
- `caption` or `lyrics` (at least one)
|
| 203 |
+
|
| 204 |
+
**Optional but Recommended**:
|
| 205 |
+
- `bpm`: Controls tempo
|
| 206 |
+
- `key_scale`: Controls musical key
|
| 207 |
+
- `time_signature`: Controls rhythm structure
|
| 208 |
+
- `audio_duration`: Controls length
|
| 209 |
+
- `vocal_language`: Controls vocal characteristics
|
| 210 |
+
|
| 211 |
+
**Use Cases**:
|
| 212 |
+
- Generate music from text descriptions
|
| 213 |
+
- Create backing tracks from prompts
|
| 214 |
+
- Generate songs with lyrics
|
| 215 |
+
|
| 216 |
+
---
|
| 217 |
+
|
| 218 |
+
### 2. Cover
|
| 219 |
+
|
| 220 |
+
**Purpose**: Transform existing audio while maintaining structure but changing style/timbre.
|
| 221 |
+
|
| 222 |
+
**Key Parameters**:
|
| 223 |
+
```python
|
| 224 |
+
config = GenerationConfig(
|
| 225 |
+
task_type="cover",
|
| 226 |
+
src_audio="original_song.mp3",
|
| 227 |
+
caption="jazz piano version",
|
| 228 |
+
audio_cover_strength=0.8, # 0.0-1.0
|
| 229 |
+
)
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
**Required**:
|
| 233 |
+
- `src_audio`: Path to source audio file
|
| 234 |
+
- `caption`: Description of desired style/transformation
|
| 235 |
+
|
| 236 |
+
**Optional**:
|
| 237 |
+
- `audio_cover_strength`: Controls influence of original audio
|
| 238 |
+
- `1.0`: Strong adherence to original structure
|
| 239 |
+
- `0.5`: Balanced transformation
|
| 240 |
+
- `0.1`: Loose interpretation
|
| 241 |
+
- `lyrics`: New lyrics (if changing vocals)
|
| 242 |
+
|
| 243 |
+
**Use Cases**:
|
| 244 |
+
- Create covers in different styles
|
| 245 |
+
- Change instrumentation while keeping melody
|
| 246 |
+
- Genre transformation
|
| 247 |
+
|
| 248 |
+
---
|
| 249 |
+
|
| 250 |
+
### 3. Repaint
|
| 251 |
+
|
| 252 |
+
**Purpose**: Regenerate a specific time segment of audio while keeping the rest unchanged.
|
| 253 |
+
|
| 254 |
+
**Key Parameters**:
|
| 255 |
+
```python
|
| 256 |
+
config = GenerationConfig(
|
| 257 |
+
task_type="repaint",
|
| 258 |
+
src_audio="original.mp3",
|
| 259 |
+
repainting_start=10.0, # seconds
|
| 260 |
+
repainting_end=20.0, # seconds
|
| 261 |
+
caption="smooth transition with piano solo",
|
| 262 |
+
)
|
| 263 |
+
```
|
| 264 |
+
|
| 265 |
+
**Required**:
|
| 266 |
+
- `src_audio`: Path to source audio file
|
| 267 |
+
- `repainting_start`: Start time in seconds
|
| 268 |
+
- `repainting_end`: End time in seconds (use `-1` for end of file)
|
| 269 |
+
- `caption`: Description of desired content for repainted section
|
| 270 |
+
|
| 271 |
+
**Use Cases**:
|
| 272 |
+
- Fix specific sections of generated music
|
| 273 |
+
- Add variations to parts of a song
|
| 274 |
+
- Create smooth transitions
|
| 275 |
+
- Replace problematic segments
|
| 276 |
+
|
| 277 |
+
---
|
| 278 |
+
|
| 279 |
+
### 4. Lego (Base Model Only)
|
| 280 |
+
|
| 281 |
+
**Purpose**: Generate a specific instrument track in context of existing audio.
|
| 282 |
+
|
| 283 |
+
**Key Parameters**:
|
| 284 |
+
```python
|
| 285 |
+
config = GenerationConfig(
|
| 286 |
+
task_type="lego",
|
| 287 |
+
src_audio="backing_track.mp3",
|
| 288 |
+
instruction="Generate the guitar track based on the audio context:",
|
| 289 |
+
caption="lead guitar melody with bluesy feel",
|
| 290 |
+
repainting_start=0.0,
|
| 291 |
+
repainting_end=-1,
|
| 292 |
+
)
|
| 293 |
+
```
|
| 294 |
+
|
| 295 |
+
**Required**:
|
| 296 |
+
- `src_audio`: Path to source/backing audio
|
| 297 |
+
- `instruction`: Must specify the track type (e.g., "Generate the {TRACK_NAME} track...")
|
| 298 |
+
- `caption`: Description of desired track characteristics
|
| 299 |
+
|
| 300 |
+
**Available Tracks**:
|
| 301 |
+
- `"vocals"`, `"backing_vocals"`, `"drums"`, `"bass"`, `"guitar"`, `"keyboard"`,
|
| 302 |
+
- `"percussion"`, `"strings"`, `"synth"`, `"fx"`, `"brass"`, `"woodwinds"`
|
| 303 |
+
|
| 304 |
+
**Use Cases**:
|
| 305 |
+
- Add specific instrument tracks
|
| 306 |
+
- Layer additional instruments over backing tracks
|
| 307 |
+
- Create multi-track compositions iteratively
|
| 308 |
+
|
| 309 |
+
---
|
| 310 |
+
|
| 311 |
+
### 5. Extract (Base Model Only)
|
| 312 |
+
|
| 313 |
+
**Purpose**: Extract/isolate a specific instrument track from mixed audio.
|
| 314 |
+
|
| 315 |
+
**Key Parameters**:
|
| 316 |
+
```python
|
| 317 |
+
config = GenerationConfig(
|
| 318 |
+
task_type="extract",
|
| 319 |
+
src_audio="full_mix.mp3",
|
| 320 |
+
instruction="Extract the vocals track from the audio:",
|
| 321 |
+
)
|
| 322 |
+
```
|
| 323 |
+
|
| 324 |
+
**Required**:
|
| 325 |
+
- `src_audio`: Path to mixed audio file
|
| 326 |
+
- `instruction`: Must specify track to extract
|
| 327 |
+
|
| 328 |
+
**Available Tracks**: Same as Lego task
|
| 329 |
+
|
| 330 |
+
**Use Cases**:
|
| 331 |
+
- Stem separation
|
| 332 |
+
- Isolate specific instruments
|
| 333 |
+
- Create remixes
|
| 334 |
+
- Analyze individual tracks
|
| 335 |
+
|
| 336 |
+
---
|
| 337 |
+
|
| 338 |
+
### 6. Complete (Base Model Only)
|
| 339 |
+
|
| 340 |
+
**Purpose**: Complete/extend partial tracks with specified instruments.
|
| 341 |
+
|
| 342 |
+
**Key Parameters**:
|
| 343 |
+
```python
|
| 344 |
+
config = GenerationConfig(
|
| 345 |
+
task_type="complete",
|
| 346 |
+
src_audio="incomplete_track.mp3",
|
| 347 |
+
instruction="Complete the input track with drums, bass, guitar:",
|
| 348 |
+
caption="rock style completion",
|
| 349 |
+
)
|
| 350 |
+
```
|
| 351 |
+
|
| 352 |
+
**Required**:
|
| 353 |
+
- `src_audio`: Path to incomplete/partial track
|
| 354 |
+
- `instruction`: Must specify which tracks to add
|
| 355 |
+
- `caption`: Description of desired style
|
| 356 |
+
|
| 357 |
+
**Use Cases**:
|
| 358 |
+
- Arrange incomplete compositions
|
| 359 |
+
- Add backing tracks
|
| 360 |
+
- Auto-complete musical ideas
|
| 361 |
+
|
| 362 |
+
---
|
| 363 |
+
|
| 364 |
+
## Complete Examples
|
| 365 |
+
|
| 366 |
+
### Example 1: Simple Text-to-Music Generation
|
| 367 |
+
|
| 368 |
+
```python
|
| 369 |
+
from acestep.inference import GenerationConfig, generate_music
|
| 370 |
+
|
| 371 |
+
config = GenerationConfig(
|
| 372 |
+
task_type="text2music",
|
| 373 |
+
caption="calm ambient music with soft piano and strings",
|
| 374 |
+
audio_duration=60,
|
| 375 |
+
bpm=80,
|
| 376 |
+
key_scale="C Major",
|
| 377 |
+
batch_size=2, # Generate 2 variations
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
result = generate_music(dit_handler, llm_handler, config)
|
| 381 |
+
|
| 382 |
+
if result.success:
|
| 383 |
+
for i, path in enumerate(result.audio_paths, 1):
|
| 384 |
+
print(f"Variation {i}: {path}")
|
| 385 |
+
```
|
| 386 |
+
|
| 387 |
+
### Example 2: Song Generation with Lyrics
|
| 388 |
+
|
| 389 |
+
```python
|
| 390 |
+
config = GenerationConfig(
|
| 391 |
+
task_type="text2music",
|
| 392 |
+
caption="pop ballad with emotional vocals",
|
| 393 |
+
lyrics="""Verse 1:
|
| 394 |
+
Walking down the street today
|
| 395 |
+
Thinking of the words you used to say
|
| 396 |
+
Everything feels different now
|
| 397 |
+
But I'll find my way somehow
|
| 398 |
+
|
| 399 |
+
Chorus:
|
| 400 |
+
I'm moving on, I'm staying strong
|
| 401 |
+
This is where I belong
|
| 402 |
+
""",
|
| 403 |
+
vocal_language="en",
|
| 404 |
+
bpm=72,
|
| 405 |
+
audio_duration=45,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
result = generate_music(dit_handler, llm_handler, config)
|
| 409 |
+
```
|
| 410 |
+
|
| 411 |
+
### Example 3: Style Cover with LM Reasoning
|
| 412 |
+
|
| 413 |
+
```python
|
| 414 |
+
config = GenerationConfig(
|
| 415 |
+
task_type="cover",
|
| 416 |
+
src_audio="original_pop_song.mp3",
|
| 417 |
+
caption="orchestral symphonic arrangement",
|
| 418 |
+
audio_cover_strength=0.7,
|
| 419 |
+
use_llm_thinking=True, # Enable LM for metadata
|
| 420 |
+
use_cot_metas=True,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
result = generate_music(dit_handler, llm_handler, config)
|
| 424 |
+
|
| 425 |
+
# Access LM-generated metadata
|
| 426 |
+
if result.lm_metadata:
|
| 427 |
+
print(f"LM detected BPM: {result.lm_metadata.get('bpm')}")
|
| 428 |
+
print(f"LM detected Key: {result.lm_metadata.get('keyscale')}")
|
| 429 |
+
```
|
| 430 |
+
|
| 431 |
+
### Example 4: Repaint Section of Audio
|
| 432 |
+
|
| 433 |
+
```python
|
| 434 |
+
config = GenerationConfig(
|
| 435 |
+
task_type="repaint",
|
| 436 |
+
src_audio="generated_track.mp3",
|
| 437 |
+
repainting_start=15.0, # Start at 15 seconds
|
| 438 |
+
repainting_end=25.0, # End at 25 seconds
|
| 439 |
+
caption="dramatic orchestral buildup",
|
| 440 |
+
inference_steps=32, # Higher quality for base model
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
result = generate_music(dit_handler, llm_handler, config)
|
| 444 |
+
```
|
| 445 |
+
|
| 446 |
+
### Example 5: Batch Generation with LM
|
| 447 |
+
|
| 448 |
+
```python
|
| 449 |
+
config = GenerationConfig(
|
| 450 |
+
task_type="text2music",
|
| 451 |
+
caption="epic cinematic trailer music",
|
| 452 |
+
batch_size=4, # Generate 4 variations
|
| 453 |
+
use_llm_thinking=True,
|
| 454 |
+
use_cot_metas=True,
|
| 455 |
+
allow_lm_batch=True, # Faster batch processing
|
| 456 |
+
lm_batch_chunk_size=2, # Process 2 at a time (GPU memory)
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
result = generate_music(dit_handler, llm_handler, config)
|
| 460 |
+
|
| 461 |
+
if result.success:
|
| 462 |
+
print(f"Generated {len(result.audio_paths)} variations")
|
| 463 |
+
```
|
| 464 |
+
|
| 465 |
+
### Example 6: High-Quality Generation (Base Model)
|
| 466 |
+
|
| 467 |
+
```python
|
| 468 |
+
config = GenerationConfig(
|
| 469 |
+
task_type="text2music",
|
| 470 |
+
caption="intricate jazz fusion with complex harmonies",
|
| 471 |
+
inference_steps=64, # High quality
|
| 472 |
+
guidance_scale=8.0,
|
| 473 |
+
use_adg=True, # Adaptive Dual Guidance
|
| 474 |
+
cfg_interval_start=0.0,
|
| 475 |
+
cfg_interval_end=1.0,
|
| 476 |
+
audio_format="wav", # Lossless format
|
| 477 |
+
use_random_seed=False,
|
| 478 |
+
seed=42, # Reproducible results
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
result = generate_music(dit_handler, llm_handler, config)
|
| 482 |
+
```
|
| 483 |
+
|
| 484 |
+
### Example 7: Extract Vocals from Mix
|
| 485 |
+
|
| 486 |
+
```python
|
| 487 |
+
config = GenerationConfig(
|
| 488 |
+
task_type="extract",
|
| 489 |
+
src_audio="full_song_mix.mp3",
|
| 490 |
+
instruction="Extract the vocals track from the audio:",
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
result = generate_music(dit_handler, llm_handler, config)
|
| 494 |
+
|
| 495 |
+
if result.success:
|
| 496 |
+
print(f"Extracted vocals: {result.audio_paths[0]}")
|
| 497 |
+
```
|
| 498 |
+
|
| 499 |
+
### Example 8: Add Guitar Track (Lego)
|
| 500 |
+
|
| 501 |
+
```python
|
| 502 |
+
config = GenerationConfig(
|
| 503 |
+
task_type="lego",
|
| 504 |
+
src_audio="drums_and_bass.mp3",
|
| 505 |
+
instruction="Generate the guitar track based on the audio context:",
|
| 506 |
+
caption="funky rhythm guitar with wah-wah effect",
|
| 507 |
+
repainting_start=0.0,
|
| 508 |
+
repainting_end=-1, # Full duration
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
result = generate_music(dit_handler, llm_handler, config)
|
| 512 |
+
```
|
| 513 |
+
|
| 514 |
+
---
|
| 515 |
+
|
| 516 |
+
## Best Practices
|
| 517 |
+
|
| 518 |
+
### 1. Caption Writing
|
| 519 |
+
|
| 520 |
+
**Good Captions**:
|
| 521 |
+
```python
|
| 522 |
+
# Specific and descriptive
|
| 523 |
+
caption="upbeat electronic dance music with heavy bass and synthesizer leads"
|
| 524 |
+
|
| 525 |
+
# Include mood and genre
|
| 526 |
+
caption="melancholic indie folk with acoustic guitar and soft vocals"
|
| 527 |
+
|
| 528 |
+
# Specify instruments
|
| 529 |
+
caption="jazz trio with piano, upright bass, and brush drums"
|
| 530 |
+
```
|
| 531 |
+
|
| 532 |
+
**Avoid**:
|
| 533 |
+
```python
|
| 534 |
+
# Too vague
|
| 535 |
+
caption="good music"
|
| 536 |
+
|
| 537 |
+
# Contradictory
|
| 538 |
+
caption="fast slow music" # Conflicting tempos
|
| 539 |
+
```
|
| 540 |
+
|
| 541 |
+
### 2. Parameter Tuning
|
| 542 |
+
|
| 543 |
+
**For Best Quality**:
|
| 544 |
+
- Use base model with `inference_steps=64` or higher
|
| 545 |
+
- Enable `use_adg=True`
|
| 546 |
+
- Set `guidance_scale=7.0-9.0`
|
| 547 |
+
- Use lossless audio format (`audio_format="wav"`)
|
| 548 |
+
|
| 549 |
+
**For Speed**:
|
| 550 |
+
- Use turbo model with `inference_steps=8`
|
| 551 |
+
- Disable ADG (`use_adg=False`)
|
| 552 |
+
- Lower `guidance_scale=5.0-7.0`
|
| 553 |
+
- Use compressed format (`audio_format="mp3"`)
|
| 554 |
+
|
| 555 |
+
**For Consistency**:
|
| 556 |
+
- Set `use_random_seed=False`
|
| 557 |
+
- Use fixed `seed` value
|
| 558 |
+
- Keep `lm_temperature` lower (0.7-0.85)
|
| 559 |
+
|
| 560 |
+
**For Diversity**:
|
| 561 |
+
- Set `use_random_seed=True`
|
| 562 |
+
- Increase `lm_temperature` (0.9-1.1)
|
| 563 |
+
- Use `batch_size > 1` for variations
|
| 564 |
+
|
| 565 |
+
### 3. Duration Guidelines
|
| 566 |
+
|
| 567 |
+
- **Instrumental**: 30-180 seconds works well
|
| 568 |
+
- **With Lyrics**: Auto-detection recommended (set `audio_duration=None`)
|
| 569 |
+
- **Short clips**: 10-20 seconds minimum
|
| 570 |
+
- **Long form**: Up to 600 seconds (10 minutes) maximum
|
| 571 |
+
|
| 572 |
+
### 4. LM Usage
|
| 573 |
+
|
| 574 |
+
**When to Enable LM (`use_llm_thinking=True`)**:
|
| 575 |
+
- Need automatic metadata detection
|
| 576 |
+
- Want caption refinement
|
| 577 |
+
- Generating from minimal input
|
| 578 |
+
- Need diverse outputs
|
| 579 |
+
|
| 580 |
+
**When to Disable LM**:
|
| 581 |
+
- Have precise metadata already
|
| 582 |
+
- Need faster generation
|
| 583 |
+
- Want full control over parameters
|
| 584 |
+
|
| 585 |
+
### 5. Batch Processing
|
| 586 |
+
|
| 587 |
+
```python
|
| 588 |
+
# Efficient batch generation
|
| 589 |
+
config = GenerationConfig(
|
| 590 |
+
batch_size=8, # Max supported
|
| 591 |
+
use_llm_thinking=True,
|
| 592 |
+
allow_lm_batch=True, # Enable for speed
|
| 593 |
+
lm_batch_chunk_size=4, # Adjust based on GPU memory
|
| 594 |
+
)
|
| 595 |
+
```
|
| 596 |
+
|
| 597 |
+
### 6. Error Handling
|
| 598 |
+
|
| 599 |
+
```python
|
| 600 |
+
result = generate_music(dit_handler, llm_handler, config)
|
| 601 |
+
|
| 602 |
+
if not result.success:
|
| 603 |
+
print(f"Generation failed: {result.error}")
|
| 604 |
+
# Check logs for details
|
| 605 |
+
else:
|
| 606 |
+
# Process successful result
|
| 607 |
+
for path in result.audio_paths:
|
| 608 |
+
# ... process audio files
|
| 609 |
+
pass
|
| 610 |
+
```
|
| 611 |
+
|
| 612 |
+
### 7. Memory Management
|
| 613 |
+
|
| 614 |
+
For large batch sizes or long durations:
|
| 615 |
+
- Monitor GPU memory usage
|
| 616 |
+
- Reduce `batch_size` if OOM errors occur
|
| 617 |
+
- Reduce `lm_batch_chunk_size` for LM operations
|
| 618 |
+
- Consider using `offload_to_cpu=True` during initialization
|
| 619 |
+
|
| 620 |
+
---
|
| 621 |
+
|
| 622 |
+
## Troubleshooting
|
| 623 |
+
|
| 624 |
+
### Common Issues
|
| 625 |
+
|
| 626 |
+
**Issue**: Out of memory errors
|
| 627 |
+
- **Solution**: Reduce `batch_size`, `inference_steps`, or enable CPU offloading
|
| 628 |
+
|
| 629 |
+
**Issue**: Poor quality results
|
| 630 |
+
- **Solution**: Increase `inference_steps`, adjust `guidance_scale`, use base model
|
| 631 |
+
|
| 632 |
+
**Issue**: Results don't match prompt
|
| 633 |
+
- **Solution**: Make caption more specific, increase `guidance_scale`, enable LM refinement
|
| 634 |
+
|
| 635 |
+
**Issue**: Slow generation
|
| 636 |
+
- **Solution**: Use turbo model, reduce `inference_steps`, disable ADG
|
| 637 |
+
|
| 638 |
+
**Issue**: LM not generating codes
|
| 639 |
+
- **Solution**: Verify `llm_handler` is initialized, check `use_llm_thinking=True` and `use_cot_metas=True`
|
| 640 |
+
|
| 641 |
+
---
|
| 642 |
+
|
| 643 |
+
## API Reference Summary
|
| 644 |
+
|
| 645 |
+
### GenerationConfig Fields
|
| 646 |
+
|
| 647 |
+
See [Configuration Parameters](#configuration-parameters) for complete documentation.
|
| 648 |
+
|
| 649 |
+
### GenerationResult Fields
|
| 650 |
+
|
| 651 |
+
```python
|
| 652 |
+
@dataclass
|
| 653 |
+
class GenerationResult:
|
| 654 |
+
# Audio outputs
|
| 655 |
+
audio_paths: List[str] # List of generated audio file paths
|
| 656 |
+
first_audio: Optional[str] # First audio (backward compatibility)
|
| 657 |
+
second_audio: Optional[str] # Second audio (backward compatibility)
|
| 658 |
+
|
| 659 |
+
# Generation metadata
|
| 660 |
+
generation_info: str # Markdown-formatted generation info
|
| 661 |
+
status_message: str # Status message
|
| 662 |
+
seed_value: str # Seed value used
|
| 663 |
+
|
| 664 |
+
# LM outputs
|
| 665 |
+
lm_metadata: Optional[Dict[str, Any]] # LM-generated metadata
|
| 666 |
+
|
| 667 |
+
# Alignment scores (if available)
|
| 668 |
+
align_score_1: Optional[float]
|
| 669 |
+
align_text_1: Optional[str]
|
| 670 |
+
align_plot_1: Optional[Any]
|
| 671 |
+
align_score_2: Optional[float]
|
| 672 |
+
align_text_2: Optional[str]
|
| 673 |
+
align_plot_2: Optional[Any]
|
| 674 |
+
|
| 675 |
+
# Status
|
| 676 |
+
success: bool # Whether generation succeeded
|
| 677 |
+
error: Optional[str] # Error message if failed
|
| 678 |
+
```
|
| 679 |
+
|
| 680 |
+
---
|
| 681 |
+
|
| 682 |
+
## Version History
|
| 683 |
+
|
| 684 |
+
- **v1.5**: Current version with refactored inference API
|
| 685 |
+
- Introduced `GenerationConfig` and `GenerationResult` dataclasses
|
| 686 |
+
- Simplified parameter passing
|
| 687 |
+
- Added comprehensive documentation
|
| 688 |
+
- Maintained backward compatibility with Gradio UI
|
| 689 |
+
|
| 690 |
+
---
|
| 691 |
+
|
| 692 |
+
For more information, see:
|
| 693 |
+
- Main README: [`README.md`](README.md)
|
| 694 |
+
- REST API Documentation: [`API.md`](API.md)
|
| 695 |
+
- Project repository: [ACE-Step-1.5](https://github.com/yourusername/ACE-Step-1.5)
|
acestep/gradio_ui/event.py
CHANGED
|
@@ -9,6 +9,8 @@ import glob
|
|
| 9 |
import time as time_module
|
| 10 |
import tempfile
|
| 11 |
import gradio as gr
|
|
|
|
|
|
|
| 12 |
from typing import Optional
|
| 13 |
from acestep.constants import (
|
| 14 |
TASK_TYPES_TURBO,
|
|
@@ -655,16 +657,11 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 655 |
user_metadata_to_pass = user_metadata if user_metadata else None
|
| 656 |
|
| 657 |
if should_use_lm_batch:
|
| 658 |
-
# BATCH LM GENERATION
|
| 659 |
-
import math
|
| 660 |
-
from loguru import logger
|
| 661 |
-
|
| 662 |
logger.info(f"Using LM batch generation for {batch_size_input} items...")
|
| 663 |
|
| 664 |
# Prepare seeds for batch items
|
| 665 |
-
|
| 666 |
-
temp_handler = AceStepHandler()
|
| 667 |
-
actual_seed_list, _ = temp_handler.prepare_seeds(batch_size_input, seed, random_seed_checkbox)
|
| 668 |
|
| 669 |
# Split batch into chunks (GPU memory constraint)
|
| 670 |
max_inference_batch_size = int(lm_batch_chunk_size)
|
|
|
|
| 9 |
import time as time_module
|
| 10 |
import tempfile
|
| 11 |
import gradio as gr
|
| 12 |
+
import math
|
| 13 |
+
from loguru import logger
|
| 14 |
from typing import Optional
|
| 15 |
from acestep.constants import (
|
| 16 |
TASK_TYPES_TURBO,
|
|
|
|
| 657 |
user_metadata_to_pass = user_metadata if user_metadata else None
|
| 658 |
|
| 659 |
if should_use_lm_batch:
|
| 660 |
+
# BATCH LM GENERATION
|
|
|
|
|
|
|
|
|
|
| 661 |
logger.info(f"Using LM batch generation for {batch_size_input} items...")
|
| 662 |
|
| 663 |
# Prepare seeds for batch items
|
| 664 |
+
actual_seed_list, _ = dit_handler.prepare_seeds(batch_size_input, seed, random_seed_checkbox)
|
|
|
|
|
|
|
| 665 |
|
| 666 |
# Split batch into chunks (GPU memory constraint)
|
| 667 |
max_inference_batch_size = int(lm_batch_chunk_size)
|
acestep/gradio_ui/events/results_handlers.py
CHANGED
|
@@ -5,6 +5,7 @@ Contains event handlers and helper functions related to result display, scoring,
|
|
| 5 |
import os
|
| 6 |
import json
|
| 7 |
import datetime
|
|
|
|
| 8 |
import tempfile
|
| 9 |
import shutil
|
| 10 |
import zipfile
|
|
@@ -310,14 +311,10 @@ def generate_with_progress(
|
|
| 310 |
|
| 311 |
if should_use_lm_batch:
|
| 312 |
# BATCH LM GENERATION
|
| 313 |
-
import math
|
| 314 |
-
from acestep.handler import AceStepHandler
|
| 315 |
-
|
| 316 |
logger.info(f"Using LM batch generation for {batch_size_input} items...")
|
| 317 |
|
| 318 |
# Prepare seeds for batch items
|
| 319 |
-
|
| 320 |
-
actual_seed_list, _ = temp_handler.prepare_seeds(batch_size_input, seed, random_seed_checkbox)
|
| 321 |
|
| 322 |
# Split batch into chunks (GPU memory constraint)
|
| 323 |
max_inference_batch_size = int(lm_batch_chunk_size)
|
|
|
|
| 5 |
import os
|
| 6 |
import json
|
| 7 |
import datetime
|
| 8 |
+
import math
|
| 9 |
import tempfile
|
| 10 |
import shutil
|
| 11 |
import zipfile
|
|
|
|
| 311 |
|
| 312 |
if should_use_lm_batch:
|
| 313 |
# BATCH LM GENERATION
|
|
|
|
|
|
|
|
|
|
| 314 |
logger.info(f"Using LM batch generation for {batch_size_input} items...")
|
| 315 |
|
| 316 |
# Prepare seeds for batch items
|
| 317 |
+
actual_seed_list, _ = dit_handler.prepare_seeds(batch_size_input, seed, random_seed_checkbox)
|
|
|
|
| 318 |
|
| 319 |
# Split batch into chunks (GPU memory constraint)
|
| 320 |
max_inference_batch_size = int(lm_batch_chunk_size)
|
acestep/handler.py
CHANGED
|
@@ -37,12 +37,15 @@ warnings.filterwarnings("ignore")
|
|
| 37 |
class AceStepHandler:
|
| 38 |
"""ACE-Step Business Logic Handler"""
|
| 39 |
|
| 40 |
-
def __init__(self):
|
| 41 |
self.model = None
|
| 42 |
self.config = None
|
| 43 |
self.device = "cpu"
|
| 44 |
self.dtype = torch.float32 # Will be set based on device in initialize_service
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
# VAE for audio encoding/decoding
|
| 48 |
self.vae = None
|
|
|
|
| 37 |
class AceStepHandler:
|
| 38 |
"""ACE-Step Business Logic Handler"""
|
| 39 |
|
| 40 |
+
def __init__(self, save_root = None):
|
| 41 |
self.model = None
|
| 42 |
self.config = None
|
| 43 |
self.device = "cpu"
|
| 44 |
self.dtype = torch.float32 # Will be set based on device in initialize_service
|
| 45 |
+
if save_root is None:
|
| 46 |
+
self.temp_dir = tempfile.mkdtemp()
|
| 47 |
+
else:
|
| 48 |
+
self.temp_dir = save_root
|
| 49 |
|
| 50 |
# VAE for audio encoding/decoding
|
| 51 |
self.vae = None
|
acestep/inference.py
ADDED
|
@@ -0,0 +1,928 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step Inference API Module
|
| 3 |
+
|
| 4 |
+
This module provides a standardized inference interface for music generation,
|
| 5 |
+
designed for third-party integration. It offers both a simplified API and
|
| 6 |
+
backward-compatible Gradio UI support.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
from typing import Optional, Union, List, Dict, Any, Tuple
|
| 11 |
+
from dataclasses import dataclass, field, asdict
|
| 12 |
+
from loguru import logger
|
| 13 |
+
import time as time_module
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class GenerationConfig:
|
| 18 |
+
"""Configuration for music generation.
|
| 19 |
+
|
| 20 |
+
Attributes:
|
| 21 |
+
# Text Inputs
|
| 22 |
+
caption: Text description of the desired music
|
| 23 |
+
lyrics: Lyrics text for vocal music (use "[Instrumental]" for instrumental)
|
| 24 |
+
|
| 25 |
+
# Music Metadata
|
| 26 |
+
bpm: Beats per minute (e.g., 120). None for auto-detection
|
| 27 |
+
key_scale: Musical key (e.g., "C Major", "Am"). Empty for auto-detection
|
| 28 |
+
time_signature: Time signature (e.g., "4/4", "3/4"). Empty for auto-detection
|
| 29 |
+
vocal_language: Language code for vocals (e.g., "en", "zh", "ja")
|
| 30 |
+
audio_duration: Duration in seconds. None for auto-detection
|
| 31 |
+
|
| 32 |
+
# Generation Parameters
|
| 33 |
+
inference_steps: Number of denoising steps (8 for turbo, 32-100 for base)
|
| 34 |
+
guidance_scale: Classifier-free guidance scale (higher = more adherence to prompt)
|
| 35 |
+
use_random_seed: Whether to use random seed (True) or fixed seed
|
| 36 |
+
seed: Random seed for reproducibility (-1 for random)
|
| 37 |
+
batch_size: Number of samples to generate (1-8)
|
| 38 |
+
|
| 39 |
+
# Advanced DiT Parameters
|
| 40 |
+
use_adg: Use Adaptive Dual Guidance (base model only)
|
| 41 |
+
cfg_interval_start: CFG application start ratio (0.0-1.0)
|
| 42 |
+
cfg_interval_end: CFG application end ratio (0.0-1.0)
|
| 43 |
+
audio_format: Output audio format ("mp3", "wav", "flac")
|
| 44 |
+
|
| 45 |
+
# Task-Specific Parameters
|
| 46 |
+
task_type: Generation task type ("text2music", "cover", "repaint", "lego", "extract", "complete")
|
| 47 |
+
reference_audio: Path to reference audio file (for style transfer)
|
| 48 |
+
src_audio: Path to source audio file (for audio-to-audio tasks)
|
| 49 |
+
audio_code_string: Pre-extracted audio codes (advanced use)
|
| 50 |
+
repainting_start: Repainting start time in seconds (for repaint/lego tasks)
|
| 51 |
+
repainting_end: Repainting end time in seconds (-1 for end of audio)
|
| 52 |
+
audio_cover_strength: Strength of audio cover/codes influence (0.0-1.0)
|
| 53 |
+
instruction: Task-specific instruction prompt (auto-generated if empty)
|
| 54 |
+
|
| 55 |
+
# 5Hz Language Model Parameters (CoT Reasoning)
|
| 56 |
+
use_llm_thinking: Enable LM-based Chain-of-Thought reasoning for metadata/codes
|
| 57 |
+
lm_temperature: LM sampling temperature (0.0-2.0, higher = more creative)
|
| 58 |
+
lm_cfg_scale: LM classifier-free guidance scale
|
| 59 |
+
lm_top_k: LM top-k sampling (0 = disabled)
|
| 60 |
+
lm_top_p: LM nucleus sampling (1.0 = disabled)
|
| 61 |
+
lm_negative_prompt: Negative prompt for LM guidance
|
| 62 |
+
use_cot_metas: Generate metadata using LM CoT
|
| 63 |
+
use_cot_caption: Refine caption using LM CoT
|
| 64 |
+
use_cot_language: Detect language using LM CoT
|
| 65 |
+
is_format_caption: Whether caption is already formatted
|
| 66 |
+
constrained_decoding_debug: Enable debug logging for constrained decoding
|
| 67 |
+
|
| 68 |
+
# Batch LM Generation
|
| 69 |
+
allow_lm_batch: Allow batch LM code generation (faster for batch_size >= 2)
|
| 70 |
+
lm_batch_chunk_size: Maximum batch size per LM inference chunk (GPU memory constraint)
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
# Text Inputs
|
| 74 |
+
caption: str = ""
|
| 75 |
+
lyrics: str = ""
|
| 76 |
+
|
| 77 |
+
# Music Metadata
|
| 78 |
+
bpm: Optional[int] = None
|
| 79 |
+
key_scale: str = ""
|
| 80 |
+
time_signature: str = ""
|
| 81 |
+
vocal_language: str = "unknown"
|
| 82 |
+
audio_duration: Optional[float] = None
|
| 83 |
+
|
| 84 |
+
# Generation Parameters
|
| 85 |
+
inference_steps: int = 8
|
| 86 |
+
guidance_scale: float = 7.0
|
| 87 |
+
use_random_seed: bool = True
|
| 88 |
+
seed: int = -1
|
| 89 |
+
batch_size: int = 1
|
| 90 |
+
|
| 91 |
+
# Advanced DiT Parameters
|
| 92 |
+
use_adg: bool = False
|
| 93 |
+
cfg_interval_start: float = 0.0
|
| 94 |
+
cfg_interval_end: float = 1.0
|
| 95 |
+
audio_format: str = "mp3"
|
| 96 |
+
|
| 97 |
+
# Task-Specific Parameters
|
| 98 |
+
task_type: str = "text2music"
|
| 99 |
+
reference_audio: Optional[str] = None
|
| 100 |
+
src_audio: Optional[str] = None
|
| 101 |
+
audio_code_string: Union[str, List[str]] = ""
|
| 102 |
+
repainting_start: float = 0.0
|
| 103 |
+
repainting_end: float = -1
|
| 104 |
+
audio_cover_strength: float = 1.0
|
| 105 |
+
instruction: str = ""
|
| 106 |
+
|
| 107 |
+
# 5Hz Language Model Parameters
|
| 108 |
+
use_llm_thinking: bool = False
|
| 109 |
+
lm_temperature: float = 0.85
|
| 110 |
+
lm_cfg_scale: float = 2.0
|
| 111 |
+
lm_top_k: int = 0
|
| 112 |
+
lm_top_p: float = 0.9
|
| 113 |
+
lm_negative_prompt: str = "NO USER INPUT"
|
| 114 |
+
use_cot_metas: bool = True
|
| 115 |
+
use_cot_caption: bool = True
|
| 116 |
+
use_cot_language: bool = True
|
| 117 |
+
is_format_caption: bool = False
|
| 118 |
+
constrained_decoding_debug: bool = False
|
| 119 |
+
|
| 120 |
+
# Batch LM Generation
|
| 121 |
+
allow_lm_batch: bool = False
|
| 122 |
+
lm_batch_chunk_size: int = 4
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@dataclass
|
| 126 |
+
class GenerationResult:
|
| 127 |
+
"""Result of music generation.
|
| 128 |
+
|
| 129 |
+
Attributes:
|
| 130 |
+
# Audio Outputs
|
| 131 |
+
audio_paths: List of paths to generated audio files
|
| 132 |
+
first_audio: Path to first generated audio (backward compatibility)
|
| 133 |
+
second_audio: Path to second generated audio (backward compatibility)
|
| 134 |
+
|
| 135 |
+
# Generation Information
|
| 136 |
+
generation_info: Markdown-formatted generation information
|
| 137 |
+
status_message: Status message from generation
|
| 138 |
+
seed_value: Actual seed value used for generation
|
| 139 |
+
|
| 140 |
+
# LM-Generated Metadata (if applicable)
|
| 141 |
+
lm_metadata: Metadata generated by language model (dict or None)
|
| 142 |
+
|
| 143 |
+
# Audio-Text Alignment Scores (if available)
|
| 144 |
+
align_score_1: First alignment score
|
| 145 |
+
align_text_1: First alignment text description
|
| 146 |
+
align_plot_1: First alignment plot image
|
| 147 |
+
align_score_2: Second alignment score
|
| 148 |
+
align_text_2: Second alignment text description
|
| 149 |
+
align_plot_2: Second alignment plot image
|
| 150 |
+
|
| 151 |
+
# Success Status
|
| 152 |
+
success: Whether generation completed successfully
|
| 153 |
+
error: Error message if generation failed
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
# Audio Outputs
|
| 157 |
+
audio_paths: List[str] = field(default_factory=list)
|
| 158 |
+
first_audio: Optional[str] = None
|
| 159 |
+
second_audio: Optional[str] = None
|
| 160 |
+
|
| 161 |
+
# Generation Information
|
| 162 |
+
generation_info: str = ""
|
| 163 |
+
status_message: str = ""
|
| 164 |
+
seed_value: str = ""
|
| 165 |
+
|
| 166 |
+
# LM-Generated Metadata
|
| 167 |
+
lm_metadata: Optional[Dict[str, Any]] = None
|
| 168 |
+
|
| 169 |
+
# Audio-Text Alignment Scores
|
| 170 |
+
align_score_1: Optional[float] = None
|
| 171 |
+
align_text_1: Optional[str] = None
|
| 172 |
+
align_plot_1: Optional[Any] = None
|
| 173 |
+
align_score_2: Optional[float] = None
|
| 174 |
+
align_text_2: Optional[str] = None
|
| 175 |
+
align_plot_2: Optional[Any] = None
|
| 176 |
+
|
| 177 |
+
# Success Status
|
| 178 |
+
success: bool = True
|
| 179 |
+
error: Optional[str] = None
|
| 180 |
+
|
| 181 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 182 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 183 |
+
return asdict(self)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def generate_music(
|
| 187 |
+
dit_handler,
|
| 188 |
+
llm_handler,
|
| 189 |
+
config: GenerationConfig,
|
| 190 |
+
) -> GenerationResult:
|
| 191 |
+
"""Generate music using ACE-Step model with optional LM reasoning.
|
| 192 |
+
|
| 193 |
+
This is the main inference API for music generation. It supports various task types
|
| 194 |
+
(text2music, cover, repaint, etc.) and can optionally use a 5Hz Language Model for
|
| 195 |
+
Chain-of-Thought reasoning to generate metadata and audio codes.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
dit_handler: Initialized DiT model handler (AceStepHandler instance)
|
| 199 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 200 |
+
config: Generation configuration (GenerationConfig instance)
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
GenerationResult: Generation result containing audio paths and metadata
|
| 204 |
+
|
| 205 |
+
Example:
|
| 206 |
+
>>> from acestep.handler import AceStepHandler
|
| 207 |
+
>>> from acestep.llm_inference import LLMHandler
|
| 208 |
+
>>> from acestep.inference import GenerationConfig, generate_music
|
| 209 |
+
>>>
|
| 210 |
+
>>> # Initialize handlers
|
| 211 |
+
>>> dit_handler = AceStepHandler()
|
| 212 |
+
>>> llm_handler = LLMHandler()
|
| 213 |
+
>>> dit_handler.initialize_service(...)
|
| 214 |
+
>>> llm_handler.initialize(...)
|
| 215 |
+
>>>
|
| 216 |
+
>>> # Configure generation
|
| 217 |
+
>>> config = GenerationConfig(
|
| 218 |
+
... caption="upbeat electronic dance music",
|
| 219 |
+
... bpm=128,
|
| 220 |
+
... audio_duration=30,
|
| 221 |
+
... batch_size=2,
|
| 222 |
+
... )
|
| 223 |
+
>>>
|
| 224 |
+
>>> # Generate music
|
| 225 |
+
>>> result = generate_music(dit_handler, llm_handler, config)
|
| 226 |
+
>>> print(f"Generated {len(result.audio_paths)} audio files")
|
| 227 |
+
>>> for path in result.audio_paths:
|
| 228 |
+
... print(f"Audio: {path}")
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
try:
|
| 232 |
+
# Phase 1: LM-based metadata and code generation (if enabled)
|
| 233 |
+
audio_code_string_to_use = config.audio_code_string
|
| 234 |
+
lm_generated_metadata = None
|
| 235 |
+
lm_generated_audio_codes = None
|
| 236 |
+
lm_generated_audio_codes_list = []
|
| 237 |
+
|
| 238 |
+
# Extract mutable copies of metadata (will be updated by LM if needed)
|
| 239 |
+
bpm = config.bpm
|
| 240 |
+
key_scale = config.key_scale
|
| 241 |
+
time_signature = config.time_signature
|
| 242 |
+
audio_duration = config.audio_duration
|
| 243 |
+
|
| 244 |
+
# Determine if we should use batch LM generation
|
| 245 |
+
should_use_lm_batch = (
|
| 246 |
+
config.use_llm_thinking
|
| 247 |
+
and llm_handler.llm_initialized
|
| 248 |
+
and config.use_cot_metas
|
| 249 |
+
and config.allow_lm_batch
|
| 250 |
+
and config.batch_size >= 2
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# LM-based Chain-of-Thought reasoning
|
| 254 |
+
if config.use_llm_thinking and llm_handler.llm_initialized and config.use_cot_metas:
|
| 255 |
+
# Convert sampling parameters
|
| 256 |
+
top_k_value = None if config.lm_top_k == 0 else int(config.lm_top_k)
|
| 257 |
+
top_p_value = None if config.lm_top_p >= 1.0 else config.lm_top_p
|
| 258 |
+
|
| 259 |
+
# Build user_metadata from user-provided values
|
| 260 |
+
user_metadata = {}
|
| 261 |
+
if bpm is not None:
|
| 262 |
+
try:
|
| 263 |
+
bpm_value = float(bpm)
|
| 264 |
+
if bpm_value > 0:
|
| 265 |
+
user_metadata['bpm'] = str(int(bpm_value))
|
| 266 |
+
except (ValueError, TypeError):
|
| 267 |
+
pass
|
| 268 |
+
|
| 269 |
+
if key_scale and key_scale.strip():
|
| 270 |
+
key_scale_clean = key_scale.strip()
|
| 271 |
+
if key_scale_clean.lower() not in ["n/a", ""]:
|
| 272 |
+
user_metadata['keyscale'] = key_scale_clean
|
| 273 |
+
|
| 274 |
+
if time_signature and time_signature.strip():
|
| 275 |
+
time_sig_clean = time_signature.strip()
|
| 276 |
+
if time_sig_clean.lower() not in ["n/a", ""]:
|
| 277 |
+
user_metadata['timesignature'] = time_sig_clean
|
| 278 |
+
|
| 279 |
+
if audio_duration is not None:
|
| 280 |
+
try:
|
| 281 |
+
duration_value = float(audio_duration)
|
| 282 |
+
if duration_value > 0:
|
| 283 |
+
user_metadata['duration'] = str(int(duration_value))
|
| 284 |
+
except (ValueError, TypeError):
|
| 285 |
+
pass
|
| 286 |
+
|
| 287 |
+
user_metadata_to_pass = user_metadata if user_metadata else None
|
| 288 |
+
|
| 289 |
+
# Batch LM generation (faster for multiple samples)
|
| 290 |
+
if should_use_lm_batch:
|
| 291 |
+
actual_seed_list, _ = dit_handler.prepare_seeds(
|
| 292 |
+
config.batch_size, config.seed, config.use_random_seed
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
max_inference_batch_size = int(config.lm_batch_chunk_size)
|
| 296 |
+
num_chunks = math.ceil(config.batch_size / max_inference_batch_size)
|
| 297 |
+
|
| 298 |
+
all_metadata_list = []
|
| 299 |
+
all_audio_codes_list = []
|
| 300 |
+
|
| 301 |
+
for chunk_idx in range(num_chunks):
|
| 302 |
+
chunk_start = chunk_idx * max_inference_batch_size
|
| 303 |
+
chunk_end = min(chunk_start + max_inference_batch_size, config.batch_size)
|
| 304 |
+
chunk_size = chunk_end - chunk_start
|
| 305 |
+
chunk_seeds = actual_seed_list[chunk_start:chunk_end]
|
| 306 |
+
|
| 307 |
+
logger.info(
|
| 308 |
+
f"LM batch chunk {chunk_idx+1}/{num_chunks} "
|
| 309 |
+
f"(size: {chunk_size}, seeds: {chunk_seeds})"
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
metadata_list, audio_codes_list, status = llm_handler.generate_with_stop_condition_batch(
|
| 313 |
+
caption=config.caption or "",
|
| 314 |
+
lyrics=config.lyrics or "",
|
| 315 |
+
batch_size=chunk_size,
|
| 316 |
+
infer_type="llm_dit",
|
| 317 |
+
temperature=config.lm_temperature,
|
| 318 |
+
cfg_scale=config.lm_cfg_scale,
|
| 319 |
+
negative_prompt=config.lm_negative_prompt,
|
| 320 |
+
top_k=top_k_value,
|
| 321 |
+
top_p=top_p_value,
|
| 322 |
+
user_metadata=user_metadata_to_pass,
|
| 323 |
+
use_cot_caption=config.use_cot_caption,
|
| 324 |
+
use_cot_language=config.use_cot_language,
|
| 325 |
+
is_format_caption=config.is_format_caption,
|
| 326 |
+
constrained_decoding_debug=config.constrained_decoding_debug,
|
| 327 |
+
seeds=chunk_seeds,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
all_metadata_list.extend(metadata_list)
|
| 331 |
+
all_audio_codes_list.extend(audio_codes_list)
|
| 332 |
+
|
| 333 |
+
lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
|
| 334 |
+
lm_generated_audio_codes_list = all_audio_codes_list
|
| 335 |
+
audio_code_string_to_use = all_audio_codes_list
|
| 336 |
+
|
| 337 |
+
# Update metadata from LM if not provided by user
|
| 338 |
+
if lm_generated_metadata:
|
| 339 |
+
bpm, key_scale, time_signature, audio_duration = _update_metadata_from_lm(
|
| 340 |
+
lm_generated_metadata, bpm, key_scale, time_signature, audio_duration
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
else:
|
| 344 |
+
# Sequential LM generation (current behavior)
|
| 345 |
+
# Phase 1: Generate CoT metadata
|
| 346 |
+
phase1_start = time_module.time()
|
| 347 |
+
metadata, _, status = llm_handler.generate_with_stop_condition(
|
| 348 |
+
caption=config.caption or "",
|
| 349 |
+
lyrics=config.lyrics or "",
|
| 350 |
+
infer_type="dit",
|
| 351 |
+
temperature=config.lm_temperature,
|
| 352 |
+
cfg_scale=config.lm_cfg_scale,
|
| 353 |
+
negative_prompt=config.lm_negative_prompt,
|
| 354 |
+
top_k=top_k_value,
|
| 355 |
+
top_p=top_p_value,
|
| 356 |
+
user_metadata=user_metadata_to_pass,
|
| 357 |
+
use_cot_caption=config.use_cot_caption,
|
| 358 |
+
use_cot_language=config.use_cot_language,
|
| 359 |
+
is_format_caption=config.is_format_caption,
|
| 360 |
+
constrained_decoding_debug=config.constrained_decoding_debug,
|
| 361 |
+
)
|
| 362 |
+
lm_phase1_time = time_module.time() - phase1_start
|
| 363 |
+
logger.info(f"LM Phase 1 (CoT) completed in {lm_phase1_time:.2f}s")
|
| 364 |
+
|
| 365 |
+
# Phase 2: Generate audio codes
|
| 366 |
+
phase2_start = time_module.time()
|
| 367 |
+
metadata, audio_codes, status = llm_handler.generate_with_stop_condition(
|
| 368 |
+
caption=config.caption or "",
|
| 369 |
+
lyrics=config.lyrics or "",
|
| 370 |
+
infer_type="llm_dit",
|
| 371 |
+
temperature=config.lm_temperature,
|
| 372 |
+
cfg_scale=config.lm_cfg_scale,
|
| 373 |
+
negative_prompt=config.lm_negative_prompt,
|
| 374 |
+
top_k=top_k_value,
|
| 375 |
+
top_p=top_p_value,
|
| 376 |
+
user_metadata=user_metadata_to_pass,
|
| 377 |
+
use_cot_caption=config.use_cot_caption,
|
| 378 |
+
use_cot_language=config.use_cot_language,
|
| 379 |
+
is_format_caption=config.is_format_caption,
|
| 380 |
+
constrained_decoding_debug=config.constrained_decoding_debug,
|
| 381 |
+
)
|
| 382 |
+
lm_phase2_time = time_module.time() - phase2_start
|
| 383 |
+
logger.info(f"LM Phase 2 (Codes) completed in {lm_phase2_time:.2f}s")
|
| 384 |
+
|
| 385 |
+
lm_generated_metadata = metadata
|
| 386 |
+
if audio_codes:
|
| 387 |
+
audio_code_string_to_use = audio_codes
|
| 388 |
+
lm_generated_audio_codes = audio_codes
|
| 389 |
+
|
| 390 |
+
# Update metadata from LM if not provided by user
|
| 391 |
+
bpm, key_scale, time_signature, audio_duration = _update_metadata_from_lm(
|
| 392 |
+
metadata, bpm, key_scale, time_signature, audio_duration
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# Phase 2: DiT music generation
|
| 396 |
+
result = dit_handler.generate_music(
|
| 397 |
+
captions=config.caption,
|
| 398 |
+
lyrics=config.lyrics,
|
| 399 |
+
bpm=bpm,
|
| 400 |
+
key_scale=key_scale,
|
| 401 |
+
time_signature=time_signature,
|
| 402 |
+
vocal_language=config.vocal_language,
|
| 403 |
+
inference_steps=config.inference_steps,
|
| 404 |
+
guidance_scale=config.guidance_scale,
|
| 405 |
+
use_random_seed=config.use_random_seed,
|
| 406 |
+
seed=config.seed,
|
| 407 |
+
reference_audio=config.reference_audio,
|
| 408 |
+
audio_duration=audio_duration,
|
| 409 |
+
batch_size=config.batch_size,
|
| 410 |
+
src_audio=config.src_audio,
|
| 411 |
+
audio_code_string=audio_code_string_to_use,
|
| 412 |
+
repainting_start=config.repainting_start,
|
| 413 |
+
repainting_end=config.repainting_end,
|
| 414 |
+
instruction=config.instruction,
|
| 415 |
+
audio_cover_strength=config.audio_cover_strength,
|
| 416 |
+
task_type=config.task_type,
|
| 417 |
+
use_adg=config.use_adg,
|
| 418 |
+
cfg_interval_start=config.cfg_interval_start,
|
| 419 |
+
cfg_interval_end=config.cfg_interval_end,
|
| 420 |
+
audio_format=config.audio_format,
|
| 421 |
+
lm_temperature=config.lm_temperature,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# Extract results
|
| 425 |
+
(first_audio, second_audio, all_audio_paths, generation_info, status_message,
|
| 426 |
+
seed_value, align_score_1, align_text_1, align_plot_1,
|
| 427 |
+
align_score_2, align_text_2, align_plot_2) = result
|
| 428 |
+
|
| 429 |
+
# Append LM metadata to generation info
|
| 430 |
+
if lm_generated_metadata:
|
| 431 |
+
generation_info = _append_lm_metadata_to_info(generation_info, lm_generated_metadata)
|
| 432 |
+
|
| 433 |
+
# Create result object
|
| 434 |
+
return GenerationResult(
|
| 435 |
+
audio_paths=all_audio_paths or [],
|
| 436 |
+
first_audio=first_audio,
|
| 437 |
+
second_audio=second_audio,
|
| 438 |
+
generation_info=generation_info,
|
| 439 |
+
status_message=status_message,
|
| 440 |
+
seed_value=seed_value,
|
| 441 |
+
lm_metadata=lm_generated_metadata,
|
| 442 |
+
align_score_1=align_score_1,
|
| 443 |
+
align_text_1=align_text_1,
|
| 444 |
+
align_plot_1=align_plot_1,
|
| 445 |
+
align_score_2=align_score_2,
|
| 446 |
+
align_text_2=align_text_2,
|
| 447 |
+
align_plot_2=align_plot_2,
|
| 448 |
+
success=True,
|
| 449 |
+
error=None,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
except Exception as e:
|
| 453 |
+
logger.exception("Music generation failed")
|
| 454 |
+
return GenerationResult(
|
| 455 |
+
success=False,
|
| 456 |
+
error=str(e),
|
| 457 |
+
generation_info=f"β Generation failed: {str(e)}",
|
| 458 |
+
status_message=f"Error: {str(e)}",
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def _update_metadata_from_lm(
|
| 463 |
+
metadata: Dict[str, Any],
|
| 464 |
+
bpm: Optional[int],
|
| 465 |
+
key_scale: str,
|
| 466 |
+
time_signature: str,
|
| 467 |
+
audio_duration: Optional[float],
|
| 468 |
+
) -> Tuple[Optional[int], str, str, Optional[float]]:
|
| 469 |
+
"""Update metadata fields from LM output if not provided by user."""
|
| 470 |
+
|
| 471 |
+
if bpm is None and metadata.get('bpm'):
|
| 472 |
+
bpm_value = metadata.get('bpm')
|
| 473 |
+
if bpm_value not in ["N/A", ""]:
|
| 474 |
+
try:
|
| 475 |
+
bpm = int(bpm_value)
|
| 476 |
+
except (ValueError, TypeError):
|
| 477 |
+
pass
|
| 478 |
+
|
| 479 |
+
if not key_scale and metadata.get('keyscale'):
|
| 480 |
+
key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
|
| 481 |
+
if key_scale_value != "N/A":
|
| 482 |
+
key_scale = key_scale_value
|
| 483 |
+
|
| 484 |
+
if not time_signature and metadata.get('timesignature'):
|
| 485 |
+
time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
|
| 486 |
+
if time_signature_value != "N/A":
|
| 487 |
+
time_signature = time_signature_value
|
| 488 |
+
|
| 489 |
+
if audio_duration is None or audio_duration <= 0:
|
| 490 |
+
audio_duration_value = metadata.get('duration', -1)
|
| 491 |
+
if audio_duration_value not in ["N/A", ""]:
|
| 492 |
+
try:
|
| 493 |
+
audio_duration = float(audio_duration_value)
|
| 494 |
+
except (ValueError, TypeError):
|
| 495 |
+
pass
|
| 496 |
+
|
| 497 |
+
return bpm, key_scale, time_signature, audio_duration
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def _append_lm_metadata_to_info(generation_info: str, metadata: Dict[str, Any]) -> str:
|
| 501 |
+
"""Append LM-generated metadata to generation info string."""
|
| 502 |
+
|
| 503 |
+
metadata_lines = []
|
| 504 |
+
if metadata.get('bpm'):
|
| 505 |
+
metadata_lines.append(f"- **BPM:** {metadata['bpm']}")
|
| 506 |
+
if metadata.get('caption'):
|
| 507 |
+
metadata_lines.append(f"- **Refined Caption:** {metadata['caption']}")
|
| 508 |
+
if metadata.get('duration'):
|
| 509 |
+
metadata_lines.append(f"- **Duration:** {metadata['duration']} seconds")
|
| 510 |
+
if metadata.get('keyscale'):
|
| 511 |
+
metadata_lines.append(f"- **Key Scale:** {metadata['keyscale']}")
|
| 512 |
+
if metadata.get('language'):
|
| 513 |
+
metadata_lines.append(f"- **Language:** {metadata['language']}")
|
| 514 |
+
if metadata.get('timesignature'):
|
| 515 |
+
metadata_lines.append(f"- **Time Signature:** {metadata['timesignature']}")
|
| 516 |
+
|
| 517 |
+
if metadata_lines:
|
| 518 |
+
metadata_section = "\n\n**π€ LM-Generated Metadata:**\n" + "\n\n".join(metadata_lines)
|
| 519 |
+
return metadata_section + "\n\n" + generation_info
|
| 520 |
+
|
| 521 |
+
return generation_info
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
# ============================================================================
|
| 525 |
+
# LEGACY GRADIO UI COMPATIBILITY LAYER
|
| 526 |
+
# ============================================================================
|
| 527 |
+
|
| 528 |
+
def generate(
|
| 529 |
+
dit_handler,
|
| 530 |
+
llm_handler,
|
| 531 |
+
captions,
|
| 532 |
+
lyrics,
|
| 533 |
+
bpm,
|
| 534 |
+
key_scale,
|
| 535 |
+
time_signature,
|
| 536 |
+
vocal_language,
|
| 537 |
+
inference_steps,
|
| 538 |
+
guidance_scale,
|
| 539 |
+
random_seed_checkbox,
|
| 540 |
+
seed,
|
| 541 |
+
reference_audio,
|
| 542 |
+
audio_duration,
|
| 543 |
+
batch_size_input,
|
| 544 |
+
src_audio,
|
| 545 |
+
text2music_audio_code_string,
|
| 546 |
+
repainting_start,
|
| 547 |
+
repainting_end,
|
| 548 |
+
instruction_display_gen,
|
| 549 |
+
audio_cover_strength,
|
| 550 |
+
task_type,
|
| 551 |
+
use_adg,
|
| 552 |
+
cfg_interval_start,
|
| 553 |
+
cfg_interval_end,
|
| 554 |
+
audio_format,
|
| 555 |
+
lm_temperature,
|
| 556 |
+
think_checkbox,
|
| 557 |
+
lm_cfg_scale,
|
| 558 |
+
lm_top_k,
|
| 559 |
+
lm_top_p,
|
| 560 |
+
lm_negative_prompt,
|
| 561 |
+
use_cot_metas,
|
| 562 |
+
use_cot_caption,
|
| 563 |
+
use_cot_language,
|
| 564 |
+
is_format_caption,
|
| 565 |
+
constrained_decoding_debug,
|
| 566 |
+
allow_lm_batch,
|
| 567 |
+
lm_batch_chunk_size,
|
| 568 |
+
):
|
| 569 |
+
"""Legacy Gradio UI compatibility wrapper.
|
| 570 |
+
|
| 571 |
+
This function maintains backward compatibility with the Gradio UI.
|
| 572 |
+
For new integrations, use generate_music() with GenerationConfig instead.
|
| 573 |
+
|
| 574 |
+
Returns:
|
| 575 |
+
Tuple with 28 elements for Gradio UI component updates
|
| 576 |
+
"""
|
| 577 |
+
|
| 578 |
+
# Convert legacy parameters to new config
|
| 579 |
+
config = GenerationConfig(
|
| 580 |
+
caption=captions,
|
| 581 |
+
lyrics=lyrics,
|
| 582 |
+
bpm=bpm,
|
| 583 |
+
key_scale=key_scale,
|
| 584 |
+
time_signature=time_signature,
|
| 585 |
+
vocal_language=vocal_language,
|
| 586 |
+
audio_duration=audio_duration,
|
| 587 |
+
inference_steps=inference_steps,
|
| 588 |
+
guidance_scale=guidance_scale,
|
| 589 |
+
use_random_seed=random_seed_checkbox,
|
| 590 |
+
seed=seed,
|
| 591 |
+
batch_size=batch_size_input,
|
| 592 |
+
use_adg=use_adg,
|
| 593 |
+
cfg_interval_start=cfg_interval_start,
|
| 594 |
+
cfg_interval_end=cfg_interval_end,
|
| 595 |
+
audio_format=audio_format,
|
| 596 |
+
task_type=task_type,
|
| 597 |
+
reference_audio=reference_audio,
|
| 598 |
+
src_audio=src_audio,
|
| 599 |
+
audio_code_string=text2music_audio_code_string,
|
| 600 |
+
repainting_start=repainting_start,
|
| 601 |
+
repainting_end=repainting_end,
|
| 602 |
+
audio_cover_strength=audio_cover_strength,
|
| 603 |
+
instruction=instruction_display_gen,
|
| 604 |
+
use_llm_thinking=think_checkbox,
|
| 605 |
+
lm_temperature=lm_temperature,
|
| 606 |
+
lm_cfg_scale=lm_cfg_scale,
|
| 607 |
+
lm_top_k=lm_top_k,
|
| 608 |
+
lm_top_p=lm_top_p,
|
| 609 |
+
lm_negative_prompt=lm_negative_prompt,
|
| 610 |
+
use_cot_metas=use_cot_metas,
|
| 611 |
+
use_cot_caption=use_cot_caption,
|
| 612 |
+
use_cot_language=use_cot_language,
|
| 613 |
+
is_format_caption=is_format_caption,
|
| 614 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 615 |
+
allow_lm_batch=allow_lm_batch,
|
| 616 |
+
lm_batch_chunk_size=lm_batch_chunk_size,
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
# Call new API
|
| 620 |
+
result = generate_music(dit_handler, llm_handler, config)
|
| 621 |
+
|
| 622 |
+
# Determine which codes to update in UI
|
| 623 |
+
if config.allow_lm_batch and result.lm_metadata:
|
| 624 |
+
# Batch mode: extract codes from metadata if available
|
| 625 |
+
lm_codes_list = result.lm_metadata.get('audio_codes_list', [])
|
| 626 |
+
updated_audio_codes = lm_codes_list[0] if lm_codes_list else text2music_audio_code_string
|
| 627 |
+
codes_outputs = (lm_codes_list + [""] * 8)[:8]
|
| 628 |
+
else:
|
| 629 |
+
# Single mode
|
| 630 |
+
lm_codes = result.lm_metadata.get('audio_codes', '') if result.lm_metadata else ''
|
| 631 |
+
updated_audio_codes = lm_codes if lm_codes else text2music_audio_code_string
|
| 632 |
+
codes_outputs = [""] * 8
|
| 633 |
+
|
| 634 |
+
# Prepare audio outputs (up to 8)
|
| 635 |
+
audio_outputs = (result.audio_paths + [None] * 8)[:8]
|
| 636 |
+
|
| 637 |
+
# Return tuple for Gradio UI (28 elements)
|
| 638 |
+
return (
|
| 639 |
+
audio_outputs[0], # generated_audio_1
|
| 640 |
+
audio_outputs[1], # generated_audio_2
|
| 641 |
+
audio_outputs[2], # generated_audio_3
|
| 642 |
+
audio_outputs[3], # generated_audio_4
|
| 643 |
+
audio_outputs[4], # generated_audio_5
|
| 644 |
+
audio_outputs[5], # generated_audio_6
|
| 645 |
+
audio_outputs[6], # generated_audio_7
|
| 646 |
+
audio_outputs[7], # generated_audio_8
|
| 647 |
+
result.audio_paths, # generated_audio_batch
|
| 648 |
+
result.generation_info,
|
| 649 |
+
result.status_message,
|
| 650 |
+
result.seed_value,
|
| 651 |
+
result.align_score_1,
|
| 652 |
+
result.align_text_1,
|
| 653 |
+
result.align_plot_1,
|
| 654 |
+
result.align_score_2,
|
| 655 |
+
result.align_text_2,
|
| 656 |
+
result.align_plot_2,
|
| 657 |
+
updated_audio_codes, # Update main audio codes in UI
|
| 658 |
+
codes_outputs[0], # text2music_audio_code_string_1
|
| 659 |
+
codes_outputs[1], # text2music_audio_code_string_2
|
| 660 |
+
codes_outputs[2], # text2music_audio_code_string_3
|
| 661 |
+
codes_outputs[3], # text2music_audio_code_string_4
|
| 662 |
+
codes_outputs[4], # text2music_audio_code_string_5
|
| 663 |
+
codes_outputs[5], # text2music_audio_code_string_6
|
| 664 |
+
codes_outputs[6], # text2music_audio_code_string_7
|
| 665 |
+
codes_outputs[7], # text2music_audio_code_string_8
|
| 666 |
+
result.lm_metadata, # Store metadata for "Send to src audio" buttons
|
| 667 |
+
is_format_caption, # Keep is_format_caption unchanged
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
# ============================================================================
|
| 672 |
+
# TESTING & EXAMPLES
|
| 673 |
+
# ============================================================================
|
| 674 |
+
|
| 675 |
+
if __name__ == "__main__":
|
| 676 |
+
"""
|
| 677 |
+
Test suite for the inference API.
|
| 678 |
+
Demonstrates various usage scenarios and validates functionality.
|
| 679 |
+
|
| 680 |
+
Usage:
|
| 681 |
+
python -m acestep.inference
|
| 682 |
+
"""
|
| 683 |
+
|
| 684 |
+
import os
|
| 685 |
+
import json
|
| 686 |
+
from acestep.handler import AceStepHandler
|
| 687 |
+
from acestep.llm_inference import LLMHandler
|
| 688 |
+
|
| 689 |
+
# Initialize paths
|
| 690 |
+
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 691 |
+
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 692 |
+
|
| 693 |
+
print("=" * 80)
|
| 694 |
+
print("ACE-Step Inference API Test Suite")
|
| 695 |
+
print("=" * 80)
|
| 696 |
+
|
| 697 |
+
# ========================================================================
|
| 698 |
+
# Initialize Handlers
|
| 699 |
+
# ========================================================================
|
| 700 |
+
print("\n[1/3] Initializing handlers...")
|
| 701 |
+
dit_handler = AceStepHandler(save_root="./")
|
| 702 |
+
llm_handler = LLMHandler()
|
| 703 |
+
|
| 704 |
+
try:
|
| 705 |
+
# Initialize DiT handler
|
| 706 |
+
print(" - Initializing DiT model...")
|
| 707 |
+
status_dit, success_dit = dit_handler.initialize_service(
|
| 708 |
+
project_root=project_root,
|
| 709 |
+
config_path="acestep-v15-turbo-rl",
|
| 710 |
+
device="cuda",
|
| 711 |
+
)
|
| 712 |
+
if not success_dit:
|
| 713 |
+
print(f" β DiT initialization failed: {status_dit}")
|
| 714 |
+
exit(1)
|
| 715 |
+
print(f" β DiT model initialized successfully")
|
| 716 |
+
|
| 717 |
+
# Initialize LLM handler
|
| 718 |
+
print(" - Initializing 5Hz LM model...")
|
| 719 |
+
status_llm, success_llm = llm_handler.initialize(
|
| 720 |
+
checkpoint_dir=checkpoint_dir,
|
| 721 |
+
lm_model_path="acestep-5Hz-lm-0.6B-v3",
|
| 722 |
+
backend="vllm",
|
| 723 |
+
device="cuda",
|
| 724 |
+
)
|
| 725 |
+
if success_llm:
|
| 726 |
+
print(f" β LM model initialized successfully")
|
| 727 |
+
else:
|
| 728 |
+
print(f" β LM initialization failed (will skip LM tests): {status_llm}")
|
| 729 |
+
|
| 730 |
+
except Exception as e:
|
| 731 |
+
print(f" β Initialization error: {e}")
|
| 732 |
+
exit(1)
|
| 733 |
+
|
| 734 |
+
# ========================================================================
|
| 735 |
+
# Helper Functions
|
| 736 |
+
# ========================================================================
|
| 737 |
+
def load_example_config(example_file: str) -> GenerationConfig:
|
| 738 |
+
"""Load configuration from an example JSON file."""
|
| 739 |
+
try:
|
| 740 |
+
with open(example_file, 'r', encoding='utf-8') as f:
|
| 741 |
+
data = json.load(f)
|
| 742 |
+
|
| 743 |
+
# Convert example format to GenerationConfig
|
| 744 |
+
# Handle time signature format (example uses "4" instead of "4/4")
|
| 745 |
+
time_sig = data.get('timesignature', '')
|
| 746 |
+
if time_sig and '/' not in time_sig:
|
| 747 |
+
time_sig = f"{time_sig}/4" # Default to /4 if only numerator given
|
| 748 |
+
|
| 749 |
+
config = GenerationConfig(
|
| 750 |
+
caption=data.get('caption', ''),
|
| 751 |
+
lyrics=data.get('lyrics', ''),
|
| 752 |
+
bpm=data.get('bpm'),
|
| 753 |
+
key_scale=data.get('keyscale', ''),
|
| 754 |
+
time_signature=time_sig,
|
| 755 |
+
vocal_language=data.get('language', 'unknown'),
|
| 756 |
+
audio_duration=data.get('duration'),
|
| 757 |
+
use_llm_thinking=data.get('think', False),
|
| 758 |
+
batch_size=data.get('batch_size', 1),
|
| 759 |
+
inference_steps=data.get('inference_steps', 8),
|
| 760 |
+
)
|
| 761 |
+
return config
|
| 762 |
+
|
| 763 |
+
except Exception as e:
|
| 764 |
+
print(f" β Failed to load example file: {e}")
|
| 765 |
+
return None
|
| 766 |
+
|
| 767 |
+
# ========================================================================
|
| 768 |
+
# Test Cases
|
| 769 |
+
# ========================================================================
|
| 770 |
+
test_results = []
|
| 771 |
+
|
| 772 |
+
def run_test(test_name: str, config: GenerationConfig, expected_outputs: int = 1):
|
| 773 |
+
"""Run a single test case and collect results."""
|
| 774 |
+
print(f"\n{'=' * 80}")
|
| 775 |
+
print(f"Test: {test_name}")
|
| 776 |
+
print(f"{'=' * 80}")
|
| 777 |
+
|
| 778 |
+
# Display configuration
|
| 779 |
+
print("\nConfiguration:")
|
| 780 |
+
print(f" Task Type: {config.task_type}")
|
| 781 |
+
print(f" Caption: {config.caption[:60]}..." if len(config.caption) > 60 else f" Caption: {config.caption}")
|
| 782 |
+
if config.lyrics:
|
| 783 |
+
print(f" Lyrics: {config.lyrics[:60]}..." if len(config.lyrics) > 60 else f" Lyrics: {config.lyrics}")
|
| 784 |
+
if config.bpm:
|
| 785 |
+
print(f" BPM: {config.bpm}")
|
| 786 |
+
if config.key_scale:
|
| 787 |
+
print(f" Key Scale: {config.key_scale}")
|
| 788 |
+
if config.time_signature:
|
| 789 |
+
print(f" Time Signature: {config.time_signature}")
|
| 790 |
+
if config.audio_duration:
|
| 791 |
+
print(f" Duration: {config.audio_duration}s")
|
| 792 |
+
print(f" Batch Size: {config.batch_size}")
|
| 793 |
+
print(f" Inference Steps: {config.inference_steps}")
|
| 794 |
+
print(f" Use LLM Thinking: {config.use_llm_thinking}")
|
| 795 |
+
|
| 796 |
+
# Run generation
|
| 797 |
+
print("\nGenerating...")
|
| 798 |
+
import time
|
| 799 |
+
start_time = time.time()
|
| 800 |
+
|
| 801 |
+
result = generate_music(dit_handler, llm_handler, config)
|
| 802 |
+
|
| 803 |
+
elapsed_time = time.time() - start_time
|
| 804 |
+
|
| 805 |
+
# Display results
|
| 806 |
+
print("\nResults:")
|
| 807 |
+
print(f" Success: {'β' if result.success else 'β'}")
|
| 808 |
+
|
| 809 |
+
if result.success:
|
| 810 |
+
print(f" Generated Files: {len(result.audio_paths)}")
|
| 811 |
+
for i, path in enumerate(result.audio_paths, 1):
|
| 812 |
+
if os.path.exists(path):
|
| 813 |
+
file_size = os.path.getsize(path) / (1024 * 1024) # MB
|
| 814 |
+
print(f" [{i}] {os.path.basename(path)} ({file_size:.2f} MB)")
|
| 815 |
+
else:
|
| 816 |
+
print(f" [{i}] {os.path.basename(path)} (file not found)")
|
| 817 |
+
|
| 818 |
+
print(f" Seed: {result.seed_value}")
|
| 819 |
+
print(f" Generation Time: {elapsed_time:.2f}s")
|
| 820 |
+
|
| 821 |
+
# Display LM metadata if available
|
| 822 |
+
if result.lm_metadata:
|
| 823 |
+
print(f"\n LM-Generated Metadata:")
|
| 824 |
+
for key, value in result.lm_metadata.items():
|
| 825 |
+
if key not in ['audio_codes', 'audio_codes_list']: # Skip large code strings
|
| 826 |
+
print(f" {key}: {value}")
|
| 827 |
+
|
| 828 |
+
# Validate outputs
|
| 829 |
+
if len(result.audio_paths) != expected_outputs:
|
| 830 |
+
print(f" β Warning: Expected {expected_outputs} outputs, got {len(result.audio_paths)}")
|
| 831 |
+
success = False
|
| 832 |
+
else:
|
| 833 |
+
success = True
|
| 834 |
+
|
| 835 |
+
else:
|
| 836 |
+
print(f" Error: {result.error}")
|
| 837 |
+
success = False
|
| 838 |
+
|
| 839 |
+
# Store test result
|
| 840 |
+
test_results.append({
|
| 841 |
+
"test_name": test_name,
|
| 842 |
+
"success": success,
|
| 843 |
+
"generation_success": result.success,
|
| 844 |
+
"num_outputs": len(result.audio_paths) if result.success else 0,
|
| 845 |
+
"expected_outputs": expected_outputs,
|
| 846 |
+
"elapsed_time": elapsed_time,
|
| 847 |
+
"error": result.error if not result.success else None,
|
| 848 |
+
})
|
| 849 |
+
|
| 850 |
+
return result
|
| 851 |
+
|
| 852 |
+
# ========================================================================
|
| 853 |
+
# Test: Production Example (from examples directory)
|
| 854 |
+
# ========================================================================
|
| 855 |
+
print("\n[2/3] Running Test...")
|
| 856 |
+
|
| 857 |
+
# Load production example (J-Rock song from examples/text2music/example_05.json)
|
| 858 |
+
example_file = os.path.join(project_root, "examples", "text2music", "example_05.json")
|
| 859 |
+
|
| 860 |
+
if not os.path.exists(example_file):
|
| 861 |
+
print(f"\n β Example file not found: {example_file}")
|
| 862 |
+
print(" Please ensure the examples directory exists.")
|
| 863 |
+
exit(1)
|
| 864 |
+
|
| 865 |
+
print(f" Loading example: {os.path.basename(example_file)}")
|
| 866 |
+
config = load_example_config(example_file)
|
| 867 |
+
|
| 868 |
+
if not config:
|
| 869 |
+
print(" β Failed to load example configuration")
|
| 870 |
+
exit(1)
|
| 871 |
+
|
| 872 |
+
# Reduce duration for faster testing (original is 200s)
|
| 873 |
+
print(f" Original duration: {config.audio_duration}s")
|
| 874 |
+
config.audio_duration = 30
|
| 875 |
+
config.use_random_seed = False
|
| 876 |
+
config.seed = 42
|
| 877 |
+
print(f" Test duration: {config.audio_duration}s (reduced for testing)")
|
| 878 |
+
|
| 879 |
+
run_test("Production Example (J-Rock Song)", config, expected_outputs=1)
|
| 880 |
+
|
| 881 |
+
# ========================================================================
|
| 882 |
+
# Test Summary
|
| 883 |
+
# ========================================================================
|
| 884 |
+
print("\n[3/3] Test Summary")
|
| 885 |
+
print("=" * 80)
|
| 886 |
+
|
| 887 |
+
if len(test_results) == 0:
|
| 888 |
+
print("No tests were run.")
|
| 889 |
+
exit(1)
|
| 890 |
+
|
| 891 |
+
result = test_results[0]
|
| 892 |
+
|
| 893 |
+
print(f"\nTest: {result['test_name']}")
|
| 894 |
+
print(f"Status: {'β PASS' if result['success'] else 'β FAIL'}")
|
| 895 |
+
print(f"Generation: {'Success' if result['generation_success'] else 'Failed'}")
|
| 896 |
+
print(f"Outputs: {result['num_outputs']}/{result['expected_outputs']}")
|
| 897 |
+
print(f"Time: {result['elapsed_time']:.2f}s")
|
| 898 |
+
|
| 899 |
+
if result["error"]:
|
| 900 |
+
print(f"Error: {result['error']}")
|
| 901 |
+
|
| 902 |
+
# Save test results to JSON
|
| 903 |
+
results_file = os.path.join(project_root, "test_results.json")
|
| 904 |
+
try:
|
| 905 |
+
with open(results_file, "w") as f:
|
| 906 |
+
json.dump({
|
| 907 |
+
"test_name": result['test_name'],
|
| 908 |
+
"success": result['success'],
|
| 909 |
+
"generation_success": result['generation_success'],
|
| 910 |
+
"num_outputs": result['num_outputs'],
|
| 911 |
+
"expected_outputs": result['expected_outputs'],
|
| 912 |
+
"elapsed_time": result['elapsed_time'],
|
| 913 |
+
"error": result['error'],
|
| 914 |
+
}, f, indent=2)
|
| 915 |
+
print(f"\nβ Test results saved to: {results_file}")
|
| 916 |
+
except Exception as e:
|
| 917 |
+
print(f"\nβ Failed to save test results: {e}")
|
| 918 |
+
|
| 919 |
+
# Exit with appropriate code
|
| 920 |
+
print("\n" + "=" * 80)
|
| 921 |
+
if result['success']:
|
| 922 |
+
print("Test passed! β")
|
| 923 |
+
print("=" * 80)
|
| 924 |
+
exit(0)
|
| 925 |
+
else:
|
| 926 |
+
print("Test failed! β")
|
| 927 |
+
print("=" * 80)
|
| 928 |
+
exit(1)
|