ChuxiJ commited on
Commit
f4d9d31
·
1 Parent(s): de88c1d

add infer_method

Browse files
acestep/api_server.py CHANGED
@@ -94,6 +94,7 @@ class GenerateMusicRequest(BaseModel):
94
  use_adg: bool = False
95
  cfg_interval_start: float = 0.0
96
  cfg_interval_end: float = 1.0
 
97
 
98
  audio_format: str = "mp3"
99
  use_tiled_decode: bool = True
@@ -584,6 +585,7 @@ def create_app() -> FastAPI:
584
  use_adg=req.use_adg,
585
  cfg_interval_start=req.cfg_interval_start,
586
  cfg_interval_end=req.cfg_interval_end,
 
587
  repainting_start=req.repainting_start,
588
  repainting_end=req.repainting_end if req.repainting_end else -1,
589
  audio_cover_strength=req.audio_cover_strength,
@@ -854,6 +856,7 @@ def create_app() -> FastAPI:
854
  use_adg=_to_bool(get("use_adg"), False),
855
  cfg_interval_start=_to_float(get("cfg_interval_start"), 0.0) or 0.0,
856
  cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
 
857
  audio_format=str(get("audio_format", "mp3") or "mp3"),
858
  use_tiled_decode=_to_bool(_get_any("use_tiled_decode", "useTiledDecode"), True),
859
  lm_model_path=str(get("lm_model_path") or "").strip() or None,
 
94
  use_adg: bool = False
95
  cfg_interval_start: float = 0.0
96
  cfg_interval_end: float = 1.0
97
+ infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
98
 
99
  audio_format: str = "mp3"
100
  use_tiled_decode: bool = True
 
585
  use_adg=req.use_adg,
586
  cfg_interval_start=req.cfg_interval_start,
587
  cfg_interval_end=req.cfg_interval_end,
588
+ infer_method=req.infer_method,
589
  repainting_start=req.repainting_start,
590
  repainting_end=req.repainting_end if req.repainting_end else -1,
591
  audio_cover_strength=req.audio_cover_strength,
 
856
  use_adg=_to_bool(get("use_adg"), False),
857
  cfg_interval_start=_to_float(get("cfg_interval_start"), 0.0) or 0.0,
858
  cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
859
+ infer_method=str(_get_any("infer_method", "inferMethod", default="ode") or "ode"),
860
  audio_format=str(get("audio_format", "mp3") or "mp3"),
861
  use_tiled_decode=_to_bool(_get_any("use_tiled_decode", "useTiledDecode"), True),
862
  lm_model_path=str(get("lm_model_path") or "").strip() or None,
acestep/gradio_ui/events/generation_handlers.py CHANGED
@@ -86,6 +86,7 @@ def load_metadata(file_obj):
86
  track_name = metadata.get('track_name')
87
  complete_track_classes = metadata.get('complete_track_classes', [])
88
  shift = metadata.get('shift', 3.0) # Default 3.0 for base models
 
89
  instrumental = metadata.get('instrumental', False) # Added: read instrumental
90
 
91
  gr.Info(t("messages.params_loaded", filename=os.path.basename(filepath)))
