ChuxiJ commited on
Commit
5f3faee
·
1 Parent(s): 76de6b9

enforce upload audio duration limits

Browse files
Files changed (1) hide show
  1. acestep/handler.py +108 -21
acestep/handler.py CHANGED
@@ -765,6 +765,9 @@ class AceStepHandler:
765
  # Normalize to stereo 48kHz
766
  audio = self._normalize_audio_to_stereo_48k(audio, sr)
767
 
 
 
 
768
  return audio
769
  except Exception as e:
770
  logger.exception("[process_target_audio] Error processing target audio")
@@ -791,25 +794,48 @@ class AceStepHandler:
791
  if len(code_ids) == 0:
792
  return None
793
 
794
- with self._load_model_context("model"):
795
- quantizer = self.model.tokenizer.quantizer
796
- detokenizer = self.model.detokenizer
797
-
798
- num_quantizers = getattr(quantizer, "num_quantizers", 1)
799
- # Create indices tensor: [T_5Hz]
800
- indices = torch.tensor(code_ids, device=self.device, dtype=torch.long) # [T_5Hz]
801
-
802
- indices = indices.unsqueeze(0).unsqueeze(-1) # [1, T_5Hz, 1]
803
-
804
- # Get quantized representation from indices
805
- # The quantizer expects [batch, T_5Hz] format and handles quantizer dimension internally
806
- quantized = quantizer.get_output_from_indices(indices)
807
- if quantized.dtype != self.dtype:
808
- quantized = quantized.to(self.dtype)
809
-
810
- # Detokenize to 25Hz: [1, T_5Hz, dim] -> [1, T_25Hz, dim]
811
- lm_hints_25hz = detokenizer(quantized)
812
- return lm_hints_25hz
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813
 
814
  def _create_default_meta(self) -> str:
815
  """Create default metadata string."""
@@ -1136,6 +1162,48 @@ class AceStepHandler:
1136
 
1137
  return audio
1138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1139
  def _normalize_audio_code_hints(self, audio_code_hints: Optional[Union[str, List[str]]], batch_size: int) -> List[Optional[str]]:
1140
  """Normalize audio_code_hints to list of correct length."""
1141
  if audio_code_hints is None:
@@ -1385,6 +1453,9 @@ class AceStepHandler:
1385
  # Normalize to stereo 48kHz
1386
  audio = self._normalize_audio_to_stereo_48k(audio, sr)
1387
 
 
 
 
1388
  return audio
1389
 
1390
  except Exception as e:
@@ -1698,12 +1769,19 @@ class AceStepHandler:
1698
  # Normalize audio_code_hints to batch list
1699
  audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size)
1700
 
 
 
 
 
1701
  for ii, refer_audio_list in enumerate(refer_audios):
 
 
1702
  if isinstance(refer_audio_list, list):
1703
  for idx, refer_audio in enumerate(refer_audio_list):
1704
- refer_audio_list[idx] = refer_audio_list[idx].to(self.device).to(torch.bfloat16)
 
1705
  elif isinstance(refer_audio_list, torch.Tensor):
1706
- refer_audios[ii] = refer_audios[ii].to(self.device)
1707
 
1708
  if vocal_languages is None:
1709
  vocal_languages = self._create_fallback_vocal_languages(batch_size)
@@ -2860,6 +2938,15 @@ class AceStepHandler:
2860
  except Exception as e:
2861
  error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
2862
  logger.exception("[generate_music] Generation failed")
 
 
 
 
 
 
 
 
 
2863
  return {
2864
  "audios": [],
2865
  "status_message": error_msg,
 
765
  # Normalize to stereo 48kHz
766
  audio = self._normalize_audio_to_stereo_48k(audio, sr)
767
 
768
+ # Enforce duration limits (10-600 seconds)
769
+ audio = self._enforce_audio_duration_limits(audio)
770
+
771
  return audio
772
  except Exception as e:
773
  logger.exception("[process_target_audio] Error processing target audio")
 
794
  if len(code_ids) == 0:
795
  return None
796
 
797
+ try:
798
+ with self._load_model_context("model"):
799
+ quantizer = self.model.tokenizer.quantizer
800
+ detokenizer = self.model.detokenizer
801
+
802
+ # Get codebook size for validation
803
+ codebook_size = getattr(quantizer, 'codebook_size', 65536)
804
+ if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0:
805
+ codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size)
806
+
807
+ # Validate code IDs are within valid range
808
+ invalid_codes = [c for c in code_ids if c < 0 or c >= codebook_size]
809
+ if invalid_codes:
810
+ logger.warning(f"[_decode_audio_codes_to_latents] Found {len(invalid_codes)} invalid codes out of range [0, {codebook_size}): {invalid_codes[:5]}...")
811
+ # Clamp invalid codes to valid range
812
+ code_ids = [max(0, min(c, codebook_size - 1)) for c in code_ids]
813
+
814
+ num_quantizers = getattr(quantizer, "num_quantizers", 1)
815
+ # Create indices tensor: [T_5Hz]
816
+ indices = torch.tensor(code_ids, device=self.device, dtype=torch.long) # [T_5Hz]
817
+
818
+ indices = indices.unsqueeze(0).unsqueeze(-1) # [1, T_5Hz, 1]
819
+
820
+ # Synchronize to catch any CUDA errors before proceeding
821
+ if torch.cuda.is_available():
822
+ torch.cuda.synchronize()
823
+
824
+ # Get quantized representation from indices
825
+ # The quantizer expects [batch, T_5Hz] format and handles quantizer dimension internally
826
+ quantized = quantizer.get_output_from_indices(indices)
827
+ if quantized.dtype != self.dtype:
828
+ quantized = quantized.to(self.dtype)
829
+
830
+ # Detokenize to 25Hz: [1, T_5Hz, dim] -> [1, T_25Hz, dim]
831
+ lm_hints_25hz = detokenizer(quantized)
832
+ return lm_hints_25hz
833
+ except Exception as e:
834
+ logger.exception(f"[_decode_audio_codes_to_latents] Error decoding audio codes: {e}")
835
+ # Clear CUDA error state
836
+ if torch.cuda.is_available():
837
+ torch.cuda.empty_cache()
838
+ return None
839
 
