xushengyuan commited on
Commit
497ca57
·
1 Parent(s): bc9e2c6

tiled decoding

Browse files
Files changed (2) hide show
  1. acestep/handler.py +86 -2
  2. test.py +20 -1
acestep/handler.py CHANGED
@@ -14,6 +14,7 @@ import torch
14
  import torchaudio
15
  import soundfile as sf
16
  import time
 
17
  from loguru import logger
18
  import warnings
19
 
@@ -1134,6 +1135,7 @@ class AceStepHandler:
1134
  code_hint = audio_code_hints[i]
1135
  # Prefer decoding from provided audio codes
1136
  if code_hint:
 
1137
  decoded_latents = self._decode_audio_codes_to_latents(code_hint)
1138
  if decoded_latents is not None:
1139
  decoded_latents = decoded_latents.squeeze(0)
@@ -1150,6 +1152,7 @@ class AceStepHandler:
1150
  target_latent = self.silence_latent[0, :expected_latent_length, :]
1151
  else:
1152
  # Ensure input is in VAE's dtype
 
1153
  vae_input = current_wav.to(self.device).to(self.vae.dtype)
1154
  target_latent = self.vae.encode(vae_input).latent_dist.sample()
1155
  # Cast back to model dtype
@@ -1319,6 +1322,7 @@ class AceStepHandler:
1319
  for i in range(batch_size):
1320
  if audio_code_hints[i] is not None:
1321
  # Decode audio codes to 25Hz latents
 
1322
  hints = self._decode_audio_codes_to_latents(audio_code_hints[i])
1323
  if hints is not None:
1324
  # Pad or crop to match max_latent_length
@@ -1563,7 +1567,9 @@ class AceStepHandler:
1563
  lyric_attention_mask = batch["lyric_attention_masks"]
1564
  text_inputs = batch["text_inputs"]
1565
 
 
1566
  text_hidden_states = self.infer_text_embeddings(text_token_idss)
 
1567
  lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss)
1568
 
1569
  is_covers = batch["is_covers"]
@@ -1576,6 +1582,7 @@ class AceStepHandler:
1576
  non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None)
1577
  non_cover_text_hidden_states = None
1578
  if non_cover_text_input_ids is not None:
 
1579
  non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids)
1580
 
1581
  return (
@@ -1803,10 +1810,78 @@ class AceStepHandler:
1803
  "cfg_interval_start": cfg_interval_start,
1804
  "cfg_interval_end": cfg_interval_end,
1805
  }
1806
-
1807
  outputs = self.model.generate_audio(**generate_kwargs)
1808
  return outputs
1809
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1810
  def generate_music(
1811
  self,
1812
  captions: str,
@@ -1834,6 +1909,7 @@ class AceStepHandler:
1834
  cfg_interval_end: float = 1.0,
1835
  audio_format: str = "mp3",
1836
  lm_temperature: float = 0.6,
 
1837
  progress=None
1838
  ) -> Tuple[Optional[str], Optional[str], List[str], str, str, str, str, str, Optional[Any], str, str, Optional[Any]]:
