Sayoyo commited on
Commit
59ce525
·
1 Parent(s): 8e83122

feat: huggingface_space support

Browse files
acestep/acestep_v15_pipeline.py CHANGED
@@ -64,14 +64,19 @@ def create_demo(init_params=None, language='en'):
64
  Returns:
65
  Gradio Blocks instance
66
  """
 
 
 
 
 
67
  # Use pre-initialized handlers if available, otherwise create new ones
68
  if init_params and init_params.get('pre_initialized') and 'dit_handler' in init_params:
69
  dit_handler = init_params['dit_handler']
70
  llm_handler = init_params['llm_handler']
71
  else:
72
- dit_handler = AceStepHandler() # DiT handler
73
- llm_handler = LLMHandler() # LM handler
74
-
75
  dataset_handler = DatasetHandler() # Dataset handler
76
 
77
  # Create Gradio interface with all handlers and initialization parameters
 
64
  Returns:
65
  Gradio Blocks instance
66
  """
67
+ # Get persistent storage path from init_params (for HuggingFace Space)
68
+ persistent_storage_path = None
69
+ if init_params:
70
+ persistent_storage_path = init_params.get('persistent_storage_path')
71
+
72
  # Use pre-initialized handlers if available, otherwise create new ones
73
  if init_params and init_params.get('pre_initialized') and 'dit_handler' in init_params:
74
  dit_handler = init_params['dit_handler']
75
  llm_handler = init_params['llm_handler']
76
  else:
77
+ dit_handler = AceStepHandler(persistent_storage_path=persistent_storage_path)
78
+ llm_handler = LLMHandler(persistent_storage_path=persistent_storage_path)
79
+
80
  dataset_handler = DatasetHandler() # Dataset handler
81
 
82
  # Create Gradio interface with all handlers and initialization parameters
acestep/handler.py CHANGED
@@ -43,72 +43,121 @@ warnings.filterwarnings("ignore")
43
 
44
  class AceStepHandler:
45
  """ACE-Step Business Logic Handler"""
46
-
47
- def __init__(self):
 
 
 
48
  self.model = None
49
  self.config = None
50
  self.device = "cpu"
51
  self.dtype = torch.float32 # Will be set based on device in initialize_service
52
 
 
 
 
 
 
53
  # VAE for audio encoding/decoding
54
  self.vae = None
55
-
56
  # Text encoder and tokenizer
57
  self.text_encoder = None
58
  self.text_tokenizer = None
59
-
60
  # Silence latent for initialization
61
  self.silence_latent = None
62
-
63
  # Sample rate
64
  self.sample_rate = 48000
65
-
66
  # Reward model (temporarily disabled)
67
  self.reward_model = None
68
-
69
  # Batch size
70
  self.batch_size = 2
71
-
72
  # Custom layers config
73
  self.custom_layers_config = {2: [6], 3: [10, 11], 4: [3], 5: [8, 9], 6: [8]}
74
  self.offload_to_cpu = False
75
  self.offload_dit_to_cpu = False
76
  self.current_offload_cost = 0.0
77
-
78
  # LoRA state
79
  self.lora_loaded = False
80
  self.use_lora = False
81
  self._base_decoder = None # Backup of original decoder
82
-
 
 
 
 
 
 
 
83
  def get_available_checkpoints(self) -> str:
84
  """Return project root directory path"""
85
- # Get project root (handler.py is in acestep/, so go up two levels to project root)
86
- project_root = self._get_project_root()
87
- # default checkpoints
88
- checkpoint_dir = os.path.join(project_root, "checkpoints")
89
  if os.path.exists(checkpoint_dir):
90
  return [checkpoint_dir]
91
  else:
92
  return []
93
-
94
  def get_available_acestep_v15_models(self) -> List[str]:
95
  """Scan and return all model directory names starting with 'acestep-v15-'"""
96
- # Get project root
97
- project_root = self._get_project_root()
98
- checkpoint_dir = os.path.join(project_root, "checkpoints")
99
-
100
  models = []
101
  if os.path.exists(checkpoint_dir):
