Spaces:
Running
on
A100
Running
on
A100
Merge pull request #11 from ace-step/fix_dit_offload
Browse files- acestep/handler.py +127 -12
acestep/handler.py
CHANGED
|
@@ -241,11 +241,9 @@ class AceStepHandler:
|
|
| 241 |
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
|
| 242 |
if os.path.exists(silence_latent_path):
|
| 243 |
self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
|
| 244 |
-
#
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
else:
|
| 248 |
-
self.silence_latent = self.silence_latent.to("cpu").to(self.dtype)
|
| 249 |
else:
|
| 250 |
raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
|
| 251 |
else:
|
|
@@ -301,6 +299,113 @@ class AceStepHandler:
|
|
| 301 |
logger.exception("[initialize_service] Error initializing model")
|
| 302 |
return error_msg, False
|
| 303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
@contextmanager
|
| 305 |
def _load_model_context(self, model_name: str):
|
| 306 |
"""
|
|
@@ -324,7 +429,7 @@ class AceStepHandler:
|
|
| 324 |
param = next(model.parameters())
|
| 325 |
if param.device.type == "cpu":
|
| 326 |
logger.info(f"[_load_model_context] Moving {model_name} to {self.device} (persistent)")
|
| 327 |
-
|
| 328 |
if hasattr(self, "silence_latent"):
|
| 329 |
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
|
| 330 |
except StopIteration:
|
|
@@ -342,9 +447,9 @@ class AceStepHandler:
|
|
| 342 |
start_time = time.time()
|
| 343 |
if model_name == "vae":
|
| 344 |
vae_dtype = self._get_vae_dtype()
|
| 345 |
-
|
| 346 |
else:
|
| 347 |
-
|
| 348 |
|
| 349 |
if model_name == "model" and hasattr(self, "silence_latent"):
|
| 350 |
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
|
|
@@ -359,10 +464,11 @@ class AceStepHandler:
|
|
| 359 |
# Offload to CPU
|
| 360 |
logger.info(f"[_load_model_context] Offloading {model_name} to CPU")
|
| 361 |
start_time = time.time()
|
| 362 |
-
|
| 363 |
|
| 364 |
-
|
| 365 |
-
|
|
|
|
| 366 |
|
| 367 |
torch.cuda.empty_cache()
|
| 368 |
offload_time = time.time() - start_time
|
|
@@ -1269,6 +1375,9 @@ class AceStepHandler:
|
|
| 1269 |
Batch dictionary ready for model input
|
| 1270 |
"""
|
| 1271 |
batch_size = len(captions)
|
|
|
|
|
|
|
|
|
|
| 1272 |
|
| 1273 |
# Normalize audio_code_hints to batch list
|
| 1274 |
audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size)
|
|
@@ -1638,6 +1747,9 @@ class AceStepHandler:
|
|
| 1638 |
def infer_refer_latent(self, refer_audioss):
|
| 1639 |
refer_audio_order_mask = []
|
| 1640 |
refer_audio_latents = []
|
|
|
|
|
|
|
|
|
|
| 1641 |
|
| 1642 |
def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor:
|
| 1643 |
"""Normalize audio tensor to [2, T] on current device."""
|
|
@@ -1932,6 +2044,9 @@ class AceStepHandler:
|
|
| 1932 |
else:
|
| 1933 |
seed_param = random.randint(0, 2**32 - 1)
|
| 1934 |
|
|
|
|
|
|
|
|
|
|
| 1935 |
generate_kwargs = {
|
| 1936 |
"text_hidden_states": text_hidden_states,
|
| 1937 |
"text_attention_mask": text_attention_mask,
|
|
@@ -1995,7 +2110,7 @@ class AceStepHandler:
|
|
| 1995 |
|
| 1996 |
return outputs
|
| 1997 |
|
| 1998 |
-
def tiled_decode(self, latents, chunk_size=512, overlap=64, offload_wav_to_cpu=
|
| 1999 |
"""
|
| 2000 |
Decode latents using tiling to reduce VRAM usage.
|
| 2001 |
Uses overlap-discard strategy to avoid boundary artifacts.
|
|
|
|
| 241 |
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
|
| 242 |
if os.path.exists(silence_latent_path):
|
| 243 |
self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
|
| 244 |
+
# Always keep silence_latent on GPU - it's used in many places outside model context
|
| 245 |
+
# and is small enough that it won't significantly impact VRAM
|
| 246 |
+
self.silence_latent = self.silence_latent.to(device).to(self.dtype)
|
|
|
|
|
|
|
| 247 |
else:
|
| 248 |
raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
|
| 249 |
else:
|
|
|
|
| 299 |
logger.exception("[initialize_service] Error initializing model")
|
| 300 |
return error_msg, False
|
| 301 |
|
| 302 |
+
def _is_on_target_device(self, tensor, target_device):
|
| 303 |
+
"""Check if tensor is on the target device (handles cuda vs cuda:0 comparison)."""
|
| 304 |
+
if tensor is None:
|
| 305 |
+
return True
|
| 306 |
+
target_type = "cpu" if target_device == "cpu" else "cuda"
|
| 307 |
+
return tensor.device.type == target_type
|
| 308 |
+
|
| 309 |
+
def _ensure_silence_latent_on_device(self):
|
| 310 |
+
"""Ensure silence_latent is on the correct device (self.device)."""
|
| 311 |
+
if hasattr(self, "silence_latent") and self.silence_latent is not None:
|
| 312 |
+
if not self._is_on_target_device(self.silence_latent, self.device):
|
| 313 |
+
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
|
| 314 |
+
|
| 315 |
+
def _move_module_recursive(self, module, target_device, dtype=None, visited=None):
|
| 316 |
+
"""
|
| 317 |
+
Recursively move a module and all its submodules to the target device.
|
| 318 |
+
This handles modules that may not be properly registered.
|
| 319 |
+
"""
|
| 320 |
+
if visited is None:
|
| 321 |
+
visited = set()
|
| 322 |
+
|
| 323 |
+
module_id = id(module)
|
| 324 |
+
if module_id in visited:
|
| 325 |
+
return
|
| 326 |
+
visited.add(module_id)
|
| 327 |
+
|
| 328 |
+
# Move the module itself
|
| 329 |
+
module.to(target_device)
|
| 330 |
+
if dtype is not None:
|
| 331 |
+
module.to(dtype)
|
| 332 |
+
|
| 333 |
+
# Move all direct parameters
|
| 334 |
+
for param_name, param in module._parameters.items():
|
| 335 |
+
if param is not None and not self._is_on_target_device(param, target_device):
|
| 336 |
+
module._parameters[param_name] = param.to(target_device)
|
| 337 |
+
if dtype is not None:
|
| 338 |
+
module._parameters[param_name] = module._parameters[param_name].to(dtype)
|
| 339 |
+
|
| 340 |
+
# Move all direct buffers
|
| 341 |
+
for buf_name, buf in module._buffers.items():
|
| 342 |
+
if buf is not None and not self._is_on_target_device(buf, target_device):
|
| 343 |
+
module._buffers[buf_name] = buf.to(target_device)
|
| 344 |
+
|
| 345 |
+
# Recursively process all submodules (registered and unregistered)
|
| 346 |
+
for name, child in module._modules.items():
|
| 347 |
+
if child is not None:
|
| 348 |
+
self._move_module_recursive(child, target_device, dtype, visited)
|
| 349 |
+
|
| 350 |
+
# Also check for any nn.Module attributes that might not be in _modules
|
| 351 |
+
for attr_name in dir(module):
|
| 352 |
+
if attr_name.startswith('_'):
|
| 353 |
+
continue
|
| 354 |
+
try:
|
| 355 |
+
attr = getattr(module, attr_name, None)
|
| 356 |
+
if isinstance(attr, torch.nn.Module) and id(attr) not in visited:
|
| 357 |
+
self._move_module_recursive(attr, target_device, dtype, visited)
|
| 358 |
+
except Exception:
|
| 359 |
+
pass
|
| 360 |
+
|
| 361 |
+
def _recursive_to_device(self, model, device, dtype=None):
|
| 362 |
+
"""
|
| 363 |
+
Recursively move all parameters and buffers of a model to the specified device.
|
| 364 |
+
This is more thorough than model.to() for some custom HuggingFace models.
|
| 365 |
+
"""
|
| 366 |
+
target_device = torch.device(device) if isinstance(device, str) else device
|
| 367 |
+
|
| 368 |
+
# Method 1: Standard .to() call
|
| 369 |
+
model.to(target_device)
|
| 370 |
+
if dtype is not None:
|
| 371 |
+
model.to(dtype)
|
| 372 |
+
|
| 373 |
+
# Method 2: Use our thorough recursive moving for any missed modules
|
| 374 |
+
self._move_module_recursive(model, target_device, dtype)
|
| 375 |
+
|
| 376 |
+
# Method 3: Force move via state_dict if there are still parameters on wrong device
|
| 377 |
+
wrong_device_params = []
|
| 378 |
+
for name, param in model.named_parameters():
|
| 379 |
+
if not self._is_on_target_device(param, device):
|
| 380 |
+
wrong_device_params.append(name)
|
| 381 |
+
|
| 382 |
+
if wrong_device_params and device != "cpu":
|
| 383 |
+
logger.warning(f"[_recursive_to_device] {len(wrong_device_params)} parameters on wrong device, using state_dict method")
|
| 384 |
+
# Get current state dict and move all tensors
|
| 385 |
+
state_dict = model.state_dict()
|
| 386 |
+
moved_state_dict = {}
|
| 387 |
+
for key, value in state_dict.items():
|
| 388 |
+
if isinstance(value, torch.Tensor):
|
| 389 |
+
moved_state_dict[key] = value.to(target_device)
|
| 390 |
+
if dtype is not None and moved_state_dict[key].is_floating_point():
|
| 391 |
+
moved_state_dict[key] = moved_state_dict[key].to(dtype)
|
| 392 |
+
else:
|
| 393 |
+
moved_state_dict[key] = value
|
| 394 |
+
model.load_state_dict(moved_state_dict)
|
| 395 |
+
|
| 396 |
+
# Synchronize CUDA to ensure all transfers are complete
|
| 397 |
+
if device != "cpu" and torch.cuda.is_available():
|
| 398 |
+
torch.cuda.synchronize()
|
| 399 |
+
|
| 400 |
+
# Final verification
|
| 401 |
+
if device != "cpu":
|
| 402 |
+
still_wrong = []
|
| 403 |
+
for name, param in model.named_parameters():
|
| 404 |
+
if not self._is_on_target_device(param, device):
|
| 405 |
+
still_wrong.append(f"{name} on {param.device}")
|
| 406 |
+
if still_wrong:
|
| 407 |
+
logger.error(f"[_recursive_to_device] CRITICAL: {len(still_wrong)} parameters still on wrong device: {still_wrong[:10]}")
|
| 408 |
+
|
| 409 |
@contextmanager
|
| 410 |
def _load_model_context(self, model_name: str):
|
| 411 |
"""
|
|
|
|
| 429 |
param = next(model.parameters())
|
| 430 |
if param.device.type == "cpu":
|
| 431 |
logger.info(f"[_load_model_context] Moving {model_name} to {self.device} (persistent)")
|
| 432 |
+
self._recursive_to_device(model, self.device, self.dtype)
|
| 433 |
if hasattr(self, "silence_latent"):
|
| 434 |
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
|
| 435 |
except StopIteration:
|
|
|
|
| 447 |
start_time = time.time()
|
| 448 |
if model_name == "vae":
|
| 449 |
vae_dtype = self._get_vae_dtype()
|
| 450 |
+
self._recursive_to_device(model, self.device, vae_dtype)
|
| 451 |
else:
|
| 452 |
+
self._recursive_to_device(model, self.device, self.dtype)
|
| 453 |
|
| 454 |
if model_name == "model" and hasattr(self, "silence_latent"):
|
| 455 |
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
|
|
|
|
| 464 |
# Offload to CPU
|
| 465 |
logger.info(f"[_load_model_context] Offloading {model_name} to CPU")
|
| 466 |
start_time = time.time()
|
| 467 |
+
self._recursive_to_device(model, "cpu")
|
| 468 |
|
| 469 |
+
# NOTE: Do NOT offload silence_latent to CPU here!
|
| 470 |
+
# silence_latent is used in many places outside of model context,
|
| 471 |
+
# so it should stay on GPU to avoid device mismatch errors.
|
| 472 |
|
| 473 |
torch.cuda.empty_cache()
|
| 474 |
offload_time = time.time() - start_time
|
|
|
|
| 1375 |
Batch dictionary ready for model input
|
| 1376 |
"""
|
| 1377 |
batch_size = len(captions)
|
| 1378 |
+
|
| 1379 |
+
# Ensure silence_latent is on the correct device for batch preparation
|
| 1380 |
+
self._ensure_silence_latent_on_device()
|
| 1381 |
|
| 1382 |
# Normalize audio_code_hints to batch list
|
| 1383 |
audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size)
|
|
|
|
| 1747 |
def infer_refer_latent(self, refer_audioss):
|
| 1748 |
refer_audio_order_mask = []
|
| 1749 |
refer_audio_latents = []
|
| 1750 |
+
|
| 1751 |
+
# Ensure silence_latent is on the correct device
|
| 1752 |
+
self._ensure_silence_latent_on_device()
|
| 1753 |
|
| 1754 |
def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor:
|
| 1755 |
"""Normalize audio tensor to [2, T] on current device."""
|
|
|
|
| 2044 |
else:
|
| 2045 |
seed_param = random.randint(0, 2**32 - 1)
|
| 2046 |
|
| 2047 |
+
# Ensure silence_latent is on the correct device before creating generate_kwargs
|
| 2048 |
+
self._ensure_silence_latent_on_device()
|
| 2049 |
+
|
| 2050 |
generate_kwargs = {
|
| 2051 |
"text_hidden_states": text_hidden_states,
|
| 2052 |
"text_attention_mask": text_attention_mask,
|
|
|
|
| 2110 |
|
| 2111 |
return outputs
|
| 2112 |
|
| 2113 |
+
def tiled_decode(self, latents, chunk_size=512, overlap=64, offload_wav_to_cpu=True):
|
| 2114 |
"""
|
| 2115 |
Decode latents using tiling to reduce VRAM usage.
|
| 2116 |
Uses overlap-discard strategy to avoid boundary artifacts.
|