1839
  """
@@ -1887,6 +1963,7 @@ class AceStepHandler:
1887
  # 1. Process reference audio
1888
  refer_audios = None
1889
  if reference_audio is not None:
 
1890
  processed_ref_audio = self.process_reference_audio(reference_audio)
1891
  if processed_ref_audio is not None:
1892
  # Convert to the format expected by the service: List[List[torch.Tensor]]
@@ -1898,6 +1975,7 @@ class AceStepHandler:
1898
  # 2. Process source audio
1899
  processed_src_audio = None
1900
  if src_audio is not None:
 
1901
  processed_src_audio = self.process_src_audio(src_audio)
1902
 
1903
  # 3. Prepare batch data
@@ -1975,7 +2053,13 @@ class AceStepHandler:
1975
  pred_latents_for_decode = pred_latents.transpose(1, 2)
1976
  # Ensure input is in VAE's dtype
1977
  pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype)
1978
- pred_wavs = self.vae.decode(pred_latents_for_decode).sample # [batch, channels, samples]
 
 
 
 
 
 
1979
  # Cast output to float32 for audio processing/saving
1980
  pred_wavs = pred_wavs.to(torch.float32)
1981
  end_time = time.time()
 
14
  import torchaudio
15
  import soundfile as sf
16
  import time
17
+ from tqdm import tqdm
18
  from loguru import logger
19
  import warnings
20
 
 
1135
  code_hint = audio_code_hints[i]
1136
  # Prefer decoding from provided audio codes
1137
  if code_hint:
1138
+ print(f"[generate_music] Decoding audio codes for item {i}...")
1139
  decoded_latents = self._decode_audio_codes_to_latents(code_hint)
1140
  if decoded_latents is not None:
1141
  decoded_latents = decoded_latents.squeeze(0)
 
1152
  target_latent = self.silence_latent[0, :expected_latent_length, :]
1153
  else:
1154
  # Ensure input is in VAE's dtype
1155
+ print(f"[generate_music] Encoding target audio to latents for item {i}...")
1156
  vae_input = current_wav.to(self.device).to(self.vae.dtype)
1157
  target_latent = self.vae.encode(vae_input).latent_dist.sample()
1158
  # Cast back to model dtype
 
1322
  for i in range(batch_size):
1323
  if audio_code_hints[i] is not None:
1324
  # Decode audio codes to 25Hz latents
1325
+ print(f"[generate_music] Decoding audio codes for LM hints for item {i}...")
1326
  hints = self._decode_audio_codes_to_latents(audio_code_hints[i])
1327
  if hints is not None:
1328
  # Pad or crop to match max_latent_length
 
1567
  lyric_attention_mask = batch["lyric_attention_masks"]
1568
  text_inputs = batch["text_inputs"]
1569
 
1570
+ print("[preprocess_batch] Inferring prompt embeddings...")
1571
  text_hidden_states = self.infer_text_embeddings(text_token_idss)
1572
+ print("[preprocess_batch] Inferring lyric embeddings...")
1573
  lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss)
1574
 
1575
  is_covers = batch["is_covers"]
 
1582
  non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None)
1583
  non_cover_text_hidden_states = None
1584
  if non_cover_text_input_ids is not None:
1585
+ print("[preprocess_batch] Inferring non-cover text embeddings...")
1586
  non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids)
1587
 
1588
  return (
 
1810
  "cfg_interval_start": cfg_interval_start,
1811
  "cfg_interval_end": cfg_interval_end,
1812
  }
1813
+ print("[service_generate] Generating audio...")
1814
  outputs = self.model.generate_audio(**generate_kwargs)
1815
  return outputs
1816
 
1817
+ def tiled_decode(self, latents, chunk_size=512, overlap=64):
1818
+ """
1819
+ Decode latents using tiling to reduce VRAM usage.
1820
+ Uses overlap-discard strategy to avoid boundary artifacts.
1821
+
1822
+ Args:
1823
+ latents: [Batch, Channels, Length]
1824
+ chunk_size: Size of latent chunk to process at once
1825
+ overlap: Overlap size in latent frames
1826
+ """
1827
+ B, C, T = latents.shape
1828
+
1829
+ # If short enough, decode directly
1830
+ if T <= chunk_size:
1831
+ return self.vae.decode(latents).sample
1832
+
1833
+ # Calculate stride (core size)
1834
+ stride = chunk_size - 2 * overlap
1835
+ if stride <= 0:
1836
+ raise ValueError(f"chunk_size {chunk_size} must be > 2 * overlap {overlap}")
1837
+
1838
+ decoded_audio_list = []
1839
+
1840
+ # We need to determine upsample factor to trim audio correctly
1841
+ upsample_factor = None
1842
+
1843
+ num_steps = math.ceil(T / stride)
1844
+
1845
+ for i in tqdm(range(num_steps), desc="Decoding audio chunks"):
1846
+ # Core range in latents
1847
+ core_start = i * stride
1848
+ core_end = min(core_start + stride, T)
1849
+
1850
+ # Window range (with overlap)
1851
+ win_start = max(0, core_start - overlap)
1852
+ win_end = min(T, core_end + overlap)
1853
+
1854
+ # Extract chunk
1855
+ latent_chunk = latents[:, :, win_start:win_end]
1856
+
1857
+ # Decode
1858
+ # [Batch, Channels, AudioSamples]
1859
+ audio_chunk = self.vae.decode(latent_chunk).sample
1860
+
1861
+ # Determine upsample factor from the first chunk
1862
+ if upsample_factor is None:
1863
+ upsample_factor = audio_chunk.shape[-1] / latent_chunk.shape[-1]
1864
+
1865
+ # Calculate trim amounts in audio samples
1866
+ # How much overlap was added at the start?
1867
+ added_start = core_start - win_start # latent frames
1868
+ trim_start = int(round(added_start * upsample_factor))
1869
+
1870
+ # How much overlap was added at the end?
1871
+ added_end = win_end - core_end # latent frames
1872
+ trim_end = int(round(added_end * upsample_factor))
1873
+
1874
+ # Trim audio
1875
+ audio_len = audio_chunk.shape[-1]
1876
+ end_idx = audio_len - trim_end
1877
+
1878
+ audio_core = audio_chunk[:, :, trim_start:end_idx]
1879
+ decoded_audio_list.append(audio_core)
1880
+
1881
+ # Concatenate
1882
+ final_audio = torch.cat(decoded_audio_list, dim=-1)
1883
+ return final_audio
1884
+
1885
  def generate_music(
1886
  self,
1887
  captions: str,
 
1909
  cfg_interval_end: float = 1.0,
1910
  audio_format: str = "mp3",
1911
  lm_temperature: float = 0.6,
1912
+ use_tiled_decode: bool = True,
1913
  progress=None
1914
  ) -> Tuple[Optional[str], Optional[str], List[str], str, str, str, str, str, Optional[Any], str, str, Optional[Any]]:
1915
  """
 
