Sayoyo commited on
Commit
4893237
·
unverified ·
2 Parent(s): 1d5e812 6b03a4e

Merge pull request #11 from ace-step/fix_dit_offload

Browse files
Files changed (1) hide show
  1. acestep/handler.py +127 -12
acestep/handler.py CHANGED
@@ -241,11 +241,9 @@ class AceStepHandler:
241
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
242
  if os.path.exists(silence_latent_path):
243
  self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
244
- # If DiT is on GPU, silence_latent should also be on GPU
245
- if not self.offload_to_cpu or not self.offload_dit_to_cpu:
246
- self.silence_latent = self.silence_latent.to(device).to(self.dtype)
247
- else:
248
- self.silence_latent = self.silence_latent.to("cpu").to(self.dtype)
249
  else:
250
  raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
251
  else:
@@ -301,6 +299,113 @@ class AceStepHandler:
301
  logger.exception("[initialize_service] Error initializing model")
302
  return error_msg, False
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  @contextmanager
305
  def _load_model_context(self, model_name: str):
306
  """
@@ -324,7 +429,7 @@ class AceStepHandler:
324
  param = next(model.parameters())
325
  if param.device.type == "cpu":
326
  logger.info(f"[_load_model_context] Moving {model_name} to {self.device} (persistent)")
327
- model.to(self.device).to(self.dtype)
328
  if hasattr(self, "silence_latent"):
329
  self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
330
  except StopIteration:
@@ -342,9 +447,9 @@ class AceStepHandler:
342
  start_time = time.time()
343
  if model_name == "vae":
344
  vae_dtype = self._get_vae_dtype()
345
- model.to(self.device).to(vae_dtype)
346
  else:
347
- model.to(self.device).to(self.dtype)
348
 
349
  if model_name == "model" and hasattr(self, "silence_latent"):
350
  self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
@@ -359,10 +464,11 @@ class AceStepHandler:
359
  # Offload to CPU
360
  logger.info(f"[_load_model_context] Offloading {model_name} to CPU")
361
  start_time = time.time()
362
- model.to("cpu")
363
 
364
- if model_name == "model" and hasattr(self, "silence_latent"):
365
- self.silence_latent = self.silence_latent.to("cpu")
 
366
 
367
  torch.cuda.empty_cache()
368
  offload_time = time.time() - start_time
@@ -1269,6 +1375,9 @@ class AceStepHandler:
1269
  Batch dictionary ready for model input
1270
  """
1271
  batch_size = len(captions)
 
 
 
1272
 
1273
  # Normalize audio_code_hints to batch list
1274
  audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size)
@@ -1638,6 +1747,9 @@ class AceStepHandler:
1638
  def infer_refer_latent(self, refer_audioss):
1639
  refer_audio_order_mask = []
1640
  refer_audio_latents = []
 
 
 
1641
 
1642
  def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor:
1643
  """Normalize audio tensor to [2, T] on current device."""
@@ -1932,6 +2044,9 @@ class AceStepHandler:
1932
  else:
1933
  seed_param = random.randint(0, 2**32 - 1)
