Gong Junmin commited on
Commit
22101a6
·
1 Parent(s): b04b635

fix service mode

Browse files
Files changed (1) hide show
  1. app.py +136 -12
app.py CHANGED
@@ -2,7 +2,7 @@
2
  ACE-Step v1.5 - HuggingFace Space Entry Point
3
 
4
  This file serves as the entry point for HuggingFace Space deployment.
5
- It imports and uses the existing v1.5 Gradio implementation without modification.
6
  """
7
  import os
8
  import sys
@@ -22,27 +22,151 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
22
  for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
23
  os.environ.pop(proxy_var, None)
24
 
25
- from acestep.acestep_v15_pipeline import create_demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  def main():
29
  """Main entry point for HuggingFace Space"""
30
-
31
- # HuggingFace Space initialization parameters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  init_params = {
33
- 'pre_initialized': False, # Lazy initialization
34
- 'service_mode': True, # Service mode
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  'language': 'en',
36
- 'persistent_storage_path': '/data', # HuggingFace Space persistent storage
37
  }
38
-
39
- # Create demo using existing v1.5 implementation
40
- demo = create_demo(init_params=init_params, language='en')
41
-
 
 
 
 
 
 
 
 
 
42
  # Enable queue for multi-user support
 
43
  demo.queue(max_size=20)
44
-
45
  # Launch
 
46
  demo.launch(
47
  server_name="0.0.0.0",
48
  server_port=7860,
 
2
  ACE-Step v1.5 - HuggingFace Space Entry Point
3
 
4
  This file serves as the entry point for HuggingFace Space deployment.
5
+ It initializes the service and launches the Gradio interface.
6
  """
7
  import os
8
  import sys
 
22
  for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
23
  os.environ.pop(proxy_var, None)
24
 
25
+ import torch
26
+ from acestep.handler import AceStepHandler
27
+ from acestep.llm_inference import LLMHandler
28
+ from acestep.dataset_handler import DatasetHandler
29
+ from acestep.gradio_ui import create_gradio_interface
30
+
31
+
32
+ def get_gpu_memory_gb():
33
+ """
34
+ Get GPU memory in GB. Returns 0 if no GPU is available.
35
+ """
36
+ try:
37
+ if torch.cuda.is_available():
38
+ total_memory = torch.cuda.get_device_properties(0).total_memory
39
+ memory_gb = total_memory / (1024**3)
40
+ return memory_gb
41
+ else:
42
+ return 0
43
+ except Exception as e:
44
+ print(f"Warning: Failed to detect GPU memory: {e}", file=sys.stderr)
45
+ return 0
46
 
47
 
48
  def main():
49
  """Main entry point for HuggingFace Space"""
50
+
51
+ # HuggingFace Space persistent storage path
52
+ persistent_storage_path = "/data"
53
+
54
+ # Detect GPU memory for auto-configuration
55
+ gpu_memory_gb = get_gpu_memory_gb()
56
+ auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
57
+
58
+ if auto_offload:
59
+ print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)")
60
+ print("Auto-enabling CPU offload to reduce GPU memory usage")
61
+ elif gpu_memory_gb > 0:
62
+ print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (>= 16GB)")
63
+ print("CPU offload disabled by default")
64
+ else:
65
+ print("No GPU detected, running on CPU")
66
+
67
+ # Create handler instances
68
+ print("Creating handlers...")
69
+ dit_handler = AceStepHandler(persistent_storage_path=persistent_storage_path)
70
+ llm_handler = LLMHandler(persistent_storage_path=persistent_storage_path)
71
+ dataset_handler = DatasetHandler()
72
+
73
+ # Service mode configuration from environment variables
74
+ config_path = os.environ.get(
75
+ "SERVICE_MODE_DIT_MODEL",
76
+ "acestep-v15-turbo-fix-inst-shift-dynamic"
77
+ )
78
+ lm_model_path = os.environ.get(
79
+ "SERVICE_MODE_LM_MODEL",
80
+ "acestep-5Hz-lm-1.7B-v4-fix"
81
+ )
82
+ backend = os.environ.get("SERVICE_MODE_BACKEND", "vllm")
83
+ device = "auto"
84
+
85
+ print(f"Service mode configuration:")
86
+ print(f" DiT model: {config_path}")
87
+ print(f" LM model: {lm_model_path}")
88
+ print(f" Backend: {backend}")
89
+ print(f" Offload to CPU: {auto_offload}")
90
+
91
+ # Determine flash attention availability
92
+ use_flash_attention = dit_handler.is_flash_attention_available()
93
+ print(f" Flash Attention: {use_flash_attention}")
94
+
95
+ # Initialize DiT model
96
+ print(f"Initializing DiT model: {config_path}...")
97
+ init_status, enable_generate = dit_handler.initialize_service(
98
+ project_root=current_dir,
99
+ config_path=config_path,
100
+ device=device,
101
+ use_flash_attention=use_flash_attention,
102
+ compile_model=False,
103
+ offload_to_cpu=auto_offload,
104
+ offload_dit_to_cpu=False
105
+ )
106
+
107
+ if not enable_generate:
108
+ print(f"Warning: DiT model initialization issue: {init_status}", file=sys.stderr)
109
+ else:
110
+ print("DiT model initialized successfully")
111
+
112
+ # Initialize LM model
113
+ checkpoint_dir = dit_handler._get_checkpoint_dir()
114
+ print(f"Initializing 5Hz LM: {lm_model_path}...")
115
+ lm_status, lm_success = llm_handler.initialize(
116
+ checkpoint_dir=checkpoint_dir,
117
+ lm_model_path=lm_model_path,
118
+ backend=backend,
119
+ device=device,
120
+ offload_to_cpu=auto_offload,
121
+ dtype=dit_handler.dtype
122
+ )
123
+
124
+ if lm_success:
125
+ print("5Hz LM initialized successfully")
126
+ init_status += f"\n{lm_status}"
127
+ else:
128
+ print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
129
+ init_status += f"\n{lm_status}"
130
+
131
+ # Prepare initialization parameters for UI
132
  init_params = {
133
+ 'pre_initialized': True,
134
+ 'service_mode': True,
135
+ 'checkpoint': None,
136
+ 'config_path': config_path,
137
+ 'device': device,
138
+ 'init_llm': True,
139
+ 'lm_model_path': lm_model_path,
140
+ 'backend': backend,
141
+ 'use_flash_attention': use_flash_attention,
142
+ 'offload_to_cpu': auto_offload,
143
+ 'offload_dit_to_cpu': False,
144
+ 'init_status': init_status,
145
+ 'enable_generate': enable_generate,
146
+ 'dit_handler': dit_handler,
147
+ 'llm_handler': llm_handler,
148
  'language': 'en',
149
+ 'persistent_storage_path': persistent_storage_path,
150
  }
151
+
152
+ print("Service initialization completed!")
153
+
154
+ # Create Gradio interface with pre-initialized handlers
155
+ print("Creating Gradio interface...")
156
+ demo = create_gradio_interface(
157
+ dit_handler,
158
+ llm_handler,
159
+ dataset_handler,
160
+ init_params=init_params,
161
+ language='en'
162
+ )
163
+
164
  # Enable queue for multi-user support
165
+ print("Enabling queue for multi-user support...")
166
  demo.queue(max_size=20)
167
+
168
  # Launch
169
+ print("Launching server on 0.0.0.0:7860...")
170
  demo.launch(
171
  server_name="0.0.0.0",
172
  server_port=7860,