840
  def _create_default_meta(self) -> str:
841
  """Create default metadata string."""
 
1162
 
1163
  return audio
1164
 
1165
+ def _enforce_audio_duration_limits(
1166
+ self,
1167
+ audio: torch.Tensor,
1168
+ sample_rate: int = 48000,
1169
+ min_duration: float = 10.0,
1170
+ max_duration: float = 600.0
1171
+ ) -> torch.Tensor:
1172
+ """
1173
+ Enforce audio duration limits by truncating or repeating.
1174
+
1175
+ Args:
1176
+ audio: Audio tensor [channels, samples] at target sample rate
1177
+ sample_rate: Sample rate of the audio (default: 48000)
1178
+ min_duration: Minimum duration in seconds (default: 10.0)
1179
+ max_duration: Maximum duration in seconds (default: 600.0)
1180
+
1181
+ Returns:
1182
+ Audio tensor with enforced duration limits
1183
+ """
1184
+ current_samples = audio.shape[-1]
1185
+ current_duration = current_samples / sample_rate
1186
+
1187
+ min_samples = int(min_duration * sample_rate)
1188
+ max_samples = int(max_duration * sample_rate)
1189
+
1190
+ # If audio is longer than max_duration, truncate
1191
+ if current_samples > max_samples:
1192
+ logger.info(f"[_enforce_audio_duration_limits] Truncating audio from {current_duration:.1f}s to {max_duration:.1f}s")
1193
+ audio = audio[..., :max_samples]
1194
+
1195
+ # If audio is shorter than min_duration, repeat to fill
1196
+ elif current_samples < min_samples:
1197
+ logger.info(f"[_enforce_audio_duration_limits] Repeating audio from {current_duration:.1f}s to reach {min_duration:.1f}s")
1198
+ # Calculate how many times to repeat
1199
+ repeat_times = int(math.ceil(min_samples / current_samples))
1200
+ # Repeat along the time dimension
1201
+ audio = audio.repeat(1, repeat_times)
1202
+ # Truncate to exactly min_samples
1203
+ audio = audio[..., :min_samples]
1204
+
1205
+ return audio
1206
+
1207
  def _normalize_audio_code_hints(self, audio_code_hints: Optional[Union[str, List[str]]], batch_size: int) -> List[Optional[str]]:
1208
  """Normalize audio_code_hints to list of correct length."""
1209
  if audio_code_hints is None:
 
1453
  # Normalize to stereo 48kHz
1454
  audio = self._normalize_audio_to_stereo_48k(audio, sr)
1455
 
1456
+ # Enforce duration limits (10-600 seconds)
1457
+ audio = self._enforce_audio_duration_limits(audio)
1458
+
1459
  return audio
1460
 
1461
  except Exception as e:
 
1769
  # Normalize audio_code_hints to batch list
1770
  audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size)
1771
 
1772
+ # Synchronize CUDA to catch any pending errors from previous operations
1773
+ if torch.cuda.is_available():
1774
+ torch.cuda.synchronize()
1775
+
1776
  for ii, refer_audio_list in enumerate(refer_audios):
1777
+ if refer_audio_list is None:
1778
+ continue
1779
  if isinstance(refer_audio_list, list):
1780
  for idx, refer_audio in enumerate(refer_audio_list):
1781
+ if refer_audio is not None and isinstance(refer_audio, torch.Tensor):
1782
+ refer_audio_list[idx] = refer_audio.to(self.device).to(torch.bfloat16)
1783
  elif isinstance(refer_audio_list, torch.Tensor):
1784
+ refer_audios[ii] = refer_audio_list.to(self.device)
1785
 
1786
  if vocal_languages is None:
1787
  vocal_languages = self._create_fallback_vocal_languages(batch_size)
 
2938
  except Exception as e:
2939
  error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
2940
  logger.exception("[generate_music] Generation failed")
2941
+
2942
+ # Clean up CUDA state after any error (especially important for CUDA errors)
2943
+ if torch.cuda.is_available():
2944
+ try:
2945
+ torch.cuda.synchronize()
2946
+ except Exception:
2947
+ pass # Ignore sync errors during cleanup
2948
+ torch.cuda.empty_cache()
2949
+
2950
  return {
2951
  "audios": [],
2952
  "status_message": error_msg,