ChuxiJ commited on
Commit
c60d64b
·
1 Parent(s): 6d3b89f

switch to zero gpu

Browse files
README.md CHANGED
@@ -3,10 +3,12 @@ title: ACE-Step v1.5
3
  emoji: 🎵
4
  colorFrom: blue
5
  colorTo: purple
6
- sdk: docker
7
- app_port: 7860
 
8
  pinned: false
9
  license: mit
 
10
  short_description: Music Generation Foundation Model v1.5
11
  ---
12
 
 
3
  emoji: 🎵
4
  colorFrom: blue
5
  colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 6.2.0
8
+ python_version: 3.11
9
  pinned: false
10
  license: mit
11
+ app_file: app.py
12
  short_description: Music Generation Foundation Model v1.5
13
  ---
14
 
acestep/gradio_ui/events/__init__.py CHANGED
@@ -2,8 +2,10 @@
2
  Gradio UI Event Handlers Module
3
  Main entry point for setting up all event handlers
4
  """
 
5
  import gradio as gr
6
  from typing import Optional
 
7
 
8
  # Import handler modules
9
  from . import generation_handlers as gen_h
@@ -11,6 +13,24 @@ from . import results_handlers as res_h
11
  from . import training_handlers as train_h
12
  from acestep.gradio_ui.i18n import t
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section, init_params=None):
16
  """Setup event handlers connecting UI components and business logic
@@ -618,12 +638,13 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
618
  ]
619
  )
620
 
 
621
  def generation_wrapper(selected_model, generation_mode, simple_query_input, simple_vocal_language, *args):
622
  """Wrapper that selects the appropriate DiT handler based on model selection"""
623
  # Convert args to list for modification
624
  args_list = list(args)
625
 
626
- # args order (after simple mode params):
627
  # captions (0), lyrics (1), bpm (2), key_scale (3), time_signature (4), vocal_language (5),
628
  # inference_steps (6), guidance_scale (7), random_seed_checkbox (8), seed (9),
629
  # reference_audio (10), audio_duration (11), batch_size_input (12), src_audio (13),
@@ -684,7 +705,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
684
  # Mark as formatted caption (LM-generated sample)
685
  args_list[36] = True # is_format_caption_state
686
 
687
- # Determine which handler to use
688
  active_handler = dit_handler # Default to primary handler
689
  if dit_handler_2 is not None and selected_model == config_path_2:
690
  active_handler = dit_handler_2
 
2
  Gradio UI Event Handlers Module
3
  Main entry point for setting up all event handlers
4
  """
5
+ import os
6
  import gradio as gr
7
  from typing import Optional
8
+ from loguru import logger
9
 
10
  # Import handler modules
11
  from . import generation_handlers as gen_h
 
13
  from . import training_handlers as train_h
14
  from acestep.gradio_ui.i18n import t
15
 
16
+ # HuggingFace Space environment detection for ZeroGPU support
17
+ IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
18
+
19
+
20
+ def _get_spaces_gpu_decorator(duration=120):
21
+ """
22
+ Get the @spaces.GPU decorator if running in HuggingFace Space environment.
23
+ Returns identity decorator if not in Space environment.
24
+ """
25
+ if IS_HUGGINGFACE_SPACE:
26
+ try:
27
+ import spaces
28
+ return spaces.GPU(duration=duration)
29
+ except ImportError:
30
+ logger.warning("spaces package not found, GPU decorator disabled")
31
+ return lambda func: func
32
+ return lambda func: func
33
+
34
 
35
  def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section, init_params=None):
36
  """Setup event handlers connecting UI components and business logic
 
638
  ]
639
  )
640
 
641
+ @_get_spaces_gpu_decorator(duration=300)
642
  def generation_wrapper(selected_model, generation_mode, simple_query_input, simple_vocal_language, *args):
643
  """Wrapper that selects the appropriate DiT handler based on model selection"""
644
  # Convert args to list for modification
645
  args_list = list(args)
646
 
647
+ # args order (after simple mode params):
648
  # captions (0), lyrics (1), bpm (2), key_scale (3), time_signature (4), vocal_language (5),
649
  # inference_steps (6), guidance_scale (7), random_seed_checkbox (8), seed (9),
650
  # reference_audio (10), audio_duration (11), batch_size_input (12), src_audio (13),
 
705
  # Mark as formatted caption (LM-generated sample)
706
  args_list[36] = True # is_format_caption_state
707
 
708
+ # Determine which handler to use based on model selection
709
  active_handler = dit_handler # Default to primary handler
710
  if dit_handler_2 is not None and selected_model == config_path_2:
711
  active_handler = dit_handler_2
acestep/gradio_ui/events/generation_handlers.py CHANGED
@@ -8,6 +8,7 @@ import random
8
  import glob
9
  import gradio as gr
10
  from typing import Optional, List, Tuple
 