102
- # Scan all directories starting with 'acestep-v15-' in checkpoints folder
103
  for item in os.listdir(checkpoint_dir):
104
  item_path = os.path.join(checkpoint_dir, item)
105
  if os.path.isdir(item_path) and item.startswith("acestep-v15-"):
106
  models.append(item)
107
-
108
- # Sort by name
109
  models.sort()
110
  return models
111
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def is_flash_attention_available(self) -> bool:
113
  """Check if flash attention is available on the system"""
114
  try:
@@ -309,11 +358,17 @@ class AceStepHandler:
309
 
310
  # Auto-detect project root (independent of passed project_root parameter)
311
  actual_project_root = self._get_project_root()
312
- checkpoint_dir = os.path.join(actual_project_root, "checkpoints")
 
313
 
314
  # 1. Load main model
315
  # config_path is relative path (e.g., "acestep-v15-turbo"), concatenate to checkpoints directory
316
  acestep_v15_checkpoint_path = os.path.join(checkpoint_dir, config_path)
 
 
 
 
 
317
  if os.path.exists(acestep_v15_checkpoint_path):
318
  # Determine attention implementation
319
  if use_flash_attention and self.is_flash_attention_available():
 
43
 
44
  class AceStepHandler:
45
  """ACE-Step Business Logic Handler"""
46
+
47
+ # HuggingFace Space environment detection
48
+ IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
49
+
50
+ def __init__(self, persistent_storage_path: Optional[str] = None):
51
  self.model = None
52
  self.config = None
53
  self.device = "cpu"
54
  self.dtype = torch.float32 # Will be set based on device in initialize_service
55
 
56
+ # HuggingFace Space persistent storage support
57
+ if persistent_storage_path is None and self.IS_HUGGINGFACE_SPACE:
58
+ persistent_storage_path = "/data"
59
+ self.persistent_storage_path = persistent_storage_path
60
+
61
  # VAE for audio encoding/decoding
62
  self.vae = None
63
+
64
  # Text encoder and tokenizer
65
  self.text_encoder = None
66
  self.text_tokenizer = None
67
+
68
  # Silence latent for initialization
69
  self.silence_latent = None
70
+
71
  # Sample rate
72
  self.sample_rate = 48000
73
+
74
  # Reward model (temporarily disabled)
75
  self.reward_model = None
76
+
77
  # Batch size
78
  self.batch_size = 2
79
+
80
  # Custom layers config
81
  self.custom_layers_config = {2: [6], 3: [10, 11], 4: [3], 5: [8, 9], 6: [8]}
82
  self.offload_to_cpu = False
83
  self.offload_dit_to_cpu = False
84
  self.current_offload_cost = 0.0
85
+
86
  # LoRA state
87
  self.lora_loaded = False
88
  self.use_lora = False
89
  self._base_decoder = None # Backup of original decoder
90
+
91
+ def _get_checkpoint_dir(self) -> str:
92
+ """Get checkpoint directory, prioritizing persistent storage if available"""
93
+ if self.persistent_storage_path:
94
+ return os.path.join(self.persistent_storage_path, "checkpoints")
95
+ project_root = self._get_project_root()
96
+ return os.path.join(project_root, "checkpoints")
97
+
98
  def get_available_checkpoints(self) -> str:
99
  """Return project root directory path"""
100
+ checkpoint_dir = self._get_checkpoint_dir()
 
 
 
101
  if os.path.exists(checkpoint_dir):
102
  return [checkpoint_dir]
103
  else:
104
  return []
105
+
106
  def get_available_acestep_v15_models(self) -> List[str]:
107
  """Scan and return all model directory names starting with 'acestep-v15-'"""
108
+ checkpoint_dir = self._get_checkpoint_dir()
109
+
 
 
110
  models = []
111
  if os.path.exists(checkpoint_dir):
 
112
  for item in os.listdir(checkpoint_dir):
113
  item_path = os.path.join(checkpoint_dir, item)
114
  if os.path.isdir(item_path) and item.startswith("acestep-v15-"):
115
  models.append(item)
116
+
 
117
  models.sort()
118
  return models
