Spaces:
Running
on
A100
Running
on
A100
Commit
·
497ca57
1
Parent(s):
bc9e2c6
tiled decoding
Browse files- acestep/handler.py +86 -2
- test.py +20 -1
acestep/handler.py
CHANGED
|
@@ -14,6 +14,7 @@ import torch
|
|
| 14 |
import torchaudio
|
| 15 |
import soundfile as sf
|
| 16 |
import time
|
|
|
|
| 17 |
from loguru import logger
|
| 18 |
import warnings
|
| 19 |
|
|
@@ -1134,6 +1135,7 @@ class AceStepHandler:
|
|
| 1134 |
code_hint = audio_code_hints[i]
|
| 1135 |
# Prefer decoding from provided audio codes
|
| 1136 |
if code_hint:
|
|
|
|
| 1137 |
decoded_latents = self._decode_audio_codes_to_latents(code_hint)
|
| 1138 |
if decoded_latents is not None:
|
| 1139 |
decoded_latents = decoded_latents.squeeze(0)
|
|
@@ -1150,6 +1152,7 @@ class AceStepHandler:
|
|
| 1150 |
target_latent = self.silence_latent[0, :expected_latent_length, :]
|
| 1151 |
else:
|
| 1152 |
# Ensure input is in VAE's dtype
|
|
|
|
| 1153 |
vae_input = current_wav.to(self.device).to(self.vae.dtype)
|
| 1154 |
target_latent = self.vae.encode(vae_input).latent_dist.sample()
|
| 1155 |
# Cast back to model dtype
|
|
@@ -1319,6 +1322,7 @@ class AceStepHandler:
|
|
| 1319 |
for i in range(batch_size):
|
| 1320 |
if audio_code_hints[i] is not None:
|
| 1321 |
# Decode audio codes to 25Hz latents
|
|
|
|
| 1322 |
hints = self._decode_audio_codes_to_latents(audio_code_hints[i])
|
| 1323 |
if hints is not None:
|
| 1324 |
# Pad or crop to match max_latent_length
|
|
@@ -1563,7 +1567,9 @@ class AceStepHandler:
|
|
| 1563 |
lyric_attention_mask = batch["lyric_attention_masks"]
|
| 1564 |
text_inputs = batch["text_inputs"]
|
| 1565 |
|
|
|
|
| 1566 |
text_hidden_states = self.infer_text_embeddings(text_token_idss)
|
|
|
|
| 1567 |
lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss)
|
| 1568 |
|
| 1569 |
is_covers = batch["is_covers"]
|
|
@@ -1576,6 +1582,7 @@ class AceStepHandler:
|
|
| 1576 |
non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None)
|
| 1577 |
non_cover_text_hidden_states = None
|
| 1578 |
if non_cover_text_input_ids is not None:
|
|
|
|
| 1579 |
non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids)
|
| 1580 |
|
| 1581 |
return (
|
|
@@ -1803,10 +1810,78 @@ class AceStepHandler:
|
|
| 1803 |
"cfg_interval_start": cfg_interval_start,
|
| 1804 |
"cfg_interval_end": cfg_interval_end,
|
| 1805 |
}
|
| 1806 |
-
|
| 1807 |
outputs = self.model.generate_audio(**generate_kwargs)
|
| 1808 |
return outputs
|
| 1809 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1810 |
def generate_music(
|
| 1811 |
self,
|
| 1812 |
captions: str,
|
|
@@ -1834,6 +1909,7 @@ class AceStepHandler:
|
|
| 1834 |
cfg_interval_end: float = 1.0,
|
| 1835 |
audio_format: str = "mp3",
|
| 1836 |
lm_temperature: float = 0.6,
|
|
|
|
| 1837 |
progress=None
|
| 1838 |
) -> Tuple[Optional[str], Optional[str], List[str], str, str, str, str, str, Optional[Any], str, str, Optional[Any]]:
|
| 1839 |
"""
|
|
@@ -1887,6 +1963,7 @@ class AceStepHandler:
|
|
| 1887 |
# 1. Process reference audio
|
| 1888 |
refer_audios = None
|
| 1889 |
if reference_audio is not None:
|
|
|
|
| 1890 |
processed_ref_audio = self.process_reference_audio(reference_audio)
|
| 1891 |
if processed_ref_audio is not None:
|
| 1892 |
# Convert to the format expected by the service: List[List[torch.Tensor]]
|
|
@@ -1898,6 +1975,7 @@ class AceStepHandler:
|
|
| 1898 |
# 2. Process source audio
|
| 1899 |
processed_src_audio = None
|
| 1900 |
if src_audio is not None:
|
|
|
|
| 1901 |
processed_src_audio = self.process_src_audio(src_audio)
|
| 1902 |
|
| 1903 |
# 3. Prepare batch data
|
|
@@ -1975,7 +2053,13 @@ class AceStepHandler:
|
|
| 1975 |
pred_latents_for_decode = pred_latents.transpose(1, 2)
|
| 1976 |
# Ensure input is in VAE's dtype
|
| 1977 |
pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype)
|
| 1978 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1979 |
# Cast output to float32 for audio processing/saving
|
| 1980 |
pred_wavs = pred_wavs.to(torch.float32)
|
| 1981 |
end_time = time.time()
|
|
|
|
| 14 |
import torchaudio
|
| 15 |
import soundfile as sf
|
| 16 |
import time
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
from loguru import logger
|
| 19 |
import warnings
|
| 20 |
|
|
|
|
| 1135 |
code_hint = audio_code_hints[i]
|
| 1136 |
# Prefer decoding from provided audio codes
|
| 1137 |
if code_hint:
|
| 1138 |
+
print(f"[generate_music] Decoding audio codes for item {i}...")
|
| 1139 |
decoded_latents = self._decode_audio_codes_to_latents(code_hint)
|
| 1140 |
if decoded_latents is not None:
|
| 1141 |
decoded_latents = decoded_latents.squeeze(0)
|
|
|
|
| 1152 |
target_latent = self.silence_latent[0, :expected_latent_length, :]
|
| 1153 |
else:
|
| 1154 |
# Ensure input is in VAE's dtype
|
| 1155 |
+
print(f"[generate_music] Encoding target audio to latents for item {i}...")
|
| 1156 |
vae_input = current_wav.to(self.device).to(self.vae.dtype)
|
| 1157 |
target_latent = self.vae.encode(vae_input).latent_dist.sample()
|
| 1158 |
# Cast back to model dtype
|
|
|
|
| 1322 |
for i in range(batch_size):
|
| 1323 |
if audio_code_hints[i] is not None:
|
| 1324 |
# Decode audio codes to 25Hz latents
|
| 1325 |
+
print(f"[generate_music] Decoding audio codes for LM hints for item {i}...")
|
| 1326 |
hints = self._decode_audio_codes_to_latents(audio_code_hints[i])
|
| 1327 |
if hints is not None:
|
| 1328 |
# Pad or crop to match max_latent_length
|
|
|
|
| 1567 |
lyric_attention_mask = batch["lyric_attention_masks"]
|
| 1568 |
text_inputs = batch["text_inputs"]
|
| 1569 |
|
| 1570 |
+
print("[preprocess_batch] Inferring prompt embeddings...")
|
| 1571 |
text_hidden_states = self.infer_text_embeddings(text_token_idss)
|
| 1572 |
+
print("[preprocess_batch] Inferring lyric embeddings...")
|
| 1573 |
lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss)
|
| 1574 |
|
| 1575 |
is_covers = batch["is_covers"]
|
|
|
|
| 1582 |
non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None)
|
| 1583 |
non_cover_text_hidden_states = None
|
| 1584 |
if non_cover_text_input_ids is not None:
|
| 1585 |
+
print("[preprocess_batch] Inferring non-cover text embeddings...")
|
| 1586 |
non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids)
|
| 1587 |
|
| 1588 |
return (
|
|
|
|
| 1810 |
"cfg_interval_start": cfg_interval_start,
|
| 1811 |
"cfg_interval_end": cfg_interval_end,
|
| 1812 |
}
|
| 1813 |
+
print("[service_generate] Generating audio...")
|
| 1814 |
outputs = self.model.generate_audio(**generate_kwargs)
|
| 1815 |
return outputs
|
| 1816 |
|
| 1817 |
+
def tiled_decode(self, latents, chunk_size=512, overlap=64):
|
| 1818 |
+
"""
|
| 1819 |
+
Decode latents using tiling to reduce VRAM usage.
|
| 1820 |
+
Uses overlap-discard strategy to avoid boundary artifacts.
|
| 1821 |
+
|
| 1822 |
+
Args:
|
| 1823 |
+
latents: [Batch, Channels, Length]
|
| 1824 |
+
chunk_size: Size of latent chunk to process at once
|
| 1825 |
+
overlap: Overlap size in latent frames
|
| 1826 |
+
"""
|
| 1827 |
+
B, C, T = latents.shape
|
| 1828 |
+
|
| 1829 |
+
# If short enough, decode directly
|
| 1830 |
+
if T <= chunk_size:
|
| 1831 |
+
return self.vae.decode(latents).sample
|
| 1832 |
+
|
| 1833 |
+
# Calculate stride (core size)
|
| 1834 |
+
stride = chunk_size - 2 * overlap
|
| 1835 |
+
if stride <= 0:
|
| 1836 |
+
raise ValueError(f"chunk_size {chunk_size} must be > 2 * overlap {overlap}")
|
| 1837 |
+
|
| 1838 |
+
decoded_audio_list = []
|
| 1839 |
+
|
| 1840 |
+
# We need to determine upsample factor to trim audio correctly
|
| 1841 |
+
upsample_factor = None
|
| 1842 |
+
|
| 1843 |
+
num_steps = math.ceil(T / stride)
|
| 1844 |
+
|
| 1845 |
+
for i in tqdm(range(num_steps), desc="Decoding audio chunks"):
|
| 1846 |
+
# Core range in latents
|
| 1847 |
+
core_start = i * stride
|
| 1848 |
+
core_end = min(core_start + stride, T)
|
| 1849 |
+
|
| 1850 |
+
# Window range (with overlap)
|
| 1851 |
+
win_start = max(0, core_start - overlap)
|
| 1852 |
+
win_end = min(T, core_end + overlap)
|
| 1853 |
+
|
| 1854 |
+
# Extract chunk
|
| 1855 |
+
latent_chunk = latents[:, :, win_start:win_end]
|
| 1856 |
+
|
| 1857 |
+
# Decode
|
| 1858 |
+
# [Batch, Channels, AudioSamples]
|
| 1859 |
+
audio_chunk = self.vae.decode(latent_chunk).sample
|
| 1860 |
+
|
| 1861 |
+
# Determine upsample factor from the first chunk
|
| 1862 |
+
if upsample_factor is None:
|
| 1863 |
+
upsample_factor = audio_chunk.shape[-1] / latent_chunk.shape[-1]
|
| 1864 |
+
|
| 1865 |
+
# Calculate trim amounts in audio samples
|
| 1866 |
+
# How much overlap was added at the start?
|
| 1867 |
+
added_start = core_start - win_start # latent frames
|
| 1868 |
+
trim_start = int(round(added_start * upsample_factor))
|
| 1869 |
+
|
| 1870 |
+
# How much overlap was added at the end?
|
| 1871 |
+
added_end = win_end - core_end # latent frames
|
| 1872 |
+
trim_end = int(round(added_end * upsample_factor))
|
| 1873 |
+
|
| 1874 |
+
# Trim audio
|
| 1875 |
+
audio_len = audio_chunk.shape[-1]
|
| 1876 |
+
end_idx = audio_len - trim_end
|
| 1877 |
+
|
| 1878 |
+
audio_core = audio_chunk[:, :, trim_start:end_idx]
|
| 1879 |
+
decoded_audio_list.append(audio_core)
|
| 1880 |
+
|
| 1881 |
+
# Concatenate
|
| 1882 |
+
final_audio = torch.cat(decoded_audio_list, dim=-1)
|
| 1883 |
+
return final_audio
|
| 1884 |
+
|
| 1885 |
def generate_music(
|
| 1886 |
self,
|
| 1887 |
captions: str,
|
|
|
|
| 1909 |
cfg_interval_end: float = 1.0,
|
| 1910 |
audio_format: str = "mp3",
|
| 1911 |
lm_temperature: float = 0.6,
|
| 1912 |
+
use_tiled_decode: bool = True,
|
| 1913 |
progress=None
|
| 1914 |
) -> Tuple[Optional[str], Optional[str], List[str], str, str, str, str, str, Optional[Any], str, str, Optional[Any]]:
|
| 1915 |
"""
|
|
|
|
| 1963 |
# 1. Process reference audio
|
| 1964 |
refer_audios = None
|
| 1965 |
if reference_audio is not None:
|
| 1966 |
+
print("[generate_music] Processing reference audio...")
|
| 1967 |
processed_ref_audio = self.process_reference_audio(reference_audio)
|
| 1968 |
if processed_ref_audio is not None:
|
| 1969 |
# Convert to the format expected by the service: List[List[torch.Tensor]]
|
|
|
|
| 1975 |
# 2. Process source audio
|
| 1976 |
processed_src_audio = None
|
| 1977 |
if src_audio is not None:
|
| 1978 |
+
print("[generate_music] Processing source audio...")
|
| 1979 |
processed_src_audio = self.process_src_audio(src_audio)
|
| 1980 |
|
| 1981 |
# 3. Prepare batch data
|
|
|
|
| 2053 |
pred_latents_for_decode = pred_latents.transpose(1, 2)
|
| 2054 |
# Ensure input is in VAE's dtype
|
| 2055 |
pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype)
|
| 2056 |
+
|
| 2057 |
+
if use_tiled_decode:
|
| 2058 |
+
print("[generate_music] Using tiled VAE decode to reduce VRAM usage...")
|
| 2059 |
+
pred_wavs = self.tiled_decode(pred_latents_for_decode) # [batch, channels, samples]
|
| 2060 |
+
else:
|
| 2061 |
+
pred_wavs = self.vae.decode(pred_latents_for_decode).sample
|
| 2062 |
+
|
| 2063 |
# Cast output to float32 for audio processing/saving
|
| 2064 |
pred_wavs = pred_wavs.to(torch.float32)
|
| 2065 |
end_time = time.time()
|
test.py
CHANGED
|
@@ -2,6 +2,7 @@ import os
|
|
| 2 |
import sys
|
| 3 |
import torch
|
| 4 |
import shutil
|
|
|
|
| 5 |
from acestep.handler import AceStepHandler
|
| 6 |
|
| 7 |
|
|
@@ -91,6 +92,12 @@ def main():
|
|
| 91 |
seeds = "320145306, 1514681811"
|
| 92 |
|
| 93 |
print("Starting generation...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
# Call generate_music
|
| 96 |
results = handler.generate_music(
|
|
@@ -109,7 +116,8 @@ def main():
|
|
| 109 |
task_type="text2music",
|
| 110 |
cfg_interval_start=0.0,
|
| 111 |
cfg_interval_end=0.95,
|
| 112 |
-
audio_format="wav"
|
|
|
|
| 113 |
)
|
| 114 |
|
| 115 |
# Unpack results
|
|
@@ -118,6 +126,17 @@ def main():
|
|
| 118 |
align_score2, align_text2, align_plot2) = results
|
| 119 |
|
| 120 |
print("\nGeneration Complete!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
print(f"Status: {status_msg}")
|
| 122 |
print(f"Info: {info}")
|
| 123 |
print(f"Seeds used: {seed_val}")
|
|
|
|
| 2 |
import sys
|
| 3 |
import torch
|
| 4 |
import shutil
|
| 5 |
+
import resource
|
| 6 |
from acestep.handler import AceStepHandler
|
| 7 |
|
| 8 |
|
|
|
|
| 92 |
seeds = "320145306, 1514681811"
|
| 93 |
|
| 94 |
print("Starting generation...")
|
| 95 |
+
|
| 96 |
+
# Reset peak memory stats
|
| 97 |
+
if hasattr(torch, 'xpu') and torch.xpu.is_available():
|
| 98 |
+
torch.xpu.reset_peak_memory_stats()
|
| 99 |
+
elif torch.cuda.is_available():
|
| 100 |
+
torch.cuda.reset_peak_memory_stats()
|
| 101 |
|
| 102 |
# Call generate_music
|
| 103 |
results = handler.generate_music(
|
|
|
|
| 116 |
task_type="text2music",
|
| 117 |
cfg_interval_start=0.0,
|
| 118 |
cfg_interval_end=0.95,
|
| 119 |
+
audio_format="wav",
|
| 120 |
+
use_tiled_decode=True,
|
| 121 |
)
|
| 122 |
|
| 123 |
# Unpack results
|
|
|
|
| 126 |
align_score2, align_text2, align_plot2) = results
|
| 127 |
|
| 128 |
print("\nGeneration Complete!")
|
| 129 |
+
|
| 130 |
+
# Print memory stats
|
| 131 |
+
if hasattr(torch, 'xpu') and torch.xpu.is_available():
|
| 132 |
+
peak_vram = torch.xpu.max_memory_allocated() / (1024 ** 3)
|
| 133 |
+
print(f"Peak VRAM usage: {peak_vram:.2f} GB")
|
| 134 |
+
elif torch.cuda.is_available():
|
| 135 |
+
peak_vram = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
| 136 |
+
print(f"Peak VRAM usage: {peak_vram:.2f} GB")
|
| 137 |
+
|
| 138 |
+
peak_ram = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024 ** 2)
|
| 139 |
+
print(f"Peak RAM usage: {peak_ram:.2f} GB")
|
| 140 |
print(f"Status: {status_msg}")
|
| 141 |
print(f"Info: {info}")
|
| 142 |
print(f"Seeds used: {seed_val}")
|