Ace-Step-v1.5 / acestep /audio_utils.py
ChuxiJ's picture
feat: save audio add soundfile fallback
140aee6
raw
history blame
10.7 kB
"""
Audio saving and transcoding utility module
Independent audio file operations outside of handler, supporting:
- Save audio tensor/numpy to files (default FLAC format, fast)
- Format conversion (FLAC/WAV/MP3)
- Batch processing
"""
import os
import hashlib
import json
from pathlib import Path
from typing import Union, Optional, List, Tuple
import torch
import numpy as np
import torchaudio
from loguru import logger
class AudioSaver:
"""Audio saving and transcoding utility class"""
def __init__(self, default_format: str = "flac"):
"""
Initialize audio saver
Args:
default_format: Default save format ('flac', 'wav', 'mp3')
"""
self.default_format = default_format.lower()
if self.default_format not in ["flac", "wav", "mp3"]:
logger.warning(f"Unsupported format {default_format}, using 'flac'")
self.default_format = "flac"
def save_audio(
self,
audio_data: Union[torch.Tensor, np.ndarray],
output_path: Union[str, Path],
sample_rate: int = 48000,
format: Optional[str] = None,
channels_first: bool = True,
) -> str:
"""
Save audio data to file
Args:
audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
output_path: Output file path (extension can be omitted)
sample_rate: Sample rate
format: Audio format ('flac', 'wav', 'mp3'), defaults to default_format
channels_first: If True, tensor format is [channels, samples], else [samples, channels]
Returns:
Actual saved file path
"""
format = (format or self.default_format).lower()
if format not in ["flac", "wav", "mp3"]:
logger.warning(f"Unsupported format {format}, using {self.default_format}")
format = self.default_format
# Ensure output path has correct extension
output_path = Path(output_path)
if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
output_path = output_path.with_suffix(f'.{format}')
# Convert to torch tensor
if isinstance(audio_data, np.ndarray):
if channels_first:
# numpy [samples, channels] -> tensor [channels, samples]
audio_tensor = torch.from_numpy(audio_data.T).float()
else:
# numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples]
audio_tensor = torch.from_numpy(audio_data).float()
if audio_tensor.dim() == 2 and audio_tensor.shape[0] < audio_tensor.shape[1]:
audio_tensor = audio_tensor.T
else:
# torch tensor
audio_tensor = audio_data.cpu().float()
if not channels_first and audio_tensor.dim() == 2:
# [samples, channels] -> [channels, samples]
if audio_tensor.shape[0] > audio_tensor.shape[1]:
audio_tensor = audio_tensor.T
# Ensure memory is contiguous
audio_tensor = audio_tensor.contiguous()
# Select backend and save
try:
if format == "mp3":
# MP3 uses ffmpeg backend
torchaudio.save(
str(output_path),
audio_tensor,
sample_rate,
channels_first=True,
backend='ffmpeg',
)
elif format in ["flac", "wav"]:
# FLAC and WAV use soundfile backend (fastest)
torchaudio.save(
str(output_path),
audio_tensor,
sample_rate,
channels_first=True,
backend='soundfile',
)
else:
# Other formats use default backend
torchaudio.save(
str(output_path),
audio_tensor,
sample_rate,
channels_first=True,
)
logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
return str(output_path)
except Exception as e:
try:
import soundfile as sf
audio_np = audio_tensor.transpose(0, 1).numpy() # -> [samples, channels]
sf.write(str(output_path), audio_np, sample_rate, format=format.upper())
logger.debug(f"[AudioSaver] Fallback soundfile Saved audio to {output_path} ({format}, {sample_rate}Hz)")
return str(output_path)
except Exception as e:
logger.error(f"[AudioSaver] Failed to save audio: {e}")
raise
def convert_audio(
self,
input_path: Union[str, Path],
output_path: Union[str, Path],
output_format: str,
remove_input: bool = False,
) -> str:
"""
Convert audio format
Args:
input_path: Input audio file path
output_path: Output audio file path
output_format: Target format ('flac', 'wav', 'mp3')
remove_input: Whether to delete input file
Returns:
Output file path
"""
input_path = Path(input_path)
output_path = Path(output_path)
if not input_path.exists():
raise FileNotFoundError(f"Input file not found: {input_path}")
# Load audio
audio_tensor, sample_rate = torchaudio.load(str(input_path))
# Save as new format
output_path = self.save_audio(
audio_tensor,
output_path,
sample_rate=sample_rate,
format=output_format,
channels_first=True
)
# Delete input file if needed
if remove_input:
input_path.unlink()
logger.debug(f"[AudioSaver] Removed input file: {input_path}")
return output_path
def save_batch(
self,
audio_batch: Union[List[torch.Tensor], torch.Tensor],
output_dir: Union[str, Path],
file_prefix: str = "audio",
sample_rate: int = 48000,
format: Optional[str] = None,
channels_first: bool = True,
) -> List[str]:
"""
Save audio batch
Args:
audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
output_dir: Output directory
file_prefix: File prefix
sample_rate: Sample rate
format: Audio format
channels_first: Tensor format flag
Returns:
List of saved file paths
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Process batch
if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
# [batch, channels, samples]
audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
elif isinstance(audio_batch, list):
audio_list = audio_batch
else:
audio_list = [audio_batch]
saved_paths = []
for i, audio in enumerate(audio_list):
output_path = output_dir / f"{file_prefix}_{i:04d}"
saved_path = self.save_audio(
audio,
output_path,
sample_rate=sample_rate,
format=format,
channels_first=channels_first
)
saved_paths.append(saved_path)
return saved_paths
def get_audio_file_hash(audio_file) -> str:
"""
Get hash identifier for an audio file.
Args:
audio_file: Path to audio file (str) or file-like object
Returns:
Hash string or empty string
"""
if audio_file is None:
return ""
try:
if isinstance(audio_file, str):
if os.path.exists(audio_file):
with open(audio_file, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()
return hashlib.md5(audio_file.encode('utf-8')).hexdigest()
elif hasattr(audio_file, 'name'):
return hashlib.md5(str(audio_file.name).encode('utf-8')).hexdigest()
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
except Exception:
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
def generate_uuid_from_params(params_dict) -> str:
"""
Generate deterministic UUID from generation parameters.
Same parameters will always generate the same UUID.
Args:
params_dict: Dictionary of parameters
Returns:
UUID string
"""
params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
hash_obj = hashlib.sha256(params_json.encode('utf-8'))
hash_hex = hash_obj.hexdigest()
uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
return uuid_str
def generate_uuid_from_audio_data(
audio_data: Union[torch.Tensor, np.ndarray],
seed: Optional[int] = None
) -> str:
"""
Generate UUID from audio data (for caching/deduplication)
Args:
audio_data: Audio data
seed: Optional seed value
Returns:
UUID string
"""
if isinstance(audio_data, torch.Tensor):
# Convert to numpy and calculate hash
audio_np = audio_data.cpu().numpy()
else:
audio_np = audio_data
# Calculate data hash
data_hash = hashlib.md5(audio_np.tobytes()).hexdigest()
if seed is not None:
combined = f"{data_hash}_{seed}"
return hashlib.md5(combined.encode()).hexdigest()
return data_hash
# Global default instance
_default_saver = AudioSaver(default_format="flac")
def save_audio(
audio_data: Union[torch.Tensor, np.ndarray],
output_path: Union[str, Path],
sample_rate: int = 48000,
format: Optional[str] = None,
channels_first: bool = True,
) -> str:
"""
Convenience function: save audio (using default configuration)
Args:
audio_data: Audio data
output_path: Output path
sample_rate: Sample rate
format: Format (default flac)
channels_first: Tensor format flag
Returns:
Saved file path
"""
return _default_saver.save_audio(
audio_data, output_path, sample_rate, format, channels_first
)