Sayoyo commited on
Commit
b04b635
·
1 Parent(s): 7334110

fix: auto download model

Browse files
Files changed (2) hide show
  1. acestep/handler.py +5 -0
  2. acestep/llm_inference.py +7 -2
acestep/handler.py CHANGED
@@ -365,6 +365,11 @@ class AceStepHandler:
365
 
366
  # 1. Load main model
367
  # config_path is relative path (e.g., "acestep-v15-turbo"), concatenate to checkpoints directory
 
 
 
 
 
368
  acestep_v15_checkpoint_path = os.path.join(checkpoint_dir, config_path)
369
 
370
  # Auto-download model if not exists (HuggingFace Space support)
 
365
 
366
  # 1. Load main model
367
  # config_path is relative path (e.g., "acestep-v15-turbo"), concatenate to checkpoints directory
368
+ # If config_path is None (HuggingFace Space with empty checkpoint), use default and auto-download
369
+ if config_path is None:
370
+ config_path = "acestep-v15-turbo"
371
+ logger.info(f"[initialize_service] config_path is None, using default: {config_path}")
372
+
373
  acestep_v15_checkpoint_path = os.path.join(checkpoint_dir, config_path)
374
 
375
  # Auto-download model if not exists (HuggingFace Space support)
acestep/llm_inference.py CHANGED
@@ -309,7 +309,7 @@ class LLMHandler:
309
  try:
310
  if device == "auto":
311
  device = "cuda" if torch.cuda.is_available() else "cpu"
312
-
313
  self.device = device
314
  self.offload_to_cpu = offload_to_cpu
315
  # Set dtype based on device: bfloat16 for cuda, float32 for cpu
@@ -317,7 +317,12 @@ class LLMHandler:
317
  self.dtype = torch.bfloat16 if device in ["cuda", "xpu"] else torch.float32
318
  else:
319
  self.dtype = dtype
320
-
 
 
 
 
 
321
  full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
322
  if not os.path.exists(full_lm_model_path):
323
  return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
 
309
  try:
310
  if device == "auto":
311
  device = "cuda" if torch.cuda.is_available() else "cpu"
312
+
313
  self.device = device
314
  self.offload_to_cpu = offload_to_cpu
315
  # Set dtype based on device: bfloat16 for cuda, float32 for cpu
 
317
  self.dtype = torch.bfloat16 if device in ["cuda", "xpu"] else torch.float32
318
  else:
319
  self.dtype = dtype
320
+
321
+ # If lm_model_path is None, use default
322
+ if lm_model_path is None:
323
+ lm_model_path = "acestep-5Hz-lm-1.7B"
324
+ logger.info(f"[initialize] lm_model_path is None, using default: {lm_model_path}")
325
+
326
  full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
327
  if not os.path.exists(full_lm_model_path):
328
  return f"❌ 5Hz LM model not found at {full_lm_model_path}", False