ChuxiJ commited on
Commit
24f370e
Β·
1 Parent(s): 1da0418

add inference code and doc

Browse files
INFERENCE.md ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ACE-Step Inference API Documentation
2
+
3
+ This document provides comprehensive documentation for the ACE-Step inference API, including parameter specifications for all supported task types.
4
+
5
+ ## Table of Contents
6
+
7
+ - [Quick Start](#quick-start)
8
+ - [API Overview](#api-overview)
9
+ - [Configuration Parameters](#configuration-parameters)
10
+ - [Task Types](#task-types)
11
+ - [Complete Examples](#complete-examples)
12
+ - [Best Practices](#best-practices)
13
+
14
+ ---
15
+
16
+ ## Quick Start
17
+
18
+ ### Basic Usage
19
+
20
+ ```python
21
+ from acestep.handler import AceStepHandler
22
+ from acestep.llm_inference import LLMHandler
23
+ from acestep.inference import GenerationConfig, generate_music
24
+
25
+ # Initialize handlers
26
+ dit_handler = AceStepHandler()
27
+ llm_handler = LLMHandler()
28
+
29
+ # Initialize services
30
+ dit_handler.initialize_service(
31
+ project_root="/path/to/project",
32
+ config_path="acestep-v15-turbo-rl",
33
+ device="cuda"
34
+ )
35
+
36
+ llm_handler.initialize(
37
+ checkpoint_dir="/path/to/checkpoints",
38
+ lm_model_path="acestep-5Hz-lm-0.6B-v3",
39
+ backend="vllm",
40
+ device="cuda"
41
+ )
42
+
43
+ # Configure generation
44
+ config = GenerationConfig(
45
+ caption="upbeat electronic dance music with heavy bass",
46
+ bpm=128,
47
+ audio_duration=30,
48
+ batch_size=1,
49
+ )
50
+
51
+ # Generate music
52
+ result = generate_music(dit_handler, llm_handler, config)
53
+
54
+ # Access results
55
+ if result.success:
56
+ for audio_path in result.audio_paths:
57
+ print(f"Generated: {audio_path}")
58
+ else:
59
+ print(f"Error: {result.error}")
60
+ ```
61
+
62
+ ---
63
+
64
+ ## API Overview
65
+
66
+ ### Main Function
67
+
68
+ ```python
69
+ def generate_music(
70
+ dit_handler: AceStepHandler,
71
+ llm_handler: LLMHandler,
72
+ config: GenerationConfig,
73
+ ) -> GenerationResult
74
+ ```
75
+
76
+ ### Configuration Object
77
+
78
+ The `GenerationConfig` dataclass consolidates all generation parameters:
79
+
80
+ ```python
81
+ @dataclass
82
+ class GenerationConfig:
83
+ # Required parameters with sensible defaults
84
+ caption: str = ""
85
+ lyrics: str = ""
86
+ # ... (see full parameter list below)
87
+ ```
88
+
89
+ ### Result Object
90
+
91
+ ```python
92
+ @dataclass
93
+ class GenerationResult:
94
+ audio_paths: List[str] # Paths to generated audio files
95
+ generation_info: str # Markdown-formatted info
96
+ status_message: str # Status message
97
+ seed_value: str # Seed used
98
+ lm_metadata: Optional[Dict] # LM-generated metadata
99
+ success: bool # Success flag
100
+ error: Optional[str] # Error message if failed
101
+ # ... (see full fields below)
102
+ ```
103
+
104
+ ---
105
+
106
+ ## Configuration Parameters
107
+
108
+ ### Text Inputs
109
+
110
+ | Parameter | Type | Default | Description |
111
+ |-----------|------|---------|-------------|
112
+ | `caption` | `str` | `""` | Text description of the desired music. Can be a simple prompt like "relaxing piano music" or detailed description with genre, mood, instruments, etc. |
113
+ | `lyrics` | `str` | `""` | Lyrics text for vocal music. Use `"[Instrumental]"` for instrumental tracks. Supports multiple languages. |
114
+
115
+ ### Music Metadata
116
+
117
+ | Parameter | Type | Default | Description |
118
+ |-----------|------|---------|-------------|
119
+ | `bpm` | `Optional[int]` | `None` | Beats per minute (30-300). `None` enables auto-detection via LM. |
120
+ | `key_scale` | `str` | `""` | Musical key (e.g., "C Major", "Am", "F# minor"). Empty string enables auto-detection. |
121
+ | `time_signature` | `str` | `""` | Time signature (e.g., "4/4", "3/4", "6/8"). Empty string enables auto-detection. |
122
+ | `vocal_language` | `str` | `"unknown"` | Language code for vocals (ISO 639-1). Supported: `"en"`, `"zh"`, `"ja"`, `"es"`, `"fr"`, etc. Use `"unknown"` for auto-detection. |
123
+ | `audio_duration` | `Optional[float]` | `None` | Duration in seconds (10-600). `None` enables auto-detection based on lyrics length. |
124
+
125
+ ### Generation Parameters
126
+
127
+ | Parameter | Type | Default | Description |
128
+ |-----------|------|---------|-------------|
129
+ | `inference_steps` | `int` | `8` | Number of denoising steps. Turbo model: 1-8 (recommended 8). Base model: 1-100 (recommended 32-64). Higher = better quality but slower. |
130
+ | `guidance_scale` | `float` | `7.0` | Classifier-free guidance scale (1.0-15.0). Higher values increase adherence to text prompt. Typical range: 5.0-9.0. |
131
+ | `use_random_seed` | `bool` | `True` | Whether to use random seed. `True` for different results each time, `False` for reproducible results. |
132
+ | `seed` | `int` | `-1` | Random seed for reproducibility. Use `-1` for random seed, or any positive integer for fixed seed. |
133
+ | `batch_size` | `int` | `1` | Number of samples to generate in parallel (1-8). Higher values require more GPU memory. |
134
+
135
+ ### Advanced DiT Parameters
136
+
137
+ | Parameter | Type | Default | Description |
138
+ |-----------|------|---------|-------------|
139
+ | `use_adg` | `bool` | `False` | Use Adaptive Dual Guidance (base model only). Improves quality at the cost of speed. |
140
+ | `cfg_interval_start` | `float` | `0.0` | CFG application start ratio (0.0-1.0). Controls when to start applying classifier-free guidance. |
141
+ | `cfg_interval_end` | `float` | `1.0` | CFG application end ratio (0.0-1.0). Controls when to stop applying classifier-free guidance. |
142
+ | `audio_format` | `str` | `"mp3"` | Output audio format. Options: `"mp3"`, `"wav"`, `"flac"`. |
143
+
144
+ ### Task-Specific Parameters
145
+
146
+ | Parameter | Type | Default | Description |
147
+ |-----------|------|---------|-------------|
148
+ | `task_type` | `str` | `"text2music"` | Generation task type. See [Task Types](#task-types) section for details. |
149
+ | `reference_audio` | `Optional[str]` | `None` | Path to reference audio file for style transfer or continuation tasks. |
150
+ | `src_audio` | `Optional[str]` | `None` | Path to source audio file for audio-to-audio tasks (cover, repaint, etc.). |
151
+ | `audio_code_string` | `Union[str, List[str]]` | `""` | Pre-extracted 5Hz audio codes. Can be single string or list for batch mode. Advanced use only. |
152
+ | `repainting_start` | `float` | `0.0` | Repainting start time in seconds (for repaint/lego tasks). |
153
+ | `repainting_end` | `float` | `-1` | Repainting end time in seconds. Use `-1` for end of audio. |
154
+ | `audio_cover_strength` | `float` | `1.0` | Strength of audio cover/codes influence (0.0-1.0). Higher = stronger influence from source audio. |
155
+ | `instruction` | `str` | `""` | Task-specific instruction prompt. Auto-generated if empty. |
156
+
157
+ ### 5Hz Language Model Parameters
158
+
159
+ | Parameter | Type | Default | Description |
160
+ |-----------|------|---------|-------------|
161
+ | `use_llm_thinking` | `bool` | `False` | Enable LM-based Chain-of-Thought reasoning. When enabled, LM generates metadata and/or audio codes. |
162
+ | `lm_temperature` | `float` | `0.85` | LM sampling temperature (0.0-2.0). Higher = more creative/diverse, lower = more conservative. |
163
+ | `lm_cfg_scale` | `float` | `2.0` | LM classifier-free guidance scale (1.0-5.0). Higher = stronger adherence to prompt. |
164
+ | `lm_top_k` | `int` | `0` | LM top-k sampling. `0` disables top-k filtering. Typical values: 40-100. |
165
+ | `lm_top_p` | `float` | `0.9` | LM nucleus sampling (0.0-1.0). `1.0` disables nucleus sampling. Typical values: 0.9-0.95. |
166
+ | `lm_negative_prompt` | `str` | `"NO USER INPUT"` | Negative prompt for LM guidance. Helps avoid unwanted characteristics. |
167
+ | `use_cot_metas` | `bool` | `True` | Generate metadata using LM CoT reasoning (BPM, key, duration, etc.). |
168
+ | `use_cot_caption` | `bool` | `True` | Refine user caption using LM CoT reasoning. |
169
+ | `use_cot_language` | `bool` | `True` | Detect vocal language using LM CoT reasoning. |
170
+ | `is_format_caption` | `bool` | `False` | Whether caption is already formatted/refined (skip LM refinement). |
171
+ | `constrained_decoding_debug` | `bool` | `False` | Enable debug logging for constrained decoding. |
172
+
173
+ ### Batch LM Generation
174
+
175
+ | Parameter | Type | Default | Description |
176
+ |-----------|------|---------|-------------|
177
+ | `allow_lm_batch` | `bool` | `False` | Allow batch LM code generation. Faster when `batch_size >= 2` and `use_llm_thinking=True`. |
178
+ | `lm_batch_chunk_size` | `int` | `4` | Maximum batch size per LM inference chunk (GPU memory constraint). |
179
+
180
+ ---
181
+
182
+ ## Task Types
183
+
184
+ ACE-Step supports 6 different generation task types, each optimized for specific use cases.
185
+
186
+ ### 1. Text2Music (Default)
187
+
188
+ **Purpose**: Generate music from text descriptions and optional metadata.
189
+
190
+ **Key Parameters**:
191
+ ```python
192
+ config = GenerationConfig(
193
+ task_type="text2music",
194
+ caption="energetic rock music with electric guitar",
195
+ lyrics="[Instrumental]", # or actual lyrics
196
+ bpm=140,
197
+ audio_duration=30,
198
+ )
199
+ ```
200
+
201
+ **Required**:
202
+ - `caption` or `lyrics` (at least one)
203
+
204
+ **Optional but Recommended**:
205
+ - `bpm`: Controls tempo
206
+ - `key_scale`: Controls musical key
207
+ - `time_signature`: Controls rhythm structure
208
+ - `audio_duration`: Controls length
209
+ - `vocal_language`: Controls vocal characteristics
210
+
211
+ **Use Cases**:
212
+ - Generate music from text descriptions
213
+ - Create backing tracks from prompts
214
+ - Generate songs with lyrics
215
+
216
+ ---
217
+
218
+ ### 2. Cover
219
+
220
+ **Purpose**: Transform existing audio while maintaining structure but changing style/timbre.
221
+
222
+ **Key Parameters**:
223
+ ```python
224
+ config = GenerationConfig(
225
+ task_type="cover",
226
+ src_audio="original_song.mp3",
227
+ caption="jazz piano version",
228
+ audio_cover_strength=0.8, # 0.0-1.0
229
+ )
230
+ ```
231
+
232
+ **Required**:
233
+ - `src_audio`: Path to source audio file
234
+ - `caption`: Description of desired style/transformation
235
+
236
+ **Optional**:
237
+ - `audio_cover_strength`: Controls influence of original audio
238
+ - `1.0`: Strong adherence to original structure
239
+ - `0.5`: Balanced transformation
240
+ - `0.1`: Loose interpretation
241
+ - `lyrics`: New lyrics (if changing vocals)
242
+
243
+ **Use Cases**:
244
+ - Create covers in different styles
245
+ - Change instrumentation while keeping melody
246
+ - Genre transformation
247
+
248
+ ---
249
+
250
+ ### 3. Repaint
251
+
252
+ **Purpose**: Regenerate a specific time segment of audio while keeping the rest unchanged.
253
+
254
+ **Key Parameters**:
255
+ ```python
256
+ config = GenerationConfig(
257
+ task_type="repaint",
258
+ src_audio="original.mp3",
259
+ repainting_start=10.0, # seconds
260
+ repainting_end=20.0, # seconds
261
+ caption="smooth transition with piano solo",
262
+ )
263
+ ```
264
+
265
+ **Required**:
266
+ - `src_audio`: Path to source audio file
267
+ - `repainting_start`: Start time in seconds
268
+ - `repainting_end`: End time in seconds (use `-1` for end of file)
269
+ - `caption`: Description of desired content for repainted section
270
+
271
+ **Use Cases**:
272
+ - Fix specific sections of generated music
273
+ - Add variations to parts of a song
274
+ - Create smooth transitions
275
+ - Replace problematic segments
276
+
277
+ ---
278
+
279
+ ### 4. Lego (Base Model Only)
280
+
281
+ **Purpose**: Generate a specific instrument track in context of existing audio.
282
+
283
+ **Key Parameters**:
284
+ ```python
285
+ config = GenerationConfig(
286
+ task_type="lego",
287
+ src_audio="backing_track.mp3",
288
+ instruction="Generate the guitar track based on the audio context:",
289
+ caption="lead guitar melody with bluesy feel",
290
+ repainting_start=0.0,
291
+ repainting_end=-1,
292
+ )
293
+ ```
294
+
295
+ **Required**:
296
+ - `src_audio`: Path to source/backing audio
297
+ - `instruction`: Must specify the track type (e.g., "Generate the {TRACK_NAME} track...")
298
+ - `caption`: Description of desired track characteristics
299
+
300
+ **Available Tracks**:
301
+ - `"vocals"`, `"backing_vocals"`, `"drums"`, `"bass"`, `"guitar"`, `"keyboard"`,
302
+ - `"percussion"`, `"strings"`, `"synth"`, `"fx"`, `"brass"`, `"woodwinds"`
303
+
304
+ **Use Cases**:
305
+ - Add specific instrument tracks
306
+ - Layer additional instruments over backing tracks
307
+ - Create multi-track compositions iteratively
308
+
309
+ ---
310
+
311
+ ### 5. Extract (Base Model Only)
312
+
313
+ **Purpose**: Extract/isolate a specific instrument track from mixed audio.
314
+
315
+ **Key Parameters**:
316
+ ```python
317
+ config = GenerationConfig(
318
+ task_type="extract",
319
+ src_audio="full_mix.mp3",
320
+ instruction="Extract the vocals track from the audio:",
321
+ )
322
+ ```
323
+
324
+ **Required**:
325
+ - `src_audio`: Path to mixed audio file
326
+ - `instruction`: Must specify track to extract
327
+
328
+ **Available Tracks**: Same as Lego task
329
+
330
+ **Use Cases**:
331
+ - Stem separation
332
+ - Isolate specific instruments
333
+ - Create remixes
334
+ - Analyze individual tracks
335
+
336
+ ---
337
+
338
+ ### 6. Complete (Base Model Only)
339
+
340
+ **Purpose**: Complete/extend partial tracks with specified instruments.
341
+
342
+ **Key Parameters**:
343
+ ```python
344
+ config = GenerationConfig(
345
+ task_type="complete",
346
+ src_audio="incomplete_track.mp3",
347
+ instruction="Complete the input track with drums, bass, guitar:",
348
+ caption="rock style completion",
349
+ )
350
+ ```
351
+
352
+ **Required**:
353
+ - `src_audio`: Path to incomplete/partial track
354
+ - `instruction`: Must specify which tracks to add
355
+ - `caption`: Description of desired style
356
+
357
+ **Use Cases**:
358
+ - Arrange incomplete compositions
359
+ - Add backing tracks
360
+ - Auto-complete musical ideas
361
+
362
+ ---
363
+
364
+ ## Complete Examples
365
+
366
+ ### Example 1: Simple Text-to-Music Generation
367
+
368
+ ```python
369
+ from acestep.inference import GenerationConfig, generate_music
370
+
371
+ config = GenerationConfig(
372
+ task_type="text2music",
373
+ caption="calm ambient music with soft piano and strings",
374
+ audio_duration=60,
375
+ bpm=80,
376
+ key_scale="C Major",
377
+ batch_size=2, # Generate 2 variations
378
+ )
379
+
380
+ result = generate_music(dit_handler, llm_handler, config)
381
+
382
+ if result.success:
383
+ for i, path in enumerate(result.audio_paths, 1):
384
+ print(f"Variation {i}: {path}")
385
+ ```
386
+
387
+ ### Example 2: Song Generation with Lyrics
388
+
389
+ ```python
390
+ config = GenerationConfig(
391
+ task_type="text2music",
392
+ caption="pop ballad with emotional vocals",
393
+ lyrics="""Verse 1:
394
+ Walking down the street today
395
+ Thinking of the words you used to say
396
+ Everything feels different now
397
+ But I'll find my way somehow
398
+
399
+ Chorus:
400
+ I'm moving on, I'm staying strong
401
+ This is where I belong
402
+ """,
403
+ vocal_language="en",
404
+ bpm=72,
405
+ audio_duration=45,
406
+ )
407
+
408
+ result = generate_music(dit_handler, llm_handler, config)
409
+ ```
410
+
411
+ ### Example 3: Style Cover with LM Reasoning
412
+
413
+ ```python
414
+ config = GenerationConfig(
415
+ task_type="cover",
416
+ src_audio="original_pop_song.mp3",
417
+ caption="orchestral symphonic arrangement",
418
+ audio_cover_strength=0.7,
419
+ use_llm_thinking=True, # Enable LM for metadata
420
+ use_cot_metas=True,
421
+ )
422
+
423
+ result = generate_music(dit_handler, llm_handler, config)
424
+
425
+ # Access LM-generated metadata
426
+ if result.lm_metadata:
427
+ print(f"LM detected BPM: {result.lm_metadata.get('bpm')}")
428
+ print(f"LM detected Key: {result.lm_metadata.get('keyscale')}")
429
+ ```
430
+
431
+ ### Example 4: Repaint Section of Audio
432
+
433
+ ```python
434
+ config = GenerationConfig(
435
+ task_type="repaint",
436
+ src_audio="generated_track.mp3",
437
+ repainting_start=15.0, # Start at 15 seconds
438
+ repainting_end=25.0, # End at 25 seconds
439
+ caption="dramatic orchestral buildup",
440
+ inference_steps=32, # Higher quality for base model
441
+ )
442
+
443
+ result = generate_music(dit_handler, llm_handler, config)
444
+ ```
445
+
446
+ ### Example 5: Batch Generation with LM
447
+
448
+ ```python
449
+ config = GenerationConfig(
450
+ task_type="text2music",
451
+ caption="epic cinematic trailer music",
452
+ batch_size=4, # Generate 4 variations
453
+ use_llm_thinking=True,
454
+ use_cot_metas=True,
455
+ allow_lm_batch=True, # Faster batch processing
456
+ lm_batch_chunk_size=2, # Process 2 at a time (GPU memory)
457
+ )
458
+
459
+ result = generate_music(dit_handler, llm_handler, config)
460
+
461
+ if result.success:
462
+ print(f"Generated {len(result.audio_paths)} variations")
463
+ ```
464
+
465
+ ### Example 6: High-Quality Generation (Base Model)
466
+
467
+ ```python
468
+ config = GenerationConfig(
469
+ task_type="text2music",
470
+ caption="intricate jazz fusion with complex harmonies",
471
+ inference_steps=64, # High quality
472
+ guidance_scale=8.0,
473
+ use_adg=True, # Adaptive Dual Guidance
474
+ cfg_interval_start=0.0,
475
+ cfg_interval_end=1.0,
476
+ audio_format="wav", # Lossless format
477
+ use_random_seed=False,
478
+ seed=42, # Reproducible results
479
+ )
480
+
481
+ result = generate_music(dit_handler, llm_handler, config)
482
+ ```
483
+
484
+ ### Example 7: Extract Vocals from Mix
485
+
486
+ ```python
487
+ config = GenerationConfig(
488
+ task_type="extract",
489
+ src_audio="full_song_mix.mp3",
490
+ instruction="Extract the vocals track from the audio:",
491
+ )
492
+
493
+ result = generate_music(dit_handler, llm_handler, config)
494
+
495
+ if result.success:
496
+ print(f"Extracted vocals: {result.audio_paths[0]}")
497
+ ```
498
+
499
+ ### Example 8: Add Guitar Track (Lego)
500
+
501
+ ```python
502
+ config = GenerationConfig(
503
+ task_type="lego",
504
+ src_audio="drums_and_bass.mp3",
505
+ instruction="Generate the guitar track based on the audio context:",
506
+ caption="funky rhythm guitar with wah-wah effect",
507
+ repainting_start=0.0,
508
+ repainting_end=-1, # Full duration
509
+ )
510
+
511
+ result = generate_music(dit_handler, llm_handler, config)
512
+ ```
513
+
514
+ ---
515
+
516
+ ## Best Practices
517
+
518
+ ### 1. Caption Writing
519
+
520
+ **Good Captions**:
521
+ ```python
522
+ # Specific and descriptive
523
+ caption="upbeat electronic dance music with heavy bass and synthesizer leads"
524
+
525
+ # Include mood and genre
526
+ caption="melancholic indie folk with acoustic guitar and soft vocals"
527
+
528
+ # Specify instruments
529
+ caption="jazz trio with piano, upright bass, and brush drums"
530
+ ```
531
+
532
+ **Avoid**:
533
+ ```python
534
+ # Too vague
535
+ caption="good music"
536
+
537
+ # Contradictory
538
+ caption="fast slow music" # Conflicting tempos
539
+ ```
540
+
541
+ ### 2. Parameter Tuning
542
+
543
+ **For Best Quality**:
544
+ - Use base model with `inference_steps=64` or higher
545
+ - Enable `use_adg=True`
546
+ - Set `guidance_scale=7.0-9.0`
547
+ - Use lossless audio format (`audio_format="wav"`)
548
+
549
+ **For Speed**:
550
+ - Use turbo model with `inference_steps=8`
551
+ - Disable ADG (`use_adg=False`)
552
+ - Lower `guidance_scale=5.0-7.0`
553
+ - Use compressed format (`audio_format="mp3"`)
554
+
555
+ **For Consistency**:
556
+ - Set `use_random_seed=False`
557
+ - Use fixed `seed` value
558
+ - Keep `lm_temperature` lower (0.7-0.85)
559
+
560
+ **For Diversity**:
561
+ - Set `use_random_seed=True`
562
+ - Increase `lm_temperature` (0.9-1.1)
563
+ - Use `batch_size > 1` for variations
564
+
565
+ ### 3. Duration Guidelines
566
+
567
+ - **Instrumental**: 30-180 seconds works well
568
+ - **With Lyrics**: Auto-detection recommended (set `audio_duration=None`)
569
+ - **Short clips**: 10-20 seconds minimum
570
+ - **Long form**: Up to 600 seconds (10 minutes) maximum
571
+
572
+ ### 4. LM Usage
573
+
574
+ **When to Enable LM (`use_llm_thinking=True`)**:
575
+ - Need automatic metadata detection
576
+ - Want caption refinement
577
+ - Generating from minimal input
578
+ - Need diverse outputs
579
+
580
+ **When to Disable LM**:
581
+ - Have precise metadata already
582
+ - Need faster generation
583
+ - Want full control over parameters
584
+
585
+ ### 5. Batch Processing
586
+
587
+ ```python
588
+ # Efficient batch generation
589
+ config = GenerationConfig(
590
+ batch_size=8, # Max supported
591
+ use_llm_thinking=True,
592
+ allow_lm_batch=True, # Enable for speed
593
+ lm_batch_chunk_size=4, # Adjust based on GPU memory
594
+ )
595
+ ```
596
+
597
+ ### 6. Error Handling
598
+
599
+ ```python
600
+ result = generate_music(dit_handler, llm_handler, config)
601
+
602
+ if not result.success:
603
+ print(f"Generation failed: {result.error}")
604
+ # Check logs for details
605
+ else:
606
+ # Process successful result
607
+ for path in result.audio_paths:
608
+ # ... process audio files
609
+ pass
610
+ ```
611
+
612
+ ### 7. Memory Management
613
+
614
+ For large batch sizes or long durations:
615
+ - Monitor GPU memory usage
616
+ - Reduce `batch_size` if OOM errors occur
617
+ - Reduce `lm_batch_chunk_size` for LM operations
618
+ - Consider using `offload_to_cpu=True` during initialization
619
+
620
+ ---
621
+
622
+ ## Troubleshooting
623
+
624
+ ### Common Issues
625
+
626
+ **Issue**: Out of memory errors
627
+ - **Solution**: Reduce `batch_size`, `inference_steps`, or enable CPU offloading
628
+
629
+ **Issue**: Poor quality results
630
+ - **Solution**: Increase `inference_steps`, adjust `guidance_scale`, use base model
631
+
632
+ **Issue**: Results don't match prompt
633
+ - **Solution**: Make caption more specific, increase `guidance_scale`, enable LM refinement
634
+
635
+ **Issue**: Slow generation
636
+ - **Solution**: Use turbo model, reduce `inference_steps`, disable ADG
637
+
638
+ **Issue**: LM not generating codes
639
+ - **Solution**: Verify `llm_handler` is initialized, check `use_llm_thinking=True` and `use_cot_metas=True`
640
+
641
+ ---
642
+
643
+ ## API Reference Summary
644
+
645
+ ### GenerationConfig Fields
646
+
647
+ See [Configuration Parameters](#configuration-parameters) for complete documentation.
648
+
649
+ ### GenerationResult Fields
650
+
651
+ ```python
652
+ @dataclass
653
+ class GenerationResult:
654
+ # Audio outputs
655
+ audio_paths: List[str] # List of generated audio file paths
656
+ first_audio: Optional[str] # First audio (backward compatibility)
657
+ second_audio: Optional[str] # Second audio (backward compatibility)
658
+
659
+ # Generation metadata
660
+ generation_info: str # Markdown-formatted generation info
661
+ status_message: str # Status message
662
+ seed_value: str # Seed value used
663
+
664
+ # LM outputs
665
+ lm_metadata: Optional[Dict[str, Any]] # LM-generated metadata
666
+
667
+ # Alignment scores (if available)
668
+ align_score_1: Optional[float]
669
+ align_text_1: Optional[str]
670
+ align_plot_1: Optional[Any]
671
+ align_score_2: Optional[float]
672
+ align_text_2: Optional[str]
673
+ align_plot_2: Optional[Any]
674
+
675
+ # Status
676
+ success: bool # Whether generation succeeded
677
+ error: Optional[str] # Error message if failed
678
+ ```
679
+
680
+ ---
681
+
682
+ ## Version History
683
+
684
+ - **v1.5**: Current version with refactored inference API
685
+ - Introduced `GenerationConfig` and `GenerationResult` dataclasses
686
+ - Simplified parameter passing
687
+ - Added comprehensive documentation
688
+ - Maintained backward compatibility with Gradio UI
689
+
690
+ ---
691
+
692
+ For more information, see:
693
+ - Main README: [`README.md`](README.md)
694
+ - REST API Documentation: [`API.md`](API.md)
695
+ - Project repository: [ACE-Step-1.5](https://github.com/yourusername/ACE-Step-1.5)
acestep/gradio_ui/event.py CHANGED
@@ -9,6 +9,8 @@ import glob
9
  import time as time_module
10
  import tempfile
11
  import gradio as gr
 
 
12
  from typing import Optional
13
  from acestep.constants import (
14
  TASK_TYPES_TURBO,
@@ -655,16 +657,11 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
655
  user_metadata_to_pass = user_metadata if user_metadata else None
656
 
657
  if should_use_lm_batch:
658
- # BATCH LM GENERATION
659
- import math
660
- from loguru import logger
661
-
662
  logger.info(f"Using LM batch generation for {batch_size_input} items...")
663
 
664
  # Prepare seeds for batch items
665
- from acestep.handler import AceStepHandler
666
- temp_handler = AceStepHandler()
667
- actual_seed_list, _ = temp_handler.prepare_seeds(batch_size_input, seed, random_seed_checkbox)
668
 
669
  # Split batch into chunks (GPU memory constraint)
670
  max_inference_batch_size = int(lm_batch_chunk_size)
 
9
  import time as time_module
10
  import tempfile
11
  import gradio as gr
12
+ import math
13
+ from loguru import logger
14
  from typing import Optional
15
  from acestep.constants import (
16
  TASK_TYPES_TURBO,
 
657
  user_metadata_to_pass = user_metadata if user_metadata else None
658
 
659
  if should_use_lm_batch:
660
+ # BATCH LM GENERATION
 
 
 
661
  logger.info(f"Using LM batch generation for {batch_size_input} items...")
662
 
663
  # Prepare seeds for batch items
664
+ actual_seed_list, _ = dit_handler.prepare_seeds(batch_size_input, seed, random_seed_checkbox)
 
 
665
 
666
  # Split batch into chunks (GPU memory constraint)
667
  max_inference_batch_size = int(lm_batch_chunk_size)
acestep/gradio_ui/events/results_handlers.py CHANGED
@@ -5,6 +5,7 @@ Contains event handlers and helper functions related to result display, scoring,
5
  import os
6
  import json
7
  import datetime
 
8
  import tempfile
9
  import shutil
10
  import zipfile
@@ -310,14 +311,10 @@ def generate_with_progress(
310
 
311
  if should_use_lm_batch:
312
  # BATCH LM GENERATION
313
- import math
314
- from acestep.handler import AceStepHandler
315
-
316
  logger.info(f"Using LM batch generation for {batch_size_input} items...")
317
 
318
  # Prepare seeds for batch items
319
- temp_handler = AceStepHandler()
320
- actual_seed_list, _ = temp_handler.prepare_seeds(batch_size_input, seed, random_seed_checkbox)
321
 
322
  # Split batch into chunks (GPU memory constraint)
323
  max_inference_batch_size = int(lm_batch_chunk_size)
 
5
  import os
6
  import json
7
  import datetime
8
+ import math
9
  import tempfile
10
  import shutil
11
  import zipfile
 
311
 
312
  if should_use_lm_batch:
313
  # BATCH LM GENERATION
 
 
 
314
  logger.info(f"Using LM batch generation for {batch_size_input} items...")
315
 
316
  # Prepare seeds for batch items
317
+ actual_seed_list, _ = dit_handler.prepare_seeds(batch_size_input, seed, random_seed_checkbox)
 
318
 
319
  # Split batch into chunks (GPU memory constraint)
320
  max_inference_batch_size = int(lm_batch_chunk_size)
acestep/handler.py CHANGED
@@ -37,12 +37,15 @@ warnings.filterwarnings("ignore")
37
  class AceStepHandler:
38
  """ACE-Step Business Logic Handler"""
39
 
40
- def __init__(self):
41
  self.model = None
42
  self.config = None
43
  self.device = "cpu"
44
  self.dtype = torch.float32 # Will be set based on device in initialize_service
45
- self.temp_dir = tempfile.mkdtemp()
 
 
 
46
 
47
  # VAE for audio encoding/decoding
48
  self.vae = None
 
37
  class AceStepHandler:
38
  """ACE-Step Business Logic Handler"""
39
 
40
+ def __init__(self, save_root = None):
41
  self.model = None
42
  self.config = None
43
  self.device = "cpu"
44
  self.dtype = torch.float32 # Will be set based on device in initialize_service
45
+ if save_root is None:
46
+ self.temp_dir = tempfile.mkdtemp()
47
+ else:
48
+ self.temp_dir = save_root
49
 
50
  # VAE for audio encoding/decoding
51
  self.vae = None
acestep/inference.py ADDED
@@ -0,0 +1,928 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step Inference API Module
3
+
4
+ This module provides a standardized inference interface for music generation,
5
+ designed for third-party integration. It offers both a simplified API and
6
+ backward-compatible Gradio UI support.
7
+ """
8
+
9
+ import math
10
+ from typing import Optional, Union, List, Dict, Any, Tuple
11
+ from dataclasses import dataclass, field, asdict
12
+ from loguru import logger
13
+ import time as time_module
14
+
15
+
16
+ @dataclass
17
+ class GenerationConfig:
18
+ """Configuration for music generation.
19
+
20
+ Attributes:
21
+ # Text Inputs
22
+ caption: Text description of the desired music
23
+ lyrics: Lyrics text for vocal music (use "[Instrumental]" for instrumental)
24
+
25
+ # Music Metadata
26
+ bpm: Beats per minute (e.g., 120). None for auto-detection
27
+ key_scale: Musical key (e.g., "C Major", "Am"). Empty for auto-detection
28
+ time_signature: Time signature (e.g., "4/4", "3/4"). Empty for auto-detection
29
+ vocal_language: Language code for vocals (e.g., "en", "zh", "ja")
30
+ audio_duration: Duration in seconds. None for auto-detection
31
+
32
+ # Generation Parameters
33
+ inference_steps: Number of denoising steps (8 for turbo, 32-100 for base)
34
+ guidance_scale: Classifier-free guidance scale (higher = more adherence to prompt)
35
+ use_random_seed: Whether to use random seed (True) or fixed seed
36
+ seed: Random seed for reproducibility (-1 for random)
37
+ batch_size: Number of samples to generate (1-8)
38
+
39
+ # Advanced DiT Parameters
40
+ use_adg: Use Adaptive Dual Guidance (base model only)
41
+ cfg_interval_start: CFG application start ratio (0.0-1.0)
42
+ cfg_interval_end: CFG application end ratio (0.0-1.0)
43
+ audio_format: Output audio format ("mp3", "wav", "flac")
44
+
45
+ # Task-Specific Parameters
46
+ task_type: Generation task type ("text2music", "cover", "repaint", "lego", "extract", "complete")
47
+ reference_audio: Path to reference audio file (for style transfer)
48
+ src_audio: Path to source audio file (for audio-to-audio tasks)
49
+ audio_code_string: Pre-extracted audio codes (advanced use)
50
+ repainting_start: Repainting start time in seconds (for repaint/lego tasks)
51
+ repainting_end: Repainting end time in seconds (-1 for end of audio)
52
+ audio_cover_strength: Strength of audio cover/codes influence (0.0-1.0)
53
+ instruction: Task-specific instruction prompt (auto-generated if empty)
54
+
55
+ # 5Hz Language Model Parameters (CoT Reasoning)
56
+ use_llm_thinking: Enable LM-based Chain-of-Thought reasoning for metadata/codes
57
+ lm_temperature: LM sampling temperature (0.0-2.0, higher = more creative)
58
+ lm_cfg_scale: LM classifier-free guidance scale
59
+ lm_top_k: LM top-k sampling (0 = disabled)
60
+ lm_top_p: LM nucleus sampling (1.0 = disabled)
61
+ lm_negative_prompt: Negative prompt for LM guidance
62
+ use_cot_metas: Generate metadata using LM CoT
63
+ use_cot_caption: Refine caption using LM CoT
64
+ use_cot_language: Detect language using LM CoT
65
+ is_format_caption: Whether caption is already formatted
66
+ constrained_decoding_debug: Enable debug logging for constrained decoding
67
+
68
+ # Batch LM Generation
69
+ allow_lm_batch: Allow batch LM code generation (faster for batch_size >= 2)
70
+ lm_batch_chunk_size: Maximum batch size per LM inference chunk (GPU memory constraint)
71
+ """
72
+
73
+ # Text Inputs
74
+ caption: str = ""
75
+ lyrics: str = ""
76
+
77
+ # Music Metadata
78
+ bpm: Optional[int] = None
79
+ key_scale: str = ""
80
+ time_signature: str = ""
81
+ vocal_language: str = "unknown"
82
+ audio_duration: Optional[float] = None
83
+
84
+ # Generation Parameters
85
+ inference_steps: int = 8
86
+ guidance_scale: float = 7.0
87
+ use_random_seed: bool = True
88
+ seed: int = -1
89
+ batch_size: int = 1
90
+
91
+ # Advanced DiT Parameters
92
+ use_adg: bool = False
93
+ cfg_interval_start: float = 0.0
94
+ cfg_interval_end: float = 1.0
95
+ audio_format: str = "mp3"
96
+
97
+ # Task-Specific Parameters
98
+ task_type: str = "text2music"
99
+ reference_audio: Optional[str] = None
100
+ src_audio: Optional[str] = None
101
+ audio_code_string: Union[str, List[str]] = ""
102
+ repainting_start: float = 0.0
103
+ repainting_end: float = -1
104
+ audio_cover_strength: float = 1.0
105
+ instruction: str = ""
106
+
107
+ # 5Hz Language Model Parameters
108
+ use_llm_thinking: bool = False
109
+ lm_temperature: float = 0.85
110
+ lm_cfg_scale: float = 2.0
111
+ lm_top_k: int = 0
112
+ lm_top_p: float = 0.9
113
+ lm_negative_prompt: str = "NO USER INPUT"
114
+ use_cot_metas: bool = True
115
+ use_cot_caption: bool = True
116
+ use_cot_language: bool = True
117
+ is_format_caption: bool = False
118
+ constrained_decoding_debug: bool = False
119
+
120
+ # Batch LM Generation
121
+ allow_lm_batch: bool = False
122
+ lm_batch_chunk_size: int = 4
123
+
124
+
125
+ @dataclass
126
+ class GenerationResult:
127
+ """Result of music generation.
128
+
129
+ Attributes:
130
+ # Audio Outputs
131
+ audio_paths: List of paths to generated audio files
132
+ first_audio: Path to first generated audio (backward compatibility)
133
+ second_audio: Path to second generated audio (backward compatibility)
134
+
135
+ # Generation Information
136
+ generation_info: Markdown-formatted generation information
137
+ status_message: Status message from generation
138
+ seed_value: Actual seed value used for generation
139
+
140
+ # LM-Generated Metadata (if applicable)
141
+ lm_metadata: Metadata generated by language model (dict or None)
142
+
143
+ # Audio-Text Alignment Scores (if available)
144
+ align_score_1: First alignment score
145
+ align_text_1: First alignment text description
146
+ align_plot_1: First alignment plot image
147
+ align_score_2: Second alignment score
148
+ align_text_2: Second alignment text description
149
+ align_plot_2: Second alignment plot image
150
+
151
+ # Success Status
152
+ success: Whether generation completed successfully
153
+ error: Error message if generation failed
154
+ """
155
+
156
+ # Audio Outputs
157
+ audio_paths: List[str] = field(default_factory=list)
158
+ first_audio: Optional[str] = None
159
+ second_audio: Optional[str] = None
160
+
161
+ # Generation Information
162
+ generation_info: str = ""
163
+ status_message: str = ""
164
+ seed_value: str = ""
165
+
166
+ # LM-Generated Metadata
167
+ lm_metadata: Optional[Dict[str, Any]] = None
168
+
169
+ # Audio-Text Alignment Scores
170
+ align_score_1: Optional[float] = None
171
+ align_text_1: Optional[str] = None
172
+ align_plot_1: Optional[Any] = None
173
+ align_score_2: Optional[float] = None
174
+ align_text_2: Optional[str] = None
175
+ align_plot_2: Optional[Any] = None
176
+
177
+ # Success Status
178
+ success: bool = True
179
+ error: Optional[str] = None
180
+
181
+ def to_dict(self) -> Dict[str, Any]:
182
+ """Convert result to dictionary for JSON serialization."""
183
+ return asdict(self)
184
+
185
+
186
+ def generate_music(
187
+ dit_handler,
188
+ llm_handler,
189
+ config: GenerationConfig,
190
+ ) -> GenerationResult:
191
+ """Generate music using ACE-Step model with optional LM reasoning.
192
+
193
+ This is the main inference API for music generation. It supports various task types
194
+ (text2music, cover, repaint, etc.) and can optionally use a 5Hz Language Model for
195
+ Chain-of-Thought reasoning to generate metadata and audio codes.
196
+
197
+ Args:
198
+ dit_handler: Initialized DiT model handler (AceStepHandler instance)
199
+ llm_handler: Initialized LLM handler (LLMHandler instance)
200
+ config: Generation configuration (GenerationConfig instance)
201
+
202
+ Returns:
203
+ GenerationResult: Generation result containing audio paths and metadata
204
+
205
+ Example:
206
+ >>> from acestep.handler import AceStepHandler
207
+ >>> from acestep.llm_inference import LLMHandler
208
+ >>> from acestep.inference import GenerationConfig, generate_music
209
+ >>>
210
+ >>> # Initialize handlers
211
+ >>> dit_handler = AceStepHandler()
212
+ >>> llm_handler = LLMHandler()
213
+ >>> dit_handler.initialize_service(...)
214
+ >>> llm_handler.initialize(...)
215
+ >>>
216
+ >>> # Configure generation
217
+ >>> config = GenerationConfig(
218
+ ... caption="upbeat electronic dance music",
219
+ ... bpm=128,
220
+ ... audio_duration=30,
221
+ ... batch_size=2,
222
+ ... )
223
+ >>>
224
+ >>> # Generate music
225
+ >>> result = generate_music(dit_handler, llm_handler, config)
226
+ >>> print(f"Generated {len(result.audio_paths)} audio files")
227
+ >>> for path in result.audio_paths:
228
+ ... print(f"Audio: {path}")
229
+ """
230
+
231
+ try:
232
+ # Phase 1: LM-based metadata and code generation (if enabled)
233
+ audio_code_string_to_use = config.audio_code_string
234
+ lm_generated_metadata = None
235
+ lm_generated_audio_codes = None
236
+ lm_generated_audio_codes_list = []
237
+
238
+ # Extract mutable copies of metadata (will be updated by LM if needed)
239
+ bpm = config.bpm
240
+ key_scale = config.key_scale
241
+ time_signature = config.time_signature
242
+ audio_duration = config.audio_duration
243
+
244
+ # Determine if we should use batch LM generation
245
+ should_use_lm_batch = (
246
+ config.use_llm_thinking
247
+ and llm_handler.llm_initialized
248
+ and config.use_cot_metas
249
+ and config.allow_lm_batch
250
+ and config.batch_size >= 2
251
+ )
252
+
253
+ # LM-based Chain-of-Thought reasoning
254
+ if config.use_llm_thinking and llm_handler.llm_initialized and config.use_cot_metas:
255
+ # Convert sampling parameters
256
+ top_k_value = None if config.lm_top_k == 0 else int(config.lm_top_k)
257
+ top_p_value = None if config.lm_top_p >= 1.0 else config.lm_top_p
258
+
259
+ # Build user_metadata from user-provided values
260
+ user_metadata = {}
261
+ if bpm is not None:
262
+ try:
263
+ bpm_value = float(bpm)
264
+ if bpm_value > 0:
265
+ user_metadata['bpm'] = str(int(bpm_value))
266
+ except (ValueError, TypeError):
267
+ pass
268
+
269
+ if key_scale and key_scale.strip():
270
+ key_scale_clean = key_scale.strip()
271
+ if key_scale_clean.lower() not in ["n/a", ""]:
272
+ user_metadata['keyscale'] = key_scale_clean
273
+
274
+ if time_signature and time_signature.strip():
275
+ time_sig_clean = time_signature.strip()
276
+ if time_sig_clean.lower() not in ["n/a", ""]:
277
+ user_metadata['timesignature'] = time_sig_clean
278
+
279
+ if audio_duration is not None:
280
+ try:
281
+ duration_value = float(audio_duration)
282
+ if duration_value > 0:
283
+ user_metadata['duration'] = str(int(duration_value))
284
+ except (ValueError, TypeError):
285
+ pass
286
+
287
+ user_metadata_to_pass = user_metadata if user_metadata else None
288
+
289
+ # Batch LM generation (faster for multiple samples)
290
+ if should_use_lm_batch:
291
+ actual_seed_list, _ = dit_handler.prepare_seeds(
292
+ config.batch_size, config.seed, config.use_random_seed
293
+ )
294
+
295
+ max_inference_batch_size = int(config.lm_batch_chunk_size)
296
+ num_chunks = math.ceil(config.batch_size / max_inference_batch_size)
297
+
298
+ all_metadata_list = []
299
+ all_audio_codes_list = []
300
+
301
+ for chunk_idx in range(num_chunks):
302
+ chunk_start = chunk_idx * max_inference_batch_size
303
+ chunk_end = min(chunk_start + max_inference_batch_size, config.batch_size)
304
+ chunk_size = chunk_end - chunk_start
305
+ chunk_seeds = actual_seed_list[chunk_start:chunk_end]
306
+
307
+ logger.info(
308
+ f"LM batch chunk {chunk_idx+1}/{num_chunks} "
309
+ f"(size: {chunk_size}, seeds: {chunk_seeds})"
310
+ )
311
+
312
+ metadata_list, audio_codes_list, status = llm_handler.generate_with_stop_condition_batch(
313
+ caption=config.caption or "",
314
+ lyrics=config.lyrics or "",
315
+ batch_size=chunk_size,
316
+ infer_type="llm_dit",
317
+ temperature=config.lm_temperature,
318
+ cfg_scale=config.lm_cfg_scale,
319
+ negative_prompt=config.lm_negative_prompt,
320
+ top_k=top_k_value,
321
+ top_p=top_p_value,
322
+ user_metadata=user_metadata_to_pass,
323
+ use_cot_caption=config.use_cot_caption,
324
+ use_cot_language=config.use_cot_language,
325
+ is_format_caption=config.is_format_caption,
326
+ constrained_decoding_debug=config.constrained_decoding_debug,
327
+ seeds=chunk_seeds,
328
+ )
329
+
330
+ all_metadata_list.extend(metadata_list)
331
+ all_audio_codes_list.extend(audio_codes_list)
332
+
333
+ lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
334
+ lm_generated_audio_codes_list = all_audio_codes_list
335
+ audio_code_string_to_use = all_audio_codes_list
336
+
337
+ # Update metadata from LM if not provided by user
338
+ if lm_generated_metadata:
339
+ bpm, key_scale, time_signature, audio_duration = _update_metadata_from_lm(
340
+ lm_generated_metadata, bpm, key_scale, time_signature, audio_duration
341
+ )
342
+
343
+ else:
344
+ # Sequential LM generation (current behavior)
345
+ # Phase 1: Generate CoT metadata
346
+ phase1_start = time_module.time()
347
+ metadata, _, status = llm_handler.generate_with_stop_condition(
348
+ caption=config.caption or "",
349
+ lyrics=config.lyrics or "",
350
+ infer_type="dit",
351
+ temperature=config.lm_temperature,
352
+ cfg_scale=config.lm_cfg_scale,
353
+ negative_prompt=config.lm_negative_prompt,
354
+ top_k=top_k_value,
355
+ top_p=top_p_value,
356
+ user_metadata=user_metadata_to_pass,
357
+ use_cot_caption=config.use_cot_caption,
358
+ use_cot_language=config.use_cot_language,
359
+ is_format_caption=config.is_format_caption,
360
+ constrained_decoding_debug=config.constrained_decoding_debug,
361
+ )
362
+ lm_phase1_time = time_module.time() - phase1_start
363
+ logger.info(f"LM Phase 1 (CoT) completed in {lm_phase1_time:.2f}s")
364
+
365
+ # Phase 2: Generate audio codes
366
+ phase2_start = time_module.time()
367
+ metadata, audio_codes, status = llm_handler.generate_with_stop_condition(
368
+ caption=config.caption or "",
369
+ lyrics=config.lyrics or "",
370
+ infer_type="llm_dit",
371
+ temperature=config.lm_temperature,
372
+ cfg_scale=config.lm_cfg_scale,
373
+ negative_prompt=config.lm_negative_prompt,
374
+ top_k=top_k_value,
375
+ top_p=top_p_value,
376
+ user_metadata=user_metadata_to_pass,
377
+ use_cot_caption=config.use_cot_caption,
378
+ use_cot_language=config.use_cot_language,
379
+ is_format_caption=config.is_format_caption,
380
+ constrained_decoding_debug=config.constrained_decoding_debug,
381
+ )
382
+ lm_phase2_time = time_module.time() - phase2_start
383
+ logger.info(f"LM Phase 2 (Codes) completed in {lm_phase2_time:.2f}s")
384
+
385
+ lm_generated_metadata = metadata
386
+ if audio_codes:
387
+ audio_code_string_to_use = audio_codes
388
+ lm_generated_audio_codes = audio_codes
389
+
390
+ # Update metadata from LM if not provided by user
391
+ bpm, key_scale, time_signature, audio_duration = _update_metadata_from_lm(
392
+ metadata, bpm, key_scale, time_signature, audio_duration
393
+ )
394
+
395
+ # Phase 2: DiT music generation
396
+ result = dit_handler.generate_music(
397
+ captions=config.caption,
398
+ lyrics=config.lyrics,
399
+ bpm=bpm,
400
+ key_scale=key_scale,
401
+ time_signature=time_signature,
402
+ vocal_language=config.vocal_language,
403
+ inference_steps=config.inference_steps,
404
+ guidance_scale=config.guidance_scale,
405
+ use_random_seed=config.use_random_seed,
406
+ seed=config.seed,
407
+ reference_audio=config.reference_audio,
408
+ audio_duration=audio_duration,
409
+ batch_size=config.batch_size,
410
+ src_audio=config.src_audio,
411
+ audio_code_string=audio_code_string_to_use,
412
+ repainting_start=config.repainting_start,
413
+ repainting_end=config.repainting_end,
414
+ instruction=config.instruction,
415
+ audio_cover_strength=config.audio_cover_strength,
416
+ task_type=config.task_type,
417
+ use_adg=config.use_adg,
418
+ cfg_interval_start=config.cfg_interval_start,
419
+ cfg_interval_end=config.cfg_interval_end,
420
+ audio_format=config.audio_format,
421
+ lm_temperature=config.lm_temperature,
422
+ )
423
+
424
+ # Extract results
425
+ (first_audio, second_audio, all_audio_paths, generation_info, status_message,
426
+ seed_value, align_score_1, align_text_1, align_plot_1,
427
+ align_score_2, align_text_2, align_plot_2) = result
428
+
429
+ # Append LM metadata to generation info
430
+ if lm_generated_metadata:
431
+ generation_info = _append_lm_metadata_to_info(generation_info, lm_generated_metadata)
432
+
433
+ # Create result object
434
+ return GenerationResult(
435
+ audio_paths=all_audio_paths or [],
436
+ first_audio=first_audio,
437
+ second_audio=second_audio,
438
+ generation_info=generation_info,
439
+ status_message=status_message,
440
+ seed_value=seed_value,
441
+ lm_metadata=lm_generated_metadata,
442
+ align_score_1=align_score_1,
443
+ align_text_1=align_text_1,
444
+ align_plot_1=align_plot_1,
445
+ align_score_2=align_score_2,
446
+ align_text_2=align_text_2,
447
+ align_plot_2=align_plot_2,
448
+ success=True,
449
+ error=None,
450
+ )
451
+
452
+ except Exception as e:
453
+ logger.exception("Music generation failed")
454
+ return GenerationResult(
455
+ success=False,
456
+ error=str(e),
457
+ generation_info=f"❌ Generation failed: {str(e)}",
458
+ status_message=f"Error: {str(e)}",
459
+ )
460
+
461
+
462
+ def _update_metadata_from_lm(
463
+ metadata: Dict[str, Any],
464
+ bpm: Optional[int],
465
+ key_scale: str,
466
+ time_signature: str,
467
+ audio_duration: Optional[float],
468
+ ) -> Tuple[Optional[int], str, str, Optional[float]]:
469
+ """Update metadata fields from LM output if not provided by user."""
470
+
471
+ if bpm is None and metadata.get('bpm'):
472
+ bpm_value = metadata.get('bpm')
473
+ if bpm_value not in ["N/A", ""]:
474
+ try:
475
+ bpm = int(bpm_value)
476
+ except (ValueError, TypeError):
477
+ pass
478
+
479
+ if not key_scale and metadata.get('keyscale'):
480
+ key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
481
+ if key_scale_value != "N/A":
482
+ key_scale = key_scale_value
483
+
484
+ if not time_signature and metadata.get('timesignature'):
485
+ time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
486
+ if time_signature_value != "N/A":
487
+ time_signature = time_signature_value
488
+
489
+ if audio_duration is None or audio_duration <= 0:
490
+ audio_duration_value = metadata.get('duration', -1)
491
+ if audio_duration_value not in ["N/A", ""]:
492
+ try:
493
+ audio_duration = float(audio_duration_value)
494
+ except (ValueError, TypeError):
495
+ pass
496
+
497
+ return bpm, key_scale, time_signature, audio_duration
498
+
499
+
500
+ def _append_lm_metadata_to_info(generation_info: str, metadata: Dict[str, Any]) -> str:
501
+ """Append LM-generated metadata to generation info string."""
502
+
503
+ metadata_lines = []
504
+ if metadata.get('bpm'):
505
+ metadata_lines.append(f"- **BPM:** {metadata['bpm']}")
506
+ if metadata.get('caption'):
507
+ metadata_lines.append(f"- **Refined Caption:** {metadata['caption']}")
508
+ if metadata.get('duration'):
509
+ metadata_lines.append(f"- **Duration:** {metadata['duration']} seconds")
510
+ if metadata.get('keyscale'):
511
+ metadata_lines.append(f"- **Key Scale:** {metadata['keyscale']}")
512
+ if metadata.get('language'):
513
+ metadata_lines.append(f"- **Language:** {metadata['language']}")
514
+ if metadata.get('timesignature'):
515
+ metadata_lines.append(f"- **Time Signature:** {metadata['timesignature']}")
516
+
517
+ if metadata_lines:
518
+ metadata_section = "\n\n**πŸ€– LM-Generated Metadata:**\n" + "\n\n".join(metadata_lines)
519
+ return metadata_section + "\n\n" + generation_info
520
+
521
+ return generation_info
522
+
523
+
524
+ # ============================================================================
525
+ # LEGACY GRADIO UI COMPATIBILITY LAYER
526
+ # ============================================================================
527
+
528
+ def generate(
529
+ dit_handler,
530
+ llm_handler,
531
+ captions,
532
+ lyrics,
533
+ bpm,
534
+ key_scale,
535
+ time_signature,
536
+ vocal_language,
537
+ inference_steps,
538
+ guidance_scale,
539
+ random_seed_checkbox,
540
+ seed,
541
+ reference_audio,
542
+ audio_duration,
543
+ batch_size_input,
544
+ src_audio,
545
+ text2music_audio_code_string,
546
+ repainting_start,
547
+ repainting_end,
548
+ instruction_display_gen,
549
+ audio_cover_strength,
550
+ task_type,
551
+ use_adg,
552
+ cfg_interval_start,
553
+ cfg_interval_end,
554
+ audio_format,
555
+ lm_temperature,
556
+ think_checkbox,
557
+ lm_cfg_scale,
558
+ lm_top_k,
559
+ lm_top_p,
560
+ lm_negative_prompt,
561
+ use_cot_metas,
562
+ use_cot_caption,
563
+ use_cot_language,
564
+ is_format_caption,
565
+ constrained_decoding_debug,
566
+ allow_lm_batch,
567
+ lm_batch_chunk_size,
568
+ ):
569
+ """Legacy Gradio UI compatibility wrapper.
570
+
571
+ This function maintains backward compatibility with the Gradio UI.
572
+ For new integrations, use generate_music() with GenerationConfig instead.
573
+
574
+ Returns:
575
+ Tuple with 28 elements for Gradio UI component updates
576
+ """
577
+
578
+ # Convert legacy parameters to new config
579
+ config = GenerationConfig(
580
+ caption=captions,
581
+ lyrics=lyrics,
582
+ bpm=bpm,
583
+ key_scale=key_scale,
584
+ time_signature=time_signature,
585
+ vocal_language=vocal_language,
586
+ audio_duration=audio_duration,
587
+ inference_steps=inference_steps,
588
+ guidance_scale=guidance_scale,
589
+ use_random_seed=random_seed_checkbox,
590
+ seed=seed,
591
+ batch_size=batch_size_input,
592
+ use_adg=use_adg,
593
+ cfg_interval_start=cfg_interval_start,
594
+ cfg_interval_end=cfg_interval_end,
595
+ audio_format=audio_format,
596
+ task_type=task_type,
597
+ reference_audio=reference_audio,
598
+ src_audio=src_audio,
599
+ audio_code_string=text2music_audio_code_string,
600
+ repainting_start=repainting_start,
601
+ repainting_end=repainting_end,
602
+ audio_cover_strength=audio_cover_strength,
603
+ instruction=instruction_display_gen,
604
+ use_llm_thinking=think_checkbox,
605
+ lm_temperature=lm_temperature,
606
+ lm_cfg_scale=lm_cfg_scale,
607
+ lm_top_k=lm_top_k,
608
+ lm_top_p=lm_top_p,
609
+ lm_negative_prompt=lm_negative_prompt,
610
+ use_cot_metas=use_cot_metas,
611
+ use_cot_caption=use_cot_caption,
612
+ use_cot_language=use_cot_language,
613
+ is_format_caption=is_format_caption,
614
+ constrained_decoding_debug=constrained_decoding_debug,
615
+ allow_lm_batch=allow_lm_batch,
616
+ lm_batch_chunk_size=lm_batch_chunk_size,
617
+ )
618
+
619
+ # Call new API
620
+ result = generate_music(dit_handler, llm_handler, config)
621
+
622
+ # Determine which codes to update in UI
623
+ if config.allow_lm_batch and result.lm_metadata:
624
+ # Batch mode: extract codes from metadata if available
625
+ lm_codes_list = result.lm_metadata.get('audio_codes_list', [])
626
+ updated_audio_codes = lm_codes_list[0] if lm_codes_list else text2music_audio_code_string
627
+ codes_outputs = (lm_codes_list + [""] * 8)[:8]
628
+ else:
629
+ # Single mode
630
+ lm_codes = result.lm_metadata.get('audio_codes', '') if result.lm_metadata else ''
631
+ updated_audio_codes = lm_codes if lm_codes else text2music_audio_code_string
632
+ codes_outputs = [""] * 8
633
+
634
+ # Prepare audio outputs (up to 8)
635
+ audio_outputs = (result.audio_paths + [None] * 8)[:8]
636
+
637
+ # Return tuple for Gradio UI (28 elements)
638
+ return (
639
+ audio_outputs[0], # generated_audio_1
640
+ audio_outputs[1], # generated_audio_2
641
+ audio_outputs[2], # generated_audio_3
642
+ audio_outputs[3], # generated_audio_4
643
+ audio_outputs[4], # generated_audio_5
644
+ audio_outputs[5], # generated_audio_6
645
+ audio_outputs[6], # generated_audio_7
646
+ audio_outputs[7], # generated_audio_8
647
+ result.audio_paths, # generated_audio_batch
648
+ result.generation_info,
649
+ result.status_message,
650
+ result.seed_value,
651
+ result.align_score_1,
652
+ result.align_text_1,
653
+ result.align_plot_1,
654
+ result.align_score_2,
655
+ result.align_text_2,
656
+ result.align_plot_2,
657
+ updated_audio_codes, # Update main audio codes in UI
658
+ codes_outputs[0], # text2music_audio_code_string_1
659
+ codes_outputs[1], # text2music_audio_code_string_2
660
+ codes_outputs[2], # text2music_audio_code_string_3
661
+ codes_outputs[3], # text2music_audio_code_string_4
662
+ codes_outputs[4], # text2music_audio_code_string_5
663
+ codes_outputs[5], # text2music_audio_code_string_6
664
+ codes_outputs[6], # text2music_audio_code_string_7
665
+ codes_outputs[7], # text2music_audio_code_string_8
666
+ result.lm_metadata, # Store metadata for "Send to src audio" buttons
667
+ is_format_caption, # Keep is_format_caption unchanged
668
+ )
669
+
670
+
671
+ # ============================================================================
672
+ # TESTING & EXAMPLES
673
+ # ============================================================================
674
+
675
+ if __name__ == "__main__":
676
+ """
677
+ Test suite for the inference API.
678
+ Demonstrates various usage scenarios and validates functionality.
679
+
680
+ Usage:
681
+ python -m acestep.inference
682
+ """
683
+
684
+ import os
685
+ import json
686
+ from acestep.handler import AceStepHandler
687
+ from acestep.llm_inference import LLMHandler
688
+
689
+ # Initialize paths
690
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
691
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
692
+
693
+ print("=" * 80)
694
+ print("ACE-Step Inference API Test Suite")
695
+ print("=" * 80)
696
+
697
+ # ========================================================================
698
+ # Initialize Handlers
699
+ # ========================================================================
700
+ print("\n[1/3] Initializing handlers...")
701
+ dit_handler = AceStepHandler(save_root="./")
702
+ llm_handler = LLMHandler()
703
+
704
+ try:
705
+ # Initialize DiT handler
706
+ print(" - Initializing DiT model...")
707
+ status_dit, success_dit = dit_handler.initialize_service(
708
+ project_root=project_root,
709
+ config_path="acestep-v15-turbo-rl",
710
+ device="cuda",
711
+ )
712
+ if not success_dit:
713
+ print(f" ❌ DiT initialization failed: {status_dit}")
714
+ exit(1)
715
+ print(f" βœ“ DiT model initialized successfully")
716
+
717
+ # Initialize LLM handler
718
+ print(" - Initializing 5Hz LM model...")
719
+ status_llm, success_llm = llm_handler.initialize(
720
+ checkpoint_dir=checkpoint_dir,
721
+ lm_model_path="acestep-5Hz-lm-0.6B-v3",
722
+ backend="vllm",
723
+ device="cuda",
724
+ )
725
+ if success_llm:
726
+ print(f" βœ“ LM model initialized successfully")
727
+ else:
728
+ print(f" ⚠ LM initialization failed (will skip LM tests): {status_llm}")
729
+
730
+ except Exception as e:
731
+ print(f" ❌ Initialization error: {e}")
732
+ exit(1)
733
+
734
+ # ========================================================================
735
+ # Helper Functions
736
+ # ========================================================================
737
+ def load_example_config(example_file: str) -> GenerationConfig:
738
+ """Load configuration from an example JSON file."""
739
+ try:
740
+ with open(example_file, 'r', encoding='utf-8') as f:
741
+ data = json.load(f)
742
+
743
+ # Convert example format to GenerationConfig
744
+ # Handle time signature format (example uses "4" instead of "4/4")
745
+ time_sig = data.get('timesignature', '')
746
+ if time_sig and '/' not in time_sig:
747
+ time_sig = f"{time_sig}/4" # Default to /4 if only numerator given
748
+
749
+ config = GenerationConfig(
750
+ caption=data.get('caption', ''),
751
+ lyrics=data.get('lyrics', ''),
752
+ bpm=data.get('bpm'),
753
+ key_scale=data.get('keyscale', ''),
754
+ time_signature=time_sig,
755
+ vocal_language=data.get('language', 'unknown'),
756
+ audio_duration=data.get('duration'),
757
+ use_llm_thinking=data.get('think', False),
758
+ batch_size=data.get('batch_size', 1),
759
+ inference_steps=data.get('inference_steps', 8),
760
+ )
761
+ return config
762
+
763
+ except Exception as e:
764
+ print(f" ⚠ Failed to load example file: {e}")
765
+ return None
766
+
767
+ # ========================================================================
768
+ # Test Cases
769
+ # ========================================================================
770
+ test_results = []
771
+
772
+ def run_test(test_name: str, config: GenerationConfig, expected_outputs: int = 1):
773
+ """Run a single test case and collect results."""
774
+ print(f"\n{'=' * 80}")
775
+ print(f"Test: {test_name}")
776
+ print(f"{'=' * 80}")
777
+
778
+ # Display configuration
779
+ print("\nConfiguration:")
780
+ print(f" Task Type: {config.task_type}")
781
+ print(f" Caption: {config.caption[:60]}..." if len(config.caption) > 60 else f" Caption: {config.caption}")
782
+ if config.lyrics:
783
+ print(f" Lyrics: {config.lyrics[:60]}..." if len(config.lyrics) > 60 else f" Lyrics: {config.lyrics}")
784
+ if config.bpm:
785
+ print(f" BPM: {config.bpm}")
786
+ if config.key_scale:
787
+ print(f" Key Scale: {config.key_scale}")
788
+ if config.time_signature:
789
+ print(f" Time Signature: {config.time_signature}")
790
+ if config.audio_duration:
791
+ print(f" Duration: {config.audio_duration}s")
792
+ print(f" Batch Size: {config.batch_size}")
793
+ print(f" Inference Steps: {config.inference_steps}")
794
+ print(f" Use LLM Thinking: {config.use_llm_thinking}")
795
+
796
+ # Run generation
797
+ print("\nGenerating...")
798
+ import time
799
+ start_time = time.time()
800
+
801
+ result = generate_music(dit_handler, llm_handler, config)
802
+
803
+ elapsed_time = time.time() - start_time
804
+
805
+ # Display results
806
+ print("\nResults:")
807
+ print(f" Success: {'βœ“' if result.success else 'βœ—'}")
808
+
809
+ if result.success:
810
+ print(f" Generated Files: {len(result.audio_paths)}")
811
+ for i, path in enumerate(result.audio_paths, 1):
812
+ if os.path.exists(path):
813
+ file_size = os.path.getsize(path) / (1024 * 1024) # MB
814
+ print(f" [{i}] {os.path.basename(path)} ({file_size:.2f} MB)")
815
+ else:
816
+ print(f" [{i}] {os.path.basename(path)} (file not found)")
817
+
818
+ print(f" Seed: {result.seed_value}")
819
+ print(f" Generation Time: {elapsed_time:.2f}s")
820
+
821
+ # Display LM metadata if available
822
+ if result.lm_metadata:
823
+ print(f"\n LM-Generated Metadata:")
824
+ for key, value in result.lm_metadata.items():
825
+ if key not in ['audio_codes', 'audio_codes_list']: # Skip large code strings
826
+ print(f" {key}: {value}")
827
+
828
+ # Validate outputs
829
+ if len(result.audio_paths) != expected_outputs:
830
+ print(f" ⚠ Warning: Expected {expected_outputs} outputs, got {len(result.audio_paths)}")
831
+ success = False
832
+ else:
833
+ success = True
834
+
835
+ else:
836
+ print(f" Error: {result.error}")
837
+ success = False
838
+
839
+ # Store test result
840
+ test_results.append({
841
+ "test_name": test_name,
842
+ "success": success,
843
+ "generation_success": result.success,
844
+ "num_outputs": len(result.audio_paths) if result.success else 0,
845
+ "expected_outputs": expected_outputs,
846
+ "elapsed_time": elapsed_time,
847
+ "error": result.error if not result.success else None,
848
+ })
849
+
850
+ return result
851
+
852
+ # ========================================================================
853
+ # Test: Production Example (from examples directory)
854
+ # ========================================================================
855
+ print("\n[2/3] Running Test...")
856
+
857
+ # Load production example (J-Rock song from examples/text2music/example_05.json)
858
+ example_file = os.path.join(project_root, "examples", "text2music", "example_05.json")
859
+
860
+ if not os.path.exists(example_file):
861
+ print(f"\n ❌ Example file not found: {example_file}")
862
+ print(" Please ensure the examples directory exists.")
863
+ exit(1)
864
+
865
+ print(f" Loading example: {os.path.basename(example_file)}")
866
+ config = load_example_config(example_file)
867
+
868
+ if not config:
869
+ print(" ❌ Failed to load example configuration")
870
+ exit(1)
871
+
872
+ # Reduce duration for faster testing (original is 200s)
873
+ print(f" Original duration: {config.audio_duration}s")
874
+ config.audio_duration = 30
875
+ config.use_random_seed = False
876
+ config.seed = 42
877
+ print(f" Test duration: {config.audio_duration}s (reduced for testing)")
878
+
879
+ run_test("Production Example (J-Rock Song)", config, expected_outputs=1)
880
+
881
+ # ========================================================================
882
+ # Test Summary
883
+ # ========================================================================
884
+ print("\n[3/3] Test Summary")
885
+ print("=" * 80)
886
+
887
+ if len(test_results) == 0:
888
+ print("No tests were run.")
889
+ exit(1)
890
+
891
+ result = test_results[0]
892
+
893
+ print(f"\nTest: {result['test_name']}")
894
+ print(f"Status: {'βœ“ PASS' if result['success'] else 'βœ— FAIL'}")
895
+ print(f"Generation: {'Success' if result['generation_success'] else 'Failed'}")
896
+ print(f"Outputs: {result['num_outputs']}/{result['expected_outputs']}")
897
+ print(f"Time: {result['elapsed_time']:.2f}s")
898
+
899
+ if result["error"]:
900
+ print(f"Error: {result['error']}")
901
+
902
+ # Save test results to JSON
903
+ results_file = os.path.join(project_root, "test_results.json")
904
+ try:
905
+ with open(results_file, "w") as f:
906
+ json.dump({
907
+ "test_name": result['test_name'],
908
+ "success": result['success'],
909
+ "generation_success": result['generation_success'],
910
+ "num_outputs": result['num_outputs'],
911
+ "expected_outputs": result['expected_outputs'],
912
+ "elapsed_time": result['elapsed_time'],
913
+ "error": result['error'],
914
+ }, f, indent=2)
915
+ print(f"\nβœ“ Test results saved to: {results_file}")
916
+ except Exception as e:
917
+ print(f"\n⚠ Failed to save test results: {e}")
918
+
919
+ # Exit with appropriate code
920
+ print("\n" + "=" * 80)
921
+ if result['success']:
922
+ print("Test passed! βœ“")
923
+ print("=" * 80)
924
+ exit(0)
925
+ else:
926
+ print("Test failed! βœ—")
927
+ print("=" * 80)
928
+ exit(1)