11
  from acestep.constants import (
12
  TASK_TYPES_TURBO,
13
  TASK_TYPES_BASE,
@@ -16,6 +17,25 @@ from acestep.gradio_ui.i18n import t
16
  from acestep.inference import understand_music, create_sample, format_sample
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def parse_and_validate_timesteps(
20
  timesteps_str: str,
21
  inference_steps: int
@@ -746,15 +766,15 @@ def handle_generation_mode_change(mode: str):
746
  think_checkbox_update, # think_checkbox - disabled for cover/repaint modes
747
  )
748
 
749
-
750
  def process_source_audio(dit_handler, llm_handler, src_audio, constrained_decoding_debug):
751
  """
752
  Process source audio: convert to codes and then transcribe.
753
  This combines convert_src_audio_to_codes_wrapper + transcribe_audio_codes.
754
 
755
  Args:
756
- dit_handler: DiT handler instance for audio code conversion
757
- llm_handler: LLM handler instance for transcription
758
  src_audio: Path to source audio file
759
  constrained_decoding_debug: Whether to enable debug logging
760
 
@@ -799,7 +819,7 @@ def process_source_audio(dit_handler, llm_handler, src_audio, constrained_decodi
799
  True # Set is_format_caption to True
800
  )
801
 
802
-
803
  def handle_create_sample(
804
  llm_handler,
805
  query: str,
@@ -819,7 +839,7 @@ def handle_create_sample(
819
  Note: cfg_scale and negative_prompt are not supported in create_sample mode.
820
 
821
  Args:
822
- llm_handler: LLM handler instance
823
  query: User's natural language music description
824
  instrumental: Whether to generate instrumental music
825
  vocal_language: Preferred vocal language for constrained decoding
@@ -929,7 +949,7 @@ def handle_create_sample(
929
  result.status_message, # status_output
930
  )
931
 
932
-
933
  def handle_format_sample(
934
  llm_handler,
935
  caption: str,
@@ -952,7 +972,7 @@ def handle_format_sample(
952
  Note: cfg_scale and negative_prompt are not supported in format mode.
953
 
954
  Args:
955
- llm_handler: LLM handler instance
956
  caption: User's caption/description
957
  lyrics: User's lyrics
958
  bpm: User-provided BPM (optional, for constrained decoding)
 
8
  import glob
9
  import gradio as gr
10
  from typing import Optional, List, Tuple
11
+ from loguru import logger
12
  from acestep.constants import (
13
  TASK_TYPES_TURBO,
14
  TASK_TYPES_BASE,
 
17
  from acestep.inference import understand_music, create_sample, format_sample
18
 
19
 
20
+ # HuggingFace Space environment detection for ZeroGPU support
21
+ IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
22
+
23
+
24
+ def _get_spaces_gpu_decorator(duration=120):
25
+ """
26
+ Get the @spaces.GPU decorator if running in HuggingFace Space environment.
27
+ Returns identity decorator if not in Space environment.
28
+ """
29
+ if IS_HUGGINGFACE_SPACE:
30
+ try:
31
+ import spaces
32
+ return spaces.GPU(duration=duration)
33
+ except ImportError:
34
+ logger.warning("spaces package not found, GPU decorator disabled")
35
+ return lambda func: func
36
+ return lambda func: func
37
+
38
+
39
  def parse_and_validate_timesteps(
40
  timesteps_str: str,
41
  inference_steps: int
 
766
  think_checkbox_update, # think_checkbox - disabled for cover/repaint modes
767
  )
768
 
769
+ @_get_spaces_gpu_decorator(duration=180)
770
  def process_source_audio(dit_handler, llm_handler, src_audio, constrained_decoding_debug):
771
  """
772
  Process source audio: convert to codes and then transcribe.
773
  This combines convert_src_audio_to_codes_wrapper + transcribe_audio_codes.
774
 
775
  Args:
776
+ dit_handler: DiT handler instance
777
+ llm_handler: LLM handler instance
778
  src_audio: Path to source audio file
779
  constrained_decoding_debug: Whether to enable debug logging
780
 
 
819
  True # Set is_format_caption to True
820
  )
821
 
822
+ @_get_spaces_gpu_decorator(duration=180)
823
  def handle_create_sample(
824
  llm_handler,
825
  query: str,
 
839
  Note: cfg_scale and negative_prompt are not supported in create_sample mode.
840
 
841
  Args:
842
+ llm_handler: LLM handler instance (unused, fetched from registry)
843
  query: User's natural language music description
844
  instrumental: Whether to generate instrumental music
845
  vocal_language: Preferred vocal language for constrained decoding
 
949
  result.status_message, # status_output
950
  )
951
 