@@ -93,7 +94,7 @@ def load_metadata(file_obj):
93
  return (
94
  task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature,
95
  audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed,
96
- use_adg, cfg_interval_start, cfg_interval_end, shift, audio_format,
97
  lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
98
  use_cot_metas, use_cot_caption, use_cot_language, audio_cover_strength,
99
  think, audio_codes, repainting_start, repainting_end,
@@ -103,10 +104,10 @@ def load_metadata(file_obj):
103
 
104
  except json.JSONDecodeError as e:
105
  gr.Warning(t("messages.invalid_json", error=str(e)))
106
- return [None] * 34 + [False]
107
  except Exception as e:
108
  gr.Warning(t("messages.load_error", error=str(e)))
109
- return [None] * 34 + [False]
110
 
111
 
112
  def load_random_example(task_type: str):
 
86
  track_name = metadata.get('track_name')
87
  complete_track_classes = metadata.get('complete_track_classes', [])
88
  shift = metadata.get('shift', 3.0) # Default 3.0 for base models
89
+ infer_method = metadata.get('infer_method', 'ode') # Default 'ode' for diffusion inference
90
  instrumental = metadata.get('instrumental', False) # Added: read instrumental
91
 
92
  gr.Info(t("messages.params_loaded", filename=os.path.basename(filepath)))
 
94
  return (
95
  task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature,
96
  audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed,
97
+ use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, audio_format,
98
  lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
99
  use_cot_metas, use_cot_caption, use_cot_language, audio_cover_strength,
100
  think, audio_codes, repainting_start, repainting_end,
 
104
 
105
  except json.JSONDecodeError as e:
106
  gr.Warning(t("messages.invalid_json", error=str(e)))
107
+ return [None] * 35 + [False]
108
  except Exception as e:
109
  gr.Warning(t("messages.load_error", error=str(e)))
110
+ return [None] * 35 + [False]
111
 
112
 
113
  def load_random_example(task_type: str):
acestep/gradio_ui/events/results_handlers.py CHANGED
@@ -452,7 +452,7 @@ def generate_with_progress(
452
  reference_audio, audio_duration, batch_size_input, src_audio,
453
  text2music_audio_code_string, repainting_start, repainting_end,
454
  instruction_display_gen, audio_cover_strength, task_type,
455
- use_adg, cfg_interval_start, cfg_interval_end, shift, audio_format, lm_temperature,
456
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
457
  use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
458
  constrained_decoding_debug,
@@ -495,6 +495,7 @@ def generate_with_progress(
495
  cfg_interval_start=cfg_interval_start,
496
  cfg_interval_end=cfg_interval_end,
497
  shift=shift,
 
498
  repainting_start=repainting_start,
499
  repainting_end=repainting_end,
500
  audio_cover_strength=audio_cover_strength,
 
452
  reference_audio, audio_duration, batch_size_input, src_audio,
453
  text2music_audio_code_string, repainting_start, repainting_end,
454
  instruction_display_gen, audio_cover_strength, task_type,
455
+ use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, audio_format, lm_temperature,
456
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
457
  use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
458
  constrained_decoding_debug,
 
495
  cfg_interval_start=cfg_interval_start,
496
  cfg_interval_end=cfg_interval_end,
497
  shift=shift,
498
+ infer_method=infer_method,
499
  repainting_start=repainting_start,
500
  repainting_end=repainting_end,
501
  audio_cover_strength=audio_cover_strength,
acestep/gradio_ui/i18n/en.json CHANGED
@@ -128,6 +128,8 @@
128
  "use_adg_info": "Enable Angle Domain Guidance",
129
  "shift_label": "Shift",
130
  "shift_info": "Timestep shift factor for base models (range 1.0~5.0, default 3.0). Not effective for turbo models.",
 
 
131
  "cfg_interval_start": "CFG Interval Start",
132
  "cfg_interval_end": "CFG Interval End",
133
  "lm_params_title": "🤖 LM Generation Parameters",
 
128
  "use_adg_info": "Enable Angle Domain Guidance",
129
  "shift_label": "Shift",
130
  "shift_info": "Timestep shift factor for base models (range 1.0~5.0, default 3.0). Not effective for turbo models.",
131
+ "infer_method_label": "Inference Method",
132
+ "infer_method_info": "Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
133
  "cfg_interval_start": "CFG Interval Start",
134
  "cfg_interval_end": "CFG Interval End",
135
  "lm_params_title": "🤖 LM Generation Parameters",
acestep/gradio_ui/i18n/ja.json CHANGED
@@ -128,6 +128,8 @@
128
  "use_adg_info": "角度ドメインガイダンスを有効化",
129
  "shift_label": "シフト",
130
  "shift_info": "baseモデル用タイムステップシフト係数 (範囲 1.0~5.0、デフォルト 3.0)。turboモデルには無効。",
 
 
131
  "cfg_interval_start": "CFG 間隔開始",
132
  "cfg_interval_end": "CFG 間隔終了",
133
  "lm_params_title": "🤖 LM 生成パラメータ",
 
128
  "use_adg_info": "角度ドメインガイダンスを有効化",
129
  "shift_label": "シフト",
130
  "shift_info": "baseモデル用タイムステップシフト係数 (範囲 1.0~5.0、デフォルト 3.0)。turboモデルには無効。",
131
+ "infer_method_label": "推論方法",
132
+ "infer_method_info": "拡散推論方法。ODE (オイラー) は高速、SDE (確率的) は異なる結果を生成する可能性があります。",
133
  "cfg_interval_start": "CFG 間隔開始",
134
  "cfg_interval_end": "CFG 間隔終了",
135
  "lm_params_title": "🤖 LM 生成パラメータ",
acestep/gradio_ui/i18n/zh.json CHANGED
@@ -128,6 +128,8 @@
128
  "use_adg_info": "启用角域引导",
129
  "shift_label": "Shift",
130
  "shift_info": "时间步偏移因子,仅对 base 模型生效 (范围 1.0~5.0,默认 3.0)。对 turbo 模型无效。",
 
 
131
  "cfg_interval_start": "CFG 间隔开始",
132
  "cfg_interval_end": "CFG 间隔结束",
133
  "lm_params_title": "🤖 LM 生成参数",
 
128
  "use_adg_info": "启用角域引导",
129
  "shift_label": "Shift",
130
  "shift_info": "时间步偏移因子,仅对 base 模型生效 (范围 1.0~5.0,默认 3.0)。对 turbo 模型无效。",
131
+ "infer_method_label": "推理方法",
132
+ "infer_method_info": "扩散推理方法。ODE (欧拉) 更快,SDE (随机) 可能产生不同结果。",
133
  "cfg_interval_start": "CFG 间隔开始",
134
  "cfg_interval_end": "CFG 间隔结束",
135
  "lm_params_title": "🤖 LM 生成参数",
acestep/gradio_ui/interfaces/generation.py CHANGED
@@ -455,6 +455,12 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
455
  info=t("generation.shift_info"),
456
  visible=False
457
  )
 
 
 
 
 
 
458
 
459
  with gr.Row():
460
  cfg_interval_start = gr.Slider(
@@ -691,6 +697,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
691
  "cfg_interval_start": cfg_interval_start,
692
  "cfg_interval_end": cfg_interval_end,
693
  "shift": shift,
 
694
  "audio_format": audio_format,
695
  "output_alignment_preference": output_alignment_preference,
696
  "think_checkbox": think_checkbox,
 
455
  info=t("generation.shift_info"),
456
  visible=False
457
  )
458
+ infer_method = gr.Dropdown(
459
+ choices=["ode", "sde"],
460
+ value="ode",
461
+ label=t("generation.infer_method_label"),
462
+ info=t("generation.infer_method_info"),
463
+ )
464
 
465
  with gr.Row():
466
  cfg_interval_start = gr.Slider(
 
697
  "cfg_interval_start": cfg_interval_start,
698
  "cfg_interval_end": cfg_interval_end,
699
  "shift": shift,
700
+ "infer_method": infer_method,
701
  "audio_format": audio_format,
702
  "output_alignment_preference": output_alignment_preference,
703
  "think_checkbox": think_checkbox,
acestep/handler.py CHANGED
@@ -2079,6 +2079,7 @@ class AceStepHandler:
2079
  cfg_interval_start: float = 0.0,
2080
  cfg_interval_end: float = 1.0,
2081
  shift: float = 1.0,
 
2082
  use_tiled_decode: bool = True,
2083
  progress=None
2084
  ) -> Dict[str, Any]:
@@ -2227,6 +2228,7 @@ class AceStepHandler:
2227
  cfg_interval_start=cfg_interval_start, # Pass CFG interval start
2228
  cfg_interval_end=cfg_interval_end, # Pass CFG interval end
2229
  shift=shift, # Pass shift parameter
 
2230
  audio_code_hints=audio_code_hints_batch, # Pass audio code hints as list
2231
  return_intermediate=should_return_intermediate
2232
  )
 
2079
  cfg_interval_start: float = 0.0,
2080
  cfg_interval_end: float = 1.0,
2081
  shift: float = 1.0,
2082
+ infer_method: str = "ode",
2083
  use_tiled_decode: bool = True,
2084
  progress=None
2085
  ) -> Dict[str, Any]:
 
2228
  cfg_interval_start=cfg_interval_start, # Pass CFG interval start
2229
  cfg_interval_end=cfg_interval_end, # Pass CFG interval end
2230
  shift=shift, # Pass shift parameter
2231
+ infer_method=infer_method, # Pass infer method (ode or sde)
2232
  audio_code_hints=audio_code_hints_batch, # Pass audio code hints as list
2233
  return_intermediate=should_return_intermediate
2234
  )
acestep/inference.py CHANGED
@@ -96,6 +96,7 @@ class GenerationParams:
96
  cfg_interval_start: float = 0.0
97
  cfg_interval_end: float = 1.0
98
  shift: float = 1.0
 
99
 
100
  repainting_start: float = 0.0
101
  repainting_end: float = -1
@@ -532,6 +533,7 @@ def generate_music(
532
  cfg_interval_start=params.cfg_interval_start,
533
  cfg_interval_end=params.cfg_interval_end,
534
  shift=params.shift,
 
535
  progress=progress,
536
  )
537
 
 
96
  cfg_interval_start: float = 0.0
97
  cfg_interval_end: float = 1.0
98
  shift: float = 1.0
99
+ infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
100
 
101
  repainting_start: float = 0.0
102
  repainting_end: float = -1
 
533
  cfg_interval_start=params.cfg_interval_start,
534
  cfg_interval_end=params.cfg_interval_end,
535
  shift=params.shift,
536
+ infer_method=params.infer_method,
537
  progress=progress,
538
  )
539