Spaces:
Building
on
Zero
Building
on
Zero
add infer_method
Browse files- acestep/api_server.py +3 -0
- acestep/gradio_ui/events/generation_handlers.py +4 -3
- acestep/gradio_ui/events/results_handlers.py +2 -1
- acestep/gradio_ui/i18n/en.json +2 -0
- acestep/gradio_ui/i18n/ja.json +2 -0
- acestep/gradio_ui/i18n/zh.json +2 -0
- acestep/gradio_ui/interfaces/generation.py +7 -0
- acestep/handler.py +2 -0
- acestep/inference.py +2 -0
acestep/api_server.py
CHANGED
|
@@ -94,6 +94,7 @@ class GenerateMusicRequest(BaseModel):
|
|
| 94 |
use_adg: bool = False
|
| 95 |
cfg_interval_start: float = 0.0
|
| 96 |
cfg_interval_end: float = 1.0
|
|
|
|
| 97 |
|
| 98 |
audio_format: str = "mp3"
|
| 99 |
use_tiled_decode: bool = True
|
|
@@ -584,6 +585,7 @@ def create_app() -> FastAPI:
|
|
| 584 |
use_adg=req.use_adg,
|
| 585 |
cfg_interval_start=req.cfg_interval_start,
|
| 586 |
cfg_interval_end=req.cfg_interval_end,
|
|
|
|
| 587 |
repainting_start=req.repainting_start,
|
| 588 |
repainting_end=req.repainting_end if req.repainting_end else -1,
|
| 589 |
audio_cover_strength=req.audio_cover_strength,
|
|
@@ -854,6 +856,7 @@ def create_app() -> FastAPI:
|
|
| 854 |
use_adg=_to_bool(get("use_adg"), False),
|
| 855 |
cfg_interval_start=_to_float(get("cfg_interval_start"), 0.0) or 0.0,
|
| 856 |
cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
|
|
|
|
| 857 |
audio_format=str(get("audio_format", "mp3") or "mp3"),
|
| 858 |
use_tiled_decode=_to_bool(_get_any("use_tiled_decode", "useTiledDecode"), True),
|
| 859 |
lm_model_path=str(get("lm_model_path") or "").strip() or None,
|
|
|
|
| 94 |
use_adg: bool = False
|
| 95 |
cfg_interval_start: float = 0.0
|
| 96 |
cfg_interval_end: float = 1.0
|
| 97 |
+
infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
|
| 98 |
|
| 99 |
audio_format: str = "mp3"
|
| 100 |
use_tiled_decode: bool = True
|
|
|
|
| 585 |
use_adg=req.use_adg,
|
| 586 |
cfg_interval_start=req.cfg_interval_start,
|
| 587 |
cfg_interval_end=req.cfg_interval_end,
|
| 588 |
+
infer_method=req.infer_method,
|
| 589 |
repainting_start=req.repainting_start,
|
| 590 |
repainting_end=req.repainting_end if req.repainting_end else -1,
|
| 591 |
audio_cover_strength=req.audio_cover_strength,
|
|
|
|
| 856 |
use_adg=_to_bool(get("use_adg"), False),
|
| 857 |
cfg_interval_start=_to_float(get("cfg_interval_start"), 0.0) or 0.0,
|
| 858 |
cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
|
| 859 |
+
infer_method=str(_get_any("infer_method", "inferMethod", default="ode") or "ode"),
|
| 860 |
audio_format=str(get("audio_format", "mp3") or "mp3"),
|
| 861 |
use_tiled_decode=_to_bool(_get_any("use_tiled_decode", "useTiledDecode"), True),
|
| 862 |
lm_model_path=str(get("lm_model_path") or "").strip() or None,
|
acestep/gradio_ui/events/generation_handlers.py
CHANGED
|
@@ -86,6 +86,7 @@ def load_metadata(file_obj):
|
|
| 86 |
track_name = metadata.get('track_name')
|
| 87 |
complete_track_classes = metadata.get('complete_track_classes', [])
|
| 88 |
shift = metadata.get('shift', 3.0) # Default 3.0 for base models
|
|
|
|
| 89 |
instrumental = metadata.get('instrumental', False) # Added: read instrumental
|
| 90 |
|
| 91 |
gr.Info(t("messages.params_loaded", filename=os.path.basename(filepath)))
|
|
@@ -93,7 +94,7 @@ def load_metadata(file_obj):
|
|
| 93 |
return (
|
| 94 |
task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature,
|
| 95 |
audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed,
|
| 96 |
-
use_adg, cfg_interval_start, cfg_interval_end, shift, audio_format,
|
| 97 |
lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 98 |
use_cot_metas, use_cot_caption, use_cot_language, audio_cover_strength,
|
| 99 |
think, audio_codes, repainting_start, repainting_end,
|
|
@@ -103,10 +104,10 @@ def load_metadata(file_obj):
|
|
| 103 |
|
| 104 |
except json.JSONDecodeError as e:
|
| 105 |
gr.Warning(t("messages.invalid_json", error=str(e)))
|
| 106 |
-
return [None] *
|
| 107 |
except Exception as e:
|
| 108 |
gr.Warning(t("messages.load_error", error=str(e)))
|
| 109 |
-
return [None] *
|
| 110 |
|
| 111 |
|
| 112 |
def load_random_example(task_type: str):
|
|
|
|
| 86 |
track_name = metadata.get('track_name')
|
| 87 |
complete_track_classes = metadata.get('complete_track_classes', [])
|
| 88 |
shift = metadata.get('shift', 3.0) # Default 3.0 for base models
|
| 89 |
+
infer_method = metadata.get('infer_method', 'ode') # Default 'ode' for diffusion inference
|
| 90 |
instrumental = metadata.get('instrumental', False) # Added: read instrumental
|
| 91 |
|
| 92 |
gr.Info(t("messages.params_loaded", filename=os.path.basename(filepath)))
|
|
|
|
| 94 |
return (
|
| 95 |
task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature,
|
| 96 |
audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed,
|
| 97 |
+
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, audio_format,
|
| 98 |
lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 99 |
use_cot_metas, use_cot_caption, use_cot_language, audio_cover_strength,
|
| 100 |
think, audio_codes, repainting_start, repainting_end,
|
|
|
|
| 104 |
|
| 105 |
except json.JSONDecodeError as e:
|
| 106 |
gr.Warning(t("messages.invalid_json", error=str(e)))
|
| 107 |
+
return [None] * 35 + [False]
|
| 108 |
except Exception as e:
|
| 109 |
gr.Warning(t("messages.load_error", error=str(e)))
|
| 110 |
+
return [None] * 35 + [False]
|
| 111 |
|
| 112 |
|
| 113 |
def load_random_example(task_type: str):
|
acestep/gradio_ui/events/results_handlers.py
CHANGED
|
@@ -452,7 +452,7 @@ def generate_with_progress(
|
|
| 452 |
reference_audio, audio_duration, batch_size_input, src_audio,
|
| 453 |
text2music_audio_code_string, repainting_start, repainting_end,
|
| 454 |
instruction_display_gen, audio_cover_strength, task_type,
|
| 455 |
-
use_adg, cfg_interval_start, cfg_interval_end, shift, audio_format, lm_temperature,
|
| 456 |
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 457 |
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
|
| 458 |
constrained_decoding_debug,
|
|
@@ -495,6 +495,7 @@ def generate_with_progress(
|
|
| 495 |
cfg_interval_start=cfg_interval_start,
|
| 496 |
cfg_interval_end=cfg_interval_end,
|
| 497 |
shift=shift,
|
|
|
|
| 498 |
repainting_start=repainting_start,
|
| 499 |
repainting_end=repainting_end,
|
| 500 |
audio_cover_strength=audio_cover_strength,
|
|
|
|
| 452 |
reference_audio, audio_duration, batch_size_input, src_audio,
|
| 453 |
text2music_audio_code_string, repainting_start, repainting_end,
|
| 454 |
instruction_display_gen, audio_cover_strength, task_type,
|
| 455 |
+
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, audio_format, lm_temperature,
|
| 456 |
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 457 |
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
|
| 458 |
constrained_decoding_debug,
|
|
|
|
| 495 |
cfg_interval_start=cfg_interval_start,
|
| 496 |
cfg_interval_end=cfg_interval_end,
|
| 497 |
shift=shift,
|
| 498 |
+
infer_method=infer_method,
|
| 499 |
repainting_start=repainting_start,
|
| 500 |
repainting_end=repainting_end,
|
| 501 |
audio_cover_strength=audio_cover_strength,
|
acestep/gradio_ui/i18n/en.json
CHANGED
|
@@ -128,6 +128,8 @@
|
|
| 128 |
"use_adg_info": "Enable Angle Domain Guidance",
|
| 129 |
"shift_label": "Shift",
|
| 130 |
"shift_info": "Timestep shift factor for base models (range 1.0~5.0, default 3.0). Not effective for turbo models.",
|
|
|
|
|
|
|
| 131 |
"cfg_interval_start": "CFG Interval Start",
|
| 132 |
"cfg_interval_end": "CFG Interval End",
|
| 133 |
"lm_params_title": "🤖 LM Generation Parameters",
|
|
|
|
| 128 |
"use_adg_info": "Enable Angle Domain Guidance",
|
| 129 |
"shift_label": "Shift",
|
| 130 |
"shift_info": "Timestep shift factor for base models (range 1.0~5.0, default 3.0). Not effective for turbo models.",
|
| 131 |
+
"infer_method_label": "Inference Method",
|
| 132 |
+
"infer_method_info": "Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
|
| 133 |
"cfg_interval_start": "CFG Interval Start",
|
| 134 |
"cfg_interval_end": "CFG Interval End",
|
| 135 |
"lm_params_title": "🤖 LM Generation Parameters",
|
acestep/gradio_ui/i18n/ja.json
CHANGED
|
@@ -128,6 +128,8 @@
|
|
| 128 |
"use_adg_info": "角度ドメインガイダンスを有効化",
|
| 129 |
"shift_label": "シフト",
|
| 130 |
"shift_info": "baseモデル用タイムステップシフト係数 (範囲 1.0~5.0、デフォルト 3.0)。turboモデルには無効。",
|
|
|
|
|
|
|
| 131 |
"cfg_interval_start": "CFG 間隔開始",
|
| 132 |
"cfg_interval_end": "CFG 間隔終了",
|
| 133 |
"lm_params_title": "🤖 LM 生成パラメータ",
|
|
|
|
| 128 |
"use_adg_info": "角度ドメインガイダンスを有効化",
|
| 129 |
"shift_label": "シフト",
|
| 130 |
"shift_info": "baseモデル用タイムステップシフト係数 (範囲 1.0~5.0、デフォルト 3.0)。turboモデルには無効。",
|
| 131 |
+
"infer_method_label": "推論方法",
|
| 132 |
+
"infer_method_info": "拡散推論方法。ODE (オイラー) は高速、SDE (確率的) は異なる結果を生成する可能性があります。",
|
| 133 |
"cfg_interval_start": "CFG 間隔開始",
|
| 134 |
"cfg_interval_end": "CFG 間隔終了",
|
| 135 |
"lm_params_title": "🤖 LM 生成パラメータ",
|
acestep/gradio_ui/i18n/zh.json
CHANGED
|
@@ -128,6 +128,8 @@
|
|
| 128 |
"use_adg_info": "启用角域引导",
|
| 129 |
"shift_label": "Shift",
|
| 130 |
"shift_info": "时间步偏移因子,仅对 base 模型生效 (范围 1.0~5.0,默认 3.0)。对 turbo 模型无效。",
|
|
|
|
|
|
|
| 131 |
"cfg_interval_start": "CFG 间隔开始",
|
| 132 |
"cfg_interval_end": "CFG 间隔结束",
|
| 133 |
"lm_params_title": "🤖 LM 生成参数",
|
|
|
|
| 128 |
"use_adg_info": "启用角域引导",
|
| 129 |
"shift_label": "Shift",
|
| 130 |
"shift_info": "时间步偏移因子,仅对 base 模型生效 (范围 1.0~5.0,默认 3.0)。对 turbo 模型无效。",
|
| 131 |
+
"infer_method_label": "推理方法",
|
| 132 |
+
"infer_method_info": "扩散推理方法。ODE (欧拉) 更快,SDE (随机) 可能产生不同结果。",
|
| 133 |
"cfg_interval_start": "CFG 间隔开始",
|
| 134 |
"cfg_interval_end": "CFG 间隔结束",
|
| 135 |
"lm_params_title": "🤖 LM 生成参数",
|
acestep/gradio_ui/interfaces/generation.py
CHANGED
|
@@ -455,6 +455,12 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 455 |
info=t("generation.shift_info"),
|
| 456 |
visible=False
|
| 457 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
|
| 459 |
with gr.Row():
|
| 460 |
cfg_interval_start = gr.Slider(
|
|
@@ -691,6 +697,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 691 |
"cfg_interval_start": cfg_interval_start,
|
| 692 |
"cfg_interval_end": cfg_interval_end,
|
| 693 |
"shift": shift,
|
|
|
|
| 694 |
"audio_format": audio_format,
|
| 695 |
"output_alignment_preference": output_alignment_preference,
|
| 696 |
"think_checkbox": think_checkbox,
|
|
|
|
| 455 |
info=t("generation.shift_info"),
|
| 456 |
visible=False
|
| 457 |
)
|
| 458 |
+
infer_method = gr.Dropdown(
|
| 459 |
+
choices=["ode", "sde"],
|
| 460 |
+
value="ode",
|
| 461 |
+
label=t("generation.infer_method_label"),
|
| 462 |
+
info=t("generation.infer_method_info"),
|
| 463 |
+
)
|
| 464 |
|
| 465 |
with gr.Row():
|
| 466 |
cfg_interval_start = gr.Slider(
|
|
|
|
| 697 |
"cfg_interval_start": cfg_interval_start,
|
| 698 |
"cfg_interval_end": cfg_interval_end,
|
| 699 |
"shift": shift,
|
| 700 |
+
"infer_method": infer_method,
|
| 701 |
"audio_format": audio_format,
|
| 702 |
"output_alignment_preference": output_alignment_preference,
|
| 703 |
"think_checkbox": think_checkbox,
|
acestep/handler.py
CHANGED
|
@@ -2079,6 +2079,7 @@ class AceStepHandler:
|
|
| 2079 |
cfg_interval_start: float = 0.0,
|
| 2080 |
cfg_interval_end: float = 1.0,
|
| 2081 |
shift: float = 1.0,
|
|
|
|
| 2082 |
use_tiled_decode: bool = True,
|
| 2083 |
progress=None
|
| 2084 |
) -> Dict[str, Any]:
|
|
@@ -2227,6 +2228,7 @@ class AceStepHandler:
|
|
| 2227 |
cfg_interval_start=cfg_interval_start, # Pass CFG interval start
|
| 2228 |
cfg_interval_end=cfg_interval_end, # Pass CFG interval end
|
| 2229 |
shift=shift, # Pass shift parameter
|
|
|
|
| 2230 |
audio_code_hints=audio_code_hints_batch, # Pass audio code hints as list
|
| 2231 |
return_intermediate=should_return_intermediate
|
| 2232 |
)
|
|
|
|
| 2079 |
cfg_interval_start: float = 0.0,
|
| 2080 |
cfg_interval_end: float = 1.0,
|
| 2081 |
shift: float = 1.0,
|
| 2082 |
+
infer_method: str = "ode",
|
| 2083 |
use_tiled_decode: bool = True,
|
| 2084 |
progress=None
|
| 2085 |
) -> Dict[str, Any]:
|
|
|
|
| 2228 |
cfg_interval_start=cfg_interval_start, # Pass CFG interval start
|
| 2229 |
cfg_interval_end=cfg_interval_end, # Pass CFG interval end
|
| 2230 |
shift=shift, # Pass shift parameter
|
| 2231 |
+
infer_method=infer_method, # Pass infer method (ode or sde)
|
| 2232 |
audio_code_hints=audio_code_hints_batch, # Pass audio code hints as list
|
| 2233 |
return_intermediate=should_return_intermediate
|
| 2234 |
)
|
acestep/inference.py
CHANGED
|
@@ -96,6 +96,7 @@ class GenerationParams:
|
|
| 96 |
cfg_interval_start: float = 0.0
|
| 97 |
cfg_interval_end: float = 1.0
|
| 98 |
shift: float = 1.0
|
|
|
|
| 99 |
|
| 100 |
repainting_start: float = 0.0
|
| 101 |
repainting_end: float = -1
|
|
@@ -532,6 +533,7 @@ def generate_music(
|
|
| 532 |
cfg_interval_start=params.cfg_interval_start,
|
| 533 |
cfg_interval_end=params.cfg_interval_end,
|
| 534 |
shift=params.shift,
|
|
|
|
| 535 |
progress=progress,
|
| 536 |
)
|
| 537 |
|
|
|
|
| 96 |
cfg_interval_start: float = 0.0
|
| 97 |
cfg_interval_end: float = 1.0
|
| 98 |
shift: float = 1.0
|
| 99 |
+
infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
|
| 100 |
|
| 101 |
repainting_start: float = 0.0
|
| 102 |
repainting_end: float = -1
|
|
|
|
| 533 |
cfg_interval_start=params.cfg_interval_start,
|
| 534 |
cfg_interval_end=params.cfg_interval_end,
|
| 535 |
shift=params.shift,
|
| 536 |
+
infer_method=params.infer_method,
|
| 537 |
progress=progress,
|
| 538 |
)
|
| 539 |
|