952
+ @_get_spaces_gpu_decorator(duration=180)
953
  def handle_format_sample(
954
  llm_handler,
955
  caption: str,
 
972
  Note: cfg_scale and negative_prompt are not supported in format mode.
973
 
974
  Args:
975
+ llm_handler: LLM handler instance (unused, fetched from registry)
976
  caption: User's caption/description
977
  lyrics: User's lyrics
978
  bpm: User-provided BPM (optional, for constrained decoding)
acestep/gradio_ui/events/results_handlers.py CHANGED
@@ -18,6 +18,26 @@ from acestep.gradio_ui.i18n import t
18
  from acestep.gradio_ui.events.generation_handlers import parse_and_validate_timesteps
19
  from acestep.inference import generate_music, GenerationParams, GenerationConfig
20
  from acestep.audio_utils import save_audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  def parse_lrc_to_subtitles(lrc_text: str, total_duration: Optional[float] = None) -> List[Dict[str, Any]]:
@@ -1038,7 +1058,7 @@ def calculate_score_handler(
1038
  error_msg = t("messages.score_error", error=str(e)) + f"\n{traceback.format_exc()}"
1039
  return error_msg
1040
 
1041
-
1042
  def calculate_score_handler_with_selection(
1043
  dit_handler,
1044
  llm_handler,
@@ -1152,7 +1172,7 @@ def calculate_score_handler_with_selection(
1152
  batch_queue
1153
  )
1154
 
1155
-
1156
  def generate_lrc_handler(dit_handler, sample_idx, current_batch_index, batch_queue, vocal_language, inference_steps):
1157
  """
1158
  Generate LRC timestamps for a specific audio sample.
@@ -1165,7 +1185,7 @@ def generate_lrc_handler(dit_handler, sample_idx, current_batch_index, batch_que
1165
  This decouples audio value updates from subtitle updates, avoiding flickering.
1166
 
1167
  Args:
1168
- dit_handler: DiT handler instance with get_lyric_timestamp method
1169
  sample_idx: Which sample to generate LRC for (1-8)
1170
  current_batch_index: Current batch index in batch_queue
1171
  batch_queue: Dictionary storing all batch generation data
 
18
  from acestep.gradio_ui.events.generation_handlers import parse_and_validate_timesteps
19
  from acestep.inference import generate_music, GenerationParams, GenerationConfig
20
  from acestep.audio_utils import save_audio
21
+ from loguru import logger
22
+
23
+
24
+ # HuggingFace Space environment detection for ZeroGPU support
25
+ IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
26
+
27
+
28
+ def _get_spaces_gpu_decorator(duration=120):
29
+ """
30
+ Get the @spaces.GPU decorator if running in HuggingFace Space environment.
31
+ Returns identity decorator if not in Space environment.
32
+ """
33
+ if IS_HUGGINGFACE_SPACE:
34
+ try:
35
+ import spaces
36
+ return spaces.GPU(duration=duration)
37
+ except ImportError:
38
+ logger.warning("spaces package not found, GPU decorator disabled")
39
+ return lambda func: func
40
+ return lambda func: func
41
 
42
 
43
  def parse_lrc_to_subtitles(lrc_text: str, total_duration: Optional[float] = None) -> List[Dict[str, Any]]:
 
1058
  error_msg = t("messages.score_error", error=str(e)) + f"\n{traceback.format_exc()}"
1059
  return error_msg
1060
 
1061
+ @_get_spaces_gpu_decorator(duration=240)
1062
  def calculate_score_handler_with_selection(
1063
  dit_handler,
1064
  llm_handler,
 
1172
  batch_queue
1173
  )
1174
 
1175
+ @_get_spaces_gpu_decorator(duration=240)
1176
  def generate_lrc_handler(dit_handler, sample_idx, current_batch_index, batch_queue, vocal_language, inference_steps):
1177
  """
1178
  Generate LRC timestamps for a specific audio sample.
 
1185
  This decouples audio value updates from subtitle updates, avoiding flickering.
1186
 
1187
  Args:
1188
+ dit_handler: DiT handler instance (unused, fetched from registry)
1189
  sample_idx: Which sample to generate LRC for (1-8)
1190
  current_batch_index: Current batch index in batch_queue
1191
  batch_queue: Dictionary storing all batch generation data
acestep/handler.py CHANGED
@@ -199,14 +199,24 @@ class AceStepHandler:
199
 
200
  return model_path
201
 
202
- def is_flash_attention_available(self) -> bool:
203
- """Check if flash attention is available on the system"""
 
204
  try:
205
- import flash_attn
206
  return True
207
  except ImportError:
208
  return False
209
-
 
 
 
 
 
 
 
 
 
210
  def is_turbo_model(self) -> bool:
211
  """Check if the currently loaded model is a turbo model"""
212
  if self.config is None:
@@ -425,33 +435,38 @@ class AceStepHandler:
425
  acestep_v15_checkpoint_path = self._ensure_model_downloaded(config_path, checkpoint_dir)
426
 
427
  if os.path.exists(acestep_v15_checkpoint_path):
428
- # Determine attention implementation
429
- if use_flash_attention and self.is_flash_attention_available():
430
- attn_implementation = "flash_attention_2"
431
  self.dtype = torch.bfloat16
432
  else:
433
  attn_implementation = "sdpa"
434
 
435
- try:
436
- logger.info(f"[initialize_service] Attempting to load model with attention implementation: {attn_implementation}")
437
- self.model = AutoModel.from_pretrained(
438
- acestep_v15_checkpoint_path,
439
- trust_remote_code=True,
440
- attn_implementation=attn_implementation,
441
- dtype="bfloat16"
442
- )
443
- except Exception as e:
444
- logger.warning(f"[initialize_service] Failed to load model with {attn_implementation}: {e}")
445
- if attn_implementation == "sdpa":
446
- logger.info("[initialize_service] Falling back to eager attention")
447
- attn_implementation = "eager"
448
  self.model = AutoModel.from_pretrained(
449
- acestep_v15_checkpoint_path,
450
- trust_remote_code=True,
451
- attn_implementation=attn_implementation
 
452
  )
453
- else:
454
- raise e
 
 
 
 
455
 
456
  self.model.config._attn_implementation = attn_implementation
457
  self.config = self.model.config
@@ -466,6 +481,8 @@ class AceStepHandler:
466
  else:
467
  self.model = self.model.to("cpu").to(self.dtype)
468
  self.model.eval()
 
 
469
 
470
  if compile_model:
471
  self.model = torch.compile(self.model)
@@ -498,7 +515,8 @@ class AceStepHandler:
498
  self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
499
  # Always keep silence_latent on GPU - it's used in many places outside model context
500
  # and is small enough that it won't significantly impact VRAM
501
- self.silence_latent = self.silence_latent.to(device).to(self.dtype)
 
502
  else:
503
  raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
504
  else:
@@ -519,6 +537,8 @@ class AceStepHandler:
519
  else:
520
  self.vae = self.vae.to("cpu").to(vae_dtype)
521
  self.vae.eval()
 
 
522
  else:
523
  raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
524
 
@@ -534,12 +554,31 @@ class AceStepHandler:
534
  else:
535
  if os.path.exists(text_encoder_path):
536
  self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)
537
- self.text_encoder = AutoModel.from_pretrained(text_encoder_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
  if not self.offload_to_cpu:
539
  self.text_encoder = self.text_encoder.to(device).to(self.dtype)
540
  else:
541
  self.text_encoder = self.text_encoder.to("cpu").to(self.dtype)
542
  self.text_encoder.eval()
 
 
543
  else:
544
  raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}")
545
 
@@ -2722,9 +2761,12 @@ class AceStepHandler:
2722
  pass
2723
 
2724
  if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
 
 
 
2725
  return {
2726
  "audios": [],
2727
- "status_message": "❌ Model not fully initialized. Please initialize all components first.",
2728
  "extra_outputs": {},
2729
  "success": False,
2730
  "error": "Model not fully initialized",
 
199
 
200
  return model_path
201
 
202
+
203
+ def is_flash_attn3_available(self) -> bool:
204
+ """Check if flash-attn3 via kernels library is available"""
205
  try:
206
+ import kernels
207
  return True
208
  except ImportError:
209
  return False
210
+
211
+ def get_best_attn_implementation(self) -> str:
212
+ """Get the best available attention implementation"""
213
+ if self.is_flash_attn3_available():
214
+ return "kernels-community/flash-attn3"
215
+ elif self.is_flash_attention_available():
216
+ return "flash_attention_2"
217
+ else:
218
+ return "sdpa"
219
+
220
  def is_turbo_model(self) -> bool:
221
  """Check if the currently loaded model is a turbo model"""
222
  if self.config is None:
 
435
  acestep_v15_checkpoint_path = self._ensure_model_downloaded(config_path, checkpoint_dir)
436
 
437
  if os.path.exists(acestep_v15_checkpoint_path):
438
+ # Determine attention implementation (prefer flash-attn3 > flash_attention_2 > sdpa)
439
+ if use_flash_attention:
440
+ attn_implementation = self.get_best_attn_implementation()
441
  self.dtype = torch.bfloat16
442
  else:
443
  attn_implementation = "sdpa"
444
 
445
+ # Try loading with the best available attention implementation, with fallbacks
446
+ attn_fallback_order = [attn_implementation]
447
+ if attn_implementation == "kernels-community/flash-attn3":
448
+ attn_fallback_order.extend(["flash_attention_2", "sdpa", "eager"])
449
+ elif attn_implementation == "flash_attention_2":
450
+ attn_fallback_order.extend(["sdpa", "eager"])
451
+ elif attn_implementation == "sdpa":
452
+ attn_fallback_order.append("eager")
453
+
454
+ for attn_impl in attn_fallback_order:
455
+ try:
456
+ logger.info(f"[initialize_service] Attempting to load model with attention implementation: {attn_impl}")
457
+
458
  self.model = AutoModel.from_pretrained(
459
+ acestep_v15_checkpoint_path,
460
+ trust_remote_code=True,
461
+ attn_implementation=attn_impl,
462
+ dtype="bfloat16"
463
  )
464
+ attn_implementation = attn_impl
465
+ break
466
+ except Exception as e:
467
+ logger.warning(f"[initialize_service] Failed to load model with {attn_impl}: {e}")
468
+ if attn_impl == attn_fallback_order[-1]:
469
+ raise e
470
 
471
  self.model.config._attn_implementation = attn_implementation
472
  self.config = self.model.config
 
481
  else:
482
  self.model = self.model.to("cpu").to(self.dtype)
483
  self.model.eval()
484
+ # Disable gradients for all parameters (required for ZeroGPU pickling)
485
+ self.model.requires_grad_(False)
486
 
487
  if compile_model:
488
  self.model = torch.compile(self.model)
 
515
  self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
516
  # Always keep silence_latent on GPU - it's used in many places outside model context
517
  # and is small enough that it won't significantly impact VRAM
518
+ # Use detach() to ensure no gradients (required for ZeroGPU pickling)
519
+ self.silence_latent = self.silence_latent.to(device).to(self.dtype).detach()
520
  else:
521
  raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
522
  else:
 
537
  else:
538
  self.vae = self.vae.to("cpu").to(vae_dtype)
539
  self.vae.eval()
540
+ # Disable gradients for all parameters (required for ZeroGPU pickling)
541
+ self.vae.requires_grad_(False)
542
  else:
543
  raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
544
 
 
554
  else:
555
  if os.path.exists(text_encoder_path):
556
  self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)
557
+ # Use best attention implementation for text encoder
558
+ text_encoder_attn = self.get_best_attn_implementation()
559
+ text_encoder_loaded = False
560
+ for attn_impl in [text_encoder_attn, "flash_attention_2", "sdpa", "eager"]:
561
+ try:
562
+ self.text_encoder = AutoModel.from_pretrained(
563
+ text_encoder_path,
564
+ attn_implementation=attn_impl,
565
+ torch_dtype=self.dtype,
566
+ )
567
+ logger.info(f"[initialize_service] Text encoder loaded with {attn_impl}")
568
+ text_encoder_loaded = True
569
+ break
570
+ except Exception as e:
571
+ logger.warning(f"[initialize_service] Failed to load text encoder with {attn_impl}: {e}")
572
+ continue
573
+ if not text_encoder_loaded:
574
+ raise RuntimeError("Failed to load text encoder with any attention implementation")
575
  if not self.offload_to_cpu:
576
  self.text_encoder = self.text_encoder.to(device).to(self.dtype)
577
  else:
578
  self.text_encoder = self.text_encoder.to("cpu").to(self.dtype)
579
  self.text_encoder.eval()
580
+ # Disable gradients for all parameters (required for ZeroGPU pickling)
581
+ self.text_encoder.requires_grad_(False)
582
  else:
583
  raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}")
584
 
 
2761
  pass
2762
 
2763
  if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
2764
+ missing = [k for k, v in [("model", self.model), ("vae", self.vae),
2765
+ ("text_tokenizer", self.text_tokenizer), ("text_encoder", self.text_encoder)] if v is None]
2766
+ logger.error(f"[generate_music] Model not fully initialized. Missing: {missing}")
2767
  return {
2768
  "audios": [],
2769
+ "status_message": f"❌ Model not fully initialized. Missing components: {missing}",
2770
  "extra_outputs": {},
2771
  "success": False,
2772
  "error": "Model not fully initialized",
acestep/inference.py CHANGED
@@ -18,20 +18,6 @@ from acestep.audio_utils import AudioSaver, generate_uuid_from_params
18
  # HuggingFace Space environment detection
19
  IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
20
 
21
- def _get_spaces_gpu_decorator(duration=180):
22
- """
23
- Get the @spaces.GPU decorator if running in HuggingFace Space environment.
24
- Returns identity decorator if not in Space environment.
25
- """
26
- if IS_HUGGINGFACE_SPACE:
27
- try:
28
- import spaces
29
- return spaces.GPU(duration=duration)
30
- except ImportError:
31
- logger.warning("spaces package not found, GPU decorator disabled")
32
- return lambda func: func
33
- return lambda func: func
34
-
35
 
36
  @dataclass
37
  class GenerationParams:
@@ -289,7 +275,6 @@ def _update_metadata_from_lm(
289
  return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
290
 
291
 
292
- @_get_spaces_gpu_decorator(duration=180)
293
  def generate_music(
294
  dit_handler,
295
  llm_handler,
@@ -924,6 +909,19 @@ def create_sample(
924
  ... print(f"Lyrics: {result.lyrics}")
925
  ... print(f"BPM: {result.bpm}")
926
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
927
  # Check if LLM is initialized
928
  if not llm_handler.llm_initialized:
929
  return CreateSampleResult(
 
18
  # HuggingFace Space environment detection
19
  IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  @dataclass
23
  class GenerationParams:
 
275
  return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
276
 
277
 
 
278
  def generate_music(
279
  dit_handler,
280
  llm_handler,
 
909
  ... print(f"Lyrics: {result.lyrics}")
910
  ... print(f"BPM: {result.bpm}")
911
  """
912
+ import torch
913
+ # Debug logging for ZeroGPU diagnosis
914
+ logger.info(f"[create_sample Debug] Entry: IS_HUGGINGFACE_SPACE={IS_HUGGINGFACE_SPACE}")
915
+ logger.info(f"[create_sample Debug] torch.cuda.is_available()={torch.cuda.is_available()}")
916
+ if torch.cuda.is_available():
917
+ logger.info(f"[create_sample Debug] torch.cuda.current_device()={torch.cuda.current_device()}")
918
+ logger.info(f"[create_sample Debug] llm_handler.device={llm_handler.device}, llm_handler.offload_to_cpu={llm_handler.offload_to_cpu}")
919
+ if llm_handler.llm is not None:
920
+ try:
921
+ logger.info(f"[create_sample Debug] Model device: {next(llm_handler.llm.parameters()).device}")
922
+ except Exception as e:
923
+ logger.info(f"[create_sample Debug] Could not get model device: {e}")
924
+
925
  # Check if LLM is initialized
926
  if not llm_handler.llm_initialized:
927
  return CreateSampleResult(
acestep/llm_inference.py CHANGED
@@ -30,6 +30,9 @@ class LLMHandler:
30
  # HuggingFace Space environment detection
31
  IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
32
 
 
 
 
33
  def __init__(self, persistent_storage_path: Optional[str] = None):
34
  """Initialize LLMHandler with default values"""
35
  self.llm = None
@@ -190,20 +193,74 @@ class LLMHandler:
190
  return self.build_formatted_prompt(
191
  caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
192
  )
193
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  def _load_pytorch_model(self, model_path: str, device: str) -> Tuple[bool, str]:
195
  """Load PyTorch model from path and return (success, status_message)"""
196
  try:
197
- self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  if not self.offload_to_cpu:
199
  self.llm = self.llm.to(device).to(self.dtype)
200
  else:
201
  self.llm = self.llm.to("cpu").to(self.dtype)
 
202
  self.llm.eval()
 
 
203
  self.llm_backend = "pt"
204
  self.llm_initialized = True
205
  logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}")
206
- status_msg = f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nBackend: PyTorch\nDevice: {device}"
207
  return True, status_msg
208
  except Exception as e:
209
  return False, f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
@@ -312,6 +369,11 @@ class LLMHandler:
312
 
313
  self.device = device
314
  self.offload_to_cpu = offload_to_cpu
 
 
 
 
 
315
  # Set dtype based on device: bfloat16 for cuda, float32 for cpu
316
  if dtype is None:
317
  self.dtype = torch.bfloat16 if device in ["cuda", "xpu"] else torch.float32
@@ -577,8 +639,11 @@ class LLMHandler:
577
  )
578
 
579
  with self._load_model_context():
580
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
581
-
 
 
 
582
  # Calculate max_new_tokens based on target_duration if specified
583
  # 5 audio codes = 1 second, plus ~500 tokens for CoT metadata and safety margin
584
  if target_duration is not None and target_duration > 0:
@@ -618,7 +683,7 @@ class LLMHandler:
618
  truncation=True,
619
  )
620
  self.llm_tokenizer.padding_side = original_padding_side
621
- batch_inputs_tokenized = {k: v.to(self.device) for k, v in batch_inputs_tokenized.items()}
622
 
623
  # Extract batch inputs
624
  batch_input_ids = batch_inputs_tokenized['input_ids']
@@ -1988,7 +2053,8 @@ class LLMHandler:
1988
  This allows us to call update_state() after each token generation.
1989
  """
1990
  model = self.llm
1991
- device = self.device
 
1992
 
1993
  # Initialize generated sequences
1994
  generated_ids = input_ids.clone()
@@ -2088,7 +2154,8 @@ class LLMHandler:
2088
  Batch format: [cond_input, uncond_input]
2089
  """
2090
  model = self.llm
2091
- device = self.device
 
2092
  batch_size = batch_input_ids.shape[0] // 2 # Half are conditional, half are unconditional
2093
  cond_start_idx = 0
2094
  uncond_start_idx = batch_size
@@ -2309,7 +2376,30 @@ class LLMHandler:
2309
  Context manager to load a model to GPU and offload it back to CPU after use.
2310
  Only used for PyTorch backend when offload_to_cpu is True.
2311
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2312
  if not self.offload_to_cpu:
 
2313
  yield
2314
  return
2315
 
@@ -2383,7 +2473,9 @@ class LLMHandler:
2383
  device = next(model_runner.model.parameters()).device
2384
  self._hf_model_for_scoring = self._hf_model_for_scoring.to(device)
2385
  self._hf_model_for_scoring.eval()
2386
-
 
 
2387
  logger.info(f"HuggingFace model for scoring ready on {device}")
2388
 
2389
  return self._hf_model_for_scoring
 
30
  # HuggingFace Space environment detection
31
  IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
32
 
33
+ # Force IS_ZEROGPU=True when on HuggingFace Space, as the env var detection is unreliable
34
+ IS_ZEROGPU = IS_HUGGINGFACE_SPACE or os.environ.get("ZEROGPU") is not None
35
+
36
  def __init__(self, persistent_storage_path: Optional[str] = None):
37
  """Initialize LLMHandler with default values"""
38
  self.llm = None
 
193
  return self.build_formatted_prompt(
194
  caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
195
  )
196
+
197
+ def is_flash_attn3_available(self) -> bool:
198
+ """Check if flash-attn3 via kernels library is available"""
199
+ try:
200
+ import kernels
201
+ return True
202
+ except ImportError:
203
+ return False
204
+
205
+ def is_flash_attention_available(self) -> bool:
206
+ """Check if flash attention is available on the system"""
207
+ try:
208
+ import flash_attn
209
+ return True
210
+ except ImportError:
211
+ return False
212
+
213
+ def get_best_attn_implementation(self) -> str:
214
+ """Get the best available attention implementation"""
215
+ if self.is_flash_attn3_available():
216
+ return "kernels-community/flash-attn3"
217
+ elif self.is_flash_attention_available():
218
+ return "flash_attention_2"
219
+ else:
220
+ return "sdpa"
221
+
222
  def _load_pytorch_model(self, model_path: str, device: str) -> Tuple[bool, str]:
223
  """Load PyTorch model from path and return (success, status_message)"""
224
  try:
225
+ # Try loading with the best available attention implementation
226
+ attn_implementation = self.get_best_attn_implementation()
227
+ attn_fallback_order = [attn_implementation]
228
+ if attn_implementation == "kernels-community/flash-attn3":
229
+ attn_fallback_order.extend(["flash_attention_2", "sdpa", "eager"])
230
+ elif attn_implementation == "flash_attention_2":
231
+ attn_fallback_order.extend(["sdpa", "eager"])
232
+ elif attn_implementation == "sdpa":
233
+ attn_fallback_order.append("eager")
234
+
235
+ for attn_impl in attn_fallback_order:
236
+ try:
237
+ logger.info(f"[LLM Load] Attempting to load model with attention implementation: {attn_impl}")
238
+ self.llm = AutoModelForCausalLM.from_pretrained(
239
+ model_path,
240
+ trust_remote_code=True,
241
+ attn_implementation=attn_impl,
242
+ torch_dtype=self.dtype,
243
+ )
244
+ attn_implementation = attn_impl
245
+ break
246
+ except Exception as e:
247
+ logger.warning(f"[LLM Load] Failed to load model with {attn_impl}: {e}")
248
+ if attn_impl == attn_fallback_order[-1]:
249
+ raise e
250
+
251
+ logger.info(f"[LLM Load Debug] Model loaded with {attn_implementation}, initial device: {next(self.llm.parameters()).device}")
252
  if not self.offload_to_cpu:
253
  self.llm = self.llm.to(device).to(self.dtype)
254
  else:
255
  self.llm = self.llm.to("cpu").to(self.dtype)
256
+ logger.info(f"[LLM Load Debug] After .to(), model device: {next(self.llm.parameters()).device}")
257
  self.llm.eval()
258
+ # Disable gradients for all parameters (required for ZeroGPU pickling)
259
+ self.llm.requires_grad_(False)
260
  self.llm_backend = "pt"
261
  self.llm_initialized = True
262
  logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}")
263
+ status_msg = f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nBackend: PyTorch ({attn_implementation})\nDevice: {device}"
264
  return True, status_msg
265
  except Exception as e:
266
  return False, f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
 
369
 
370
  self.device = device
371
  self.offload_to_cpu = offload_to_cpu
372
+
373
+ # Debug logging for ZeroGPU diagnosis
374
+ logger.info(f"[LLM Init Debug] IS_ZEROGPU={self.IS_ZEROGPU}, IS_HUGGINGFACE_SPACE={self.IS_HUGGINGFACE_SPACE}")
375
+ logger.info(f"[LLM Init Debug] torch.cuda.is_available()={torch.cuda.is_available()}")
376
+ logger.info(f"[LLM Init Debug] device={device}, offload_to_cpu={offload_to_cpu}")
377
  # Set dtype based on device: bfloat16 for cuda, float32 for cpu
378
  if dtype is None:
379
  self.dtype = torch.bfloat16 if device in ["cuda", "xpu"] else torch.float32
 
639
  )
640
 
641
  with self._load_model_context():
642
+ # Move inputs to the same device as the model (important for ZeroGPU where model may be on CPU)
643
+ model_device = next(self.llm.parameters()).device
644
+ inputs = {k: v.to(model_device) for k, v in inputs.items()}
645
+ logger.info(f"[_run_pt_single Debug] Inputs moved to model device: {model_device}")
646
+ logger.info(f"[_run_pt_single Debug] Input actual device: {inputs['input_ids'].device}")
647
  # Calculate max_new_tokens based on target_duration if specified
648
  # 5 audio codes = 1 second, plus ~500 tokens for CoT metadata and safety margin
649
  if target_duration is not None and target_duration > 0:
 
683
  truncation=True,
684
  )
685
  self.llm_tokenizer.padding_side = original_padding_side
686
+ batch_inputs_tokenized = {k: v.to(model_device) for k, v in batch_inputs_tokenized.items()}
687
 
688
  # Extract batch inputs
689
  batch_input_ids = batch_inputs_tokenized['input_ids']
 
2053
  This allows us to call update_state() after each token generation.
2054
  """
2055
  model = self.llm
2056
+ # Get device from model (important for ZeroGPU where model may be on different device than self.device)
2057
+ device = next(model.parameters()).device
2058
 
2059
  # Initialize generated sequences
2060
  generated_ids = input_ids.clone()
 
2154
  Batch format: [cond_input, uncond_input]
2155
  """
2156
  model = self.llm
2157
+ # Get device from model (important for ZeroGPU where model may be on different device than self.device)
2158
+ device = next(model.parameters()).device
2159
  batch_size = batch_input_ids.shape[0] // 2 # Half are conditional, half are unconditional
2160
  cond_start_idx = 0
2161
  uncond_start_idx = batch_size
 
2376
  Context manager to load a model to GPU and offload it back to CPU after use.
2377
  Only used for PyTorch backend when offload_to_cpu is True.
2378
  """
2379
+ logger.info(f"[_load_model_context Debug] Entry: offload_to_cpu={self.offload_to_cpu}, backend={self.llm_backend}, self.device={self.device}")
2380
+ logger.info(f"[_load_model_context Debug] torch.cuda.is_available()={torch.cuda.is_available()}, IS_ZEROGPU={self.IS_ZEROGPU}")
2381
+
2382
+ model_device = None
2383
+ if self.llm is not None:
2384
+ model_device = next(self.llm.parameters()).device
2385
+ logger.info(f"[_load_model_context Debug] Model current device: {model_device}")
2386
+
2387
+ # In ZeroGPU, model may be on CPU even though self.device="cuda" (due to hijacked .to() during init)
2388
+ # Move to CUDA if available and model is on CPU
2389
+ needs_move_to_cuda = (
2390
+ self.llm is not None
2391
+ and torch.cuda.is_available()
2392
+ and model_device is not None
2393
+ and model_device.type == "cpu"
2394
+ )
2395
+
2396
+ if needs_move_to_cuda:
2397
+ logger.info(f"[_load_model_context Debug] Moving model from CPU to cuda")
2398
+ self.llm = self.llm.to("cuda").to(self.dtype)
2399
+ logger.info(f"[_load_model_context Debug] Model now on: {next(self.llm.parameters()).device}")
2400
+
2401
  if not self.offload_to_cpu:
2402
+ logger.info(f"[_load_model_context Debug] offload_to_cpu=False, yielding")
2403
  yield
2404
  return
2405
 
 
2473
  device = next(model_runner.model.parameters()).device
2474
  self._hf_model_for_scoring = self._hf_model_for_scoring.to(device)
2475
  self._hf_model_for_scoring.eval()
2476
+ # Disable gradients for all parameters (required for ZeroGPU pickling)
2477
+ self._hf_model_for_scoring.requires_grad_(False)
2478
+
2479
  logger.info(f"HuggingFace model for scoring ready on {device}")
2480
 
2481
  return self._hf_model_for_scoring
app.py CHANGED
@@ -1,8 +1,13 @@
1
  """
2
  ACE-Step v1.5 - HuggingFace Space Entry Point
3
-
4
  This file serves as the entry point for HuggingFace Space deployment.
5
  It initializes the service and launches the Gradio interface.
 
 
 
 
 
 
6
  """
7
  import os
8
  import sys
@@ -22,12 +27,26 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
22
  for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
23
  os.environ.pop(proxy_var, None)
24
 
 
 
 
 
 
 
 
 
25
  import torch
26
  from acestep.handler import AceStepHandler
27
  from acestep.llm_inference import LLMHandler
28
  from acestep.dataset_handler import DatasetHandler
29
  from acestep.gradio_ui import create_gradio_interface
30
 
 
 
 
 
 
 
31
 
32
  def get_gpu_memory_gb():
33
  """
@@ -105,14 +124,30 @@ def main():
105
  print("UI will be fully functional but generation is disabled")
106
  print("=" * 60)
107
 
 
 
 
 
 
 
 
 
 
108
  # Get persistent storage path (auto-detect)
109
  persistent_storage_path = get_persistent_storage_path()
110
 
111
  # Detect GPU memory for auto-configuration
 
112
  gpu_memory_gb = get_gpu_memory_gb()
113
- auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
114
 
115
- if not debug_ui:
 
 
 
 
 
 
 
116
  if auto_offload:
117
  print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)")
118
  print("Auto-enabling CPU offload to reduce GPU memory usage")
@@ -140,7 +175,11 @@ def main():
140
  "SERVICE_MODE_LM_MODEL",
141
  "acestep-5Hz-lm-1.7B"
142
  )
143
- backend = os.environ.get("SERVICE_MODE_BACKEND", "vllm")
 
 
 
 
144
  device = "auto"
145
 
146
  print(f"Service mode configuration:")
@@ -151,6 +190,7 @@ def main():
151
  print(f" Backend: {backend}")
152
  print(f" Offload to CPU: {auto_offload}")
153
  print(f" DEBUG_UI: {debug_ui}")
 
154
 
155
  # Determine flash attention availability
156
  use_flash_attention = dit_handler.is_flash_attention_available()
@@ -230,7 +270,7 @@ def main():
230
  else:
231
  print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
232
  init_status += f"\n{lm_status}"
233
-
234
  # Build available models list for UI
235
  available_dit_models = [config_path]
236
  if config_path_2 and dit_handler_2 is not None:
@@ -275,7 +315,7 @@ def main():
275
 
276
  # Enable queue for multi-user support
277
  print("Enabling queue for multi-user support...")
278
- demo.queue(max_size=20, default_concurrency_limit=1)
279
 
280
  # Launch
281
  print("Launching server on 0.0.0.0:7860...")
@@ -288,4 +328,4 @@ def main():
288
 
289
 
290
  if __name__ == "__main__":
291
- main()
 
1
  """
2
  ACE-Step v1.5 - HuggingFace Space Entry Point
 
3
  This file serves as the entry point for HuggingFace Space deployment.
4
  It initializes the service and launches the Gradio interface.
5
+ ZeroGPU Support:
6
+ - ZeroGPU uses the 'spaces' package to intercept CUDA operations
7
+ - Models are loaded to "cuda" during startup but actual GPU allocation is deferred
8
+ - Handlers are registered globally so forked processes inherit them without pickling
9
+ - @spaces.GPU decorators are on top-level Gradio event handlers, not internal functions
10
+ - nano-vllm uses direct CUDA APIs that bypass spaces interception, so we use PyTorch backend
11
  """
12
  import os
13
  import sys
 
27
  for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
28
  os.environ.pop(proxy_var, None)
29
 
30
+ # Import spaces for ZeroGPU support (must be imported before torch for proper interception)
31
+ # This is a no-op if not running on HuggingFace Spaces
32
+ try:
33
+ import spaces
34
+ HAS_SPACES = True
35
+ except ImportError:
36
+ HAS_SPACES = False
37
+
38
  import torch
39
  from acestep.handler import AceStepHandler
40
  from acestep.llm_inference import LLMHandler
41
  from acestep.dataset_handler import DatasetHandler
42
  from acestep.gradio_ui import create_gradio_interface
43
 
44
+ # Detect ZeroGPU environment
45
+ IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
46
+ # ZeroGPU detection: check env var OR assume ZeroGPU for all HF Spaces (safer default)
47
+ # The SPACE_HARDWARE env var is unreliable, so we assume ZeroGPU if on HF Space
48
+ IS_ZEROGPU = IS_HUGGINGFACE_SPACE or os.environ.get("ZEROGPU") is not None
49
+
50
 
51
  def get_gpu_memory_gb():
52
  """
 
124
  print("UI will be fully functional but generation is disabled")
125
  print("=" * 60)
126
 
127
+ # Log ZeroGPU detection
128
+ if IS_ZEROGPU:
129
+ print("=" * 60)
130
+ print("ZeroGPU environment detected")
131
+ print("- Using spaces package for GPU allocation")
132
+ print("- PyTorch backend forced for LLM (nano-vllm incompatible)")
133
+ print("- GPU will be allocated on-demand during generation")
134
+ print("=" * 60)
135
+
136
  # Get persistent storage path (auto-detect)
137
  persistent_storage_path = get_persistent_storage_path()
138
 
139
  # Detect GPU memory for auto-configuration
140
+ # Note: In ZeroGPU, GPU may not be available during startup, so this may return 0
141
  gpu_memory_gb = get_gpu_memory_gb()
 
142
 
143
+ # For ZeroGPU, we don't need CPU offload as GPU is allocated dynamically
144
+ if IS_ZEROGPU:
145
+ auto_offload = False
146
+ print("ZeroGPU: CPU offload disabled (GPU allocated on-demand)")
147
+ else:
148
+ auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
149
+
150
+ if not debug_ui and not IS_ZEROGPU:
151
  if auto_offload:
152
  print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)")
153
  print("Auto-enabling CPU offload to reduce GPU memory usage")
 
175
  "SERVICE_MODE_LM_MODEL",
176
  "acestep-5Hz-lm-1.7B"
177
  )
178
+ # For ZeroGPU, force PyTorch backend (nano-vllm uses direct CUDA APIs)
179
+ if IS_ZEROGPU:
180
+ backend = "pt"
181
+ else:
182
+ backend = os.environ.get("SERVICE_MODE_BACKEND", "vllm")
183
  device = "auto"
184
 
185
  print(f"Service mode configuration:")
 
190
  print(f" Backend: {backend}")
191
  print(f" Offload to CPU: {auto_offload}")
192
  print(f" DEBUG_UI: {debug_ui}")
193
+ print(f" ZeroGPU: {IS_ZEROGPU}")
194
 
195
  # Determine flash attention availability
196
  use_flash_attention = dit_handler.is_flash_attention_available()
 
270
  else:
271
  print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
272
  init_status += f"\n{lm_status}"
273
+
274
  # Build available models list for UI
275
  available_dit_models = [config_path]
276
  if config_path_2 and dit_handler_2 is not None:
 
315
 
316
  # Enable queue for multi-user support
317
  print("Enabling queue for multi-user support...")
318
+ demo.queue(max_size=20)
319
 
320
  # Launch
321
  print("Launching server on 0.0.0.0:7860...")
 
328
 
329
 
330
  if __name__ == "__main__":
331
+ main()
requirements.txt CHANGED
@@ -1,11 +1,7 @@
1
  # PyTorch with CUDA 12.8 (for Windows/Linux)
2
- --extra-index-url https://download.pytorch.org/whl/cu128
3
- torch==2.7.1; sys_platform == 'win32'
4
- torchaudio==2.7.1; sys_platform == 'win32'
5
- torchvision; sys_platform == 'win32'
6
- torch>=2.9.1; sys_platform != 'win32'
7
- torchaudio>=2.9.1; sys_platform != 'win32'
8
- torchvision; sys_platform != 'win32'
9
 
10
  # Core dependencies
11
  transformers>=4.51.0,<4.58.0
@@ -14,6 +10,7 @@ gradio==6.2.0
14
  matplotlib>=3.7.5
15
  scipy>=1.10.1
16
  soundfile>=0.13.1
 
17
  loguru>=0.7.3
18
  einops>=0.8.1
19
  accelerate>=1.12.0
@@ -33,6 +30,8 @@ triton-windows>=3.0.0,<3.4; sys_platform == 'win32'
33
  triton>=3.0.0; sys_platform != 'win32'
34
  flash-attn @ https://github.com/sdbds/flash-attention-for-windows/releases/download/2.8.2/flash_attn-2.8.2+cu128torch2.7.1cxx11abiFALSEfullbackward-cp311-cp311-win_amd64.whl ; sys_platform == 'win32' and python_version == '3.11' and platform_machine == 'AMD64'
35
  flash-attn @ https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.12/flash_attn-2.8.3+cu128torch2.10-cp311-cp311-linux_x86_64.whl ; sys_platform == 'linux' and python_version == '3.11'
 
 
36
  xxhash
37
 
38
  # HuggingFace Space required
 
1
  # PyTorch with CUDA 12.8 (for Windows/Linux)
2
+ torch==2.9.1
3
+ torchaudio==2.9.1
4
+ torchvision==0.24.1
 
 
 
 
5
 
6
  # Core dependencies
7
  transformers>=4.51.0,<4.58.0
 
10
  matplotlib>=3.7.5
11
  scipy>=1.10.1
12
  soundfile>=0.13.1
13
+ ffmpeg-python
14
  loguru>=0.7.3
15
  einops>=0.8.1
16
  accelerate>=1.12.0
 
30
  triton>=3.0.0; sys_platform != 'win32'
31
  flash-attn @ https://github.com/sdbds/flash-attention-for-windows/releases/download/2.8.2/flash_attn-2.8.2+cu128torch2.7.1cxx11abiFALSEfullbackward-cp311-cp311-win_amd64.whl ; sys_platform == 'win32' and python_version == '3.11' and platform_machine == 'AMD64'
32
  flash-attn @ https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.12/flash_attn-2.8.3+cu128torch2.10-cp311-cp311-linux_x86_64.whl ; sys_platform == 'linux' and python_version == '3.11'
33
+ # Kernels library for flash-attn3 (preferred over flash-attn when available)
34
+ kernels
35
  xxhash
36
 
37
  # HuggingFace Space required