1963
  # 1. Process reference audio
1964
  refer_audios = None
1965
  if reference_audio is not None:
1966
+ print("[generate_music] Processing reference audio...")
1967
  processed_ref_audio = self.process_reference_audio(reference_audio)
1968
  if processed_ref_audio is not None:
1969
  # Convert to the format expected by the service: List[List[torch.Tensor]]
 
1975
  # 2. Process source audio
1976
  processed_src_audio = None
1977
  if src_audio is not None:
1978
+ print("[generate_music] Processing source audio...")
1979
  processed_src_audio = self.process_src_audio(src_audio)
1980
 
1981
  # 3. Prepare batch data
 
2053
  pred_latents_for_decode = pred_latents.transpose(1, 2)
2054
  # Ensure input is in VAE's dtype
2055
  pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype)
2056
+
2057
+ if use_tiled_decode:
2058
+ print("[generate_music] Using tiled VAE decode to reduce VRAM usage...")
2059
+ pred_wavs = self.tiled_decode(pred_latents_for_decode) # [batch, channels, samples]
2060
+ else:
2061
+ pred_wavs = self.vae.decode(pred_latents_for_decode).sample
2062
+
2063
  # Cast output to float32 for audio processing/saving
2064
  pred_wavs = pred_wavs.to(torch.float32)
2065
  end_time = time.time()
test.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import sys
3
  import torch
4
  import shutil
 
5
  from acestep.handler import AceStepHandler
6
 
7
 
@@ -91,6 +92,12 @@ def main():
91
  seeds = "320145306, 1514681811"
92
 
93
  print("Starting generation...")
 
 
 
 
 
 
94
 
95
  # Call generate_music
96
  results = handler.generate_music(
@@ -109,7 +116,8 @@ def main():
109
  task_type="text2music",
110
  cfg_interval_start=0.0,
111
  cfg_interval_end=0.95,
112
- audio_format="wav"
 
113
  )
114
 
115
  # Unpack results
@@ -118,6 +126,17 @@ def main():
118
  align_score2, align_text2, align_plot2) = results
119
 
120
  print("\nGeneration Complete!")
 
 
 
 
 
 
 
 
 
 
 
121
  print(f"Status: {status_msg}")
122
  print(f"Info: {info}")
123
  print(f"Seeds used: {seed_val}")
 
2
  import sys
3
  import torch
4
  import shutil
5
+ import resource
6
  from acestep.handler import AceStepHandler
7
 
8
 
 
92
  seeds = "320145306, 1514681811"
93
 
94
  print("Starting generation...")
95
+
96
+ # Reset peak memory stats
97
+ if hasattr(torch, 'xpu') and torch.xpu.is_available():
98
+ torch.xpu.reset_peak_memory_stats()
99
+ elif torch.cuda.is_available():
100
+ torch.cuda.reset_peak_memory_stats()
101
 
102
  # Call generate_music
103
  results = handler.generate_music(
 
116
  task_type="text2music",
117
  cfg_interval_start=0.0,
118
  cfg_interval_end=0.95,
119
+ audio_format="wav",
120
+ use_tiled_decode=True,
121
  )
122
 
123
  # Unpack results
 
126
  align_score2, align_text2, align_plot2) = results
127
 
128
  print("\nGeneration Complete!")
129
+
130
+ # Print memory stats
131
+ if hasattr(torch, 'xpu') and torch.xpu.is_available():
132
+ peak_vram = torch.xpu.max_memory_allocated() / (1024 ** 3)
133
+ print(f"Peak VRAM usage: {peak_vram:.2f} GB")
134
+ elif torch.cuda.is_available():
135
+ peak_vram = torch.cuda.max_memory_allocated() / (1024 ** 3)
136
+ print(f"Peak VRAM usage: {peak_vram:.2f} GB")
137
+
138
+ peak_ram = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024 ** 2)
139
+ print(f"Peak RAM usage: {peak_ram:.2f} GB")
140
  print(f"Status: {status_msg}")
141
  print(f"Info: {info}")
142
  print(f"Seeds used: {seed_val}")