ChuxiJ commited on
Commit
bc7e55b
·
1 Parent(s): 7ef4a67

support input timesteps

Browse files
acestep/api_server.py CHANGED
@@ -102,6 +102,10 @@ class GenerateMusicRequest(BaseModel):
102
  cfg_interval_start: float = 0.0
103
  cfg_interval_end: float = 1.0
104
  infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
 
 
 
 
105
 
106
  audio_format: str = "mp3"
107
  use_tiled_decode: bool = True
@@ -754,13 +758,14 @@ def create_app() -> FastAPI:
754
  keyscale=key_scale,
755
  timesignature=time_signature,
756
  duration=audio_duration if audio_duration else -1.0,
757
- inference_steps=req.inference_steps,
758
  seed=req.seed,
759
  guidance_scale=req.guidance_scale,
760
  use_adg=req.use_adg,
761
  cfg_interval_start=req.cfg_interval_start,
762
  cfg_interval_end=req.cfg_interval_end,
763
  infer_method=req.infer_method,
 
764
  repainting_start=req.repainting_start,
765
  repainting_end=req.repainting_end if req.repainting_end else -1,
766
  audio_cover_strength=req.audio_cover_strength,
@@ -1289,5 +1294,11 @@ def main() -> None:
1289
  )
1290
 
1291
 
 
 
 
 
 
 
1292
  if __name__ == "__main__":
1293
  main()
 
102
  cfg_interval_start: float = 0.0
103
  cfg_interval_end: float = 1.0
104
  infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
105
+ timesteps: Optional[str] = Field(
106
+ default=None,
107
+ description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')"
108
+ )
109
 
110
  audio_format: str = "mp3"
111
  use_tiled_decode: bool = True
 
758
  keyscale=key_scale,
759
  timesignature=time_signature,
760
  duration=audio_duration if audio_duration else -1.0,
761
+ inference_steps=actual_inference_steps,
762
  seed=req.seed,
763
  guidance_scale=req.guidance_scale,
764
  use_adg=req.use_adg,
765
  cfg_interval_start=req.cfg_interval_start,
766
  cfg_interval_end=req.cfg_interval_end,
767
  infer_method=req.infer_method,
768
+ timesteps=parsed_timesteps,
769
  repainting_start=req.repainting_start,
770
  repainting_end=req.repainting_end if req.repainting_end else -1,
771
  audio_cover_strength=req.audio_cover_strength,
 
1294
  )
1295
 
1296
 
1297
+ if __name__ == "__main__":
1298
+ main()
1299
+ ,
1300
+ )
1301
+
1302
+
1303
  if __name__ == "__main__":
1304
  main()
acestep/gradio_ui/events/__init__.py CHANGED
@@ -54,7 +54,19 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
54
  generation_section["offload_to_cpu_checkbox"],
55
  generation_section["offload_dit_to_cpu_checkbox"],
56
  ],
57
- outputs=[generation_section["init_status"], generation_section["generate_btn"], generation_section["service_config_accordion"]]
 
 
 
 
 
 
 
 
 
 
 
 
58
  )
59
 
60
  # ========== UI Visibility Updates ==========
@@ -312,6 +324,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
312
  generation_section["cfg_interval_end"],
313
  generation_section["shift"],
314
  generation_section["infer_method"],
 
315
  generation_section["audio_format"],
316
  generation_section["lm_temperature"],
317
  generation_section["lm_cfg_scale"],
@@ -510,6 +523,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
510
  generation_section["cfg_interval_end"],
511
  generation_section["shift"],
512
  generation_section["infer_method"],
 
513
  generation_section["audio_format"],
514
  generation_section["lm_temperature"],
515
  generation_section["think_checkbox"],
@@ -697,6 +711,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
697
  generation_section["cfg_interval_end"],
698
  generation_section["shift"],
699
  generation_section["infer_method"],
 
700
  generation_section["audio_format"],
701
  generation_section["lm_temperature"],
702
  generation_section["think_checkbox"],
 
54
  generation_section["offload_to_cpu_checkbox"],
55
  generation_section["offload_dit_to_cpu_checkbox"],
56
  ],
57
+ outputs=[
58
+ generation_section["init_status"],
59
+ generation_section["generate_btn"],
60
+ generation_section["service_config_accordion"],
61
+ # Model type settings (updated based on actual loaded model)
62
+ generation_section["inference_steps"],
63
+ generation_section["guidance_scale"],
64
+ generation_section["use_adg"],
65
+ generation_section["shift"],
66
+ generation_section["cfg_interval_start"],
67
+ generation_section["cfg_interval_end"],
68
+ generation_section["task_type"],
69
+ ]
70
  )
71
 
72
  # ========== UI Visibility Updates ==========
 
324
  generation_section["cfg_interval_end"],
325
  generation_section["shift"],
326
  generation_section["infer_method"],
327
+ generation_section["custom_timesteps"],
328
  generation_section["audio_format"],
329
  generation_section["lm_temperature"],
330
  generation_section["lm_cfg_scale"],
 
523
  generation_section["cfg_interval_end"],
