Spaces:
Running
on
A100
Running
on
A100
feat: update argparse for api_server
Browse files- acestep/api_server.py +32 -5
- acestep/llm_inference.py +1 -1
- pyproject.toml +12 -4
acestep/api_server.py
CHANGED
|
@@ -33,8 +33,8 @@ from fastapi import FastAPI, HTTPException, Request
|
|
| 33 |
from pydantic import BaseModel, Field
|
| 34 |
from starlette.datastructures import UploadFile as StarletteUploadFile
|
| 35 |
|
| 36 |
-
from .handler import AceStepHandler
|
| 37 |
-
from .llm_inference import LLMHandler
|
| 38 |
|
| 39 |
|
| 40 |
JobStatus = Literal["queued", "running", "succeeded", "failed"]
|
|
@@ -1069,6 +1069,15 @@ def create_app() -> FastAPI:
|
|
| 1069 |
error=rec.error,
|
| 1070 |
)
|
| 1071 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1072 |
return app
|
| 1073 |
|
| 1074 |
|
|
@@ -1076,13 +1085,31 @@ app = create_app()
|
|
| 1076 |
|
| 1077 |
|
| 1078 |
def main() -> None:
|
|
|
|
| 1079 |
import uvicorn
|
| 1080 |
|
| 1081 |
-
|
| 1082 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1083 |
|
| 1084 |
# IMPORTANT: in-memory queue/store -> workers MUST be 1
|
| 1085 |
-
uvicorn.run(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1086 |
|
| 1087 |
|
| 1088 |
if __name__ == "__main__":
|
|
|
|
| 33 |
from pydantic import BaseModel, Field
|
| 34 |
from starlette.datastructures import UploadFile as StarletteUploadFile
|
| 35 |
|
| 36 |
+
from acestep.handler import AceStepHandler
|
| 37 |
+
from acestep.llm_inference import LLMHandler
|
| 38 |
|
| 39 |
|
| 40 |
JobStatus = Literal["queued", "running", "succeeded", "failed"]
|
|
|
|
| 1069 |
error=rec.error,
|
| 1070 |
)
|
| 1071 |
|
| 1072 |
+
@app.get("/health")
|
| 1073 |
+
async def health_check():
|
| 1074 |
+
"""Health check endpoint for service status."""
|
| 1075 |
+
return {
|
| 1076 |
+
"status": "ok",
|
| 1077 |
+
"service": "ACE-Step API",
|
| 1078 |
+
"version": "1.0",
|
| 1079 |
+
}
|
| 1080 |
+
|
| 1081 |
return app
|
| 1082 |
|
| 1083 |
|
|
|
|
| 1085 |
|
| 1086 |
|
| 1087 |
def main() -> None:
|
| 1088 |
+
import argparse
|
| 1089 |
import uvicorn
|
| 1090 |
|
| 1091 |
+
parser = argparse.ArgumentParser(description="ACE-Step API server")
|
| 1092 |
+
parser.add_argument(
|
| 1093 |
+
"--host",
|
| 1094 |
+
default=os.getenv("ACESTEP_API_HOST", "127.0.0.1"),
|
| 1095 |
+
help="Bind host (default from ACESTEP_API_HOST or 127.0.0.1)",
|
| 1096 |
+
)
|
| 1097 |
+
parser.add_argument(
|
| 1098 |
+
"--port",
|
| 1099 |
+
type=int,
|
| 1100 |
+
default=int(os.getenv("ACESTEP_API_PORT", "8001")),
|
| 1101 |
+
help="Bind port (default from ACESTEP_API_PORT or 8001)",
|
| 1102 |
+
)
|
| 1103 |
+
args = parser.parse_args()
|
| 1104 |
|
| 1105 |
# IMPORTANT: in-memory queue/store -> workers MUST be 1
|
| 1106 |
+
uvicorn.run(
|
| 1107 |
+
"acestep.api_server:app",
|
| 1108 |
+
host=str(args.host),
|
| 1109 |
+
port=int(args.port),
|
| 1110 |
+
reload=False,
|
| 1111 |
+
workers=1,
|
| 1112 |
+
)
|
| 1113 |
|
| 1114 |
|
| 1115 |
if __name__ == "__main__":
|
acestep/llm_inference.py
CHANGED
|
@@ -16,7 +16,7 @@ from transformers.generation.logits_process import (
|
|
| 16 |
LogitsProcessorList,
|
| 17 |
RepetitionPenaltyLogitsProcessor,
|
| 18 |
)
|
| 19 |
-
from .constrained_logits_processor import MetadataConstrainedLogitsProcessor
|
| 20 |
|
| 21 |
|
| 22 |
class LLMHandler:
|
|
|
|
| 16 |
LogitsProcessorList,
|
| 17 |
RepetitionPenaltyLogitsProcessor,
|
| 18 |
)
|
| 19 |
+
from acestep.constrained_logits_processor import MetadataConstrainedLogitsProcessor
|
| 20 |
|
| 21 |
|
| 22 |
class LLMHandler:
|
pyproject.toml
CHANGED
|
@@ -6,9 +6,17 @@ readme = "README.md"
|
|
| 6 |
requires-python = ">=3.12,<3.13"
|
| 7 |
license = {text = "Apache-2.0"}
|
| 8 |
dependencies = [
|
| 9 |
-
|
| 10 |
-
"
|
| 11 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"transformers",
|
| 13 |
"diffusers",
|
| 14 |
"gradio",
|
|
@@ -38,4 +46,4 @@ name = "pytorch"
|
|
| 38 |
url = "https://download.pytorch.org/whl/cu128"
|
| 39 |
|
| 40 |
[tool.hatch.build.targets.wheel]
|
| 41 |
-
packages = ["acestep"]
|
|
|
|
| 6 |
requires-python = ">=3.12,<3.13"
|
| 7 |
license = {text = "Apache-2.0"}
|
| 8 |
dependencies = [
|
| 9 |
+
# PyTorch for Linux/Windows with CUDA
|
| 10 |
+
"torch>=2.9.1; sys_platform != 'darwin'",
|
| 11 |
+
"torchvision; sys_platform != 'darwin'",
|
| 12 |
+
"torchaudio>=2.9.1; sys_platform != 'darwin'",
|
| 13 |
+
|
| 14 |
+
# PyTorch for macOS (CPU / MPS)
|
| 15 |
+
"torch>=2.9.1; sys_platform == 'darwin'",
|
| 16 |
+
"torchvision; sys_platform == 'darwin'",
|
| 17 |
+
"torchaudio>=2.9.1; sys_platform == 'darwin'",
|
| 18 |
+
|
| 19 |
+
# Common dependencies
|
| 20 |
"transformers",
|
| 21 |
"diffusers",
|
| 22 |
"gradio",
|
|
|
|
| 46 |
url = "https://download.pytorch.org/whl/cu128"
|
| 47 |
|
| 48 |
[tool.hatch.build.targets.wheel]
|
| 49 |
+
packages = ["acestep"]
|