Ace-Step-v1.5 / acestep /acestep_v15_pipeline.py
Sayoyo's picture
Merge branch 'main' into huggingface_space
baf2271
raw
history blame
14.3 kB
"""
ACE-Step V1.5 Pipeline
Handler wrapper connecting model and UI
"""
import os
import sys
# Load environment variables from .env file in project root
# This allows configuration without hardcoding values
# Falls back to .env.example if .env is not found
try:
from dotenv import load_dotenv
# Get project root directory
_current_file = os.path.abspath(__file__)
_project_root = os.path.dirname(os.path.dirname(_current_file))
_env_path = os.path.join(_project_root, '.env')
_env_example_path = os.path.join(_project_root, '.env.example')
if os.path.exists(_env_path):
load_dotenv(_env_path)
print(f"Loaded configuration from {_env_path}")
elif os.path.exists(_env_example_path):
load_dotenv(_env_example_path)
print(f"Loaded configuration from {_env_example_path} (fallback)")
except ImportError:
# python-dotenv not installed, skip loading .env
pass
# Clear proxy settings that may affect Gradio
for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
os.environ.pop(proxy_var, None)
try:
# When executed as a module: `python -m acestep.acestep_v15_pipeline`
from .handler import AceStepHandler
from .llm_inference import LLMHandler
from .dataset_handler import DatasetHandler
from .gradio_ui import create_gradio_interface
except ImportError:
# When executed as a script: `python acestep/acestep_v15_pipeline.py`
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from acestep.handler import AceStepHandler
from acestep.llm_inference import LLMHandler
from acestep.dataset_handler import DatasetHandler
from acestep.gradio_ui import create_gradio_interface
def create_demo(init_params=None, language='en'):
"""
Create Gradio demo interface
Args:
init_params: Dictionary containing initialization parameters and state.
If None, service will not be pre-initialized.
Keys: 'pre_initialized' (bool), 'checkpoint', 'config_path', 'device',
'init_llm', 'lm_model_path', 'backend', 'use_flash_attention',
'offload_to_cpu', 'offload_dit_to_cpu', 'init_status',
'dit_handler', 'llm_handler' (initialized handlers if pre-initialized),
'language' (UI language code)
language: UI language code ('en', 'zh', 'ja', default: 'en')
Returns:
Gradio Blocks instance
"""
# Get persistent storage path from init_params (for HuggingFace Space)
persistent_storage_path = None
if init_params:
persistent_storage_path = init_params.get('persistent_storage_path')
# Use pre-initialized handlers if available, otherwise create new ones
if init_params and init_params.get('pre_initialized') and 'dit_handler' in init_params:
dit_handler = init_params['dit_handler']
llm_handler = init_params['llm_handler']
else:
dit_handler = AceStepHandler(persistent_storage_path=persistent_storage_path)
llm_handler = LLMHandler(persistent_storage_path=persistent_storage_path)
dataset_handler = DatasetHandler() # Dataset handler
# Create Gradio interface with all handlers and initialization parameters
demo = create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=init_params, language=language)
return demo
def get_gpu_memory_gb():
"""
Get GPU memory in GB. Returns 0 if no GPU is available.
"""
try:
import torch
if torch.cuda.is_available():
# Get total memory of the first GPU in GB
total_memory = torch.cuda.get_device_properties(0).total_memory
memory_gb = total_memory / (1024**3) # Convert bytes to GB
return memory_gb
else:
return 0
except Exception as e:
print(f"Warning: Failed to detect GPU memory: {e}", file=sys.stderr)
return 0
def main():
"""Main entry function"""
import argparse
# Detect GPU memory to auto-configure offload settings
gpu_memory_gb = get_gpu_memory_gb()
auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
if auto_offload:
print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)")
print("Auto-enabling CPU offload to reduce GPU memory usage")
elif gpu_memory_gb > 0:
print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (>= 16GB)")
print("CPU offload disabled by default")
else:
print("No GPU detected, running on CPU")
parser = argparse.ArgumentParser(description="Gradio Demo for ACE-Step V1.5")
parser.add_argument("--port", type=int, default=7860, help="Port to run the gradio server on")
parser.add_argument("--share", action="store_true", help="Create a public link")
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name (default: 127.0.0.1, use 0.0.0.0 for all interfaces)")
parser.add_argument("--language", type=str, default="en", choices=["en", "zh", "ja"], help="UI language: en (English), zh (中文), ja (日本語)")
# Service mode argument
parser.add_argument("--service_mode", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False,
help="Enable service mode (default: False). When enabled, uses preset models and restricts UI options.")
# Service initialization arguments
parser.add_argument("--init_service", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Initialize service on startup (default: False)")
parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file path (optional, for display purposes)")
parser.add_argument("--config_path", type=str, default=None, help="Main model path (e.g., 'acestep-v15-turbo')")
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"], help="Processing device (default: auto)")
parser.add_argument("--init_llm", type=lambda x: x.lower() in ['true', '1', 'yes'], default=True, help="Initialize 5Hz LM (default: True)")
parser.add_argument("--lm_model_path", type=str, default=None, help="5Hz LM model path (e.g., 'acestep-5Hz-lm-0.6B')")
parser.add_argument("--backend", type=str, default="vllm", choices=["vllm", "pt"], help="5Hz LM backend (default: vllm)")
parser.add_argument("--use_flash_attention", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Use flash attention (default: auto-detect)")
parser.add_argument("--offload_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=auto_offload, help=f"Offload models to CPU (default: {'True' if auto_offload else 'False'}, auto-detected based on GPU VRAM)")
parser.add_argument("--offload_dit_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload DiT to CPU (default: False)")
args = parser.parse_args()
# Service mode defaults (can be configured via .env file)
if args.service_mode:
print("Service mode enabled - applying preset configurations...")
# Force init_service in service mode
args.init_service = True
# Default DiT model for service mode (from env or fallback)
if args.config_path is None:
args.config_path = os.environ.get(
"SERVICE_MODE_DIT_MODEL",
"acestep-v15-turbo-fix-inst-shift-dynamic"
)
# Default LM model for service mode (from env or fallback)
if args.lm_model_path is None:
args.lm_model_path = os.environ.get(
"SERVICE_MODE_LM_MODEL",
"acestep-5Hz-lm-1.7B-v4-fix"
)
# Backend for service mode (from env or fallback to vllm)
args.backend = os.environ.get("SERVICE_MODE_BACKEND", "vllm")
print(f" DiT model: {args.config_path}")
print(f" LM model: {args.lm_model_path}")
print(f" Backend: {args.backend}")
try:
init_params = None
# If init_service is True, perform initialization before creating UI
if args.init_service:
print("Initializing service from command line...")
# Create handler instances for initialization
dit_handler = AceStepHandler()
llm_handler = LLMHandler()
# Auto-select config_path if not provided
if args.config_path is None:
available_models = dit_handler.get_available_acestep_v15_models()
if available_models:
args.config_path = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else available_models[0]
print(f"Auto-selected config_path: {args.config_path}")
else:
print("Error: No available models found. Please specify --config_path", file=sys.stderr)
sys.exit(1)
# Get project root (same logic as in handler)
current_file = os.path.abspath(__file__)
project_root = os.path.dirname(os.path.dirname(current_file))
# Determine flash attention setting
use_flash_attention = args.use_flash_attention
if use_flash_attention is None:
use_flash_attention = dit_handler.is_flash_attention_available()
# Initialize DiT handler
print(f"Initializing DiT model: {args.config_path} on {args.device}...")
init_status, enable_generate = dit_handler.initialize_service(
project_root=project_root,
config_path=args.config_path,
device=args.device,
use_flash_attention=use_flash_attention,
compile_model=False,
offload_to_cpu=args.offload_to_cpu,
offload_dit_to_cpu=args.offload_dit_to_cpu
)
if not enable_generate:
print(f"Error initializing DiT model: {init_status}", file=sys.stderr)
sys.exit(1)
print(f"DiT model initialized successfully")
# Initialize LM handler if requested
lm_status = ""
if args.init_llm:
if args.lm_model_path is None:
# Try to get default LM model
available_lm_models = llm_handler.get_available_5hz_lm_models()
if available_lm_models:
args.lm_model_path = available_lm_models[0]
print(f"Using default LM model: {args.lm_model_path}")
else:
print("Warning: No LM models available, skipping LM initialization", file=sys.stderr)
args.init_llm = False
if args.init_llm and args.lm_model_path:
checkpoint_dir = os.path.join(project_root, "checkpoints")
print(f"Initializing 5Hz LM: {args.lm_model_path} on {args.device}...")
lm_status, lm_success = llm_handler.initialize(
checkpoint_dir=checkpoint_dir,
lm_model_path=args.lm_model_path,
backend=args.backend,
device=args.device,
offload_to_cpu=args.offload_to_cpu,
dtype=dit_handler.dtype
)
if lm_success:
print(f"5Hz LM initialized successfully")
init_status += f"\n{lm_status}"
else:
print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
init_status += f"\n{lm_status}"
# Prepare initialization parameters for UI
init_params = {
'pre_initialized': True,
'service_mode': args.service_mode,
'checkpoint': args.checkpoint,
'config_path': args.config_path,
'device': args.device,
'init_llm': args.init_llm,
'lm_model_path': args.lm_model_path,
'backend': args.backend,
'use_flash_attention': use_flash_attention,
'offload_to_cpu': args.offload_to_cpu,
'offload_dit_to_cpu': args.offload_dit_to_cpu,
'init_status': init_status,
'enable_generate': enable_generate,
'dit_handler': dit_handler,
'llm_handler': llm_handler,
'language': args.language
}
print("Service initialization completed successfully!")
# Create and launch demo
print(f"Creating Gradio interface with language: {args.language}...")
demo = create_demo(init_params=init_params, language=args.language)
# Enable queue for multi-user support
# This ensures proper request queuing and prevents concurrent generation conflicts
print("Enabling queue for multi-user support...")
demo.queue(
max_size=20, # Maximum queue size (adjust based on your needs)
status_update_rate="auto", # Update rate for queue status
)
print(f"Launching server on {args.server_name}:{args.port}...")
demo.launch(
server_name=args.server_name,
server_port=args.port,
share=args.share,
debug=args.debug,
show_error=True,
prevent_thread_lock=False, # Keep thread locked to maintain server running
inbrowser=False, # Don't auto-open browser
)
except Exception as e:
print(f"Error launching Gradio: {e}", file=sys.stderr)
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()