524
  generation_section["shift"],
525
  generation_section["infer_method"],
526
+ generation_section["custom_timesteps"],
527
  generation_section["audio_format"],
528
  generation_section["lm_temperature"],
529
  generation_section["think_checkbox"],
 
711
  generation_section["cfg_interval_end"],
712
  generation_section["shift"],
713
  generation_section["infer_method"],
714
+ generation_section["custom_timesteps"],
715
  generation_section["audio_format"],
716
  generation_section["lm_temperature"],
717
  generation_section["think_checkbox"],
acestep/gradio_ui/events/generation_handlers.py CHANGED
@@ -7,7 +7,7 @@ import json
7
  import random
8
  import glob
9
  import gradio as gr
10
- from typing import Optional
11
  from acestep.constants import (
12
  TASK_TYPES_TURBO,
13
  TASK_TYPES_BASE,
@@ -16,6 +16,56 @@ from acestep.gradio_ui.i18n import t
16
  from acestep.inference import understand_music, create_sample, format_sample
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def load_metadata(file_obj):
20
  """Load generation parameters from a JSON file"""
21
  if file_obj is None:
@@ -321,50 +371,31 @@ def refresh_checkpoints(dit_handler):
321
 
322
 
323
  def update_model_type_settings(config_path):
324
- """Update UI settings based on model type"""
 
 
 
 
 
325
  if config_path is None:
326
  config_path = ""
327
  config_path_lower = config_path.lower()
328
 
 
 
329
  if "turbo" in config_path_lower:
330
- # Turbo model: max 8 steps, hide CFG/ADG/shift, only show text2music/repaint/cover
331
- # Shift is not effective for turbo models, default to 1.0
332
- return (
333
- gr.update(value=8, maximum=8, minimum=1), # inference_steps
334
- gr.update(visible=False), # guidance_scale
335
- gr.update(visible=False), # use_adg
336
- gr.update(value=1.0, visible=False), # shift (not effective for turbo)
337
- gr.update(visible=False), # cfg_interval_start
338
- gr.update(visible=False), # cfg_interval_end
339
- gr.update(choices=TASK_TYPES_TURBO), # task_type
340
- )
341
  elif "base" in config_path_lower:
342
- # Base model: max 100 steps, show CFG/ADG/shift, show all task types
343
- # Shift range 1.0~5.0, default 3.0 for base models
344
- return (
345
- gr.update(value=32, maximum=100, minimum=1), # inference_steps
346
- gr.update(visible=True), # guidance_scale
347
- gr.update(visible=True), # use_adg
348
- gr.update(value=3.0, visible=True), # shift (effective for base, default 3.0)
349
- gr.update(visible=True), # cfg_interval_start
350
- gr.update(visible=True), # cfg_interval_end
351
- gr.update(choices=TASK_TYPES_BASE), # task_type
352
- )
353
  else:
354
- # Default to turbo settings
355
- return (
356
- gr.update(value=8, maximum=8, minimum=1),
357
- gr.update(visible=False),
358
- gr.update(visible=False),
359
- gr.update(value=1.0, visible=False), # shift default 1.0
360
- gr.update(visible=False),
361
- gr.update(visible=False),
362
- gr.update(choices=TASK_TYPES_TURBO), # task_type
363
- )
364
 
365
 
366
  def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
367
- """Wrapper for service initialization, returns status, button state, and accordion state"""
368
  # Initialize DiT handler
369
  status, enable = dit_handler.initialize_service(
370
  checkpoint, config_path, device,
@@ -400,7 +431,42 @@ def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, devi
400
  is_model_initialized = dit_handler.model is not None
401
  accordion_state = gr.update(open=not is_model_initialized)
402
 
403
- return status, gr.update(interactive=enable), accordion_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
 
406
  def update_negative_prompt_visibility(init_llm_checked):
 
7
  import random
8
  import glob
9
  import gradio as gr
10
+ from typing import Optional, List, Tuple
11
  from acestep.constants import (
12
  TASK_TYPES_TURBO,
13
  TASK_TYPES_BASE,
 
16
  from acestep.inference import understand_music, create_sample, format_sample
17
 
18
 
19
+ def parse_and_validate_timesteps(
20
+ timesteps_str: str,
21
+ inference_steps: int
22
+ ) -> Tuple[Optional[List[float]], bool, str]:
23
+ """
24
+ Parse timesteps string and validate.
25
+
26
+ Args:
27
+ timesteps_str: Comma-separated timesteps string (e.g., "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
28
+ inference_steps: Expected number of inference steps
29
+
30
+ Returns:
31
+ Tuple of (parsed_timesteps, has_warning, warning_message)
32
+ - parsed_timesteps: List of float timesteps, or None if invalid/empty
33
+ - has_warning: Whether a warning was shown
34
+ - warning_message: Description of the warning
35
+ """
36
+ if not timesteps_str or not timesteps_str.strip():
37
+ return None, False, ""
38
+
39
+ # Parse comma-separated values
40
+ values = [v.strip() for v in timesteps_str.split(",") if v.strip()]
41
+
42
+ if not values:
43
+ return None, False, ""
44
+
45
+ # Handle optional trailing 0
46
+ if values[-1] != "0":
47
+ values.append("0")
48
+
49
+ try:
50
+ timesteps = [float(v) for v in values]
51
+ except ValueError:
52
+ gr.Warning(t("messages.invalid_timesteps_format"))
53
+ return None, True, "Invalid format"
54
+
55
+ # Validate range [0, 1]
56
+ if any(ts < 0 or ts > 1 for ts in timesteps):
57
+ gr.Warning(t("messages.timesteps_out_of_range"))
58
+ return None, True, "Out of range"
59
+
60
+ # Check if count matches inference_steps
61
+ actual_steps = len(timesteps) - 1
62
+ if actual_steps != inference_steps:
63
+ gr.Warning(t("messages.timesteps_count_mismatch", actual=actual_steps, expected=inference_steps))
64
+ return timesteps, True, f"Using {actual_steps} steps from timesteps"
65
+
66
+ return timesteps, False, ""
67
+
68
+
69
  def load_metadata(file_obj):
70
  """Load generation parameters from a JSON file"""
71
  if file_obj is None:
 
371
 
372
 
373
  def update_model_type_settings(config_path):
374
+ """Update UI settings based on model type (fallback when handler not initialized yet)
375
+
376
+ Note: This is used as a fallback when the user changes config_path dropdown
377
+ before initializing the model. The actual settings are determined by the
378
+ handler's is_turbo_model() method after initialization.
379
+ """
380
  if config_path is None:
381
  config_path = ""
382
  config_path_lower = config_path.lower()
383
 
384
+ # Determine is_turbo based on config_path string
385
+ # This is a heuristic fallback - actual model type is determined after loading
386
  if "turbo" in config_path_lower:
387
+ is_turbo = True
 
 
 
 
 
 
 
 
 
 
388
  elif "base" in config_path_lower:
389
+ is_turbo = False
 
 
 
 
 
 
 
 
 
 
390
  else:
391
+ # Default to turbo settings for unknown model types
392
+ is_turbo = True
393
+
394
+ return get_model_type_ui_settings(is_turbo)
 
 
 
 
 
 
395
 
396
 
397
  def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
398
+ """Wrapper for service initialization, returns status, button state, accordion state, and model type settings"""
399
  # Initialize DiT handler
400
  status, enable = dit_handler.initialize_service(
401
  checkpoint, config_path, device,
 
431
  is_model_initialized = dit_handler.model is not None
432
  accordion_state = gr.update(open=not is_model_initialized)
433
 
434
+ # Get model type settings based on actual loaded model
435
+ is_turbo = dit_handler.is_turbo_model()
436
+ model_type_settings = get_model_type_ui_settings(is_turbo)
437
+
438
+ return (
439
+ status,
440
+ gr.update(interactive=enable),
441
+ accordion_state,
442
+ *model_type_settings
443
+ )
444
+
445
+
446
+ def get_model_type_ui_settings(is_turbo: bool):
447
+ """Get UI settings based on whether the model is turbo or base"""
448
+ if is_turbo:
449
+ # Turbo model: max 8 steps, hide CFG/ADG/shift, only show text2music/repaint/cover
450
+ return (
451
+ gr.update(value=8, maximum=8, minimum=1), # inference_steps
452
+ gr.update(visible=False), # guidance_scale
453
+ gr.update(visible=False), # use_adg
454
+ gr.update(value=1.0, visible=False), # shift (not effective for turbo)
455
+ gr.update(visible=False), # cfg_interval_start
456
+ gr.update(visible=False), # cfg_interval_end
457
+ gr.update(choices=TASK_TYPES_TURBO), # task_type
458
+ )
459
+ else:
460
+ # Base model: max 200 steps, default 32, show CFG/ADG/shift, show all task types
461
+ return (
462
+ gr.update(value=32, maximum=200, minimum=1), # inference_steps
463
+ gr.update(visible=True), # guidance_scale
464
+ gr.update(visible=True), # use_adg
465
+ gr.update(value=3.0, visible=True), # shift (effective for base, default 3.0)
466
+ gr.update(visible=True), # cfg_interval_start
467
+ gr.update(visible=True), # cfg_interval_end
468
+ gr.update(choices=TASK_TYPES_BASE), # task_type
469
+ )
470
 
471
 
472
  def update_negative_prompt_visibility(init_llm_checked):
acestep/gradio_ui/events/results_handlers.py CHANGED
@@ -15,6 +15,7 @@ from typing import Dict, Any, Optional, List
15
  import gradio as gr
16
  from loguru import logger
17
  from acestep.gradio_ui.i18n import t
 
18
  from acestep.inference import generate_music, GenerationParams, GenerationConfig
19
  from acestep.audio_utils import save_audio
20
 
@@ -452,7 +453,7 @@ def generate_with_progress(
452
  reference_audio, audio_duration, batch_size_input, src_audio,
453
  text2music_audio_code_string, repainting_start, repainting_end,
454
  instruction_display_gen, audio_cover_strength, task_type,
455
- use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, audio_format, lm_temperature,
456
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
457
  use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
458
  constrained_decoding_debug,
@@ -473,6 +474,14 @@ def generate_with_progress(
473
  logger.info("[generate_with_progress] Skipping Phase 1 metas COT: sample is already formatted (is_format_caption=True)")
474
  gr.Info(t("messages.skipping_metas_cot"))
475
 
 
 
 
 
 
 
 
 
476
  # step 1: prepare inputs
477
  # generate_music, GenerationParams, GenerationConfig
478
  gen_params = GenerationParams(
@@ -489,13 +498,14 @@ def generate_with_progress(
489
  keyscale=key_scale,
490
  timesignature=time_signature,
491
  duration=audio_duration,
492
- inference_steps=inference_steps,
493
  guidance_scale=guidance_scale,
494
  use_adg=use_adg,
495
  cfg_interval_start=cfg_interval_start,
496
  cfg_interval_end=cfg_interval_end,
497
  shift=shift,
498
  infer_method=infer_method,
 
499
  repainting_start=repainting_start,
500
  repainting_end=repainting_end,
501
  audio_cover_strength=audio_cover_strength,
@@ -1311,7 +1321,7 @@ def capture_current_params(
1311
  reference_audio, audio_duration, batch_size_input, src_audio,
1312
  text2music_audio_code_string, repainting_start, repainting_end,
1313
  instruction_display_gen, audio_cover_strength, task_type,
1314
- use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, audio_format, lm_temperature,
1315
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
1316
  use_cot_metas, use_cot_caption, use_cot_language,
1317
  constrained_decoding_debug, allow_lm_batch, auto_score, auto_lrc, score_scale, lm_batch_chunk_size,
@@ -1349,6 +1359,7 @@ def capture_current_params(
1349
  "cfg_interval_end": cfg_interval_end,
1350
  "shift": shift,
1351
  "infer_method": infer_method,
 
1352
  "audio_format": audio_format,
1353
  "lm_temperature": lm_temperature,
1354
  "think_checkbox": think_checkbox,
@@ -1377,7 +1388,7 @@ def generate_with_batch_management(
1377
  reference_audio, audio_duration, batch_size_input, src_audio,
1378
  text2music_audio_code_string, repainting_start, repainting_end,
1379
  instruction_display_gen, audio_cover_strength, task_type,
1380
- use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, audio_format, lm_temperature,
1381
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
1382
  use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
1383
  constrained_decoding_debug,
@@ -1406,7 +1417,7 @@ def generate_with_batch_management(
1406
  reference_audio, audio_duration, batch_size_input, src_audio,
1407
  text2music_audio_code_string, repainting_start, repainting_end,
1408
  instruction_display_gen, audio_cover_strength, task_type,
1409
- use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, audio_format, lm_temperature,
1410
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
1411
  use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
1412
  constrained_decoding_debug,
@@ -1673,6 +1684,7 @@ def generate_next_batch_background(
1673
  params.setdefault("cfg_interval_end", 1.0)
1674
  params.setdefault("shift", 1.0)
1675
  params.setdefault("infer_method", "ode")
 
1676
  params.setdefault("audio_format", "mp3")
1677
  params.setdefault("lm_temperature", 0.85)
1678
  params.setdefault("think_checkbox", True)
@@ -1724,6 +1736,7 @@ def generate_next_batch_background(
1724
  cfg_interval_end=params.get("cfg_interval_end"),
1725
  shift=params.get("shift"),
1726
  infer_method=params.get("infer_method"),
 
1727
  audio_format=params.get("audio_format"),
1728
  lm_temperature=params.get("lm_temperature"),
1729
  think_checkbox=params.get("think_checkbox"),
 
15
  import gradio as gr
16
  from loguru import logger
17
  from acestep.gradio_ui.i18n import t
18
+ from acestep.gradio_ui.events.generation_handlers import parse_and_validate_timesteps
19
  from acestep.inference import generate_music, GenerationParams, GenerationConfig
20
  from acestep.audio_utils import save_audio
21
 
 
453
  reference_audio, audio_duration, batch_size_input, src_audio,
454
  text2music_audio_code_string, repainting_start, repainting_end,
455
  instruction_display_gen, audio_cover_strength, task_type,
456
+ use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, custom_timesteps, audio_format, lm_temperature,
457
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
458
  use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
459
  constrained_decoding_debug,
 
474
  logger.info("[generate_with_progress] Skipping Phase 1 metas COT: sample is already formatted (is_format_caption=True)")
475
  gr.Info(t("messages.skipping_metas_cot"))
476
 
477
+ # Parse and validate custom timesteps
478
+ parsed_timesteps, has_timesteps_warning, _ = parse_and_validate_timesteps(custom_timesteps, inference_steps)
479
+
480
+ # Update inference_steps if custom timesteps provided (to match UI display)
481
+ actual_inference_steps = inference_steps
482
+ if parsed_timesteps is not None:
483
+ actual_inference_steps = len(parsed_timesteps) - 1
484
+
485
  # step 1: prepare inputs
486
  # generate_music, GenerationParams, GenerationConfig
487
  gen_params = GenerationParams(
 
498
  keyscale=key_scale,
499
  timesignature=time_signature,
500
  duration=audio_duration,
501
+ inference_steps=actual_inference_steps,
502
  guidance_scale=guidance_scale,
503
  use_adg=use_adg,
504
  cfg_interval_start=cfg_interval_start,
505
  cfg_interval_end=cfg_interval_end,
506
  shift=shift,
507
  infer_method=infer_method,
508
+ timesteps=parsed_timesteps,
509
  repainting_start=repainting_start,
510
  repainting_end=repainting_end,
511
  audio_cover_strength=audio_cover_strength,
 
1321
  reference_audio, audio_duration, batch_size_input, src_audio,
1322
  text2music_audio_code_string, repainting_start, repainting_end,
1323
  instruction_display_gen, audio_cover_strength, task_type,
1324
+ use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, custom_timesteps, audio_format, lm_temperature,
1325
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
1326
  use_cot_metas, use_cot_caption, use_cot_language,
1327
  constrained_decoding_debug, allow_lm_batch, auto_score, auto_lrc, score_scale, lm_batch_chunk_size,
 
1359
  "cfg_interval_end": cfg_interval_end,
1360
  "shift": shift,
1361
  "infer_method": infer_method,
1362
+ "custom_timesteps": custom_timesteps,
1363
  "audio_format": audio_format,
1364
  "lm_temperature": lm_temperature,
1365
  "think_checkbox": think_checkbox,
 
1388
  reference_audio, audio_duration, batch_size_input, src_audio,
1389
  text2music_audio_code_string, repainting_start, repainting_end,
1390
  instruction_display_gen, audio_cover_strength, task_type,
1391
+ use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, custom_timesteps, audio_format, lm_temperature,
1392
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
1393
  use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
1394
  constrained_decoding_debug,
 
1417
  reference_audio, audio_duration, batch_size_input, src_audio,
1418
  text2music_audio_code_string, repainting_start, repainting_end,
1419
  instruction_display_gen, audio_cover_strength, task_type,
1420
+ use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, custom_timesteps, audio_format, lm_temperature,
1421
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
1422
  use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
1423
  constrained_decoding_debug,
 
1684
  params.setdefault("cfg_interval_end", 1.0)
1685
  params.setdefault("shift", 1.0)
1686
  params.setdefault("infer_method", "ode")
1687
+ params.setdefault("custom_timesteps", "")
1688
  params.setdefault("audio_format", "mp3")
1689
  params.setdefault("lm_temperature", 0.85)
1690
  params.setdefault("think_checkbox", True)
 
1736
  cfg_interval_end=params.get("cfg_interval_end"),
1737
  shift=params.get("shift"),
1738
  infer_method=params.get("infer_method"),
1739
+ custom_timesteps=params.get("custom_timesteps"),
1740
  audio_format=params.get("audio_format"),
1741
  lm_temperature=params.get("lm_temperature"),
1742
  think_checkbox=params.get("think_checkbox"),
acestep/gradio_ui/i18n/en.json CHANGED
@@ -115,7 +115,7 @@
115
  "batch_size_info": "Number of audio to generate (max 8)",
116
  "advanced_settings": "🔧 Advanced Settings",
117
  "inference_steps_label": "DiT Inference Steps",
118
- "inference_steps_info": "Turbo: max 8, Base: max 100",
119
  "guidance_scale_label": "DiT Guidance Scale (Only support for base model)",
120
  "guidance_scale_info": "Higher values follow text more closely",
121
  "seed_label": "Seed",
@@ -130,6 +130,8 @@
130
  "shift_info": "Timestep shift factor for base models (range 1.0~5.0, default 3.0). Not effective for turbo models.",
131
  "infer_method_label": "Inference Method",
132
  "infer_method_info": "Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
 
 
133
  "cfg_interval_start": "CFG Interval Start",
134
  "cfg_interval_end": "CFG Interval End",
135
  "lm_params_title": "🤖 LM Generation Parameters",
@@ -233,6 +235,9 @@
233
  "simple_example_loaded": "🎲 Loaded random example from {filename}",
234
  "format_success": "✅ Caption and lyrics formatted successfully",
235
  "format_failed": "❌ Format failed: {error}",
236
- "skipping_metas_cot": "⚡ Skipping Phase 1 metas COT (sample already formatted)"
 
 
 
237
  }
238
  }
 
115
  "batch_size_info": "Number of audio to generate (max 8)",
116
  "advanced_settings": "🔧 Advanced Settings",
117
  "inference_steps_label": "DiT Inference Steps",
118
+ "inference_steps_info": "Turbo: max 8, Base: max 200",
119
  "guidance_scale_label": "DiT Guidance Scale (Only support for base model)",
120
  "guidance_scale_info": "Higher values follow text more closely",
121
  "seed_label": "Seed",
 
130
  "shift_info": "Timestep shift factor for base models (range 1.0~5.0, default 3.0). Not effective for turbo models.",
131
  "infer_method_label": "Inference Method",
132
  "infer_method_info": "Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
133
+ "custom_timesteps_label": "Custom Timesteps",
134
+ "custom_timesteps_info": "Optional: comma-separated values from 1.0 to 0.0 (e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference steps and shift.",
135
  "cfg_interval_start": "CFG Interval Start",
136
  "cfg_interval_end": "CFG Interval End",
137
  "lm_params_title": "🤖 LM Generation Parameters",
 
235
  "simple_example_loaded": "🎲 Loaded random example from {filename}",
236
  "format_success": "✅ Caption and lyrics formatted successfully",
237
  "format_failed": "❌ Format failed: {error}",
238
+ "skipping_metas_cot": "⚡ Skipping Phase 1 metas COT (sample already formatted)",
239
+ "invalid_timesteps_format": "⚠️ Invalid timesteps format. Using default schedule.",
240
+ "timesteps_out_of_range": "⚠️ Timesteps must be in range [0, 1]. Using default schedule.",
241
+ "timesteps_count_mismatch": "⚠️ Timesteps count ({actual}) differs from inference_steps ({expected}). Using timesteps count."
242
  }
243
  }
acestep/gradio_ui/i18n/ja.json CHANGED
@@ -115,7 +115,7 @@
115
  "batch_size_info": "生成するオーディオの数(最大8)",
116
  "advanced_settings": "🔧 詳細設定",
117
  "inference_steps_label": "DiT 推論ステップ",
118
- "inference_steps_info": "Turbo: 最大8、Base: 最大100",
119
  "guidance_scale_label": "DiT ガイダンススケール(baseモデルのみサポート)",
120
  "guidance_scale_info": "値が高いほどテキストに忠実に従う",
121
  "seed_label": "シード",
@@ -130,6 +130,8 @@
130
  "shift_info": "baseモデル用タイムステップシフト係数 (範囲 1.0~5.0、デフォルト 3.0)。turboモデルには無効。",
131
  "infer_method_label": "推論方法",
132
  "infer_method_info": "拡散推論方法。ODE (オイラー) は高速、SDE (確率的) は異なる結果を生成する可能性があります。",
 
 
133
  "cfg_interval_start": "CFG 間隔開始",
134
  "cfg_interval_end": "CFG 間隔終了",
135
  "lm_params_title": "🤖 LM 生成パラメータ",
@@ -233,6 +235,9 @@
233
  "simple_example_loaded": "🎲 {filename} からランダムサンプルを読み込みました",
234
  "format_success": "✅ キャプションと歌詞のフォーマットに成功しました",
235
  "format_failed": "❌ フォーマットに失敗しました: {error}",
236
- "skipping_metas_cot": "⚡ Phase 1 メタデータ COT をスキップ(サンプルは既にフォーマット済み)"
 
 
 
237
  }
238
  }
 
115
  "batch_size_info": "生成するオーディオの数(最大8)",
116
  "advanced_settings": "🔧 詳細設定",
117
  "inference_steps_label": "DiT 推論ステップ",
118
+ "inference_steps_info": "Turbo: 最大8、Base: 最大200",
119
  "guidance_scale_label": "DiT ガイダンススケール(baseモデルのみサポート)",
120
  "guidance_scale_info": "値が高いほどテキストに忠実に従う",
121
  "seed_label": "シード",
 
130
  "shift_info": "baseモデル用タイムステップシフト係数 (範囲 1.0~5.0、デフォルト 3.0)。turboモデルには無効。",
131
  "infer_method_label": "推論方法",
132
  "infer_method_info": "拡散推論方法。ODE (オイラー) は高速、SDE (確率的) は異なる結果を生成する可能性があります。",
133
+ "custom_timesteps_label": "カスタムタイムステップ",
134
+ "custom_timesteps_info": "オプション:1.0から0.0へのカンマ区切り値(例:'0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。推論ステップとシフトを上書きします。",
135
  "cfg_interval_start": "CFG 間隔開始",
136
  "cfg_interval_end": "CFG 間隔終了",
137
  "lm_params_title": "🤖 LM 生成パラメータ",
 
235
  "simple_example_loaded": "🎲 {filename} からランダムサンプルを読み込みました",
236
  "format_success": "✅ キャプションと歌詞のフォーマットに成功しました",
237
  "format_failed": "❌ フォーマットに失敗しました: {error}",
238
+ "skipping_metas_cot": "⚡ Phase 1 メタデータ COT をスキップ(サンプルは既にフォーマット済み)",
239
+ "invalid_timesteps_format": "⚠️ タイムステップ形式が無効です。デフォルトスケジュールを使用します。",
240
+ "timesteps_out_of_range": "⚠️ タイムステップは [0, 1] の範囲内である必要があります。デフォルトスケジュールを使用します。",
241
+ "timesteps_count_mismatch": "⚠️ タイムステップ数 ({actual}) が推論ステップ数 ({expected}) と異なります。タイムステップ数を使用します。"
242
  }
243
  }
acestep/gradio_ui/i18n/zh.json CHANGED
@@ -115,7 +115,7 @@
115
  "batch_size_info": "要生成的音频数量(最多8个)",
116
  "advanced_settings": "🔧 高级设置",
117
  "inference_steps_label": "DiT 推理步数",
118
- "inference_steps_info": "Turbo: 最多8, Base: 最多100",
119
  "guidance_scale_label": "DiT 引导比例(仅支持base模型)",
120
  "guidance_scale_info": "更高的值更紧密地遵循文本",
121
  "seed_label": "种子",
@@ -130,6 +130,8 @@
130
  "shift_info": "时间步偏移因子,仅对 base 模型生效 (范围 1.0~5.0,默认 3.0)。对 turbo 模型无效。",
131
  "infer_method_label": "推理方法",
132
  "infer_method_info": "扩散推理方法。ODE (欧拉) 更快,SDE (随机) 可能产生不同结果。",
 
 
133
  "cfg_interval_start": "CFG 间隔开始",
134
  "cfg_interval_end": "CFG 间隔结束",
135
  "lm_params_title": "🤖 LM 生成参数",
@@ -233,6 +235,9 @@
233
  "simple_example_loaded": "🎲 已从 {filename} 加载随机示例",
234
  "format_success": "✅ 描述和歌词格式化成功",
235
  "format_failed": "❌ 格式化失败: {error}",
236
- "skipping_metas_cot": "⚡ 跳过 Phase 1 元数据 COT(样本已格式化)"
 
 
 
237
  }
238
  }
 
115
  "batch_size_info": "要生成的音频数量(最多8个)",
116
  "advanced_settings": "🔧 高级设置",
117
  "inference_steps_label": "DiT 推理步数",
118
+ "inference_steps_info": "Turbo: 最多8, Base: 最多200",
119
  "guidance_scale_label": "DiT 引导比例(仅支持base模型)",
120
  "guidance_scale_info": "更高的值更紧密地遵循文本",
121
  "seed_label": "种子",
 
130
  "shift_info": "时间步偏移因子,仅对 base 模型生效 (范围 1.0~5.0,默认 3.0)。对 turbo 模型无效。",
131
  "infer_method_label": "推理方法",
132
  "infer_method_info": "扩散推理方法。ODE (欧拉) 更快,SDE (随机) 可能产生不同结果。",
133
+ "custom_timesteps_label": "自定义时间步",
134
+ "custom_timesteps_info": "可选:从 1.0 到 0.0 的逗号分隔值(例如 '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。会覆盖推理步数和 shift 设置。",
135
  "cfg_interval_start": "CFG 间隔开始",
136
  "cfg_interval_end": "CFG 间隔结束",
137
  "lm_params_title": "🤖 LM 生成参数",
 
235
  "simple_example_loaded": "🎲 已从 {filename} 加载随机示例",
236
  "format_success": "✅ 描述和歌词格式化成功",
237
  "format_failed": "❌ 格式化失败: {error}",
238
+ "skipping_metas_cot": "⚡ 跳过 Phase 1 元数据 COT(样本已格式化)",
239
+ "invalid_timesteps_format": "⚠️ 时间步格式无效,使用默认调度。",
240
+ "timesteps_out_of_range": "⚠️ 时间步必须在 [0, 1] 范围内,使用默认调度。",
241
+ "timesteps_count_mismatch": "⚠️ 时间步数量 ({actual}) 与推理步数 ({expected}) 不匹配,将使用时间步数量。"
242
  }
243
  }
acestep/gradio_ui/interfaces/generation.py CHANGED
@@ -402,6 +402,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
402
  )
403
 
404
  # Advanced Settings
 
 
405
  with gr.Accordion(t("generation.advanced_settings"), open=False):
406
  with gr.Row():
407
  inference_steps = gr.Slider(
@@ -462,6 +464,14 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
462
  info=t("generation.infer_method_info"),
463
  )
464
 
 
 
 
 
 
 
 
 
465
  with gr.Row():
466
  cfg_interval_start = gr.Slider(
467
  minimum=0.0,
@@ -698,6 +708,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
698
  "cfg_interval_end": cfg_interval_end,
699
  "shift": shift,
700
  "infer_method": infer_method,
 
701
  "audio_format": audio_format,
702
  "output_alignment_preference": output_alignment_preference,
703
  "think_checkbox": think_checkbox,
 
402
  )
403
 
404
  # Advanced Settings
405
+ # Default UI settings use turbo mode (max 8 steps, hide CFG/ADG/shift)
406
+ # These will be updated after model initialization based on handler.is_turbo_model()
407
  with gr.Accordion(t("generation.advanced_settings"), open=False):
408
  with gr.Row():
409
  inference_steps = gr.Slider(
 
464
  info=t("generation.infer_method_info"),
465
  )
466
 
467
+ with gr.Row():
468
+ custom_timesteps = gr.Textbox(
469
+ label=t("generation.custom_timesteps_label"),
470
+ placeholder="0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0",
471
+ value="",
472
+ info=t("generation.custom_timesteps_info"),
473
+ )
474
+
475
  with gr.Row():
476
  cfg_interval_start = gr.Slider(
477
  minimum=0.0,
 
708
  "cfg_interval_end": cfg_interval_end,
709
  "shift": shift,
710
  "infer_method": infer_method,
711
+ "custom_timesteps": custom_timesteps,
712
  "audio_format": audio_format,
713
  "output_alignment_preference": output_alignment_preference,
714
  "think_checkbox": think_checkbox,
acestep/handler.py CHANGED
@@ -108,6 +108,12 @@ class AceStepHandler:
108
  except ImportError:
109
  return False
110
 
 
 
 
 
 
 
111
  def initialize_service(
112
  self,
113
  project_root: str,
@@ -1786,6 +1792,7 @@ class AceStepHandler:
1786
  shift: float = 1.0,
1787
  audio_code_hints: Optional[Union[str, List[str]]] = None,
1788
  infer_method: str = "ode",
 
1789
  ) -> Dict[str, Any]:
1790
 
1791
  """
@@ -1949,6 +1956,9 @@ class AceStepHandler:
1949
  "cfg_interval_end": cfg_interval_end,
1950
  "shift": shift,
1951
  }
 
 
 
1952
  logger.info("[service_generate] Generating audio...")
1953
  with self._load_model_context("model"):
1954
  # Prepare condition tensors first (for LRC timestamp generation)
@@ -2081,6 +2091,7 @@ class AceStepHandler:
2081
  shift: float = 1.0,
2082
  infer_method: str = "ode",
2083
  use_tiled_decode: bool = True,
 
2084
  progress=None
2085
  ) -> Dict[str, Any]:
2086
  """
@@ -2230,7 +2241,8 @@ class AceStepHandler:
2230
  shift=shift, # Pass shift parameter
2231
  infer_method=infer_method, # Pass infer method (ode or sde)
2232
  audio_code_hints=audio_code_hints_batch, # Pass audio code hints as list
2233
- return_intermediate=should_return_intermediate
 
2234
  )
2235
 
2236
  logger.info("[generate_music] Model generation completed. Decoding latents...")
 
108
  except ImportError:
109
  return False
110
 
111
+ def is_turbo_model(self) -> bool:
112
+ """Check if the currently loaded model is a turbo model"""
113
+ if self.config is None:
114
+ return False
115
+ return getattr(self.config, 'is_turbo', False)
116
+
117
  def initialize_service(
118
  self,
119
  project_root: str,
 
1792
  shift: float = 1.0,
1793
  audio_code_hints: Optional[Union[str, List[str]]] = None,
1794
  infer_method: str = "ode",
1795
+ timesteps: Optional[List[float]] = None,
1796
  ) -> Dict[str, Any]:
1797
 
1798
  """
 
1956
  "cfg_interval_end": cfg_interval_end,
1957
  "shift": shift,
1958
  }
1959
+ # Add custom timesteps if provided (convert to tensor)
1960
+ if timesteps is not None:
1961
+ generate_kwargs["timesteps"] = torch.tensor(timesteps, dtype=torch.float32)
1962
  logger.info("[service_generate] Generating audio...")
1963
  with self._load_model_context("model"):
1964
  # Prepare condition tensors first (for LRC timestamp generation)
 
2091
  shift: float = 1.0,
2092
  infer_method: str = "ode",
2093
  use_tiled_decode: bool = True,
2094
+ timesteps: Optional[List[float]] = None,
2095
  progress=None
2096
  ) -> Dict[str, Any]:
2097
  """
 
2241
  shift=shift, # Pass shift parameter
2242
  infer_method=infer_method, # Pass infer method (ode or sde)
2243
  audio_code_hints=audio_code_hints_batch, # Pass audio code hints as list
2244
+ return_intermediate=should_return_intermediate,
2245
+ timesteps=timesteps, # Pass custom timesteps if provided
2246
  )
2247
 
2248
  logger.info("[generate_music] Model generation completed. Decoding latents...")
acestep/inference.py CHANGED
@@ -97,6 +97,9 @@ class GenerationParams:
97
  cfg_interval_end: float = 1.0
98
  shift: float = 1.0
99
  infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
 
 
 
100
 
101
  repainting_start: float = 0.0
102
  repainting_end: float = -1
@@ -534,6 +537,7 @@ def generate_music(
534
  cfg_interval_end=params.cfg_interval_end,
535
  shift=params.shift,
536
  infer_method=params.infer_method,
 
537
  progress=progress,
538
  )
539
 
 
97
  cfg_interval_end: float = 1.0
98
  shift: float = 1.0
99
  infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
100
+ # Custom timesteps (parsed from string like "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
101
+ # If provided, overrides inference_steps and shift
102
+ timesteps: Optional[List[float]] = None
103
 
104
  repainting_start: float = 0.0
105
  repainting_end: float = -1
 
537
  cfg_interval_end=params.cfg_interval_end,
538
  shift=params.shift,
539
  infer_method=params.infer_method,
540
+ timesteps=params.timesteps,
541
  progress=progress,
542
  )
543