Sayoyo commited on
Commit
ef09b17
·
1 Parent(s): a657594

fix: update model repo_id

Browse files
Files changed (1) hide show
  1. acestep/handler.py +17 -15
acestep/handler.py CHANGED
@@ -122,6 +122,9 @@ class AceStepHandler:
122
  Ensure model is downloaded from HuggingFace Hub.
123
  Used for HuggingFace Space auto-download support.
124
 
 
 
 
125
  Args:
126
  model_name: Model directory name (e.g., "acestep-v15-turbo")
127
  checkpoint_dir: Target checkpoint directory
@@ -131,29 +134,28 @@ class AceStepHandler:
131
  """
132
  from huggingface_hub import snapshot_download
133
 
134
- # Model name to HuggingFace repo ID mapping
135
- MODEL_REPO_MAP = {
136
- "acestep-v15-turbo": "ACE-Step/ACE-Step-v1-3.5B-turbo",
137
- "acestep-v15-base": "ACE-Step/ACE-Step-v1-3.5B",
138
- }
139
-
140
- repo_id = MODEL_REPO_MAP.get(model_name)
141
- if repo_id is None:
142
- # Try using model_name as repo_id directly
143
- repo_id = f"ACE-Step/{model_name}"
144
 
145
  model_path = os.path.join(checkpoint_dir, model_name)
146
- logger.info(f"Downloading model {repo_id} to {model_path}...")
 
 
 
 
 
 
 
147
 
148
  try:
149
  snapshot_download(
150
- repo_id=repo_id,
151
- local_dir=model_path,
152
  local_dir_use_symlinks=False,
153
  )
154
- logger.info(f"Model {repo_id} downloaded successfully")
155
  except Exception as e:
156
- logger.error(f"Failed to download model {repo_id}: {e}")
157
  raise
158
 
159
  return model_path
 
122
  Ensure model is downloaded from HuggingFace Hub.
123
  Used for HuggingFace Space auto-download support.
124
 
125
+ Downloads the unified ACE-Step/Ace-Step1.5 repository which contains
126
+ both acestep-v15-turbo and acestep-5Hz-lm-1.7B models.
127
+
128
  Args:
129
  model_name: Model directory name (e.g., "acestep-v15-turbo")
130
  checkpoint_dir: Target checkpoint directory
 
134
  """
135
  from huggingface_hub import snapshot_download
136
 
137
+ # Unified repository containing all models
138
+ REPO_ID = "ACE-Step/Ace-Step1.5"
 
 
 
 
 
 
 
 
139
 
140
  model_path = os.path.join(checkpoint_dir, model_name)
141
+
142
+ # Check if model already exists
143
+ if os.path.exists(model_path) and os.listdir(model_path):
144
+ logger.info(f"Model {model_name} already exists at {model_path}")
145
+ return model_path
146
+
147
+ # Download the entire repository to checkpoint_dir
148
+ logger.info(f"Downloading {REPO_ID} to {checkpoint_dir}...")
149
 
150
  try:
151
  snapshot_download(
152
+ repo_id=REPO_ID,
153
+ local_dir=checkpoint_dir,
154
  local_dir_use_symlinks=False,
155
  )
156
+ logger.info(f"Repository {REPO_ID} downloaded successfully")
157
  except Exception as e:
158
+ logger.error(f"Failed to download repository {REPO_ID}: {e}")
159
  raise
160
 
161
  return model_path