Spaces:
Running
on
A100
Running
on
A100
support input timesteps
Browse files- acestep/api_server.py +12 -1
- acestep/gradio_ui/events/__init__.py +16 -1
- acestep/gradio_ui/events/generation_handlers.py +102 -36
- acestep/gradio_ui/events/results_handlers.py +18 -5
- acestep/gradio_ui/i18n/en.json +7 -2
- acestep/gradio_ui/i18n/ja.json +7 -2
- acestep/gradio_ui/i18n/zh.json +7 -2
- acestep/gradio_ui/interfaces/generation.py +11 -0
- acestep/handler.py +13 -1
- acestep/inference.py +4 -0
acestep/api_server.py
CHANGED
|
@@ -102,6 +102,10 @@ class GenerateMusicRequest(BaseModel):
|
|
| 102 |
cfg_interval_start: float = 0.0
|
| 103 |
cfg_interval_end: float = 1.0
|
| 104 |
infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
audio_format: str = "mp3"
|
| 107 |
use_tiled_decode: bool = True
|
|
@@ -754,13 +758,14 @@ def create_app() -> FastAPI:
|
|
| 754 |
keyscale=key_scale,
|
| 755 |
timesignature=time_signature,
|
| 756 |
duration=audio_duration if audio_duration else -1.0,
|
| 757 |
-
inference_steps=
|
| 758 |
seed=req.seed,
|
| 759 |
guidance_scale=req.guidance_scale,
|
| 760 |
use_adg=req.use_adg,
|
| 761 |
cfg_interval_start=req.cfg_interval_start,
|
| 762 |
cfg_interval_end=req.cfg_interval_end,
|
| 763 |
infer_method=req.infer_method,
|
|
|
|
| 764 |
repainting_start=req.repainting_start,
|
| 765 |
repainting_end=req.repainting_end if req.repainting_end else -1,
|
| 766 |
audio_cover_strength=req.audio_cover_strength,
|
|
@@ -1289,5 +1294,11 @@ def main() -> None:
|
|
| 1289 |
)
|
| 1290 |
|
| 1291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1292 |
if __name__ == "__main__":
|
| 1293 |
main()
|
|
|
|
| 102 |
cfg_interval_start: float = 0.0
|
| 103 |
cfg_interval_end: float = 1.0
|
| 104 |
infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
|
| 105 |
+
timesteps: Optional[str] = Field(
|
| 106 |
+
default=None,
|
| 107 |
+
description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')"
|
| 108 |
+
)
|
| 109 |
|
| 110 |
audio_format: str = "mp3"
|
| 111 |
use_tiled_decode: bool = True
|
|
|
|
| 758 |
keyscale=key_scale,
|
| 759 |
timesignature=time_signature,
|
| 760 |
duration=audio_duration if audio_duration else -1.0,
|
| 761 |
+
inference_steps=actual_inference_steps,
|
| 762 |
seed=req.seed,
|
| 763 |
guidance_scale=req.guidance_scale,
|
| 764 |
use_adg=req.use_adg,
|
| 765 |
cfg_interval_start=req.cfg_interval_start,
|
| 766 |
cfg_interval_end=req.cfg_interval_end,
|
| 767 |
infer_method=req.infer_method,
|
| 768 |
+
timesteps=parsed_timesteps,
|
| 769 |
repainting_start=req.repainting_start,
|
| 770 |
repainting_end=req.repainting_end if req.repainting_end else -1,
|
| 771 |
audio_cover_strength=req.audio_cover_strength,
|
|
|
|
| 1294 |
)
|
| 1295 |
|
| 1296 |
|
| 1297 |
+
if __name__ == "__main__":
|
| 1298 |
+
main()
|
| 1299 |
+
,
|
| 1300 |
+
)
|
| 1301 |
+
|
| 1302 |
+
|
| 1303 |
if __name__ == "__main__":
|
| 1304 |
main()
|
acestep/gradio_ui/events/__init__.py
CHANGED
|
@@ -54,7 +54,19 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 54 |
generation_section["offload_to_cpu_checkbox"],
|
| 55 |
generation_section["offload_dit_to_cpu_checkbox"],
|
| 56 |
],
|
| 57 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
)
|
| 59 |
|
| 60 |
# ========== UI Visibility Updates ==========
|
|
@@ -312,6 +324,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 312 |
generation_section["cfg_interval_end"],
|
| 313 |
generation_section["shift"],
|
| 314 |
generation_section["infer_method"],
|
|
|
|
| 315 |
generation_section["audio_format"],
|
| 316 |
generation_section["lm_temperature"],
|
| 317 |
generation_section["lm_cfg_scale"],
|
|
@@ -510,6 +523,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 510 |
generation_section["cfg_interval_end"],
|
| 511 |
generation_section["shift"],
|
| 512 |
generation_section["infer_method"],
|
|
|
|
| 513 |
generation_section["audio_format"],
|
| 514 |
generation_section["lm_temperature"],
|
| 515 |
generation_section["think_checkbox"],
|
|
@@ -697,6 +711,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 697 |
generation_section["cfg_interval_end"],
|
| 698 |
generation_section["shift"],
|
| 699 |
generation_section["infer_method"],
|
|
|
|
| 700 |
generation_section["audio_format"],
|
| 701 |
generation_section["lm_temperature"],
|
| 702 |
generation_section["think_checkbox"],
|
|
|
|
| 54 |
generation_section["offload_to_cpu_checkbox"],
|
| 55 |
generation_section["offload_dit_to_cpu_checkbox"],
|
| 56 |
],
|
| 57 |
+
outputs=[
|
| 58 |
+
generation_section["init_status"],
|
| 59 |
+
generation_section["generate_btn"],
|
| 60 |
+
generation_section["service_config_accordion"],
|
| 61 |
+
# Model type settings (updated based on actual loaded model)
|
| 62 |
+
generation_section["inference_steps"],
|
| 63 |
+
generation_section["guidance_scale"],
|
| 64 |
+
generation_section["use_adg"],
|
| 65 |
+
generation_section["shift"],
|
| 66 |
+
generation_section["cfg_interval_start"],
|
| 67 |
+
generation_section["cfg_interval_end"],
|
| 68 |
+
generation_section["task_type"],
|
| 69 |
+
]
|
| 70 |
)
|
| 71 |
|
| 72 |
# ========== UI Visibility Updates ==========
|
|
|
|
| 324 |
generation_section["cfg_interval_end"],
|
| 325 |
generation_section["shift"],
|
| 326 |
generation_section["infer_method"],
|
| 327 |
+
generation_section["custom_timesteps"],
|
| 328 |
generation_section["audio_format"],
|
| 329 |
generation_section["lm_temperature"],
|
| 330 |
generation_section["lm_cfg_scale"],
|
|
|
|
| 523 |
generation_section["cfg_interval_end"],
|
| 524 |
generation_section["shift"],
|
| 525 |
generation_section["infer_method"],
|
| 526 |
+
generation_section["custom_timesteps"],
|
| 527 |
generation_section["audio_format"],
|
| 528 |
generation_section["lm_temperature"],
|
| 529 |
generation_section["think_checkbox"],
|
|
|
|
| 711 |
generation_section["cfg_interval_end"],
|
| 712 |
generation_section["shift"],
|
| 713 |
generation_section["infer_method"],
|
| 714 |
+
generation_section["custom_timesteps"],
|
| 715 |
generation_section["audio_format"],
|
| 716 |
generation_section["lm_temperature"],
|
| 717 |
generation_section["think_checkbox"],
|
acestep/gradio_ui/events/generation_handlers.py
CHANGED
|
@@ -7,7 +7,7 @@ import json
|
|
| 7 |
import random
|
| 8 |
import glob
|
| 9 |
import gradio as gr
|
| 10 |
-
from typing import Optional
|
| 11 |
from acestep.constants import (
|
| 12 |
TASK_TYPES_TURBO,
|
| 13 |
TASK_TYPES_BASE,
|
|
@@ -16,6 +16,56 @@ from acestep.gradio_ui.i18n import t
|
|
| 16 |
from acestep.inference import understand_music, create_sample, format_sample
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def load_metadata(file_obj):
|
| 20 |
"""Load generation parameters from a JSON file"""
|
| 21 |
if file_obj is None:
|
|
@@ -321,50 +371,31 @@ def refresh_checkpoints(dit_handler):
|
|
| 321 |
|
| 322 |
|
| 323 |
def update_model_type_settings(config_path):
|
| 324 |
-
"""Update UI settings based on model type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
if config_path is None:
|
| 326 |
config_path = ""
|
| 327 |
config_path_lower = config_path.lower()
|
| 328 |
|
|
|
|
|
|
|
| 329 |
if "turbo" in config_path_lower:
|
| 330 |
-
|
| 331 |
-
# Shift is not effective for turbo models, default to 1.0
|
| 332 |
-
return (
|
| 333 |
-
gr.update(value=8, maximum=8, minimum=1), # inference_steps
|
| 334 |
-
gr.update(visible=False), # guidance_scale
|
| 335 |
-
gr.update(visible=False), # use_adg
|
| 336 |
-
gr.update(value=1.0, visible=False), # shift (not effective for turbo)
|
| 337 |
-
gr.update(visible=False), # cfg_interval_start
|
| 338 |
-
gr.update(visible=False), # cfg_interval_end
|
| 339 |
-
gr.update(choices=TASK_TYPES_TURBO), # task_type
|
| 340 |
-
)
|
| 341 |
elif "base" in config_path_lower:
|
| 342 |
-
|
| 343 |
-
# Shift range 1.0~5.0, default 3.0 for base models
|
| 344 |
-
return (
|
| 345 |
-
gr.update(value=32, maximum=100, minimum=1), # inference_steps
|
| 346 |
-
gr.update(visible=True), # guidance_scale
|
| 347 |
-
gr.update(visible=True), # use_adg
|
| 348 |
-
gr.update(value=3.0, visible=True), # shift (effective for base, default 3.0)
|
| 349 |
-
gr.update(visible=True), # cfg_interval_start
|
| 350 |
-
gr.update(visible=True), # cfg_interval_end
|
| 351 |
-
gr.update(choices=TASK_TYPES_BASE), # task_type
|
| 352 |
-
)
|
| 353 |
else:
|
| 354 |
-
# Default to turbo settings
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
gr.update(visible=False),
|
| 359 |
-
gr.update(value=1.0, visible=False), # shift default 1.0
|
| 360 |
-
gr.update(visible=False),
|
| 361 |
-
gr.update(visible=False),
|
| 362 |
-
gr.update(choices=TASK_TYPES_TURBO), # task_type
|
| 363 |
-
)
|
| 364 |
|
| 365 |
|
| 366 |
def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
|
| 367 |
-
"""Wrapper for service initialization, returns status, button state, and
|
| 368 |
# Initialize DiT handler
|
| 369 |
status, enable = dit_handler.initialize_service(
|
| 370 |
checkpoint, config_path, device,
|
|
@@ -400,7 +431,42 @@ def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, devi
|
|
| 400 |
is_model_initialized = dit_handler.model is not None
|
| 401 |
accordion_state = gr.update(open=not is_model_initialized)
|
| 402 |
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
|
| 406 |
def update_negative_prompt_visibility(init_llm_checked):
|
|
|
|
| 7 |
import random
|
| 8 |
import glob
|
| 9 |
import gradio as gr
|
| 10 |
+
from typing import Optional, List, Tuple
|
| 11 |
from acestep.constants import (
|
| 12 |
TASK_TYPES_TURBO,
|
| 13 |
TASK_TYPES_BASE,
|
|
|
|
| 16 |
from acestep.inference import understand_music, create_sample, format_sample
|
| 17 |
|
| 18 |
|
| 19 |
+
def parse_and_validate_timesteps(
|
| 20 |
+
timesteps_str: str,
|
| 21 |
+
inference_steps: int
|
| 22 |
+
) -> Tuple[Optional[List[float]], bool, str]:
|
| 23 |
+
"""
|
| 24 |
+
Parse timesteps string and validate.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
timesteps_str: Comma-separated timesteps string (e.g., "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
|
| 28 |
+
inference_steps: Expected number of inference steps
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Tuple of (parsed_timesteps, has_warning, warning_message)
|
| 32 |
+
- parsed_timesteps: List of float timesteps, or None if invalid/empty
|
| 33 |
+
- has_warning: Whether a warning was shown
|
| 34 |
+
- warning_message: Description of the warning
|
| 35 |
+
"""
|
| 36 |
+
if not timesteps_str or not timesteps_str.strip():
|
| 37 |
+
return None, False, ""
|
| 38 |
+
|
| 39 |
+
# Parse comma-separated values
|
| 40 |
+
values = [v.strip() for v in timesteps_str.split(",") if v.strip()]
|
| 41 |
+
|
| 42 |
+
if not values:
|
| 43 |
+
return None, False, ""
|
| 44 |
+
|
| 45 |
+
# Handle optional trailing 0
|
| 46 |
+
if values[-1] != "0":
|
| 47 |
+
values.append("0")
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
timesteps = [float(v) for v in values]
|
| 51 |
+
except ValueError:
|
| 52 |
+
gr.Warning(t("messages.invalid_timesteps_format"))
|
| 53 |
+
return None, True, "Invalid format"
|
| 54 |
+
|
| 55 |
+
# Validate range [0, 1]
|
| 56 |
+
if any(ts < 0 or ts > 1 for ts in timesteps):
|
| 57 |
+
gr.Warning(t("messages.timesteps_out_of_range"))
|
| 58 |
+
return None, True, "Out of range"
|
| 59 |
+
|
| 60 |
+
# Check if count matches inference_steps
|
| 61 |
+
actual_steps = len(timesteps) - 1
|
| 62 |
+
if actual_steps != inference_steps:
|
| 63 |
+
gr.Warning(t("messages.timesteps_count_mismatch", actual=actual_steps, expected=inference_steps))
|
| 64 |
+
return timesteps, True, f"Using {actual_steps} steps from timesteps"
|
| 65 |
+
|
| 66 |
+
return timesteps, False, ""
|
| 67 |
+
|
| 68 |
+
|
| 69 |
def load_metadata(file_obj):
|
| 70 |
"""Load generation parameters from a JSON file"""
|
| 71 |
if file_obj is None:
|
|
|
|
| 371 |
|
| 372 |
|
| 373 |
def update_model_type_settings(config_path):
|
| 374 |
+
"""Update UI settings based on model type (fallback when handler not initialized yet)
|
| 375 |
+
|
| 376 |
+
Note: This is used as a fallback when the user changes config_path dropdown
|
| 377 |
+
before initializing the model. The actual settings are determined by the
|
| 378 |
+
handler's is_turbo_model() method after initialization.
|
| 379 |
+
"""
|
| 380 |
if config_path is None:
|
| 381 |
config_path = ""
|
| 382 |
config_path_lower = config_path.lower()
|
| 383 |
|
| 384 |
+
# Determine is_turbo based on config_path string
|
| 385 |
+
# This is a heuristic fallback - actual model type is determined after loading
|
| 386 |
if "turbo" in config_path_lower:
|
| 387 |
+
is_turbo = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
elif "base" in config_path_lower:
|
| 389 |
+
is_turbo = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
else:
|
| 391 |
+
# Default to turbo settings for unknown model types
|
| 392 |
+
is_turbo = True
|
| 393 |
+
|
| 394 |
+
return get_model_type_ui_settings(is_turbo)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
|
| 396 |
|
| 397 |
def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
|
| 398 |
+
"""Wrapper for service initialization, returns status, button state, accordion state, and model type settings"""
|
| 399 |
# Initialize DiT handler
|
| 400 |
status, enable = dit_handler.initialize_service(
|
| 401 |
checkpoint, config_path, device,
|
|
|
|
| 431 |
is_model_initialized = dit_handler.model is not None
|
| 432 |
accordion_state = gr.update(open=not is_model_initialized)
|
| 433 |
|
| 434 |
+
# Get model type settings based on actual loaded model
|
| 435 |
+
is_turbo = dit_handler.is_turbo_model()
|
| 436 |
+
model_type_settings = get_model_type_ui_settings(is_turbo)
|
| 437 |
+
|
| 438 |
+
return (
|
| 439 |
+
status,
|
| 440 |
+
gr.update(interactive=enable),
|
| 441 |
+
accordion_state,
|
| 442 |
+
*model_type_settings
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def get_model_type_ui_settings(is_turbo: bool):
|
| 447 |
+
"""Get UI settings based on whether the model is turbo or base"""
|
| 448 |
+
if is_turbo:
|
| 449 |
+
# Turbo model: max 8 steps, hide CFG/ADG/shift, only show text2music/repaint/cover
|
| 450 |
+
return (
|
| 451 |
+
gr.update(value=8, maximum=8, minimum=1), # inference_steps
|
| 452 |
+
gr.update(visible=False), # guidance_scale
|
| 453 |
+
gr.update(visible=False), # use_adg
|
| 454 |
+
gr.update(value=1.0, visible=False), # shift (not effective for turbo)
|
| 455 |
+
gr.update(visible=False), # cfg_interval_start
|
| 456 |
+
gr.update(visible=False), # cfg_interval_end
|
| 457 |
+
gr.update(choices=TASK_TYPES_TURBO), # task_type
|
| 458 |
+
)
|
| 459 |
+
else:
|
| 460 |
+
# Base model: max 200 steps, default 32, show CFG/ADG/shift, show all task types
|
| 461 |
+
return (
|
| 462 |
+
gr.update(value=32, maximum=200, minimum=1), # inference_steps
|
| 463 |
+
gr.update(visible=True), # guidance_scale
|
| 464 |
+
gr.update(visible=True), # use_adg
|
| 465 |
+
gr.update(value=3.0, visible=True), # shift (effective for base, default 3.0)
|
| 466 |
+
gr.update(visible=True), # cfg_interval_start
|
| 467 |
+
gr.update(visible=True), # cfg_interval_end
|
| 468 |
+
gr.update(choices=TASK_TYPES_BASE), # task_type
|
| 469 |
+
)
|
| 470 |
|
| 471 |
|
| 472 |
def update_negative_prompt_visibility(init_llm_checked):
|
acestep/gradio_ui/events/results_handlers.py
CHANGED
|
@@ -15,6 +15,7 @@ from typing import Dict, Any, Optional, List
|
|
| 15 |
import gradio as gr
|
| 16 |
from loguru import logger
|
| 17 |
from acestep.gradio_ui.i18n import t
|
|
|
|
| 18 |
from acestep.inference import generate_music, GenerationParams, GenerationConfig
|
| 19 |
from acestep.audio_utils import save_audio
|
| 20 |
|
|
@@ -452,7 +453,7 @@ def generate_with_progress(
|
|
| 452 |
reference_audio, audio_duration, batch_size_input, src_audio,
|
| 453 |
text2music_audio_code_string, repainting_start, repainting_end,
|
| 454 |
instruction_display_gen, audio_cover_strength, task_type,
|
| 455 |
-
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, audio_format, lm_temperature,
|
| 456 |
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 457 |
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
|
| 458 |
constrained_decoding_debug,
|
|
@@ -473,6 +474,14 @@ def generate_with_progress(
|
|
| 473 |
logger.info("[generate_with_progress] Skipping Phase 1 metas COT: sample is already formatted (is_format_caption=True)")
|
| 474 |
gr.Info(t("messages.skipping_metas_cot"))
|
| 475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
# step 1: prepare inputs
|
| 477 |
# generate_music, GenerationParams, GenerationConfig
|
| 478 |
gen_params = GenerationParams(
|
|
@@ -489,13 +498,14 @@ def generate_with_progress(
|
|
| 489 |
keyscale=key_scale,
|
| 490 |
timesignature=time_signature,
|
| 491 |
duration=audio_duration,
|
| 492 |
-
inference_steps=
|
| 493 |
guidance_scale=guidance_scale,
|
| 494 |
use_adg=use_adg,
|
| 495 |
cfg_interval_start=cfg_interval_start,
|
| 496 |
cfg_interval_end=cfg_interval_end,
|
| 497 |
shift=shift,
|
| 498 |
infer_method=infer_method,
|
|
|
|
| 499 |
repainting_start=repainting_start,
|
| 500 |
repainting_end=repainting_end,
|
| 501 |
audio_cover_strength=audio_cover_strength,
|
|
@@ -1311,7 +1321,7 @@ def capture_current_params(
|
|
| 1311 |
reference_audio, audio_duration, batch_size_input, src_audio,
|
| 1312 |
text2music_audio_code_string, repainting_start, repainting_end,
|
| 1313 |
instruction_display_gen, audio_cover_strength, task_type,
|
| 1314 |
-
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, audio_format, lm_temperature,
|
| 1315 |
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 1316 |
use_cot_metas, use_cot_caption, use_cot_language,
|
| 1317 |
constrained_decoding_debug, allow_lm_batch, auto_score, auto_lrc, score_scale, lm_batch_chunk_size,
|
|
@@ -1349,6 +1359,7 @@ def capture_current_params(
|
|
| 1349 |
"cfg_interval_end": cfg_interval_end,
|
| 1350 |
"shift": shift,
|
| 1351 |
"infer_method": infer_method,
|
|
|
|
| 1352 |
"audio_format": audio_format,
|
| 1353 |
"lm_temperature": lm_temperature,
|
| 1354 |
"think_checkbox": think_checkbox,
|
|
@@ -1377,7 +1388,7 @@ def generate_with_batch_management(
|
|
| 1377 |
reference_audio, audio_duration, batch_size_input, src_audio,
|
| 1378 |
text2music_audio_code_string, repainting_start, repainting_end,
|
| 1379 |
instruction_display_gen, audio_cover_strength, task_type,
|
| 1380 |
-
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, audio_format, lm_temperature,
|
| 1381 |
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 1382 |
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
|
| 1383 |
constrained_decoding_debug,
|
|
@@ -1406,7 +1417,7 @@ def generate_with_batch_management(
|
|
| 1406 |
reference_audio, audio_duration, batch_size_input, src_audio,
|
| 1407 |
text2music_audio_code_string, repainting_start, repainting_end,
|
| 1408 |
instruction_display_gen, audio_cover_strength, task_type,
|
| 1409 |
-
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, audio_format, lm_temperature,
|
| 1410 |
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 1411 |
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
|
| 1412 |
constrained_decoding_debug,
|
|
@@ -1673,6 +1684,7 @@ def generate_next_batch_background(
|
|
| 1673 |
params.setdefault("cfg_interval_end", 1.0)
|
| 1674 |
params.setdefault("shift", 1.0)
|
| 1675 |
params.setdefault("infer_method", "ode")
|
|
|
|
| 1676 |
params.setdefault("audio_format", "mp3")
|
| 1677 |
params.setdefault("lm_temperature", 0.85)
|
| 1678 |
params.setdefault("think_checkbox", True)
|
|
@@ -1724,6 +1736,7 @@ def generate_next_batch_background(
|
|
| 1724 |
cfg_interval_end=params.get("cfg_interval_end"),
|
| 1725 |
shift=params.get("shift"),
|
| 1726 |
infer_method=params.get("infer_method"),
|
|
|
|
| 1727 |
audio_format=params.get("audio_format"),
|
| 1728 |
lm_temperature=params.get("lm_temperature"),
|
| 1729 |
think_checkbox=params.get("think_checkbox"),
|
|
|
|
| 15 |
import gradio as gr
|
| 16 |
from loguru import logger
|
| 17 |
from acestep.gradio_ui.i18n import t
|
| 18 |
+
from acestep.gradio_ui.events.generation_handlers import parse_and_validate_timesteps
|
| 19 |
from acestep.inference import generate_music, GenerationParams, GenerationConfig
|
| 20 |
from acestep.audio_utils import save_audio
|
| 21 |
|
|
|
|
| 453 |
reference_audio, audio_duration, batch_size_input, src_audio,
|
| 454 |
text2music_audio_code_string, repainting_start, repainting_end,
|
| 455 |
instruction_display_gen, audio_cover_strength, task_type,
|
| 456 |
+
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, custom_timesteps, audio_format, lm_temperature,
|
| 457 |
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 458 |
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
|
| 459 |
constrained_decoding_debug,
|
|
|
|
| 474 |
logger.info("[generate_with_progress] Skipping Phase 1 metas COT: sample is already formatted (is_format_caption=True)")
|
| 475 |
gr.Info(t("messages.skipping_metas_cot"))
|
| 476 |
|
| 477 |
+
# Parse and validate custom timesteps
|
| 478 |
+
parsed_timesteps, has_timesteps_warning, _ = parse_and_validate_timesteps(custom_timesteps, inference_steps)
|
| 479 |
+
|
| 480 |
+
# Update inference_steps if custom timesteps provided (to match UI display)
|
| 481 |
+
actual_inference_steps = inference_steps
|
| 482 |
+
if parsed_timesteps is not None:
|
| 483 |
+
actual_inference_steps = len(parsed_timesteps) - 1
|
| 484 |
+
|
| 485 |
# step 1: prepare inputs
|
| 486 |
# generate_music, GenerationParams, GenerationConfig
|
| 487 |
gen_params = GenerationParams(
|
|
|
|
| 498 |
keyscale=key_scale,
|
| 499 |
timesignature=time_signature,
|
| 500 |
duration=audio_duration,
|
| 501 |
+
inference_steps=actual_inference_steps,
|
| 502 |
guidance_scale=guidance_scale,
|
| 503 |
use_adg=use_adg,
|
| 504 |
cfg_interval_start=cfg_interval_start,
|
| 505 |
cfg_interval_end=cfg_interval_end,
|
| 506 |
shift=shift,
|
| 507 |
infer_method=infer_method,
|
| 508 |
+
timesteps=parsed_timesteps,
|
| 509 |
repainting_start=repainting_start,
|
| 510 |
repainting_end=repainting_end,
|
| 511 |
audio_cover_strength=audio_cover_strength,
|
|
|
|
| 1321 |
reference_audio, audio_duration, batch_size_input, src_audio,
|
| 1322 |
text2music_audio_code_string, repainting_start, repainting_end,
|
| 1323 |
instruction_display_gen, audio_cover_strength, task_type,
|
| 1324 |
+
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, custom_timesteps, audio_format, lm_temperature,
|
| 1325 |
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 1326 |
use_cot_metas, use_cot_caption, use_cot_language,
|
| 1327 |
constrained_decoding_debug, allow_lm_batch, auto_score, auto_lrc, score_scale, lm_batch_chunk_size,
|
|
|
|
| 1359 |
"cfg_interval_end": cfg_interval_end,
|
| 1360 |
"shift": shift,
|
| 1361 |
"infer_method": infer_method,
|
| 1362 |
+
"custom_timesteps": custom_timesteps,
|
| 1363 |
"audio_format": audio_format,
|
| 1364 |
"lm_temperature": lm_temperature,
|
| 1365 |
"think_checkbox": think_checkbox,
|
|
|
|
| 1388 |
reference_audio, audio_duration, batch_size_input, src_audio,
|
| 1389 |
text2music_audio_code_string, repainting_start, repainting_end,
|
| 1390 |
instruction_display_gen, audio_cover_strength, task_type,
|
| 1391 |
+
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, custom_timesteps, audio_format, lm_temperature,
|
| 1392 |
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 1393 |
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
|
| 1394 |
constrained_decoding_debug,
|
|
|
|
| 1417 |
reference_audio, audio_duration, batch_size_input, src_audio,
|
| 1418 |
text2music_audio_code_string, repainting_start, repainting_end,
|
| 1419 |
instruction_display_gen, audio_cover_strength, task_type,
|
| 1420 |
+
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, custom_timesteps, audio_format, lm_temperature,
|
| 1421 |
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 1422 |
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
|
| 1423 |
constrained_decoding_debug,
|
|
|
|
| 1684 |
params.setdefault("cfg_interval_end", 1.0)
|
| 1685 |
params.setdefault("shift", 1.0)
|
| 1686 |
params.setdefault("infer_method", "ode")
|
| 1687 |
+
params.setdefault("custom_timesteps", "")
|
| 1688 |
params.setdefault("audio_format", "mp3")
|
| 1689 |
params.setdefault("lm_temperature", 0.85)
|
| 1690 |
params.setdefault("think_checkbox", True)
|
|
|
|
| 1736 |
cfg_interval_end=params.get("cfg_interval_end"),
|
| 1737 |
shift=params.get("shift"),
|
| 1738 |
infer_method=params.get("infer_method"),
|
| 1739 |
+
custom_timesteps=params.get("custom_timesteps"),
|
| 1740 |
audio_format=params.get("audio_format"),
|
| 1741 |
lm_temperature=params.get("lm_temperature"),
|
| 1742 |
think_checkbox=params.get("think_checkbox"),
|
acestep/gradio_ui/i18n/en.json
CHANGED
|
@@ -115,7 +115,7 @@
|
|
| 115 |
"batch_size_info": "Number of audio to generate (max 8)",
|
| 116 |
"advanced_settings": "🔧 Advanced Settings",
|
| 117 |
"inference_steps_label": "DiT Inference Steps",
|
| 118 |
-
"inference_steps_info": "Turbo: max 8, Base: max
|
| 119 |
"guidance_scale_label": "DiT Guidance Scale (Only support for base model)",
|
| 120 |
"guidance_scale_info": "Higher values follow text more closely",
|
| 121 |
"seed_label": "Seed",
|
|
@@ -130,6 +130,8 @@
|
|
| 130 |
"shift_info": "Timestep shift factor for base models (range 1.0~5.0, default 3.0). Not effective for turbo models.",
|
| 131 |
"infer_method_label": "Inference Method",
|
| 132 |
"infer_method_info": "Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
|
|
|
|
|
|
|
| 133 |
"cfg_interval_start": "CFG Interval Start",
|
| 134 |
"cfg_interval_end": "CFG Interval End",
|
| 135 |
"lm_params_title": "🤖 LM Generation Parameters",
|
|
@@ -233,6 +235,9 @@
|
|
| 233 |
"simple_example_loaded": "🎲 Loaded random example from {filename}",
|
| 234 |
"format_success": "✅ Caption and lyrics formatted successfully",
|
| 235 |
"format_failed": "❌ Format failed: {error}",
|
| 236 |
-
"skipping_metas_cot": "⚡ Skipping Phase 1 metas COT (sample already formatted)"
|
|
|
|
|
|
|
|
|
|
| 237 |
}
|
| 238 |
}
|
|
|
|
| 115 |
"batch_size_info": "Number of audio to generate (max 8)",
|
| 116 |
"advanced_settings": "🔧 Advanced Settings",
|
| 117 |
"inference_steps_label": "DiT Inference Steps",
|
| 118 |
+
"inference_steps_info": "Turbo: max 8, Base: max 200",
|
| 119 |
"guidance_scale_label": "DiT Guidance Scale (Only support for base model)",
|
| 120 |
"guidance_scale_info": "Higher values follow text more closely",
|
| 121 |
"seed_label": "Seed",
|
|
|
|
| 130 |
"shift_info": "Timestep shift factor for base models (range 1.0~5.0, default 3.0). Not effective for turbo models.",
|
| 131 |
"infer_method_label": "Inference Method",
|
| 132 |
"infer_method_info": "Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
|
| 133 |
+
"custom_timesteps_label": "Custom Timesteps",
|
| 134 |
+
"custom_timesteps_info": "Optional: comma-separated values from 1.0 to 0.0 (e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference steps and shift.",
|
| 135 |
"cfg_interval_start": "CFG Interval Start",
|
| 136 |
"cfg_interval_end": "CFG Interval End",
|
| 137 |
"lm_params_title": "🤖 LM Generation Parameters",
|
|
|
|
| 235 |
"simple_example_loaded": "🎲 Loaded random example from {filename}",
|
| 236 |
"format_success": "✅ Caption and lyrics formatted successfully",
|
| 237 |
"format_failed": "❌ Format failed: {error}",
|
| 238 |
+
"skipping_metas_cot": "⚡ Skipping Phase 1 metas COT (sample already formatted)",
|
| 239 |
+
"invalid_timesteps_format": "⚠️ Invalid timesteps format. Using default schedule.",
|
| 240 |
+
"timesteps_out_of_range": "⚠️ Timesteps must be in range [0, 1]. Using default schedule.",
|
| 241 |
+
"timesteps_count_mismatch": "⚠️ Timesteps count ({actual}) differs from inference_steps ({expected}). Using timesteps count."
|
| 242 |
}
|
| 243 |
}
|
acestep/gradio_ui/i18n/ja.json
CHANGED
|
@@ -115,7 +115,7 @@
|
|
| 115 |
"batch_size_info": "生成するオーディオの数(最大8)",
|
| 116 |
"advanced_settings": "🔧 詳細設定",
|
| 117 |
"inference_steps_label": "DiT 推論ステップ",
|
| 118 |
-
"inference_steps_info": "Turbo: 最大8、Base: 最大
|
| 119 |
"guidance_scale_label": "DiT ガイダンススケール(baseモデルのみサポート)",
|
| 120 |
"guidance_scale_info": "値が高いほどテキストに忠実に従う",
|
| 121 |
"seed_label": "シード",
|
|
@@ -130,6 +130,8 @@
|
|
| 130 |
"shift_info": "baseモデル用タイムステップシフト係数 (範囲 1.0~5.0、デフォルト 3.0)。turboモデルには無効。",
|
| 131 |
"infer_method_label": "推論方法",
|
| 132 |
"infer_method_info": "拡散推論方法。ODE (オイラー) は高速、SDE (確率的) は異なる結果を生成する可能性があります。",
|
|
|
|
|
|
|
| 133 |
"cfg_interval_start": "CFG 間隔開始",
|
| 134 |
"cfg_interval_end": "CFG 間隔終了",
|
| 135 |
"lm_params_title": "🤖 LM 生成パラメータ",
|
|
@@ -233,6 +235,9 @@
|
|
| 233 |
"simple_example_loaded": "🎲 {filename} からランダムサンプルを読み込みました",
|
| 234 |
"format_success": "✅ キャプションと歌詞のフォーマットに成功しました",
|
| 235 |
"format_failed": "❌ フォーマットに失敗しました: {error}",
|
| 236 |
-
"skipping_metas_cot": "⚡ Phase 1 メタデータ COT をスキップ(サンプルは既にフォーマット済み)"
|
|
|
|
|
|
|
|
|
|
| 237 |
}
|
| 238 |
}
|
|
|
|
| 115 |
"batch_size_info": "生成するオーディオの数(最大8)",
|
| 116 |
"advanced_settings": "🔧 詳細設定",
|
| 117 |
"inference_steps_label": "DiT 推論ステップ",
|
| 118 |
+
"inference_steps_info": "Turbo: 最大8、Base: 最大200",
|
| 119 |
"guidance_scale_label": "DiT ガイダンススケール(baseモデルのみサポート)",
|
| 120 |
"guidance_scale_info": "値が高いほどテキストに忠実に従う",
|
| 121 |
"seed_label": "シード",
|
|
|
|
| 130 |
"shift_info": "baseモデル用タイムステップシフト係数 (範囲 1.0~5.0、デフォルト 3.0)。turboモデルには無効。",
|
| 131 |
"infer_method_label": "推論方法",
|
| 132 |
"infer_method_info": "拡散推論方法。ODE (オイラー) は高速、SDE (確率的) は異なる結果を生成する可能性があります。",
|
| 133 |
+
"custom_timesteps_label": "カスタムタイムステップ",
|
| 134 |
+
"custom_timesteps_info": "オプション:1.0から0.0へのカンマ区切り値(例:'0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。推論ステップとシフトを上書きします。",
|
| 135 |
"cfg_interval_start": "CFG 間隔開始",
|
| 136 |
"cfg_interval_end": "CFG 間隔終了",
|
| 137 |
"lm_params_title": "🤖 LM 生成パラメータ",
|
|
|
|
| 235 |
"simple_example_loaded": "🎲 {filename} からランダムサンプルを読み込みました",
|
| 236 |
"format_success": "✅ キャプションと歌詞のフォーマットに成功しました",
|
| 237 |
"format_failed": "❌ フォーマットに失敗しました: {error}",
|
| 238 |
+
"skipping_metas_cot": "⚡ Phase 1 メタデータ COT をスキップ(サンプルは既にフォーマット済み)",
|
| 239 |
+
"invalid_timesteps_format": "⚠️ タイムステップ形式が無効です。デフォルトスケジュールを使用します。",
|
| 240 |
+
"timesteps_out_of_range": "⚠️ タイムステップは [0, 1] の範囲内である必要があります。デフォルトスケジュールを使用します。",
|
| 241 |
+
"timesteps_count_mismatch": "⚠️ タイムステップ数 ({actual}) が推論ステップ数 ({expected}) と異なります。タイムステップ数を使用します。"
|
| 242 |
}
|
| 243 |
}
|
acestep/gradio_ui/i18n/zh.json
CHANGED
|
@@ -115,7 +115,7 @@
|
|
| 115 |
"batch_size_info": "要生成的音频数量(最多8个)",
|
| 116 |
"advanced_settings": "🔧 高级设置",
|
| 117 |
"inference_steps_label": "DiT 推理步数",
|
| 118 |
-
"inference_steps_info": "Turbo: 最多8, Base: 最多
|
| 119 |
"guidance_scale_label": "DiT 引导比例(仅支持base模型)",
|
| 120 |
"guidance_scale_info": "更高的值更紧密地遵循文本",
|
| 121 |
"seed_label": "种子",
|
|
@@ -130,6 +130,8 @@
|
|
| 130 |
"shift_info": "时间步偏移因子,仅对 base 模型生效 (范围 1.0~5.0,默认 3.0)。对 turbo 模型无效。",
|
| 131 |
"infer_method_label": "推理方法",
|
| 132 |
"infer_method_info": "扩散推理方法。ODE (欧拉) 更快,SDE (随机) 可能产生不同结果。",
|
|
|
|
|
|
|
| 133 |
"cfg_interval_start": "CFG 间隔开始",
|
| 134 |
"cfg_interval_end": "CFG 间隔结束",
|
| 135 |
"lm_params_title": "🤖 LM 生成参数",
|
|
@@ -233,6 +235,9 @@
|
|
| 233 |
"simple_example_loaded": "🎲 已从 {filename} 加载随机示例",
|
| 234 |
"format_success": "✅ 描述和歌词格式化成功",
|
| 235 |
"format_failed": "❌ 格式化失败: {error}",
|
| 236 |
-
"skipping_metas_cot": "⚡ 跳过 Phase 1 元数据 COT(样本已格式化)"
|
|
|
|
|
|
|
|
|
|
| 237 |
}
|
| 238 |
}
|
|
|
|
| 115 |
"batch_size_info": "要生成的音频数量(最多8个)",
|
| 116 |
"advanced_settings": "🔧 高级设置",
|
| 117 |
"inference_steps_label": "DiT 推理步数",
|
| 118 |
+
"inference_steps_info": "Turbo: 最多8, Base: 最多200",
|
| 119 |
"guidance_scale_label": "DiT 引导比例(仅支持base模型)",
|
| 120 |
"guidance_scale_info": "更高的值更紧密地遵循文本",
|
| 121 |
"seed_label": "种子",
|
|
|
|
| 130 |
"shift_info": "时间步偏移因子,仅对 base 模型生效 (范围 1.0~5.0,默认 3.0)。对 turbo 模型无效。",
|
| 131 |
"infer_method_label": "推理方法",
|
| 132 |
"infer_method_info": "扩散推理方法。ODE (欧拉) 更快,SDE (随机) 可能产生不同结果。",
|
| 133 |
+
"custom_timesteps_label": "自定义时间步",
|
| 134 |
+
"custom_timesteps_info": "可选:从 1.0 到 0.0 的逗号分隔值(例如 '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。会覆盖推理步数和 shift 设置。",
|
| 135 |
"cfg_interval_start": "CFG 间隔开始",
|
| 136 |
"cfg_interval_end": "CFG 间隔结束",
|
| 137 |
"lm_params_title": "🤖 LM 生成参数",
|
|
|
|
| 235 |
"simple_example_loaded": "🎲 已从 {filename} 加载随机示例",
|
| 236 |
"format_success": "✅ 描述和歌词格式化成功",
|
| 237 |
"format_failed": "❌ 格式化失败: {error}",
|
| 238 |
+
"skipping_metas_cot": "⚡ 跳过 Phase 1 元数据 COT(样本已格式化)",
|
| 239 |
+
"invalid_timesteps_format": "⚠️ 时间步格式无效,使用默认调度。",
|
| 240 |
+
"timesteps_out_of_range": "⚠️ 时间步必须在 [0, 1] 范围内,使用默认调度。",
|
| 241 |
+
"timesteps_count_mismatch": "⚠️ 时间步数量 ({actual}) 与推理步数 ({expected}) 不匹配,将使用时间步数量。"
|
| 242 |
}
|
| 243 |
}
|
acestep/gradio_ui/interfaces/generation.py
CHANGED
|
@@ -402,6 +402,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 402 |
)
|
| 403 |
|
| 404 |
# Advanced Settings
|
|
|
|
|
|
|
| 405 |
with gr.Accordion(t("generation.advanced_settings"), open=False):
|
| 406 |
with gr.Row():
|
| 407 |
inference_steps = gr.Slider(
|
|
@@ -462,6 +464,14 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 462 |
info=t("generation.infer_method_info"),
|
| 463 |
)
|
| 464 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
with gr.Row():
|
| 466 |
cfg_interval_start = gr.Slider(
|
| 467 |
minimum=0.0,
|
|
@@ -698,6 +708,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 698 |
"cfg_interval_end": cfg_interval_end,
|
| 699 |
"shift": shift,
|
| 700 |
"infer_method": infer_method,
|
|
|
|
| 701 |
"audio_format": audio_format,
|
| 702 |
"output_alignment_preference": output_alignment_preference,
|
| 703 |
"think_checkbox": think_checkbox,
|
|
|
|
| 402 |
)
|
| 403 |
|
| 404 |
# Advanced Settings
|
| 405 |
+
# Default UI settings use turbo mode (max 8 steps, hide CFG/ADG/shift)
|
| 406 |
+
# These will be updated after model initialization based on handler.is_turbo_model()
|
| 407 |
with gr.Accordion(t("generation.advanced_settings"), open=False):
|
| 408 |
with gr.Row():
|
| 409 |
inference_steps = gr.Slider(
|
|
|
|
| 464 |
info=t("generation.infer_method_info"),
|
| 465 |
)
|
| 466 |
|
| 467 |
+
with gr.Row():
|
| 468 |
+
custom_timesteps = gr.Textbox(
|
| 469 |
+
label=t("generation.custom_timesteps_label"),
|
| 470 |
+
placeholder="0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0",
|
| 471 |
+
value="",
|
| 472 |
+
info=t("generation.custom_timesteps_info"),
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
with gr.Row():
|
| 476 |
cfg_interval_start = gr.Slider(
|
| 477 |
minimum=0.0,
|
|
|
|
| 708 |
"cfg_interval_end": cfg_interval_end,
|
| 709 |
"shift": shift,
|
| 710 |
"infer_method": infer_method,
|
| 711 |
+
"custom_timesteps": custom_timesteps,
|
| 712 |
"audio_format": audio_format,
|
| 713 |
"output_alignment_preference": output_alignment_preference,
|
| 714 |
"think_checkbox": think_checkbox,
|
acestep/handler.py
CHANGED
|
@@ -108,6 +108,12 @@ class AceStepHandler:
|
|
| 108 |
except ImportError:
|
| 109 |
return False
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
def initialize_service(
|
| 112 |
self,
|
| 113 |
project_root: str,
|
|
@@ -1786,6 +1792,7 @@ class AceStepHandler:
|
|
| 1786 |
shift: float = 1.0,
|
| 1787 |
audio_code_hints: Optional[Union[str, List[str]]] = None,
|
| 1788 |
infer_method: str = "ode",
|
|
|
|
| 1789 |
) -> Dict[str, Any]:
|
| 1790 |
|
| 1791 |
"""
|
|
@@ -1949,6 +1956,9 @@ class AceStepHandler:
|
|
| 1949 |
"cfg_interval_end": cfg_interval_end,
|
| 1950 |
"shift": shift,
|
| 1951 |
}
|
|
|
|
|
|
|
|
|
|
| 1952 |
logger.info("[service_generate] Generating audio...")
|
| 1953 |
with self._load_model_context("model"):
|
| 1954 |
# Prepare condition tensors first (for LRC timestamp generation)
|
|
@@ -2081,6 +2091,7 @@ class AceStepHandler:
|
|
| 2081 |
shift: float = 1.0,
|
| 2082 |
infer_method: str = "ode",
|
| 2083 |
use_tiled_decode: bool = True,
|
|
|
|
| 2084 |
progress=None
|
| 2085 |
) -> Dict[str, Any]:
|
| 2086 |
"""
|
|
@@ -2230,7 +2241,8 @@ class AceStepHandler:
|
|
| 2230 |
shift=shift, # Pass shift parameter
|
| 2231 |
infer_method=infer_method, # Pass infer method (ode or sde)
|
| 2232 |
audio_code_hints=audio_code_hints_batch, # Pass audio code hints as list
|
| 2233 |
-
return_intermediate=should_return_intermediate
|
|
|
|
| 2234 |
)
|
| 2235 |
|
| 2236 |
logger.info("[generate_music] Model generation completed. Decoding latents...")
|
|
|
|
| 108 |
except ImportError:
|
| 109 |
return False
|
| 110 |
|
| 111 |
+
def is_turbo_model(self) -> bool:
|
| 112 |
+
"""Check if the currently loaded model is a turbo model"""
|
| 113 |
+
if self.config is None:
|
| 114 |
+
return False
|
| 115 |
+
return getattr(self.config, 'is_turbo', False)
|
| 116 |
+
|
| 117 |
def initialize_service(
|
| 118 |
self,
|
| 119 |
project_root: str,
|
|
|
|
| 1792 |
shift: float = 1.0,
|
| 1793 |
audio_code_hints: Optional[Union[str, List[str]]] = None,
|
| 1794 |
infer_method: str = "ode",
|
| 1795 |
+
timesteps: Optional[List[float]] = None,
|
| 1796 |
) -> Dict[str, Any]:
|
| 1797 |
|
| 1798 |
"""
|
|
|
|
| 1956 |
"cfg_interval_end": cfg_interval_end,
|
| 1957 |
"shift": shift,
|
| 1958 |
}
|
| 1959 |
+
# Add custom timesteps if provided (convert to tensor)
|
| 1960 |
+
if timesteps is not None:
|
| 1961 |
+
generate_kwargs["timesteps"] = torch.tensor(timesteps, dtype=torch.float32)
|
| 1962 |
logger.info("[service_generate] Generating audio...")
|
| 1963 |
with self._load_model_context("model"):
|
| 1964 |
# Prepare condition tensors first (for LRC timestamp generation)
|
|
|
|
| 2091 |
shift: float = 1.0,
|
| 2092 |
infer_method: str = "ode",
|
| 2093 |
use_tiled_decode: bool = True,
|
| 2094 |
+
timesteps: Optional[List[float]] = None,
|
| 2095 |
progress=None
|
| 2096 |
) -> Dict[str, Any]:
|
| 2097 |
"""
|
|
|
|
| 2241 |
shift=shift, # Pass shift parameter
|
| 2242 |
infer_method=infer_method, # Pass infer method (ode or sde)
|
| 2243 |
audio_code_hints=audio_code_hints_batch, # Pass audio code hints as list
|
| 2244 |
+
return_intermediate=should_return_intermediate,
|
| 2245 |
+
timesteps=timesteps, # Pass custom timesteps if provided
|
| 2246 |
)
|
| 2247 |
|
| 2248 |
logger.info("[generate_music] Model generation completed. Decoding latents...")
|
acestep/inference.py
CHANGED
|
@@ -97,6 +97,9 @@ class GenerationParams:
|
|
| 97 |
cfg_interval_end: float = 1.0
|
| 98 |
shift: float = 1.0
|
| 99 |
infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
repainting_start: float = 0.0
|
| 102 |
repainting_end: float = -1
|
|
@@ -534,6 +537,7 @@ def generate_music(
|
|
| 534 |
cfg_interval_end=params.cfg_interval_end,
|
| 535 |
shift=params.shift,
|
| 536 |
infer_method=params.infer_method,
|
|
|
|
| 537 |
progress=progress,
|
| 538 |
)
|
| 539 |
|
|
|
|
| 97 |
cfg_interval_end: float = 1.0
|
| 98 |
shift: float = 1.0
|
| 99 |
infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
|
| 100 |
+
# Custom timesteps (parsed from string like "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
|
| 101 |
+
# If provided, overrides inference_steps and shift
|
| 102 |
+
timesteps: Optional[List[float]] = None
|
| 103 |
|
| 104 |
repainting_start: float = 0.0
|
| 105 |
repainting_end: float = -1
|
|
|
|
| 537 |
cfg_interval_end=params.cfg_interval_end,
|
| 538 |
shift=params.shift,
|
| 539 |
infer_method=params.infer_method,
|
| 540 |
+
timesteps=params.timesteps,
|
| 541 |
progress=progress,
|
| 542 |
)
|
| 543 |
|