Sayoyo commited on
Commit
db9ad53
·
2 Parent(s): 3c7cb5d 4be7cc1

Merge branch 'main' of https://github.com/ace-step/ACE-Step-1.5

Browse files
Files changed (1) hide show
  1. acestep/api_server.py +17 -9
acestep/api_server.py CHANGED
@@ -102,9 +102,13 @@ class GenerateMusicRequest(BaseModel):
102
  cfg_interval_start: float = 0.0
103
  cfg_interval_end: float = 1.0
104
  infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
 
 
 
 
105
  timesteps: Optional[str] = Field(
106
  default=None,
107
- description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')"
108
  )
109
 
110
  audio_format: str = "mp3"
@@ -748,6 +752,15 @@ def create_app() -> FastAPI:
748
  print(f"[api_server] Warning: format_sample failed: {format_result.error}, using original input")
749
 
750
  print(f"[api_server] Before GenerationParams: thinking={thinking}, sample_mode={sample_mode}")
 
 
 
 
 
 
 
 
 
751
  print(f"[api_server] Caption/Lyrics to use: caption_len={len(caption)}, lyrics_len={len(lyrics)}")
752
 
753
  # Parse timesteps if provided
@@ -779,12 +792,13 @@ def create_app() -> FastAPI:
779
  keyscale=key_scale,
780
  timesignature=time_signature,
781
  duration=audio_duration if audio_duration else -1.0,
782
- inference_steps=actual_inference_steps,
783
  seed=req.seed,
784
  guidance_scale=req.guidance_scale,
785
  use_adg=req.use_adg,
786
  cfg_interval_start=req.cfg_interval_start,
787
  cfg_interval_end=req.cfg_interval_end,
 
788
  infer_method=req.infer_method,
789
  timesteps=parsed_timesteps,
790
  repainting_start=req.repainting_start,
@@ -1069,6 +1083,7 @@ def create_app() -> FastAPI:
1069
  cfg_interval_start=_to_float(get("cfg_interval_start"), 0.0) or 0.0,
1070
  cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
1071
  infer_method=str(_get_any("infer_method", "inferMethod", default="ode") or "ode"),
 
1072
  audio_format=str(get("audio_format", "mp3") or "mp3"),
1073
  use_tiled_decode=_to_bool(_get_any("use_tiled_decode", "useTiledDecode"), True),
1074
  lm_model_path=str(get("lm_model_path") or "").strip() or None,
@@ -1321,12 +1336,5 @@ def main() -> None:
1321
  workers=1,
1322
  )
1323
 
1324
-
1325
- if __name__ == "__main__":
1326
- main()
1327
- ,
1328
- )
1329
-
1330
-
1331
  if __name__ == "__main__":
1332
  main()
 
102
  cfg_interval_start: float = 0.0
103
  cfg_interval_end: float = 1.0
104
  infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
105
+ shift: float = Field(
106
+ default=3.0,
107
+ description="Timestep shift factor (range 1.0~5.0, default 3.0). Only effective for base models, not turbo models."
108
+ )
109
  timesteps: Optional[str] = Field(
110
  default=None,
111
+ description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference_steps and shift."
112
  )
113
 
114
  audio_format: str = "mp3"
 
752
  print(f"[api_server] Warning: format_sample failed: {format_result.error}, using original input")
753
 
754
  print(f"[api_server] Before GenerationParams: thinking={thinking}, sample_mode={sample_mode}")
755
+ # Parse timesteps string to list of floats if provided
756
+ parsed_timesteps = None
757
+ if req.timesteps and req.timesteps.strip():
758
+ try:
759
+ parsed_timesteps = [float(t.strip()) for t in req.timesteps.split(",") if t.strip()]
760
+ except ValueError:
761
+ print(f"[api_server] Warning: Failed to parse timesteps '{req.timesteps}', using default")
762
+ parsed_timesteps = None
763
+
764
  print(f"[api_server] Caption/Lyrics to use: caption_len={len(caption)}, lyrics_len={len(lyrics)}")
765
 
766
  # Parse timesteps if provided
 
792
  keyscale=key_scale,
793
  timesignature=time_signature,
794
  duration=audio_duration if audio_duration else -1.0,
795
+ inference_steps=req.inference_steps,
796
  seed=req.seed,
797
  guidance_scale=req.guidance_scale,
798
  use_adg=req.use_adg,
799
  cfg_interval_start=req.cfg_interval_start,
800
  cfg_interval_end=req.cfg_interval_end,
801
+ shift=req.shift,
802
  infer_method=req.infer_method,
803
  timesteps=parsed_timesteps,
804
  repainting_start=req.repainting_start,
 
1083
  cfg_interval_start=_to_float(get("cfg_interval_start"), 0.0) or 0.0,
1084
  cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
1085
  infer_method=str(_get_any("infer_method", "inferMethod", default="ode") or "ode"),
1086
+ shift=_to_float(_get_any("shift"), 3.0) or 3.0,
1087
  audio_format=str(get("audio_format", "mp3") or "mp3"),
1088
  use_tiled_decode=_to_bool(_get_any("use_tiled_decode", "useTiledDecode"), True),
1089
  lm_model_path=str(get("lm_model_path") or "").strip() or None,
 
1336
  workers=1,
1337
  )
1338
 
 
 
 
 
 
 
 
1339
  if __name__ == "__main__":
1340
  main()