Spaces:
Running
on
A100
Running
on
A100
enforce upload audio duration limits
Browse files- acestep/handler.py +108 -21
acestep/handler.py
CHANGED
|
@@ -765,6 +765,9 @@ class AceStepHandler:
|
|
| 765 |
# Normalize to stereo 48kHz
|
| 766 |
audio = self._normalize_audio_to_stereo_48k(audio, sr)
|
| 767 |
|
|
|
|
|
|
|
|
|
|
| 768 |
return audio
|
| 769 |
except Exception as e:
|
| 770 |
logger.exception("[process_target_audio] Error processing target audio")
|
|
@@ -791,25 +794,48 @@ class AceStepHandler:
|
|
| 791 |
if len(code_ids) == 0:
|
| 792 |
return None
|
| 793 |
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
|
| 814 |
def _create_default_meta(self) -> str:
|
| 815 |
"""Create default metadata string."""
|
|
@@ -1136,6 +1162,48 @@ class AceStepHandler:
|
|
| 1136 |
|
| 1137 |
return audio
|
| 1138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1139 |
def _normalize_audio_code_hints(self, audio_code_hints: Optional[Union[str, List[str]]], batch_size: int) -> List[Optional[str]]:
|
| 1140 |
"""Normalize audio_code_hints to list of correct length."""
|
| 1141 |
if audio_code_hints is None:
|
|
@@ -1385,6 +1453,9 @@ class AceStepHandler:
|
|
| 1385 |
# Normalize to stereo 48kHz
|
| 1386 |
audio = self._normalize_audio_to_stereo_48k(audio, sr)
|
| 1387 |
|
|
|
|
|
|
|
|
|
|
| 1388 |
return audio
|
| 1389 |
|
| 1390 |
except Exception as e:
|
|
@@ -1698,12 +1769,19 @@ class AceStepHandler:
|
|
| 1698 |
# Normalize audio_code_hints to batch list
|
| 1699 |
audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size)
|
| 1700 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1701 |
for ii, refer_audio_list in enumerate(refer_audios):
|
|
|
|
|
|
|
| 1702 |
if isinstance(refer_audio_list, list):
|
| 1703 |
for idx, refer_audio in enumerate(refer_audio_list):
|
| 1704 |
-
|
|
|
|
| 1705 |
elif isinstance(refer_audio_list, torch.Tensor):
|
| 1706 |
-
refer_audios[ii] =
|
| 1707 |
|
| 1708 |
if vocal_languages is None:
|
| 1709 |
vocal_languages = self._create_fallback_vocal_languages(batch_size)
|
|
@@ -2860,6 +2938,15 @@ class AceStepHandler:
|
|
| 2860 |
except Exception as e:
|
| 2861 |
error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
|
| 2862 |
logger.exception("[generate_music] Generation failed")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2863 |
return {
|
| 2864 |
"audios": [],
|
| 2865 |
"status_message": error_msg,
|
|
|
|
| 765 |
# Normalize to stereo 48kHz
|
| 766 |
audio = self._normalize_audio_to_stereo_48k(audio, sr)
|
| 767 |
|
| 768 |
+
# Enforce duration limits (10-600 seconds)
|
| 769 |
+
audio = self._enforce_audio_duration_limits(audio)
|
| 770 |
+
|
| 771 |
return audio
|
| 772 |
except Exception as e:
|
| 773 |
logger.exception("[process_target_audio] Error processing target audio")
|
|
|
|
| 794 |
if len(code_ids) == 0:
|
| 795 |
return None
|
| 796 |
|
| 797 |
+
try:
|
| 798 |
+
with self._load_model_context("model"):
|
| 799 |
+
quantizer = self.model.tokenizer.quantizer
|
| 800 |
+
detokenizer = self.model.detokenizer
|
| 801 |
+
|
| 802 |
+
# Get codebook size for validation
|
| 803 |
+
codebook_size = getattr(quantizer, 'codebook_size', 65536)
|
| 804 |
+
if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0:
|
| 805 |
+
codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size)
|
| 806 |
+
|
| 807 |
+
# Validate code IDs are within valid range
|
| 808 |
+
invalid_codes = [c for c in code_ids if c < 0 or c >= codebook_size]
|
| 809 |
+
if invalid_codes:
|
| 810 |
+
logger.warning(f"[_decode_audio_codes_to_latents] Found {len(invalid_codes)} invalid codes out of range [0, {codebook_size}): {invalid_codes[:5]}...")
|
| 811 |
+
# Clamp invalid codes to valid range
|
| 812 |
+
code_ids = [max(0, min(c, codebook_size - 1)) for c in code_ids]
|
| 813 |
+
|
| 814 |
+
num_quantizers = getattr(quantizer, "num_quantizers", 1)
|
| 815 |
+
# Create indices tensor: [T_5Hz]
|
| 816 |
+
indices = torch.tensor(code_ids, device=self.device, dtype=torch.long) # [T_5Hz]
|
| 817 |
+
|
| 818 |
+
indices = indices.unsqueeze(0).unsqueeze(-1) # [1, T_5Hz, 1]
|
| 819 |
+
|
| 820 |
+
# Synchronize to catch any CUDA errors before proceeding
|
| 821 |
+
if torch.cuda.is_available():
|
| 822 |
+
torch.cuda.synchronize()
|
| 823 |
+
|
| 824 |
+
# Get quantized representation from indices
|
| 825 |
+
# The quantizer expects [batch, T_5Hz] format and handles quantizer dimension internally
|
| 826 |
+
quantized = quantizer.get_output_from_indices(indices)
|
| 827 |
+
if quantized.dtype != self.dtype:
|
| 828 |
+
quantized = quantized.to(self.dtype)
|
| 829 |
+
|
| 830 |
+
# Detokenize to 25Hz: [1, T_5Hz, dim] -> [1, T_25Hz, dim]
|
| 831 |
+
lm_hints_25hz = detokenizer(quantized)
|
| 832 |
+
return lm_hints_25hz
|
| 833 |
+
except Exception as e:
|
| 834 |
+
logger.exception(f"[_decode_audio_codes_to_latents] Error decoding audio codes: {e}")
|
| 835 |
+
# Clear CUDA error state
|
| 836 |
+
if torch.cuda.is_available():
|
| 837 |
+
torch.cuda.empty_cache()
|
| 838 |
+
return None
|
| 839 |
|
| 840 |
def _create_default_meta(self) -> str:
|
| 841 |
"""Create default metadata string."""
|
|
|
|
| 1162 |
|
| 1163 |
return audio
|
| 1164 |
|
| 1165 |
+
def _enforce_audio_duration_limits(
|
| 1166 |
+
self,
|
| 1167 |
+
audio: torch.Tensor,
|
| 1168 |
+
sample_rate: int = 48000,
|
| 1169 |
+
min_duration: float = 10.0,
|
| 1170 |
+
max_duration: float = 600.0
|
| 1171 |
+
) -> torch.Tensor:
|
| 1172 |
+
"""
|
| 1173 |
+
Enforce audio duration limits by truncating or repeating.
|
| 1174 |
+
|
| 1175 |
+
Args:
|
| 1176 |
+
audio: Audio tensor [channels, samples] at target sample rate
|
| 1177 |
+
sample_rate: Sample rate of the audio (default: 48000)
|
| 1178 |
+
min_duration: Minimum duration in seconds (default: 10.0)
|
| 1179 |
+
max_duration: Maximum duration in seconds (default: 600.0)
|
| 1180 |
+
|
| 1181 |
+
Returns:
|
| 1182 |
+
Audio tensor with enforced duration limits
|
| 1183 |
+
"""
|
| 1184 |
+
current_samples = audio.shape[-1]
|
| 1185 |
+
current_duration = current_samples / sample_rate
|
| 1186 |
+
|
| 1187 |
+
min_samples = int(min_duration * sample_rate)
|
| 1188 |
+
max_samples = int(max_duration * sample_rate)
|
| 1189 |
+
|
| 1190 |
+
# If audio is longer than max_duration, truncate
|
| 1191 |
+
if current_samples > max_samples:
|
| 1192 |
+
logger.info(f"[_enforce_audio_duration_limits] Truncating audio from {current_duration:.1f}s to {max_duration:.1f}s")
|
| 1193 |
+
audio = audio[..., :max_samples]
|
| 1194 |
+
|
| 1195 |
+
# If audio is shorter than min_duration, repeat to fill
|
| 1196 |
+
elif current_samples < min_samples:
|
| 1197 |
+
logger.info(f"[_enforce_audio_duration_limits] Repeating audio from {current_duration:.1f}s to reach {min_duration:.1f}s")
|
| 1198 |
+
# Calculate how many times to repeat
|
| 1199 |
+
repeat_times = int(math.ceil(min_samples / current_samples))
|
| 1200 |
+
# Repeat along the time dimension
|
| 1201 |
+
audio = audio.repeat(1, repeat_times)
|
| 1202 |
+
# Truncate to exactly min_samples
|
| 1203 |
+
audio = audio[..., :min_samples]
|
| 1204 |
+
|
| 1205 |
+
return audio
|
| 1206 |
+
|
| 1207 |
def _normalize_audio_code_hints(self, audio_code_hints: Optional[Union[str, List[str]]], batch_size: int) -> List[Optional[str]]:
|
| 1208 |
"""Normalize audio_code_hints to list of correct length."""
|
| 1209 |
if audio_code_hints is None:
|
|
|
|
| 1453 |
# Normalize to stereo 48kHz
|
| 1454 |
audio = self._normalize_audio_to_stereo_48k(audio, sr)
|
| 1455 |
|
| 1456 |
+
# Enforce duration limits (10-600 seconds)
|
| 1457 |
+
audio = self._enforce_audio_duration_limits(audio)
|
| 1458 |
+
|
| 1459 |
return audio
|
| 1460 |
|
| 1461 |
except Exception as e:
|
|
|
|
| 1769 |
# Normalize audio_code_hints to batch list
|
| 1770 |
audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size)
|
| 1771 |
|
| 1772 |
+
# Synchronize CUDA to catch any pending errors from previous operations
|
| 1773 |
+
if torch.cuda.is_available():
|
| 1774 |
+
torch.cuda.synchronize()
|
| 1775 |
+
|
| 1776 |
for ii, refer_audio_list in enumerate(refer_audios):
|
| 1777 |
+
if refer_audio_list is None:
|
| 1778 |
+
continue
|
| 1779 |
if isinstance(refer_audio_list, list):
|
| 1780 |
for idx, refer_audio in enumerate(refer_audio_list):
|
| 1781 |
+
if refer_audio is not None and isinstance(refer_audio, torch.Tensor):
|
| 1782 |
+
refer_audio_list[idx] = refer_audio.to(self.device).to(torch.bfloat16)
|
| 1783 |
elif isinstance(refer_audio_list, torch.Tensor):
|
| 1784 |
+
refer_audios[ii] = refer_audio_list.to(self.device)
|
| 1785 |
|
| 1786 |
if vocal_languages is None:
|
| 1787 |
vocal_languages = self._create_fallback_vocal_languages(batch_size)
|
|
|
|
| 2938 |
except Exception as e:
|
| 2939 |
error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
|
| 2940 |
logger.exception("[generate_music] Generation failed")
|
| 2941 |
+
|
| 2942 |
+
# Clean up CUDA state after any error (especially important for CUDA errors)
|
| 2943 |
+
if torch.cuda.is_available():
|
| 2944 |
+
try:
|
| 2945 |
+
torch.cuda.synchronize()
|
| 2946 |
+
except Exception:
|
| 2947 |
+
pass # Ignore sync errors during cleanup
|
| 2948 |
+
torch.cuda.empty_cache()
|
| 2949 |
+
|
| 2950 |
return {
|
| 2951 |
"audios": [],
|
| 2952 |
"status_message": error_msg,
|