119
+
120
+ def _ensure_model_downloaded(self, model_name: str, checkpoint_dir: str) -> str:
121
+ """
122
+ Ensure model is downloaded from HuggingFace Hub.
123
+ Used for HuggingFace Space auto-download support.
124
+
125
+ Args:
126
+ model_name: Model directory name (e.g., "acestep-v15-turbo")
127
+ checkpoint_dir: Target checkpoint directory
128
+
129
+ Returns:
130
+ Path to the downloaded model
131
+ """
132
+ from huggingface_hub import snapshot_download
133
+
134
+ # Model name to HuggingFace repo ID mapping
135
+ MODEL_REPO_MAP = {
136
+ "acestep-v15-turbo": "ACE-Step/ACE-Step-v1-3.5B-turbo",
137
+ "acestep-v15-base": "ACE-Step/ACE-Step-v1-3.5B",
138
+ }
139
+
140
+ repo_id = MODEL_REPO_MAP.get(model_name)
141
+ if repo_id is None:
142
+ # Try using model_name as repo_id directly
143
+ repo_id = f"ACE-Step/{model_name}"
144
+
145
+ model_path = os.path.join(checkpoint_dir, model_name)
146
+ logger.info(f"Downloading model {repo_id} to {model_path}...")
147
+
148
+ try:
149
+ snapshot_download(
150
+ repo_id=repo_id,
151
+ local_dir=model_path,
152
+ local_dir_use_symlinks=False,
153
+ )
154
+ logger.info(f"Model {repo_id} downloaded successfully")
155
+ except Exception as e:
156
+ logger.error(f"Failed to download model {repo_id}: {e}")
157
+ raise
158
+
159
+ return model_path
160
+
161
  def is_flash_attention_available(self) -> bool:
162
  """Check if flash attention is available on the system"""
163
  try:
 
358
 
359
  # Auto-detect project root (independent of passed project_root parameter)
360
  actual_project_root = self._get_project_root()
361
+ checkpoint_dir = self._get_checkpoint_dir()
362
+ os.makedirs(checkpoint_dir, exist_ok=True)
363
 
364
  # 1. Load main model
365
  # config_path is relative path (e.g., "acestep-v15-turbo"), concatenate to checkpoints directory
366
  acestep_v15_checkpoint_path = os.path.join(checkpoint_dir, config_path)
367
+
368
+ # Auto-download model if not exists (HuggingFace Space support)
369
+ if not os.path.exists(acestep_v15_checkpoint_path):
370
+ acestep_v15_checkpoint_path = self._ensure_model_downloaded(config_path, checkpoint_dir)
371
+
372
  if os.path.exists(acestep_v15_checkpoint_path):
373
  # Determine attention implementation
374
  if use_flash_attention and self.is_flash_attention_available():
acestep/inference.py CHANGED
@@ -2,7 +2,7 @@
2
  ACE-Step Inference API Module
3
 
4
  This module provides a standardized inference interface for music generation,
5
- designed for third-party integration. It offers both a simplified API and
6
  backward-compatible Gradio UI support.
7
  """
8
 
@@ -15,6 +15,23 @@ from loguru import logger
15
 
16
  from acestep.audio_utils import AudioSaver, generate_uuid_from_params
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  @dataclass
20
  class GenerationParams:
@@ -272,6 +289,7 @@ def _update_metadata_from_lm(
272
  return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
273
 
274
 
 
275
  def generate_music(
276
  dit_handler,
277
  llm_handler,
 
2
  ACE-Step Inference API Module
3
 
4
  This module provides a standardized inference interface for music generation,
5
+ designed for third-party integration. It offers both a simplified API and
6
  backward-compatible Gradio UI support.
7
  """
8
 
 
15
 
16
  from acestep.audio_utils import AudioSaver, generate_uuid_from_params
17
 
18
+ # HuggingFace Space environment detection
19
+ IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
20
+
21
+ def _get_spaces_gpu_decorator(duration=180):
22
+ """
23
+ Get the @spaces.GPU decorator if running in HuggingFace Space environment.
24
+ Returns identity decorator if not in Space environment.
25
+ """
26
+ if IS_HUGGINGFACE_SPACE:
27
+ try:
28
+ import spaces
29
+ return spaces.GPU(duration=duration)
30
+ except ImportError:
31
+ logger.warning("spaces package not found, GPU decorator disabled")
32
+ return lambda func: func
33
+ return lambda func: func
34
+
35
 
