Spaces:
Running
on
Zero
Running
on
Zero
Merge branch 'main' of https://github.com/ace-step/ACE-Step-1.5
Browse files- acestep/api_server.py +17 -9
acestep/api_server.py
CHANGED
|
@@ -102,9 +102,13 @@ class GenerateMusicRequest(BaseModel):
|
|
| 102 |
cfg_interval_start: float = 0.0
|
| 103 |
cfg_interval_end: float = 1.0
|
| 104 |
infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
timesteps: Optional[str] = Field(
|
| 106 |
default=None,
|
| 107 |
-
description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')"
|
| 108 |
)
|
| 109 |
|
| 110 |
audio_format: str = "mp3"
|
|
@@ -748,6 +752,15 @@ def create_app() -> FastAPI:
|
|
| 748 |
print(f"[api_server] Warning: format_sample failed: {format_result.error}, using original input")
|
| 749 |
|
| 750 |
print(f"[api_server] Before GenerationParams: thinking={thinking}, sample_mode={sample_mode}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 751 |
print(f"[api_server] Caption/Lyrics to use: caption_len={len(caption)}, lyrics_len={len(lyrics)}")
|
| 752 |
|
| 753 |
# Parse timesteps if provided
|
|
@@ -779,12 +792,13 @@ def create_app() -> FastAPI:
|
|
| 779 |
keyscale=key_scale,
|
| 780 |
timesignature=time_signature,
|
| 781 |
duration=audio_duration if audio_duration else -1.0,
|
| 782 |
-
inference_steps=
|
| 783 |
seed=req.seed,
|
| 784 |
guidance_scale=req.guidance_scale,
|
| 785 |
use_adg=req.use_adg,
|
| 786 |
cfg_interval_start=req.cfg_interval_start,
|
| 787 |
cfg_interval_end=req.cfg_interval_end,
|
|
|
|
| 788 |
infer_method=req.infer_method,
|
| 789 |
timesteps=parsed_timesteps,
|
| 790 |
repainting_start=req.repainting_start,
|
|
@@ -1069,6 +1083,7 @@ def create_app() -> FastAPI:
|
|
| 1069 |
cfg_interval_start=_to_float(get("cfg_interval_start"), 0.0) or 0.0,
|
| 1070 |
cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
|
| 1071 |
infer_method=str(_get_any("infer_method", "inferMethod", default="ode") or "ode"),
|
|
|
|
| 1072 |
audio_format=str(get("audio_format", "mp3") or "mp3"),
|
| 1073 |
use_tiled_decode=_to_bool(_get_any("use_tiled_decode", "useTiledDecode"), True),
|
| 1074 |
lm_model_path=str(get("lm_model_path") or "").strip() or None,
|
|
@@ -1321,12 +1336,5 @@ def main() -> None:
|
|
| 1321 |
workers=1,
|
| 1322 |
)
|
| 1323 |
|
| 1324 |
-
|
| 1325 |
-
if __name__ == "__main__":
|
| 1326 |
-
main()
|
| 1327 |
-
,
|
| 1328 |
-
)
|
| 1329 |
-
|
| 1330 |
-
|
| 1331 |
if __name__ == "__main__":
|
| 1332 |
main()
|
|
|
|
| 102 |
cfg_interval_start: float = 0.0
|
| 103 |
cfg_interval_end: float = 1.0
|
| 104 |
infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
|
| 105 |
+
shift: float = Field(
|
| 106 |
+
default=3.0,
|
| 107 |
+
description="Timestep shift factor (range 1.0~5.0, default 3.0). Only effective for base models, not turbo models."
|
| 108 |
+
)
|
| 109 |
timesteps: Optional[str] = Field(
|
| 110 |
default=None,
|
| 111 |
+
description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference_steps and shift."
|
| 112 |
)
|
| 113 |
|
| 114 |
audio_format: str = "mp3"
|
|
|
|
| 752 |
print(f"[api_server] Warning: format_sample failed: {format_result.error}, using original input")
|
| 753 |
|
| 754 |
print(f"[api_server] Before GenerationParams: thinking={thinking}, sample_mode={sample_mode}")
|
| 755 |
+
# Parse timesteps string to list of floats if provided
|
| 756 |
+
parsed_timesteps = None
|
| 757 |
+
if req.timesteps and req.timesteps.strip():
|
| 758 |
+
try:
|
| 759 |
+
parsed_timesteps = [float(t.strip()) for t in req.timesteps.split(",") if t.strip()]
|
| 760 |
+
except ValueError:
|
| 761 |
+
print(f"[api_server] Warning: Failed to parse timesteps '{req.timesteps}', using default")
|
| 762 |
+
parsed_timesteps = None
|
| 763 |
+
|
| 764 |
print(f"[api_server] Caption/Lyrics to use: caption_len={len(caption)}, lyrics_len={len(lyrics)}")
|
| 765 |
|
| 766 |
# Parse timesteps if provided
|
|
|
|
| 792 |
keyscale=key_scale,
|
| 793 |
timesignature=time_signature,
|
| 794 |
duration=audio_duration if audio_duration else -1.0,
|
| 795 |
+
inference_steps=req.inference_steps,
|
| 796 |
seed=req.seed,
|
| 797 |
guidance_scale=req.guidance_scale,
|
| 798 |
use_adg=req.use_adg,
|
| 799 |
cfg_interval_start=req.cfg_interval_start,
|
| 800 |
cfg_interval_end=req.cfg_interval_end,
|
| 801 |
+
shift=req.shift,
|
| 802 |
infer_method=req.infer_method,
|
| 803 |
timesteps=parsed_timesteps,
|
| 804 |
repainting_start=req.repainting_start,
|
|
|
|
| 1083 |
cfg_interval_start=_to_float(get("cfg_interval_start"), 0.0) or 0.0,
|
| 1084 |
cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
|
| 1085 |
infer_method=str(_get_any("infer_method", "inferMethod", default="ode") or "ode"),
|
| 1086 |
+
shift=_to_float(_get_any("shift"), 3.0) or 3.0,
|
| 1087 |
audio_format=str(get("audio_format", "mp3") or "mp3"),
|
| 1088 |
use_tiled_decode=_to_bool(_get_any("use_tiled_decode", "useTiledDecode"), True),
|
| 1089 |
lm_model_path=str(get("lm_model_path") or "").strip() or None,
|
|
|
|
| 1336 |
workers=1,
|
| 1337 |
)
|
| 1338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1339 |
if __name__ == "__main__":
|
| 1340 |
main()
|