Spaces:
Running
on
Zero
Running
on
Zero
switch to zero gpu
Browse files- README.md +4 -2
- acestep/gradio_ui/events/__init__.py +23 -2
- acestep/gradio_ui/events/generation_handlers.py +27 -7
- acestep/gradio_ui/events/results_handlers.py +23 -3
- acestep/handler.py +70 -28
- acestep/inference.py +13 -15
- acestep/llm_inference.py +101 -9
- app.py +47 -7
- requirements.txt +6 -7
README.md
CHANGED
|
@@ -3,10 +3,12 @@ title: ACE-Step v1.5
|
|
| 3 |
emoji: 🎵
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: purple
|
| 6 |
-
sdk:
|
| 7 |
-
|
|
|
|
| 8 |
pinned: false
|
| 9 |
license: mit
|
|
|
|
| 10 |
short_description: Music Generation Foundation Model v1.5
|
| 11 |
---
|
| 12 |
|
|
|
|
| 3 |
emoji: 🎵
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.2.0
|
| 8 |
+
python_version: 3.11
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
+
app_file: app.py
|
| 12 |
short_description: Music Generation Foundation Model v1.5
|
| 13 |
---
|
| 14 |
|
acestep/gradio_ui/events/__init__.py
CHANGED
|
@@ -2,8 +2,10 @@
|
|
| 2 |
Gradio UI Event Handlers Module
|
| 3 |
Main entry point for setting up all event handlers
|
| 4 |
"""
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
from typing import Optional
|
|
|
|
| 7 |
|
| 8 |
# Import handler modules
|
| 9 |
from . import generation_handlers as gen_h
|
|
@@ -11,6 +13,24 @@ from . import results_handlers as res_h
|
|
| 11 |
from . import training_handlers as train_h
|
| 12 |
from acestep.gradio_ui.i18n import t
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section, init_params=None):
|
| 16 |
"""Setup event handlers connecting UI components and business logic
|
|
@@ -618,12 +638,13 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 618 |
]
|
| 619 |
)
|
| 620 |
|
|
|
|
| 621 |
def generation_wrapper(selected_model, generation_mode, simple_query_input, simple_vocal_language, *args):
|
| 622 |
"""Wrapper that selects the appropriate DiT handler based on model selection"""
|
| 623 |
# Convert args to list for modification
|
| 624 |
args_list = list(args)
|
| 625 |
|
| 626 |
-
# args order (after simple mode params):
|
| 627 |
# captions (0), lyrics (1), bpm (2), key_scale (3), time_signature (4), vocal_language (5),
|
| 628 |
# inference_steps (6), guidance_scale (7), random_seed_checkbox (8), seed (9),
|
| 629 |
# reference_audio (10), audio_duration (11), batch_size_input (12), src_audio (13),
|
|
@@ -684,7 +705,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 684 |
# Mark as formatted caption (LM-generated sample)
|
| 685 |
args_list[36] = True # is_format_caption_state
|
| 686 |
|
| 687 |
-
# Determine which handler to use
|
| 688 |
active_handler = dit_handler # Default to primary handler
|
| 689 |
if dit_handler_2 is not None and selected_model == config_path_2:
|
| 690 |
active_handler = dit_handler_2
|
|
|
|
| 2 |
Gradio UI Event Handlers Module
|
| 3 |
Main entry point for setting up all event handlers
|
| 4 |
"""
|
| 5 |
+
import os
|
| 6 |
import gradio as gr
|
| 7 |
from typing import Optional
|
| 8 |
+
from loguru import logger
|
| 9 |
|
| 10 |
# Import handler modules
|
| 11 |
from . import generation_handlers as gen_h
|
|
|
|
| 13 |
from . import training_handlers as train_h
|
| 14 |
from acestep.gradio_ui.i18n import t
|
| 15 |
|
| 16 |
+
# HuggingFace Space environment detection for ZeroGPU support
|
| 17 |
+
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _get_spaces_gpu_decorator(duration=120):
|
| 21 |
+
"""
|
| 22 |
+
Get the @spaces.GPU decorator if running in HuggingFace Space environment.
|
| 23 |
+
Returns identity decorator if not in Space environment.
|
| 24 |
+
"""
|
| 25 |
+
if IS_HUGGINGFACE_SPACE:
|
| 26 |
+
try:
|
| 27 |
+
import spaces
|
| 28 |
+
return spaces.GPU(duration=duration)
|
| 29 |
+
except ImportError:
|
| 30 |
+
logger.warning("spaces package not found, GPU decorator disabled")
|
| 31 |
+
return lambda func: func
|
| 32 |
+
return lambda func: func
|
| 33 |
+
|
| 34 |
|
| 35 |
def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section, init_params=None):
|
| 36 |
"""Setup event handlers connecting UI components and business logic
|
|
|
|
| 638 |
]
|
| 639 |
)
|
| 640 |
|
| 641 |
+
@_get_spaces_gpu_decorator(duration=300)
|
| 642 |
def generation_wrapper(selected_model, generation_mode, simple_query_input, simple_vocal_language, *args):
|
| 643 |
"""Wrapper that selects the appropriate DiT handler based on model selection"""
|
| 644 |
# Convert args to list for modification
|
| 645 |
args_list = list(args)
|
| 646 |
|
| 647 |
+
# args order (after simple mode params):
|
| 648 |
# captions (0), lyrics (1), bpm (2), key_scale (3), time_signature (4), vocal_language (5),
|
| 649 |
# inference_steps (6), guidance_scale (7), random_seed_checkbox (8), seed (9),
|
| 650 |
# reference_audio (10), audio_duration (11), batch_size_input (12), src_audio (13),
|
|
|
|
| 705 |
# Mark as formatted caption (LM-generated sample)
|
| 706 |
args_list[36] = True # is_format_caption_state
|
| 707 |
|
| 708 |
+
# Determine which handler to use based on model selection
|
| 709 |
active_handler = dit_handler # Default to primary handler
|
| 710 |
if dit_handler_2 is not None and selected_model == config_path_2:
|
| 711 |
active_handler = dit_handler_2
|
acestep/gradio_ui/events/generation_handlers.py
CHANGED
|
@@ -8,6 +8,7 @@ import random
|
|
| 8 |
import glob
|
| 9 |
import gradio as gr
|
| 10 |
from typing import Optional, List, Tuple
|
|
|
|
| 11 |
from acestep.constants import (
|
| 12 |
TASK_TYPES_TURBO,
|
| 13 |
TASK_TYPES_BASE,
|
|
@@ -16,6 +17,25 @@ from acestep.gradio_ui.i18n import t
|
|
| 16 |
from acestep.inference import understand_music, create_sample, format_sample
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def parse_and_validate_timesteps(
|
| 20 |
timesteps_str: str,
|
| 21 |
inference_steps: int
|
|
@@ -746,15 +766,15 @@ def handle_generation_mode_change(mode: str):
|
|
| 746 |
think_checkbox_update, # think_checkbox - disabled for cover/repaint modes
|
| 747 |
)
|
| 748 |
|
| 749 |
-
|
| 750 |
def process_source_audio(dit_handler, llm_handler, src_audio, constrained_decoding_debug):
|
| 751 |
"""
|
| 752 |
Process source audio: convert to codes and then transcribe.
|
| 753 |
This combines convert_src_audio_to_codes_wrapper + transcribe_audio_codes.
|
| 754 |
|
| 755 |
Args:
|
| 756 |
-
dit_handler: DiT handler instance
|
| 757 |
-
llm_handler: LLM handler instance
|
| 758 |
src_audio: Path to source audio file
|
| 759 |
constrained_decoding_debug: Whether to enable debug logging
|
| 760 |
|
|
@@ -799,7 +819,7 @@ def process_source_audio(dit_handler, llm_handler, src_audio, constrained_decodi
|
|
| 799 |
True # Set is_format_caption to True
|
| 800 |
)
|
| 801 |
|
| 802 |
-
|
| 803 |
def handle_create_sample(
|
| 804 |
llm_handler,
|
| 805 |
query: str,
|
|
@@ -819,7 +839,7 @@ def handle_create_sample(
|
|
| 819 |
Note: cfg_scale and negative_prompt are not supported in create_sample mode.
|
| 820 |
|
| 821 |
Args:
|
| 822 |
-
llm_handler: LLM handler instance
|
| 823 |
query: User's natural language music description
|
| 824 |
instrumental: Whether to generate instrumental music
|
| 825 |
vocal_language: Preferred vocal language for constrained decoding
|
|
@@ -929,7 +949,7 @@ def handle_create_sample(
|
|
| 929 |
result.status_message, # status_output
|
| 930 |
)
|
| 931 |
|
| 932 |
-
|
| 933 |
def handle_format_sample(
|
| 934 |
llm_handler,
|
| 935 |
caption: str,
|
|
@@ -952,7 +972,7 @@ def handle_format_sample(
|
|
| 952 |
Note: cfg_scale and negative_prompt are not supported in format mode.
|
| 953 |
|
| 954 |
Args:
|
| 955 |
-
llm_handler: LLM handler instance
|
| 956 |
caption: User's caption/description
|
| 957 |
lyrics: User's lyrics
|
| 958 |
bpm: User-provided BPM (optional, for constrained decoding)
|
|
|
|
| 8 |
import glob
|
| 9 |
import gradio as gr
|
| 10 |
from typing import Optional, List, Tuple
|
| 11 |
+
from loguru import logger
|
| 12 |
from acestep.constants import (
|
| 13 |
TASK_TYPES_TURBO,
|
| 14 |
TASK_TYPES_BASE,
|
|
|
|
| 17 |
from acestep.inference import understand_music, create_sample, format_sample
|
| 18 |
|
| 19 |
|
| 20 |
+
# HuggingFace Space environment detection for ZeroGPU support
|
| 21 |
+
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_spaces_gpu_decorator(duration=120):
|
| 25 |
+
"""
|
| 26 |
+
Get the @spaces.GPU decorator if running in HuggingFace Space environment.
|
| 27 |
+
Returns identity decorator if not in Space environment.
|
| 28 |
+
"""
|
| 29 |
+
if IS_HUGGINGFACE_SPACE:
|
| 30 |
+
try:
|
| 31 |
+
import spaces
|
| 32 |
+
return spaces.GPU(duration=duration)
|
| 33 |
+
except ImportError:
|
| 34 |
+
logger.warning("spaces package not found, GPU decorator disabled")
|
| 35 |
+
return lambda func: func
|
| 36 |
+
return lambda func: func
|
| 37 |
+
|
| 38 |
+
|
| 39 |
def parse_and_validate_timesteps(
|
| 40 |
timesteps_str: str,
|
| 41 |
inference_steps: int
|
|
|
|
| 766 |
think_checkbox_update, # think_checkbox - disabled for cover/repaint modes
|
| 767 |
)
|
| 768 |
|
| 769 |
+
@_get_spaces_gpu_decorator(duration=180)
|
| 770 |
def process_source_audio(dit_handler, llm_handler, src_audio, constrained_decoding_debug):
|
| 771 |
"""
|
| 772 |
Process source audio: convert to codes and then transcribe.
|
| 773 |
This combines convert_src_audio_to_codes_wrapper + transcribe_audio_codes.
|
| 774 |
|
| 775 |
Args:
|
| 776 |
+
dit_handler: DiT handler instance
|
| 777 |
+
llm_handler: LLM handler instance
|
| 778 |
src_audio: Path to source audio file
|
| 779 |
constrained_decoding_debug: Whether to enable debug logging
|
| 780 |
|
|
|
|
| 819 |
True # Set is_format_caption to True
|
| 820 |
)
|
| 821 |
|
| 822 |
+
@_get_spaces_gpu_decorator(duration=180)
|
| 823 |
def handle_create_sample(
|
| 824 |
llm_handler,
|
| 825 |
query: str,
|
|
|
|
| 839 |
Note: cfg_scale and negative_prompt are not supported in create_sample mode.
|
| 840 |
|
| 841 |
Args:
|
| 842 |
+
llm_handler: LLM handler instance (unused, fetched from registry)
|
| 843 |
query: User's natural language music description
|
| 844 |
instrumental: Whether to generate instrumental music
|
| 845 |
vocal_language: Preferred vocal language for constrained decoding
|
|
|
|
| 949 |
result.status_message, # status_output
|
| 950 |
)
|
| 951 |
|
| 952 |
+
@_get_spaces_gpu_decorator(duration=180)
|
| 953 |
def handle_format_sample(
|
| 954 |
llm_handler,
|
| 955 |
caption: str,
|
|
|
|
| 972 |
Note: cfg_scale and negative_prompt are not supported in format mode.
|
| 973 |
|
| 974 |
Args:
|
| 975 |
+
llm_handler: LLM handler instance (unused, fetched from registry)
|
| 976 |
caption: User's caption/description
|
| 977 |
lyrics: User's lyrics
|
| 978 |
bpm: User-provided BPM (optional, for constrained decoding)
|
acestep/gradio_ui/events/results_handlers.py
CHANGED
|
@@ -18,6 +18,26 @@ from acestep.gradio_ui.i18n import t
|
|
| 18 |
from acestep.gradio_ui.events.generation_handlers import parse_and_validate_timesteps
|
| 19 |
from acestep.inference import generate_music, GenerationParams, GenerationConfig
|
| 20 |
from acestep.audio_utils import save_audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def parse_lrc_to_subtitles(lrc_text: str, total_duration: Optional[float] = None) -> List[Dict[str, Any]]:
|
|
@@ -1038,7 +1058,7 @@ def calculate_score_handler(
|
|
| 1038 |
error_msg = t("messages.score_error", error=str(e)) + f"\n{traceback.format_exc()}"
|
| 1039 |
return error_msg
|
| 1040 |
|
| 1041 |
-
|
| 1042 |
def calculate_score_handler_with_selection(
|
| 1043 |
dit_handler,
|
| 1044 |
llm_handler,
|
|
@@ -1152,7 +1172,7 @@ def calculate_score_handler_with_selection(
|
|
| 1152 |
batch_queue
|
| 1153 |
)
|
| 1154 |
|
| 1155 |
-
|
| 1156 |
def generate_lrc_handler(dit_handler, sample_idx, current_batch_index, batch_queue, vocal_language, inference_steps):
|
| 1157 |
"""
|
| 1158 |
Generate LRC timestamps for a specific audio sample.
|
|
@@ -1165,7 +1185,7 @@ def generate_lrc_handler(dit_handler, sample_idx, current_batch_index, batch_que
|
|
| 1165 |
This decouples audio value updates from subtitle updates, avoiding flickering.
|
| 1166 |
|
| 1167 |
Args:
|
| 1168 |
-
dit_handler: DiT handler instance
|
| 1169 |
sample_idx: Which sample to generate LRC for (1-8)
|
| 1170 |
current_batch_index: Current batch index in batch_queue
|
| 1171 |
batch_queue: Dictionary storing all batch generation data
|
|
|
|
| 18 |
from acestep.gradio_ui.events.generation_handlers import parse_and_validate_timesteps
|
| 19 |
from acestep.inference import generate_music, GenerationParams, GenerationConfig
|
| 20 |
from acestep.audio_utils import save_audio
|
| 21 |
+
from loguru import logger
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# HuggingFace Space environment detection for ZeroGPU support
|
| 25 |
+
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _get_spaces_gpu_decorator(duration=120):
|
| 29 |
+
"""
|
| 30 |
+
Get the @spaces.GPU decorator if running in HuggingFace Space environment.
|
| 31 |
+
Returns identity decorator if not in Space environment.
|
| 32 |
+
"""
|
| 33 |
+
if IS_HUGGINGFACE_SPACE:
|
| 34 |
+
try:
|
| 35 |
+
import spaces
|
| 36 |
+
return spaces.GPU(duration=duration)
|
| 37 |
+
except ImportError:
|
| 38 |
+
logger.warning("spaces package not found, GPU decorator disabled")
|
| 39 |
+
return lambda func: func
|
| 40 |
+
return lambda func: func
|
| 41 |
|
| 42 |
|
| 43 |
def parse_lrc_to_subtitles(lrc_text: str, total_duration: Optional[float] = None) -> List[Dict[str, Any]]:
|
|
|
|
| 1058 |
error_msg = t("messages.score_error", error=str(e)) + f"\n{traceback.format_exc()}"
|
| 1059 |
return error_msg
|
| 1060 |
|
| 1061 |
+
@_get_spaces_gpu_decorator(duration=240)
|
| 1062 |
def calculate_score_handler_with_selection(
|
| 1063 |
dit_handler,
|
| 1064 |
llm_handler,
|
|
|
|
| 1172 |
batch_queue
|
| 1173 |
)
|
| 1174 |
|
| 1175 |
+
@_get_spaces_gpu_decorator(duration=240)
|
| 1176 |
def generate_lrc_handler(dit_handler, sample_idx, current_batch_index, batch_queue, vocal_language, inference_steps):
|
| 1177 |
"""
|
| 1178 |
Generate LRC timestamps for a specific audio sample.
|
|
|
|
| 1185 |
This decouples audio value updates from subtitle updates, avoiding flickering.
|
| 1186 |
|
| 1187 |
Args:
|
| 1188 |
+
dit_handler: DiT handler instance (unused, fetched from registry)
|
| 1189 |
sample_idx: Which sample to generate LRC for (1-8)
|
| 1190 |
current_batch_index: Current batch index in batch_queue
|
| 1191 |
batch_queue: Dictionary storing all batch generation data
|
acestep/handler.py
CHANGED
|
@@ -199,14 +199,24 @@ class AceStepHandler:
|
|
| 199 |
|
| 200 |
return model_path
|
| 201 |
|
| 202 |
-
|
| 203 |
-
|
|
|
|
| 204 |
try:
|
| 205 |
-
import
|
| 206 |
return True
|
| 207 |
except ImportError:
|
| 208 |
return False
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
def is_turbo_model(self) -> bool:
|
| 211 |
"""Check if the currently loaded model is a turbo model"""
|
| 212 |
if self.config is None:
|
|
@@ -425,33 +435,38 @@ class AceStepHandler:
|
|
| 425 |
acestep_v15_checkpoint_path = self._ensure_model_downloaded(config_path, checkpoint_dir)
|
| 426 |
|
| 427 |
if os.path.exists(acestep_v15_checkpoint_path):
|
| 428 |
-
# Determine attention implementation
|
| 429 |
-
if use_flash_attention
|
| 430 |
-
attn_implementation =
|
| 431 |
self.dtype = torch.bfloat16
|
| 432 |
else:
|
| 433 |
attn_implementation = "sdpa"
|
| 434 |
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
)
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
logger.info("[initialize_service]
|
| 447 |
-
|
| 448 |
self.model = AutoModel.from_pretrained(
|
| 449 |
-
acestep_v15_checkpoint_path,
|
| 450 |
-
trust_remote_code=True,
|
| 451 |
-
attn_implementation=
|
|
|
|
| 452 |
)
|
| 453 |
-
|
| 454 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
|
| 456 |
self.model.config._attn_implementation = attn_implementation
|
| 457 |
self.config = self.model.config
|
|
@@ -466,6 +481,8 @@ class AceStepHandler:
|
|
| 466 |
else:
|
| 467 |
self.model = self.model.to("cpu").to(self.dtype)
|
| 468 |
self.model.eval()
|
|
|
|
|
|
|
| 469 |
|
| 470 |
if compile_model:
|
| 471 |
self.model = torch.compile(self.model)
|
|
@@ -498,7 +515,8 @@ class AceStepHandler:
|
|
| 498 |
self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
|
| 499 |
# Always keep silence_latent on GPU - it's used in many places outside model context
|
| 500 |
# and is small enough that it won't significantly impact VRAM
|
| 501 |
-
|
|
|
|
| 502 |
else:
|
| 503 |
raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
|
| 504 |
else:
|
|
@@ -519,6 +537,8 @@ class AceStepHandler:
|
|
| 519 |
else:
|
| 520 |
self.vae = self.vae.to("cpu").to(vae_dtype)
|
| 521 |
self.vae.eval()
|
|
|
|
|
|
|
| 522 |
else:
|
| 523 |
raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
|
| 524 |
|
|
@@ -534,12 +554,31 @@ class AceStepHandler:
|
|
| 534 |
else:
|
| 535 |
if os.path.exists(text_encoder_path):
|
| 536 |
self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)
|
| 537 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
if not self.offload_to_cpu:
|
| 539 |
self.text_encoder = self.text_encoder.to(device).to(self.dtype)
|
| 540 |
else:
|
| 541 |
self.text_encoder = self.text_encoder.to("cpu").to(self.dtype)
|
| 542 |
self.text_encoder.eval()
|
|
|
|
|
|
|
| 543 |
else:
|
| 544 |
raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}")
|
| 545 |
|
|
@@ -2722,9 +2761,12 @@ class AceStepHandler:
|
|
| 2722 |
pass
|
| 2723 |
|
| 2724 |
if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
|
|
|
|
|
|
|
|
|
|
| 2725 |
return {
|
| 2726 |
"audios": [],
|
| 2727 |
-
"status_message": "❌ Model not fully initialized.
|
| 2728 |
"extra_outputs": {},
|
| 2729 |
"success": False,
|
| 2730 |
"error": "Model not fully initialized",
|
|
|
|
| 199 |
|
| 200 |
return model_path
|
| 201 |
|
| 202 |
+
|
| 203 |
+
def is_flash_attn3_available(self) -> bool:
|
| 204 |
+
"""Check if flash-attn3 via kernels library is available"""
|
| 205 |
try:
|
| 206 |
+
import kernels
|
| 207 |
return True
|
| 208 |
except ImportError:
|
| 209 |
return False
|
| 210 |
+
|
| 211 |
+
def get_best_attn_implementation(self) -> str:
|
| 212 |
+
"""Get the best available attention implementation"""
|
| 213 |
+
if self.is_flash_attn3_available():
|
| 214 |
+
return "kernels-community/flash-attn3"
|
| 215 |
+
elif self.is_flash_attention_available():
|
| 216 |
+
return "flash_attention_2"
|
| 217 |
+
else:
|
| 218 |
+
return "sdpa"
|
| 219 |
+
|
| 220 |
def is_turbo_model(self) -> bool:
|
| 221 |
"""Check if the currently loaded model is a turbo model"""
|
| 222 |
if self.config is None:
|
|
|
|
| 435 |
acestep_v15_checkpoint_path = self._ensure_model_downloaded(config_path, checkpoint_dir)
|
| 436 |
|
| 437 |
if os.path.exists(acestep_v15_checkpoint_path):
|
| 438 |
+
# Determine attention implementation (prefer flash-attn3 > flash_attention_2 > sdpa)
|
| 439 |
+
if use_flash_attention:
|
| 440 |
+
attn_implementation = self.get_best_attn_implementation()
|
| 441 |
self.dtype = torch.bfloat16
|
| 442 |
else:
|
| 443 |
attn_implementation = "sdpa"
|
| 444 |
|
| 445 |
+
# Try loading with the best available attention implementation, with fallbacks
|
| 446 |
+
attn_fallback_order = [attn_implementation]
|
| 447 |
+
if attn_implementation == "kernels-community/flash-attn3":
|
| 448 |
+
attn_fallback_order.extend(["flash_attention_2", "sdpa", "eager"])
|
| 449 |
+
elif attn_implementation == "flash_attention_2":
|
| 450 |
+
attn_fallback_order.extend(["sdpa", "eager"])
|
| 451 |
+
elif attn_implementation == "sdpa":
|
| 452 |
+
attn_fallback_order.append("eager")
|
| 453 |
+
|
| 454 |
+
for attn_impl in attn_fallback_order:
|
| 455 |
+
try:
|
| 456 |
+
logger.info(f"[initialize_service] Attempting to load model with attention implementation: {attn_impl}")
|
| 457 |
+
|
| 458 |
self.model = AutoModel.from_pretrained(
|
| 459 |
+
acestep_v15_checkpoint_path,
|
| 460 |
+
trust_remote_code=True,
|
| 461 |
+
attn_implementation=attn_impl,
|
| 462 |
+
dtype="bfloat16"
|
| 463 |
)
|
| 464 |
+
attn_implementation = attn_impl
|
| 465 |
+
break
|
| 466 |
+
except Exception as e:
|
| 467 |
+
logger.warning(f"[initialize_service] Failed to load model with {attn_impl}: {e}")
|
| 468 |
+
if attn_impl == attn_fallback_order[-1]:
|
| 469 |
+
raise e
|
| 470 |
|
| 471 |
self.model.config._attn_implementation = attn_implementation
|
| 472 |
self.config = self.model.config
|
|
|
|
| 481 |
else:
|
| 482 |
self.model = self.model.to("cpu").to(self.dtype)
|
| 483 |
self.model.eval()
|
| 484 |
+
# Disable gradients for all parameters (required for ZeroGPU pickling)
|
| 485 |
+
self.model.requires_grad_(False)
|
| 486 |
|
| 487 |
if compile_model:
|
| 488 |
self.model = torch.compile(self.model)
|
|
|
|
| 515 |
self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
|
| 516 |
# Always keep silence_latent on GPU - it's used in many places outside model context
|
| 517 |
# and is small enough that it won't significantly impact VRAM
|
| 518 |
+
# Use detach() to ensure no gradients (required for ZeroGPU pickling)
|
| 519 |
+
self.silence_latent = self.silence_latent.to(device).to(self.dtype).detach()
|
| 520 |
else:
|
| 521 |
raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
|
| 522 |
else:
|
|
|
|
| 537 |
else:
|
| 538 |
self.vae = self.vae.to("cpu").to(vae_dtype)
|
| 539 |
self.vae.eval()
|
| 540 |
+
# Disable gradients for all parameters (required for ZeroGPU pickling)
|
| 541 |
+
self.vae.requires_grad_(False)
|
| 542 |
else:
|
| 543 |
raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
|
| 544 |
|
|
|
|
| 554 |
else:
|
| 555 |
if os.path.exists(text_encoder_path):
|
| 556 |
self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)
|
| 557 |
+
# Use best attention implementation for text encoder
|
| 558 |
+
text_encoder_attn = self.get_best_attn_implementation()
|
| 559 |
+
text_encoder_loaded = False
|
| 560 |
+
for attn_impl in [text_encoder_attn, "flash_attention_2", "sdpa", "eager"]:
|
| 561 |
+
try:
|
| 562 |
+
self.text_encoder = AutoModel.from_pretrained(
|
| 563 |
+
text_encoder_path,
|
| 564 |
+
attn_implementation=attn_impl,
|
| 565 |
+
torch_dtype=self.dtype,
|
| 566 |
+
)
|
| 567 |
+
logger.info(f"[initialize_service] Text encoder loaded with {attn_impl}")
|
| 568 |
+
text_encoder_loaded = True
|
| 569 |
+
break
|
| 570 |
+
except Exception as e:
|
| 571 |
+
logger.warning(f"[initialize_service] Failed to load text encoder with {attn_impl}: {e}")
|
| 572 |
+
continue
|
| 573 |
+
if not text_encoder_loaded:
|
| 574 |
+
raise RuntimeError("Failed to load text encoder with any attention implementation")
|
| 575 |
if not self.offload_to_cpu:
|
| 576 |
self.text_encoder = self.text_encoder.to(device).to(self.dtype)
|
| 577 |
else:
|
| 578 |
self.text_encoder = self.text_encoder.to("cpu").to(self.dtype)
|
| 579 |
self.text_encoder.eval()
|
| 580 |
+
# Disable gradients for all parameters (required for ZeroGPU pickling)
|
| 581 |
+
self.text_encoder.requires_grad_(False)
|
| 582 |
else:
|
| 583 |
raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}")
|
| 584 |
|
|
|
|
| 2761 |
pass
|
| 2762 |
|
| 2763 |
if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
|
| 2764 |
+
missing = [k for k, v in [("model", self.model), ("vae", self.vae),
|
| 2765 |
+
("text_tokenizer", self.text_tokenizer), ("text_encoder", self.text_encoder)] if v is None]
|
| 2766 |
+
logger.error(f"[generate_music] Model not fully initialized. Missing: {missing}")
|
| 2767 |
return {
|
| 2768 |
"audios": [],
|
| 2769 |
+
"status_message": f"❌ Model not fully initialized. Missing components: {missing}",
|
| 2770 |
"extra_outputs": {},
|
| 2771 |
"success": False,
|
| 2772 |
"error": "Model not fully initialized",
|
acestep/inference.py
CHANGED
|
@@ -18,20 +18,6 @@ from acestep.audio_utils import AudioSaver, generate_uuid_from_params
|
|
| 18 |
# HuggingFace Space environment detection
|
| 19 |
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
|
| 20 |
|
| 21 |
-
def _get_spaces_gpu_decorator(duration=180):
|
| 22 |
-
"""
|
| 23 |
-
Get the @spaces.GPU decorator if running in HuggingFace Space environment.
|
| 24 |
-
Returns identity decorator if not in Space environment.
|
| 25 |
-
"""
|
| 26 |
-
if IS_HUGGINGFACE_SPACE:
|
| 27 |
-
try:
|
| 28 |
-
import spaces
|
| 29 |
-
return spaces.GPU(duration=duration)
|
| 30 |
-
except ImportError:
|
| 31 |
-
logger.warning("spaces package not found, GPU decorator disabled")
|
| 32 |
-
return lambda func: func
|
| 33 |
-
return lambda func: func
|
| 34 |
-
|
| 35 |
|
| 36 |
@dataclass
|
| 37 |
class GenerationParams:
|
|
@@ -289,7 +275,6 @@ def _update_metadata_from_lm(
|
|
| 289 |
return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
|
| 290 |
|
| 291 |
|
| 292 |
-
@_get_spaces_gpu_decorator(duration=180)
|
| 293 |
def generate_music(
|
| 294 |
dit_handler,
|
| 295 |
llm_handler,
|
|
@@ -924,6 +909,19 @@ def create_sample(
|
|
| 924 |
... print(f"Lyrics: {result.lyrics}")
|
| 925 |
... print(f"BPM: {result.bpm}")
|
| 926 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 927 |
# Check if LLM is initialized
|
| 928 |
if not llm_handler.llm_initialized:
|
| 929 |
return CreateSampleResult(
|
|
|
|
| 18 |
# HuggingFace Space environment detection
|
| 19 |
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
@dataclass
|
| 23 |
class GenerationParams:
|
|
|
|
| 275 |
return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
|
| 276 |
|
| 277 |
|
|
|
|
| 278 |
def generate_music(
|
| 279 |
dit_handler,
|
| 280 |
llm_handler,
|
|
|
|
| 909 |
... print(f"Lyrics: {result.lyrics}")
|
| 910 |
... print(f"BPM: {result.bpm}")
|
| 911 |
"""
|
| 912 |
+
import torch
|
| 913 |
+
# Debug logging for ZeroGPU diagnosis
|
| 914 |
+
logger.info(f"[create_sample Debug] Entry: IS_HUGGINGFACE_SPACE={IS_HUGGINGFACE_SPACE}")
|
| 915 |
+
logger.info(f"[create_sample Debug] torch.cuda.is_available()={torch.cuda.is_available()}")
|
| 916 |
+
if torch.cuda.is_available():
|
| 917 |
+
logger.info(f"[create_sample Debug] torch.cuda.current_device()={torch.cuda.current_device()}")
|
| 918 |
+
logger.info(f"[create_sample Debug] llm_handler.device={llm_handler.device}, llm_handler.offload_to_cpu={llm_handler.offload_to_cpu}")
|
| 919 |
+
if llm_handler.llm is not None:
|
| 920 |
+
try:
|
| 921 |
+
logger.info(f"[create_sample Debug] Model device: {next(llm_handler.llm.parameters()).device}")
|
| 922 |
+
except Exception as e:
|
| 923 |
+
logger.info(f"[create_sample Debug] Could not get model device: {e}")
|
| 924 |
+
|
| 925 |
# Check if LLM is initialized
|
| 926 |
if not llm_handler.llm_initialized:
|
| 927 |
return CreateSampleResult(
|
acestep/llm_inference.py
CHANGED
|
@@ -30,6 +30,9 @@ class LLMHandler:
|
|
| 30 |
# HuggingFace Space environment detection
|
| 31 |
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
|
| 32 |
|
|
|
|
|
|
|
|
|
|
| 33 |
def __init__(self, persistent_storage_path: Optional[str] = None):
|
| 34 |
"""Initialize LLMHandler with default values"""
|
| 35 |
self.llm = None
|
|
@@ -190,20 +193,74 @@ class LLMHandler:
|
|
| 190 |
return self.build_formatted_prompt(
|
| 191 |
caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
|
| 192 |
)
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
def _load_pytorch_model(self, model_path: str, device: str) -> Tuple[bool, str]:
|
| 195 |
"""Load PyTorch model from path and return (success, status_message)"""
|
| 196 |
try:
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
if not self.offload_to_cpu:
|
| 199 |
self.llm = self.llm.to(device).to(self.dtype)
|
| 200 |
else:
|
| 201 |
self.llm = self.llm.to("cpu").to(self.dtype)
|
|
|
|
| 202 |
self.llm.eval()
|
|
|
|
|
|
|
| 203 |
self.llm_backend = "pt"
|
| 204 |
self.llm_initialized = True
|
| 205 |
logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}")
|
| 206 |
-
status_msg = f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nBackend: PyTorch\nDevice: {device}"
|
| 207 |
return True, status_msg
|
| 208 |
except Exception as e:
|
| 209 |
return False, f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
|
@@ -312,6 +369,11 @@ class LLMHandler:
|
|
| 312 |
|
| 313 |
self.device = device
|
| 314 |
self.offload_to_cpu = offload_to_cpu
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
# Set dtype based on device: bfloat16 for cuda, float32 for cpu
|
| 316 |
if dtype is None:
|
| 317 |
self.dtype = torch.bfloat16 if device in ["cuda", "xpu"] else torch.float32
|
|
@@ -577,8 +639,11 @@ class LLMHandler:
|
|
| 577 |
)
|
| 578 |
|
| 579 |
with self._load_model_context():
|
| 580 |
-
inputs
|
| 581 |
-
|
|
|
|
|
|
|
|
|
|
| 582 |
# Calculate max_new_tokens based on target_duration if specified
|
| 583 |
# 5 audio codes = 1 second, plus ~500 tokens for CoT metadata and safety margin
|
| 584 |
if target_duration is not None and target_duration > 0:
|
|
@@ -618,7 +683,7 @@ class LLMHandler:
|
|
| 618 |
truncation=True,
|
| 619 |
)
|
| 620 |
self.llm_tokenizer.padding_side = original_padding_side
|
| 621 |
-
batch_inputs_tokenized = {k: v.to(
|
| 622 |
|
| 623 |
# Extract batch inputs
|
| 624 |
batch_input_ids = batch_inputs_tokenized['input_ids']
|
|
@@ -1988,7 +2053,8 @@ class LLMHandler:
|
|
| 1988 |
This allows us to call update_state() after each token generation.
|
| 1989 |
"""
|
| 1990 |
model = self.llm
|
| 1991 |
-
device
|
|
|
|
| 1992 |
|
| 1993 |
# Initialize generated sequences
|
| 1994 |
generated_ids = input_ids.clone()
|
|
@@ -2088,7 +2154,8 @@ class LLMHandler:
|
|
| 2088 |
Batch format: [cond_input, uncond_input]
|
| 2089 |
"""
|
| 2090 |
model = self.llm
|
| 2091 |
-
device
|
|
|
|
| 2092 |
batch_size = batch_input_ids.shape[0] // 2 # Half are conditional, half are unconditional
|
| 2093 |
cond_start_idx = 0
|
| 2094 |
uncond_start_idx = batch_size
|
|
@@ -2309,7 +2376,30 @@ class LLMHandler:
|
|
| 2309 |
Context manager to load a model to GPU and offload it back to CPU after use.
|
| 2310 |
Only used for PyTorch backend when offload_to_cpu is True.
|
| 2311 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2312 |
if not self.offload_to_cpu:
|
|
|
|
| 2313 |
yield
|
| 2314 |
return
|
| 2315 |
|
|
@@ -2383,7 +2473,9 @@ class LLMHandler:
|
|
| 2383 |
device = next(model_runner.model.parameters()).device
|
| 2384 |
self._hf_model_for_scoring = self._hf_model_for_scoring.to(device)
|
| 2385 |
self._hf_model_for_scoring.eval()
|
| 2386 |
-
|
|
|
|
|
|
|
| 2387 |
logger.info(f"HuggingFace model for scoring ready on {device}")
|
| 2388 |
|
| 2389 |
return self._hf_model_for_scoring
|
|
|
|
| 30 |
# HuggingFace Space environment detection
|
| 31 |
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
|
| 32 |
|
| 33 |
+
# Force IS_ZEROGPU=True when on HuggingFace Space, as the env var detection is unreliable
|
| 34 |
+
IS_ZEROGPU = IS_HUGGINGFACE_SPACE or os.environ.get("ZEROGPU") is not None
|
| 35 |
+
|
| 36 |
def __init__(self, persistent_storage_path: Optional[str] = None):
|
| 37 |
"""Initialize LLMHandler with default values"""
|
| 38 |
self.llm = None
|
|
|
|
| 193 |
return self.build_formatted_prompt(
|
| 194 |
caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
|
| 195 |
)
|
| 196 |
+
|
| 197 |
+
def is_flash_attn3_available(self) -> bool:
|
| 198 |
+
"""Check if flash-attn3 via kernels library is available"""
|
| 199 |
+
try:
|
| 200 |
+
import kernels
|
| 201 |
+
return True
|
| 202 |
+
except ImportError:
|
| 203 |
+
return False
|
| 204 |
+
|
| 205 |
+
def is_flash_attention_available(self) -> bool:
|
| 206 |
+
"""Check if flash attention is available on the system"""
|
| 207 |
+
try:
|
| 208 |
+
import flash_attn
|
| 209 |
+
return True
|
| 210 |
+
except ImportError:
|
| 211 |
+
return False
|
| 212 |
+
|
| 213 |
+
def get_best_attn_implementation(self) -> str:
|
| 214 |
+
"""Get the best available attention implementation"""
|
| 215 |
+
if self.is_flash_attn3_available():
|
| 216 |
+
return "kernels-community/flash-attn3"
|
| 217 |
+
elif self.is_flash_attention_available():
|
| 218 |
+
return "flash_attention_2"
|
| 219 |
+
else:
|
| 220 |
+
return "sdpa"
|
| 221 |
+
|
| 222 |
def _load_pytorch_model(self, model_path: str, device: str) -> Tuple[bool, str]:
|
| 223 |
"""Load PyTorch model from path and return (success, status_message)"""
|
| 224 |
try:
|
| 225 |
+
# Try loading with the best available attention implementation
|
| 226 |
+
attn_implementation = self.get_best_attn_implementation()
|
| 227 |
+
attn_fallback_order = [attn_implementation]
|
| 228 |
+
if attn_implementation == "kernels-community/flash-attn3":
|
| 229 |
+
attn_fallback_order.extend(["flash_attention_2", "sdpa", "eager"])
|
| 230 |
+
elif attn_implementation == "flash_attention_2":
|
| 231 |
+
attn_fallback_order.extend(["sdpa", "eager"])
|
| 232 |
+
elif attn_implementation == "sdpa":
|
| 233 |
+
attn_fallback_order.append("eager")
|
| 234 |
+
|
| 235 |
+
for attn_impl in attn_fallback_order:
|
| 236 |
+
try:
|
| 237 |
+
logger.info(f"[LLM Load] Attempting to load model with attention implementation: {attn_impl}")
|
| 238 |
+
self.llm = AutoModelForCausalLM.from_pretrained(
|
| 239 |
+
model_path,
|
| 240 |
+
trust_remote_code=True,
|
| 241 |
+
attn_implementation=attn_impl,
|
| 242 |
+
torch_dtype=self.dtype,
|
| 243 |
+
)
|
| 244 |
+
attn_implementation = attn_impl
|
| 245 |
+
break
|
| 246 |
+
except Exception as e:
|
| 247 |
+
logger.warning(f"[LLM Load] Failed to load model with {attn_impl}: {e}")
|
| 248 |
+
if attn_impl == attn_fallback_order[-1]:
|
| 249 |
+
raise e
|
| 250 |
+
|
| 251 |
+
logger.info(f"[LLM Load Debug] Model loaded with {attn_implementation}, initial device: {next(self.llm.parameters()).device}")
|
| 252 |
if not self.offload_to_cpu:
|
| 253 |
self.llm = self.llm.to(device).to(self.dtype)
|
| 254 |
else:
|
| 255 |
self.llm = self.llm.to("cpu").to(self.dtype)
|
| 256 |
+
logger.info(f"[LLM Load Debug] After .to(), model device: {next(self.llm.parameters()).device}")
|
| 257 |
self.llm.eval()
|
| 258 |
+
# Disable gradients for all parameters (required for ZeroGPU pickling)
|
| 259 |
+
self.llm.requires_grad_(False)
|
| 260 |
self.llm_backend = "pt"
|
| 261 |
self.llm_initialized = True
|
| 262 |
logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}")
|
| 263 |
+
status_msg = f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nBackend: PyTorch ({attn_implementation})\nDevice: {device}"
|
| 264 |
return True, status_msg
|
| 265 |
except Exception as e:
|
| 266 |
return False, f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
|
|
|
| 369 |
|
| 370 |
self.device = device
|
| 371 |
self.offload_to_cpu = offload_to_cpu
|
| 372 |
+
|
| 373 |
+
# Debug logging for ZeroGPU diagnosis
|
| 374 |
+
logger.info(f"[LLM Init Debug] IS_ZEROGPU={self.IS_ZEROGPU}, IS_HUGGINGFACE_SPACE={self.IS_HUGGINGFACE_SPACE}")
|
| 375 |
+
logger.info(f"[LLM Init Debug] torch.cuda.is_available()={torch.cuda.is_available()}")
|
| 376 |
+
logger.info(f"[LLM Init Debug] device={device}, offload_to_cpu={offload_to_cpu}")
|
| 377 |
# Set dtype based on device: bfloat16 for cuda, float32 for cpu
|
| 378 |
if dtype is None:
|
| 379 |
self.dtype = torch.bfloat16 if device in ["cuda", "xpu"] else torch.float32
|
|
|
|
| 639 |
)
|
| 640 |
|
| 641 |
with self._load_model_context():
|
| 642 |
+
# Move inputs to the same device as the model (important for ZeroGPU where model may be on CPU)
|
| 643 |
+
model_device = next(self.llm.parameters()).device
|
| 644 |
+
inputs = {k: v.to(model_device) for k, v in inputs.items()}
|
| 645 |
+
logger.info(f"[_run_pt_single Debug] Inputs moved to model device: {model_device}")
|
| 646 |
+
logger.info(f"[_run_pt_single Debug] Input actual device: {inputs['input_ids'].device}")
|
| 647 |
# Calculate max_new_tokens based on target_duration if specified
|
| 648 |
# 5 audio codes = 1 second, plus ~500 tokens for CoT metadata and safety margin
|
| 649 |
if target_duration is not None and target_duration > 0:
|
|
|
|
| 683 |
truncation=True,
|
| 684 |
)
|
| 685 |
self.llm_tokenizer.padding_side = original_padding_side
|
| 686 |
+
batch_inputs_tokenized = {k: v.to(model_device) for k, v in batch_inputs_tokenized.items()}
|
| 687 |
|
| 688 |
# Extract batch inputs
|
| 689 |
batch_input_ids = batch_inputs_tokenized['input_ids']
|
|
|
|
| 2053 |
This allows us to call update_state() after each token generation.
|
| 2054 |
"""
|
| 2055 |
model = self.llm
|
| 2056 |
+
# Get device from model (important for ZeroGPU where model may be on different device than self.device)
|
| 2057 |
+
device = next(model.parameters()).device
|
| 2058 |
|
| 2059 |
# Initialize generated sequences
|
| 2060 |
generated_ids = input_ids.clone()
|
|
|
|
| 2154 |
Batch format: [cond_input, uncond_input]
|
| 2155 |
"""
|
| 2156 |
model = self.llm
|
| 2157 |
+
# Get device from model (important for ZeroGPU where model may be on different device than self.device)
|
| 2158 |
+
device = next(model.parameters()).device
|
| 2159 |
batch_size = batch_input_ids.shape[0] // 2 # Half are conditional, half are unconditional
|
| 2160 |
cond_start_idx = 0
|
| 2161 |
uncond_start_idx = batch_size
|
|
|
|
| 2376 |
Context manager to load a model to GPU and offload it back to CPU after use.
|
| 2377 |
Only used for PyTorch backend when offload_to_cpu is True.
|
| 2378 |
"""
|
| 2379 |
+
logger.info(f"[_load_model_context Debug] Entry: offload_to_cpu={self.offload_to_cpu}, backend={self.llm_backend}, self.device={self.device}")
|
| 2380 |
+
logger.info(f"[_load_model_context Debug] torch.cuda.is_available()={torch.cuda.is_available()}, IS_ZEROGPU={self.IS_ZEROGPU}")
|
| 2381 |
+
|
| 2382 |
+
model_device = None
|
| 2383 |
+
if self.llm is not None:
|
| 2384 |
+
model_device = next(self.llm.parameters()).device
|
| 2385 |
+
logger.info(f"[_load_model_context Debug] Model current device: {model_device}")
|
| 2386 |
+
|
| 2387 |
+
# In ZeroGPU, model may be on CPU even though self.device="cuda" (due to hijacked .to() during init)
|
| 2388 |
+
# Move to CUDA if available and model is on CPU
|
| 2389 |
+
needs_move_to_cuda = (
|
| 2390 |
+
self.llm is not None
|
| 2391 |
+
and torch.cuda.is_available()
|
| 2392 |
+
and model_device is not None
|
| 2393 |
+
and model_device.type == "cpu"
|
| 2394 |
+
)
|
| 2395 |
+
|
| 2396 |
+
if needs_move_to_cuda:
|
| 2397 |
+
logger.info(f"[_load_model_context Debug] Moving model from CPU to cuda")
|
| 2398 |
+
self.llm = self.llm.to("cuda").to(self.dtype)
|
| 2399 |
+
logger.info(f"[_load_model_context Debug] Model now on: {next(self.llm.parameters()).device}")
|
| 2400 |
+
|
| 2401 |
if not self.offload_to_cpu:
|
| 2402 |
+
logger.info(f"[_load_model_context Debug] offload_to_cpu=False, yielding")
|
| 2403 |
yield
|
| 2404 |
return
|
| 2405 |
|
|
|
|
| 2473 |
device = next(model_runner.model.parameters()).device
|
| 2474 |
self._hf_model_for_scoring = self._hf_model_for_scoring.to(device)
|
| 2475 |
self._hf_model_for_scoring.eval()
|
| 2476 |
+
# Disable gradients for all parameters (required for ZeroGPU pickling)
|
| 2477 |
+
self._hf_model_for_scoring.requires_grad_(False)
|
| 2478 |
+
|
| 2479 |
logger.info(f"HuggingFace model for scoring ready on {device}")
|
| 2480 |
|
| 2481 |
return self._hf_model_for_scoring
|
app.py
CHANGED
|
@@ -1,8 +1,13 @@
|
|
| 1 |
"""
|
| 2 |
ACE-Step v1.5 - HuggingFace Space Entry Point
|
| 3 |
-
|
| 4 |
This file serves as the entry point for HuggingFace Space deployment.
|
| 5 |
It initializes the service and launches the Gradio interface.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
import os
|
| 8 |
import sys
|
|
@@ -22,12 +27,26 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
|
| 22 |
for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
|
| 23 |
os.environ.pop(proxy_var, None)
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
import torch
|
| 26 |
from acestep.handler import AceStepHandler
|
| 27 |
from acestep.llm_inference import LLMHandler
|
| 28 |
from acestep.dataset_handler import DatasetHandler
|
| 29 |
from acestep.gradio_ui import create_gradio_interface
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
def get_gpu_memory_gb():
|
| 33 |
"""
|
|
@@ -105,14 +124,30 @@ def main():
|
|
| 105 |
print("UI will be fully functional but generation is disabled")
|
| 106 |
print("=" * 60)
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
# Get persistent storage path (auto-detect)
|
| 109 |
persistent_storage_path = get_persistent_storage_path()
|
| 110 |
|
| 111 |
# Detect GPU memory for auto-configuration
|
|
|
|
| 112 |
gpu_memory_gb = get_gpu_memory_gb()
|
| 113 |
-
auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
|
| 114 |
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
if auto_offload:
|
| 117 |
print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)")
|
| 118 |
print("Auto-enabling CPU offload to reduce GPU memory usage")
|
|
@@ -140,7 +175,11 @@ def main():
|
|
| 140 |
"SERVICE_MODE_LM_MODEL",
|
| 141 |
"acestep-5Hz-lm-1.7B"
|
| 142 |
)
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
device = "auto"
|
| 145 |
|
| 146 |
print(f"Service mode configuration:")
|
|
@@ -151,6 +190,7 @@ def main():
|
|
| 151 |
print(f" Backend: {backend}")
|
| 152 |
print(f" Offload to CPU: {auto_offload}")
|
| 153 |
print(f" DEBUG_UI: {debug_ui}")
|
|
|
|
| 154 |
|
| 155 |
# Determine flash attention availability
|
| 156 |
use_flash_attention = dit_handler.is_flash_attention_available()
|
|
@@ -230,7 +270,7 @@ def main():
|
|
| 230 |
else:
|
| 231 |
print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
|
| 232 |
init_status += f"\n{lm_status}"
|
| 233 |
-
|
| 234 |
# Build available models list for UI
|
| 235 |
available_dit_models = [config_path]
|
| 236 |
if config_path_2 and dit_handler_2 is not None:
|
|
@@ -275,7 +315,7 @@ def main():
|
|
| 275 |
|
| 276 |
# Enable queue for multi-user support
|
| 277 |
print("Enabling queue for multi-user support...")
|
| 278 |
-
demo.queue(max_size=20
|
| 279 |
|
| 280 |
# Launch
|
| 281 |
print("Launching server on 0.0.0.0:7860...")
|
|
@@ -288,4 +328,4 @@ def main():
|
|
| 288 |
|
| 289 |
|
| 290 |
if __name__ == "__main__":
|
| 291 |
-
main()
|
|
|
|
| 1 |
"""
|
| 2 |
ACE-Step v1.5 - HuggingFace Space Entry Point
|
|
|
|
| 3 |
This file serves as the entry point for HuggingFace Space deployment.
|
| 4 |
It initializes the service and launches the Gradio interface.
|
| 5 |
+
ZeroGPU Support:
|
| 6 |
+
- ZeroGPU uses the 'spaces' package to intercept CUDA operations
|
| 7 |
+
- Models are loaded to "cuda" during startup but actual GPU allocation is deferred
|
| 8 |
+
- Handlers are registered globally so forked processes inherit them without pickling
|
| 9 |
+
- @spaces.GPU decorators are on top-level Gradio event handlers, not internal functions
|
| 10 |
+
- nano-vllm uses direct CUDA APIs that bypass spaces interception, so we use PyTorch backend
|
| 11 |
"""
|
| 12 |
import os
|
| 13 |
import sys
|
|
|
|
| 27 |
for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
|
| 28 |
os.environ.pop(proxy_var, None)
|
| 29 |
|
| 30 |
+
# Import spaces for ZeroGPU support (must be imported before torch for proper interception)
|
| 31 |
+
# This is a no-op if not running on HuggingFace Spaces
|
| 32 |
+
try:
|
| 33 |
+
import spaces
|
| 34 |
+
HAS_SPACES = True
|
| 35 |
+
except ImportError:
|
| 36 |
+
HAS_SPACES = False
|
| 37 |
+
|
| 38 |
import torch
|
| 39 |
from acestep.handler import AceStepHandler
|
| 40 |
from acestep.llm_inference import LLMHandler
|
| 41 |
from acestep.dataset_handler import DatasetHandler
|
| 42 |
from acestep.gradio_ui import create_gradio_interface
|
| 43 |
|
| 44 |
+
# Detect ZeroGPU environment
|
| 45 |
+
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
|
| 46 |
+
# ZeroGPU detection: check env var OR assume ZeroGPU for all HF Spaces (safer default)
|
| 47 |
+
# The SPACE_HARDWARE env var is unreliable, so we assume ZeroGPU if on HF Space
|
| 48 |
+
IS_ZEROGPU = IS_HUGGINGFACE_SPACE or os.environ.get("ZEROGPU") is not None
|
| 49 |
+
|
| 50 |
|
| 51 |
def get_gpu_memory_gb():
|
| 52 |
"""
|
|
|
|
| 124 |
print("UI will be fully functional but generation is disabled")
|
| 125 |
print("=" * 60)
|
| 126 |
|
| 127 |
+
# Log ZeroGPU detection
|
| 128 |
+
if IS_ZEROGPU:
|
| 129 |
+
print("=" * 60)
|
| 130 |
+
print("ZeroGPU environment detected")
|
| 131 |
+
print("- Using spaces package for GPU allocation")
|
| 132 |
+
print("- PyTorch backend forced for LLM (nano-vllm incompatible)")
|
| 133 |
+
print("- GPU will be allocated on-demand during generation")
|
| 134 |
+
print("=" * 60)
|
| 135 |
+
|
| 136 |
# Get persistent storage path (auto-detect)
|
| 137 |
persistent_storage_path = get_persistent_storage_path()
|
| 138 |
|
| 139 |
# Detect GPU memory for auto-configuration
|
| 140 |
+
# Note: In ZeroGPU, GPU may not be available during startup, so this may return 0
|
| 141 |
gpu_memory_gb = get_gpu_memory_gb()
|
|
|
|
| 142 |
|
| 143 |
+
# For ZeroGPU, we don't need CPU offload as GPU is allocated dynamically
|
| 144 |
+
if IS_ZEROGPU:
|
| 145 |
+
auto_offload = False
|
| 146 |
+
print("ZeroGPU: CPU offload disabled (GPU allocated on-demand)")
|
| 147 |
+
else:
|
| 148 |
+
auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
|
| 149 |
+
|
| 150 |
+
if not debug_ui and not IS_ZEROGPU:
|
| 151 |
if auto_offload:
|
| 152 |
print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)")
|
| 153 |
print("Auto-enabling CPU offload to reduce GPU memory usage")
|
|
|
|
| 175 |
"SERVICE_MODE_LM_MODEL",
|
| 176 |
"acestep-5Hz-lm-1.7B"
|
| 177 |
)
|
| 178 |
+
# For ZeroGPU, force PyTorch backend (nano-vllm uses direct CUDA APIs)
|
| 179 |
+
if IS_ZEROGPU:
|
| 180 |
+
backend = "pt"
|
| 181 |
+
else:
|
| 182 |
+
backend = os.environ.get("SERVICE_MODE_BACKEND", "vllm")
|
| 183 |
device = "auto"
|
| 184 |
|
| 185 |
print(f"Service mode configuration:")
|
|
|
|
| 190 |
print(f" Backend: {backend}")
|
| 191 |
print(f" Offload to CPU: {auto_offload}")
|
| 192 |
print(f" DEBUG_UI: {debug_ui}")
|
| 193 |
+
print(f" ZeroGPU: {IS_ZEROGPU}")
|
| 194 |
|
| 195 |
# Determine flash attention availability
|
| 196 |
use_flash_attention = dit_handler.is_flash_attention_available()
|
|
|
|
| 270 |
else:
|
| 271 |
print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
|
| 272 |
init_status += f"\n{lm_status}"
|
| 273 |
+
|
| 274 |
# Build available models list for UI
|
| 275 |
available_dit_models = [config_path]
|
| 276 |
if config_path_2 and dit_handler_2 is not None:
|
|
|
|
| 315 |
|
| 316 |
# Enable queue for multi-user support
|
| 317 |
print("Enabling queue for multi-user support...")
|
| 318 |
+
demo.queue(max_size=20)
|
| 319 |
|
| 320 |
# Launch
|
| 321 |
print("Launching server on 0.0.0.0:7860...")
|
|
|
|
| 328 |
|
| 329 |
|
| 330 |
if __name__ == "__main__":
|
| 331 |
+
main()
|
requirements.txt
CHANGED
|
@@ -1,11 +1,7 @@
|
|
| 1 |
# PyTorch with CUDA 12.8 (for Windows/Linux)
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
torchvision; sys_platform == 'win32'
|
| 6 |
-
torch>=2.9.1; sys_platform != 'win32'
|
| 7 |
-
torchaudio>=2.9.1; sys_platform != 'win32'
|
| 8 |
-
torchvision; sys_platform != 'win32'
|
| 9 |
|
| 10 |
# Core dependencies
|
| 11 |
transformers>=4.51.0,<4.58.0
|
|
@@ -14,6 +10,7 @@ gradio==6.2.0
|
|
| 14 |
matplotlib>=3.7.5
|
| 15 |
scipy>=1.10.1
|
| 16 |
soundfile>=0.13.1
|
|
|
|
| 17 |
loguru>=0.7.3
|
| 18 |
einops>=0.8.1
|
| 19 |
accelerate>=1.12.0
|
|
@@ -33,6 +30,8 @@ triton-windows>=3.0.0,<3.4; sys_platform == 'win32'
|
|
| 33 |
triton>=3.0.0; sys_platform != 'win32'
|
| 34 |
flash-attn @ https://github.com/sdbds/flash-attention-for-windows/releases/download/2.8.2/flash_attn-2.8.2+cu128torch2.7.1cxx11abiFALSEfullbackward-cp311-cp311-win_amd64.whl ; sys_platform == 'win32' and python_version == '3.11' and platform_machine == 'AMD64'
|
| 35 |
flash-attn @ https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.12/flash_attn-2.8.3+cu128torch2.10-cp311-cp311-linux_x86_64.whl ; sys_platform == 'linux' and python_version == '3.11'
|
|
|
|
|
|
|
| 36 |
xxhash
|
| 37 |
|
| 38 |
# HuggingFace Space required
|
|
|
|
| 1 |
# PyTorch with CUDA 12.8 (for Windows/Linux)
|
| 2 |
+
torch==2.9.1
|
| 3 |
+
torchaudio==2.9.1
|
| 4 |
+
torchvision==0.24.1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
# Core dependencies
|
| 7 |
transformers>=4.51.0,<4.58.0
|
|
|
|
| 10 |
matplotlib>=3.7.5
|
| 11 |
scipy>=1.10.1
|
| 12 |
soundfile>=0.13.1
|
| 13 |
+
ffmpeg-python
|
| 14 |
loguru>=0.7.3
|
| 15 |
einops>=0.8.1
|
| 16 |
accelerate>=1.12.0
|
|
|
|
| 30 |
triton>=3.0.0; sys_platform != 'win32'
|
| 31 |
flash-attn @ https://github.com/sdbds/flash-attention-for-windows/releases/download/2.8.2/flash_attn-2.8.2+cu128torch2.7.1cxx11abiFALSEfullbackward-cp311-cp311-win_amd64.whl ; sys_platform == 'win32' and python_version == '3.11' and platform_machine == 'AMD64'
|
| 32 |
flash-attn @ https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.12/flash_attn-2.8.3+cu128torch2.10-cp311-cp311-linux_x86_64.whl ; sys_platform == 'linux' and python_version == '3.11'
|
| 33 |
+
# Kernels library for flash-attn3 (preferred over flash-attn when available)
|
| 34 |
+
kernels
|
| 35 |
xxhash
|
| 36 |
|
| 37 |
# HuggingFace Space required
|