36
  @dataclass
37
  class GenerationParams:
 
289
  return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
290
 
291
 
292
+ @_get_spaces_gpu_decorator(duration=180)
293
  def generate_music(
294
  dit_handler,
295
  llm_handler,
acestep/llm_inference.py CHANGED
@@ -26,8 +26,11 @@ class LLMHandler:
26
  """5Hz LM Handler for audio code generation"""
27
 
28
  STOP_REASONING_TAG = "</think>"
29
-
30
- def __init__(self):
 
 
 
31
  """Initialize LLMHandler with default values"""
32
  self.llm = None
33
  self.llm_tokenizer = None
@@ -37,26 +40,37 @@ class LLMHandler:
37
  self.device = "cpu"
38
  self.dtype = torch.float32
39
  self.offload_to_cpu = False
40
-
41
- # Shared constrained decoding processor (initialized once when LLM is loaded)
 
 
 
 
 
42
  self.constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None
43
-
44
- # Shared HuggingFace model for perplexity calculation (when using vllm backend)
45
  self._hf_model_for_scoring = None
46
-
47
- def get_available_5hz_lm_models(self) -> List[str]:
48
- """Scan and return all model directory names starting with 'acestep-5Hz-lm-'"""
 
 
49
  current_file = os.path.abspath(__file__)
50
  project_root = os.path.dirname(os.path.dirname(current_file))
51
- checkpoint_dir = os.path.join(project_root, "checkpoints")
52
-
 
 
 
 
53
  models = []
54
  if os.path.exists(checkpoint_dir):
55
  for item in os.listdir(checkpoint_dir):
56
  item_path = os.path.join(checkpoint_dir, item)
57
  if os.path.isdir(item_path) and item.startswith("acestep-5Hz-lm-"):
58
  models.append(item)
59
-
60
  models.sort()
61
  return models
62
 
 
26
  """5Hz LM Handler for audio code generation"""
27
 
28
  STOP_REASONING_TAG = "</think>"
29
+
30
+ # HuggingFace Space environment detection
31
+ IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
32
+
33
+ def __init__(self, persistent_storage_path: Optional[str] = None):
34
  """Initialize LLMHandler with default values"""
35
  self.llm = None
36
  self.llm_tokenizer = None
 
40
  self.device = "cpu"
41
  self.dtype = torch.float32
42
  self.offload_to_cpu = False
43
+
44
+ # HuggingFace Space persistent storage support
45
+ if persistent_storage_path is None and self.IS_HUGGINGFACE_SPACE:
46
+ persistent_storage_path = "/data"
47
+ self.persistent_storage_path = persistent_storage_path
48
+
49
+ # Shared constrained decoding processor
50
  self.constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None
51
+
52
+ # Shared HuggingFace model for perplexity calculation
53
  self._hf_model_for_scoring = None
54
+
55
+ def _get_checkpoint_dir(self) -> str:
56
+ """Get checkpoint directory, prioritizing persistent storage"""
57
+ if self.persistent_storage_path:
58
+ return os.path.join(self.persistent_storage_path, "checkpoints")
59
  current_file = os.path.abspath(__file__)
60
  project_root = os.path.dirname(os.path.dirname(current_file))
61
+ return os.path.join(project_root, "checkpoints")
62
+
63
+ def get_available_5hz_lm_models(self) -> List[str]:
64
+ """Scan and return all model directory names starting with 'acestep-5Hz-lm-'"""
65
+ checkpoint_dir = self._get_checkpoint_dir()
66
+
67
  models = []
68
  if os.path.exists(checkpoint_dir):
69
  for item in os.listdir(checkpoint_dir):
70
  item_path = os.path.join(checkpoint_dir, item)
71
  if os.path.isdir(item_path) and item.startswith("acestep-5Hz-lm-"):
72
  models.append(item)
73
+
74
  models.sort()
75
  return models
76