ChuxiJ commited on
Commit
875a989
·
1 Parent(s): 7f5c13a

feat: update argparse for api_server

Browse files
acestep/api_server.py CHANGED
@@ -33,8 +33,8 @@ from fastapi import FastAPI, HTTPException, Request
33
  from pydantic import BaseModel, Field
34
  from starlette.datastructures import UploadFile as StarletteUploadFile
35
 
36
- from .handler import AceStepHandler
37
- from .llm_inference import LLMHandler
38
 
39
 
40
  JobStatus = Literal["queued", "running", "succeeded", "failed"]
@@ -1069,6 +1069,15 @@ def create_app() -> FastAPI:
1069
  error=rec.error,
1070
  )
1071
 
 
 
 
 
 
 
 
 
 
1072
  return app
1073
 
1074
 
@@ -1076,13 +1085,31 @@ app = create_app()
1076
 
1077
 
1078
  def main() -> None:
 
1079
  import uvicorn
1080
 
1081
- host = os.getenv("ACESTEP_API_HOST", "127.0.0.1")
1082
- port = int(os.getenv("ACESTEP_API_PORT", "8001"))
 
 
 
 
 
 
 
 
 
 
 
1083
 
1084
  # IMPORTANT: in-memory queue/store -> workers MUST be 1
1085
- uvicorn.run("acestep.api_server:app", host=host, port=port, reload=False, workers=1)
 
 
 
 
 
 
1086
 
1087
 
1088
  if __name__ == "__main__":
 
33
  from pydantic import BaseModel, Field
34
  from starlette.datastructures import UploadFile as StarletteUploadFile
35
 
36
+ from acestep.handler import AceStepHandler
37
+ from acestep.llm_inference import LLMHandler
38
 
39
 
40
  JobStatus = Literal["queued", "running", "succeeded", "failed"]
 
1069
  error=rec.error,
1070
  )
1071
 
1072
+ @app.get("/health")
1073
+ async def health_check():
1074
+ """Health check endpoint for service status."""
1075
+ return {
1076
+ "status": "ok",
1077
+ "service": "ACE-Step API",
1078
+ "version": "1.0",
1079
+ }
1080
+
1081
  return app
1082
 
1083
 
 
1085
 
1086
 
1087
  def main() -> None:
1088
+ import argparse
1089
  import uvicorn
1090
 
1091
+ parser = argparse.ArgumentParser(description="ACE-Step API server")
1092
+ parser.add_argument(
1093
+ "--host",
1094
+ default=os.getenv("ACESTEP_API_HOST", "127.0.0.1"),
1095
+ help="Bind host (default from ACESTEP_API_HOST or 127.0.0.1)",
1096
+ )
1097
+ parser.add_argument(
1098
+ "--port",
1099
+ type=int,
1100
+ default=int(os.getenv("ACESTEP_API_PORT", "8001")),
1101
+ help="Bind port (default from ACESTEP_API_PORT or 8001)",
1102
+ )
1103
+ args = parser.parse_args()
1104
 
1105
  # IMPORTANT: in-memory queue/store -> workers MUST be 1
1106
+ uvicorn.run(
1107
+ "acestep.api_server:app",
1108
+ host=str(args.host),
1109
+ port=int(args.port),
1110
+ reload=False,
1111
+ workers=1,
1112
+ )
1113
 
1114
 
1115
  if __name__ == "__main__":
acestep/llm_inference.py CHANGED
@@ -16,7 +16,7 @@ from transformers.generation.logits_process import (
16
  LogitsProcessorList,
17
  RepetitionPenaltyLogitsProcessor,
18
  )
19
- from .constrained_logits_processor import MetadataConstrainedLogitsProcessor
20
 
21
 
22
  class LLMHandler:
 
16
  LogitsProcessorList,
17
  RepetitionPenaltyLogitsProcessor,
18
  )
19
+ from acestep.constrained_logits_processor import MetadataConstrainedLogitsProcessor
20
 
21
 
22
  class LLMHandler:
pyproject.toml CHANGED
@@ -6,9 +6,17 @@ readme = "README.md"
6
  requires-python = ">=3.12,<3.13"
7
  license = {text = "Apache-2.0"}
8
  dependencies = [
9
- "torch>=2.9.1",
10
- "torchvision",
11
- "torchaudio>=2.9.1",
 
 
 
 
 
 
 
 
12
  "transformers",
13
  "diffusers",
14
  "gradio",
@@ -38,4 +46,4 @@ name = "pytorch"
38
  url = "https://download.pytorch.org/whl/cu128"
39
 
40
  [tool.hatch.build.targets.wheel]
41
- packages = ["acestep"]
 
6
  requires-python = ">=3.12,<3.13"
7
  license = {text = "Apache-2.0"}
8
  dependencies = [
9
+ # PyTorch for Linux/Windows with CUDA
10
+ "torch>=2.9.1; sys_platform != 'darwin'",
11
+ "torchvision; sys_platform != 'darwin'",
12
+ "torchaudio>=2.9.1; sys_platform != 'darwin'",
13
+
14
+ # PyTorch for macOS (CPU / MPS)
15
+ "torch>=2.9.1; sys_platform == 'darwin'",
16
+ "torchvision; sys_platform == 'darwin'",
17
+ "torchaudio>=2.9.1; sys_platform == 'darwin'",
18
+
19
+ # Common dependencies
20
  "transformers",
21
  "diffusers",
22
  "gradio",
 
46
  url = "https://download.pytorch.org/whl/cu128"
47
 
48
  [tool.hatch.build.targets.wheel]
49
+ packages = ["acestep"]