1934
 
 
 
 
1935
  generate_kwargs = {
1936
  "text_hidden_states": text_hidden_states,
1937
  "text_attention_mask": text_attention_mask,
@@ -1995,7 +2110,7 @@ class AceStepHandler:
1995
 
1996
  return outputs
1997
 
1998
- def tiled_decode(self, latents, chunk_size=512, overlap=64, offload_wav_to_cpu=False):
1999
  """
2000
  Decode latents using tiling to reduce VRAM usage.
2001
  Uses overlap-discard strategy to avoid boundary artifacts.
 
241
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
242
  if os.path.exists(silence_latent_path):
243
  self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
244
+ # Always keep silence_latent on GPU - it's used in many places outside model context
245
+ # and is small enough that it won't significantly impact VRAM
246
+ self.silence_latent = self.silence_latent.to(device).to(self.dtype)
 
 
247
  else:
248
  raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
249
  else:
 
299
  logger.exception("[initialize_service] Error initializing model")
300
  return error_msg, False
301
 
302
+ def _is_on_target_device(self, tensor, target_device):
303
+ """Check if tensor is on the target device (handles cuda vs cuda:0 comparison)."""
304
+ if tensor is None:
305
+ return True
306
+ target_type = "cpu" if target_device == "cpu" else "cuda"
307
+ return tensor.device.type == target_type
308
+
309
+ def _ensure_silence_latent_on_device(self):
310
+ """Ensure silence_latent is on the correct device (self.device)."""
311
+ if hasattr(self, "silence_latent") and self.silence_latent is not None:
312
+ if not self._is_on_target_device(self.silence_latent, self.device):
313
+ self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
314
+
315
+ def _move_module_recursive(self, module, target_device, dtype=None, visited=None):
316
+ """
317
+ Recursively move a module and all its submodules to the target device.
318
+ This handles modules that may not be properly registered.
319
+ """
320
+ if visited is None:
321
+ visited = set()
322
+
323
+ module_id = id(module)
324
+ if module_id in visited:
325
+ return
326
+ visited.add(module_id)
327
+
328
+ # Move the module itself
329
+ module.to(target_device)
330
+ if dtype is not None:
331
+ module.to(dtype)
332
+
333
+ # Move all direct parameters
334
+ for param_name, param in module._parameters.items():
335
+ if param is not None and not self._is_on_target_device(param, target_device):
336
+ module._parameters[param_name] = param.to(target_device)
337
+ if dtype is not None:
338
+ module._parameters[param_name] = module._parameters[param_name].to(dtype)
339
+
340
+ # Move all direct buffers
341
+ for buf_name, buf in module._buffers.items():
342
+ if buf is not None and not self._is_on_target_device(buf, target_device):
343
+ module._buffers[buf_name] = buf.to(target_device)
344
+
345
+ # Recursively process all submodules (registered and unregistered)
346
+ for name, child in module._modules.items():
347
+ if child is not None:
348
+ self._move_module_recursive(child, target_device, dtype, visited)
349
+
350
+ # Also check for any nn.Module attributes that might not be in _modules
351
+ for attr_name in dir(module):
352
+ if attr_name.startswith('_'):
353
+ continue
354
+ try:
355
+ attr = getattr(module, attr_name, None)
356
+ if isinstance(attr, torch.nn.Module) and id(attr) not in visited:
357
+ self._move_module_recursive(attr, target_device, dtype, visited)
358
+ except Exception:
359
+ pass
360
+
361
+ def _recursive_to_device(self, model, device, dtype=None):
362
+ """
363
+ Recursively move all parameters and buffers of a model to the specified device.
364
+ This is more thorough than model.to() for some custom HuggingFace models.
365
+ """
366
+ target_device = torch.device(device) if isinstance(device, str) else device
367
+
368
+ # Method 1: Standard .to() call
369
+ model.to(target_device)
370
+ if dtype is not None:
371
+ model.to(dtype)
372
+
373
+ # Method 2: Use our thorough recursive moving for any missed modules
374
+ self._move_module_recursive(model, target_device, dtype)
375
+
376
+ # Method 3: Force move via state_dict if there are still parameters on wrong device
377
+ wrong_device_params = []
378
+ for name, param in model.named_parameters():
379
+ if not self._is_on_target_device(param, device):
380
+ wrong_device_params.append(name)
381
+
382
+ if wrong_device_params and device != "cpu":
383
+ logger.warning(f"[_recursive_to_device] {len(wrong_device_params)} parameters on wrong device, using state_dict method")
384
+ # Get current state dict and move all tensors
385
+ state_dict = model.state_dict()
386
+ moved_state_dict = {}
387
+ for key, value in state_dict.items():
388
+ if isinstance(value, torch.Tensor):
389
+ moved_state_dict[key] = value.to(target_device)
390
+ if dtype is not None and moved_state_dict[key].is_floating_point():
391
+ moved_state_dict[key] = moved_state_dict[key].to(dtype)
392
+ else:
393
+ moved_state_dict[key] = value
394
+ model.load_state_dict(moved_state_dict)
395
+
396
+ # Synchronize CUDA to ensure all transfers are complete
397
+ if device != "cpu" and torch.cuda.is_available():
398
+ torch.cuda.synchronize()
399
+
400
+ # Final verification
401
+ if device != "cpu":
402
+ still_wrong = []
403
+ for name, param in model.named_parameters():
404
+ if not self._is_on_target_device(param, device):
405
+ still_wrong.append(f"{name} on {param.device}")
406
+ if still_wrong:
407
+ logger.error(f"[_recursive_to_device] CRITICAL: {len(still_wrong)} parameters still on wrong device: {still_wrong[:10]}")
408
+
409
  @contextmanager
410
  def _load_model_context(self, model_name: str):
411
  """
 
429
  param = next(model.parameters())
430
  if param.device.type == "cpu":
431
  logger.info(f"[_load_model_context] Moving {model_name} to {self.device} (persistent)")
432
+ self._recursive_to_device(model, self.device, self.dtype)
433
  if hasattr(self, "silence_latent"):
434
  self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
435
  except StopIteration:
 
447
  start_time = time.time()
448
  if model_name == "vae":
449
  vae_dtype = self._get_vae_dtype()
450
+ self._recursive_to_device(model, self.device, vae_dtype)
451
  else:
452
+ self._recursive_to_device(model, self.device, self.dtype)
453
 
454
  if model_name == "model" and hasattr(self, "silence_latent"):
455
  self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
 
464
  # Offload to CPU
465
  logger.info(f"[_load_model_context] Offloading {model_name} to CPU")
466
  start_time = time.time()
467
+ self._recursive_to_device(model, "cpu")
468
 
469
+ # NOTE: Do NOT offload silence_latent to CPU here!
470
+ # silence_latent is used in many places outside of model context,
471
+ # so it should stay on GPU to avoid device mismatch errors.
472
 
473
  torch.cuda.empty_cache()
474
  offload_time = time.time() - start_time
 
1375
  Batch dictionary ready for model input
1376
  """
1377
  batch_size = len(captions)
1378
+
1379
+ # Ensure silence_latent is on the correct device for batch preparation
1380
+ self._ensure_silence_latent_on_device()
1381
 
1382
  # Normalize audio_code_hints to batch list
1383
  audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size)
 
1747
  def infer_refer_latent(self, refer_audioss):
1748
  refer_audio_order_mask = []
1749
  refer_audio_latents = []
1750
+
1751
+ # Ensure silence_latent is on the correct device
1752
+ self._ensure_silence_latent_on_device()
1753
 
1754
  def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor:
1755
  """Normalize audio tensor to [2, T] on current device."""
 
2044
  else:
2045
  seed_param = random.randint(0, 2**32 - 1)
2046
 
2047
+ # Ensure silence_latent is on the correct device before creating generate_kwargs
2048
+ self._ensure_silence_latent_on_device()
2049
+
2050
  generate_kwargs = {
2051
  "text_hidden_states": text_hidden_states,
2052
  "text_attention_mask": text_attention_mask,
 
2110
 
2111
  return outputs
2112
 
2113
+ def tiled_decode(self, latents, chunk_size=512, overlap=64, offload_wav_to_cpu=True):
2114
  """
2115
  Decode latents using tiling to reduce VRAM usage.
2116
  Uses overlap-discard strategy to avoid boundary artifacts.