Spaces:
Running
Running
| """ | |
| Seed-VC Streaming API Server | |
| architecture.md と model_ref.md に基づいて実装 | |
| """ | |
| import io | |
| import os | |
| import sys | |
| import time | |
| import uuid | |
| from typing import Optional, Dict | |
| from argparse import Namespace | |
| import numpy as np | |
| import soundfile as sf | |
| import librosa | |
| import torch | |
| import torchaudio | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.responses import Response | |
| from pydantic import BaseModel | |
| from huggingface_hub import hf_hub_download | |
| # Seed-VC | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'seed-vc')) | |
| # Hugging Face cache directory (absolute path) | |
| cache_dir = '/app/checkpoints' | |
| os.makedirs(cache_dir, exist_ok=True) | |
| os.environ['HF_HOME'] = cache_dir | |
| os.environ['HF_HUB_CACHE'] = cache_dir | |
| os.environ['TRANSFORMERS_CACHE'] = cache_dir | |
| os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' | |
| # MPSを無効化してCPUを強制 | |
| import torch | |
| torch.backends.mps.is_available = lambda: False | |
| from inference import load_models | |
| # ============================================================================= | |
| # Configuration (architecture.md Section 5) | |
| # ============================================================================= | |
| DEFAULT_SAMPLE_RATE = 16000 | |
| DEFAULT_CHUNK_LEN_MS = 1000 | |
| DEFAULT_OVERLAP_MS = 200 | |
| SESSION_EXPIRE_SEC = 600 | |
| # model_ref.md Section 3.1 | |
| # Hugging Face Hubから参照音声をダウンロード | |
| # リポジトリ: Akatuki25/seed-vc-ref-audios (dataset) | |
| DEFAULT_REF_PRESET = "default_female" | |
| REF_PRESETS = { | |
| "default_female": ("Akatuki25/seed-vc-ref-audios", "default_female.wav"), | |
| "default_male": ("Akatuki25/seed-vc-ref-audios", "default_male.wav"), | |
| } | |
| # ダウンロード済み参照音声のキャッシュ | |
| downloaded_ref_cache = {} | |
| # ============================================================================= | |
| # Global Variables | |
| # ============================================================================= | |
| # MPSは避ける(seed-vcとの互換性問題) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Seed-VCモデル (inference.py load_models()の戻り値) | |
| model = None | |
| semantic_fn = None | |
| f0_fn = None | |
| vocoder_fn = None | |
| campplus_model = None | |
| to_mel = None | |
| mel_fn_args = None | |
| model_sr = 22050 | |
| # ============================================================================= | |
| # Session State (architecture.md Section 4.1) | |
| # ============================================================================= | |
| class SessionState: | |
| def __init__(self, sample_rate: int, tgt_speaker_id: Optional[str] = None): | |
| self.sample_rate = sample_rate | |
| self.tgt_speaker_id = tgt_speaker_id | |
| self.last_output_tail: Optional[np.ndarray] = None | |
| # model_ref.md Section 3: 参照音声の管理 | |
| self.ref_audio_tensor = None # 参照音声 (model_sr, float tensor) | |
| self.ref_mel = None | |
| self.ref_semantic = None | |
| self.style_embed = None | |
| self.last_access_ts = time.time() | |
| self.chunk_len_ms = DEFAULT_CHUNK_LEN_MS | |
| self.overlap_ms = DEFAULT_OVERLAP_MS | |
| SESSIONS: Dict[str, SessionState] = {} | |
| # ============================================================================= | |
| # FastAPI App | |
| # ============================================================================= | |
| app = FastAPI(title="Seed-VC Streaming API", version="1.0.0") | |
| async def startup_event(): | |
| """モデルロード (architecture.md Section 4.3.1)""" | |
| global model, semantic_fn, f0_fn, vocoder_fn, campplus_model, to_mel, mel_fn_args, model_sr | |
| print(f"Device: {device}") | |
| print("Loading Seed-VC models...") | |
| # inference.pyのload_modelsをそのまま使用 | |
| args = Namespace( | |
| f0_condition=False, # model_ref.md: 22050Hz系を使う | |
| checkpoint=None, | |
| config=None, | |
| fp16=False | |
| ) | |
| model, semantic_fn, f0_fn, vocoder_fn, campplus_model, to_mel, mel_fn_args = load_models(args) | |
| model_sr = mel_fn_args['sampling_rate'] | |
| print(f"Models loaded! SR={model_sr}") | |
| # ============================================================================= | |
| # Pydantic Models (architecture.md Section 3.2) | |
| # ============================================================================= | |
| class SessionCreateRequest(BaseModel): | |
| sample_rate: int = DEFAULT_SAMPLE_RATE | |
| tgt_speaker_id: Optional[str] = None | |
| ref_preset_id: Optional[str] = None | |
| use_uploaded_ref: bool = False | |
| chunk_len_ms: int = DEFAULT_CHUNK_LEN_MS | |
| overlap_ms: int = DEFAULT_OVERLAP_MS | |
| class SessionCreateResponse(BaseModel): | |
| session_id: str | |
| sample_rate: int | |
| chunk_len_ms: int | |
| overlap_ms: int | |
| class SessionEndRequest(BaseModel): | |
| session_id: str | |
| # ============================================================================= | |
| # Utility Functions | |
| # ============================================================================= | |
| def load_wav_to_numpy(file_bytes: bytes, target_sr: int) -> tuple[np.ndarray, int]: | |
| """WAVファイルをnumpy配列に変換""" | |
| audio, sr = sf.read(io.BytesIO(file_bytes)) | |
| if len(audio.shape) > 1: | |
| audio = audio.mean(axis=1) | |
| if sr != target_sr: | |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr) | |
| sr = target_sr | |
| if audio.dtype in (np.float32, np.float64): | |
| audio = (audio * 32767).astype(np.int16) | |
| return audio, sr | |
| def numpy_to_wav_bytes(audio: np.ndarray, sr: int) -> bytes: | |
| """numpy配列をWAVバイト列に変換""" | |
| buffer = io.BytesIO() | |
| sf.write(buffer, audio, sr, format="WAV", subtype="PCM_16") | |
| buffer.seek(0) | |
| return buffer.read() | |
| def crossfade(prev_tail: Optional[np.ndarray], new_chunk: np.ndarray, fade_len: int) -> np.ndarray: | |
| """クロスフェード (architecture.md Section 4.2.1)""" | |
| if prev_tail is None: | |
| return new_chunk | |
| fade_len = min(fade_len, len(prev_tail), len(new_chunk)) | |
| if fade_len <= 0: | |
| return new_chunk | |
| fade_in = np.linspace(0.0, 1.0, fade_len, endpoint=True) | |
| fade_out = 1.0 - fade_in | |
| mixed_head = (prev_tail[-fade_len:] * fade_out + new_chunk[:fade_len] * fade_in).astype(np.int16) | |
| tail = new_chunk[fade_len:] | |
| return np.concatenate([mixed_head, tail]) | |
| def download_ref_preset(preset_id: str) -> str: | |
| """ | |
| Hugging Face Hubから参照音声をダウンロード | |
| Returns: ローカルファイルパス | |
| """ | |
| if preset_id in downloaded_ref_cache: | |
| return downloaded_ref_cache[preset_id] | |
| if preset_id not in REF_PRESETS: | |
| raise ValueError(f"Unknown preset_id: {preset_id}") | |
| repo_id, filename = REF_PRESETS[preset_id] | |
| print(f"Downloading reference audio from {repo_id}/{filename}...") | |
| local_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| repo_type="dataset", | |
| cache_dir=cache_dir | |
| ) | |
| downloaded_ref_cache[preset_id] = local_path | |
| print(f"Downloaded to {local_path}") | |
| return local_path | |
| def prepare_reference_audio(audio_path: str, state: SessionState): | |
| """ | |
| 参照音声を準備 (model_ref.md Section 3) | |
| inference.py の main() と同じロジック | |
| """ | |
| # 参照音声をロード | |
| ref_audio, file_sr = librosa.load(audio_path, sr=model_sr) | |
| ref_audio = ref_audio[:model_sr * 25] # 25秒まで | |
| # tensorに変換 | |
| ref_audio_tensor = torch.tensor(ref_audio).unsqueeze(0).float().to(device) | |
| state.ref_audio_tensor = ref_audio_tensor | |
| # mel spectrogram | |
| state.ref_mel = to_mel(ref_audio_tensor) | |
| # Whisper semantic features | |
| ref_waves_16k = torchaudio.functional.resample(ref_audio_tensor, model_sr, 16000) | |
| state.ref_semantic = semantic_fn(ref_waves_16k) | |
| # CAMPPlus style embedding | |
| feat = torchaudio.compliance.kaldi.fbank( | |
| ref_waves_16k, | |
| num_mel_bins=80, | |
| dither=0, | |
| sample_frequency=16000 | |
| ) | |
| feat = feat - feat.mean(dim=0, keepdim=True) | |
| state.style_embed = campplus_model(feat.unsqueeze(0)) | |
| print(f"Reference prepared: mel={state.ref_mel.shape}, semantic={state.ref_semantic.shape}") | |
| def seed_vc_infer(chunk_np: np.ndarray, chunk_sr: int, state: SessionState) -> np.ndarray: | |
| """ | |
| Seed-VCで音声変換 (architecture.md Section 4.3.2) | |
| inference.py main()のロジックを使用 | |
| """ | |
| # int16 -> float32 | |
| if chunk_np.dtype == np.int16: | |
| source_audio = chunk_np.astype(np.float32) / 32768.0 | |
| else: | |
| source_audio = chunk_np.astype(np.float32) | |
| # model_sr にリサンプル | |
| if chunk_sr != model_sr: | |
| source_audio = librosa.resample(source_audio, orig_sr=chunk_sr, target_sr=model_sr) | |
| # tensor化 | |
| source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device) | |
| # 16kHz変換してWhisper特徴抽出 | |
| converted_waves_16k = torchaudio.functional.resample(source_audio, model_sr, 16000) | |
| S_alt = semantic_fn(converted_waves_16k) | |
| # mel spectrogram | |
| mel = to_mel(source_audio.to(device).float()) | |
| # target lengths | |
| target_lengths = torch.LongTensor([mel.size(2)]).to(device) | |
| target2_lengths = torch.LongTensor([state.ref_mel.size(2)]).to(device) | |
| # length regulator (inference.py line 354-360) | |
| with torch.no_grad(): | |
| cond, _, _, _, _ = model.length_regulator( | |
| S_alt, ylens=target_lengths, n_quantizers=3, f0=None | |
| ) | |
| prompt_condition, _, _, _, _ = model.length_regulator( | |
| state.ref_semantic, ylens=target2_lengths, n_quantizers=3, f0=None | |
| ) | |
| # 条件結合 | |
| cat_condition = torch.cat([prompt_condition, cond], dim=1) | |
| # CFM inference (inference.py line 373-376) | |
| with torch.no_grad(): | |
| vc_target = model.cfm.inference( | |
| cat_condition, | |
| torch.LongTensor([cat_condition.size(1)]).to(device), | |
| state.ref_mel, | |
| state.style_embed, | |
| None, | |
| 10, # diffusion_steps | |
| inference_cfg_rate=0.7 | |
| ) | |
| # プロンプト部分削除 | |
| vc_target = vc_target[:, :, state.ref_mel.size(-1):] | |
| # Vocoder (inference.py line 378) | |
| with torch.no_grad(): | |
| vc_wave = vocoder_fn(vc_target.float()).squeeze() | |
| vc_wave = vc_wave[None, :] | |
| # numpy変換 | |
| output_wave = vc_wave[0].cpu().numpy() | |
| # int16に戻す | |
| output_int16 = (output_wave * 32767).clip(-32768, 32767).astype(np.int16) | |
| return output_int16 | |
| # ============================================================================= | |
| # Endpoints (architecture.md Section 3.2) | |
| # ============================================================================= | |
| async def health_check(): | |
| """3.2.1 GET /health""" | |
| return {"status": "ok"} | |
| async def create_session(body: SessionCreateRequest): | |
| """ | |
| 3.2.2 POST /session | |
| model_ref.md Section 2.2(A) | |
| """ | |
| session_id = str(uuid.uuid4()) | |
| state = SessionState( | |
| sample_rate=body.sample_rate, | |
| tgt_speaker_id=body.tgt_speaker_id | |
| ) | |
| state.chunk_len_ms = body.chunk_len_ms | |
| state.overlap_ms = body.overlap_ms | |
| # 参照音声設定 (model_ref.md Section 3.2) | |
| if not body.use_uploaded_ref: | |
| preset_id = body.ref_preset_id or DEFAULT_REF_PRESET | |
| if preset_id is None: | |
| raise HTTPException(status_code=400, detail="ref_preset_id or use_uploaded_ref=true required") | |
| wav_path = download_ref_preset(preset_id) | |
| prepare_reference_audio(wav_path, state) | |
| SESSIONS[session_id] = state | |
| return SessionCreateResponse( | |
| session_id=session_id, | |
| sample_rate=body.sample_rate, | |
| chunk_len_ms=body.chunk_len_ms, | |
| overlap_ms=body.overlap_ms, | |
| ) | |
| async def upload_ref_audio( | |
| session_id: str = Form(...), | |
| ref_audio: UploadFile = File(...) | |
| ): | |
| """ | |
| model_ref.md Section 2.2(B) | |
| """ | |
| if session_id not in SESSIONS: | |
| raise HTTPException(status_code=400, detail="Invalid session_id") | |
| state = SESSIONS[session_id] | |
| # 一時ファイル保存 | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
| content = await ref_audio.read() | |
| tmp.write(content) | |
| tmp_path = tmp.name | |
| try: | |
| prepare_reference_audio(tmp_path, state) | |
| finally: | |
| os.unlink(tmp_path) | |
| state.last_access_ts = time.time() | |
| return {"status": "ok"} | |
| async def process_chunk( | |
| session_id: str = Form(...), | |
| chunk_id: int = Form(...), | |
| audio: UploadFile = File(...) | |
| ): | |
| """ | |
| 3.2.3 POST /chunk | |
| architecture.md Section 3.2.3 サーバ内部処理フロー | |
| """ | |
| if session_id not in SESSIONS: | |
| raise HTTPException(status_code=400, detail="Invalid session_id") | |
| state = SESSIONS[session_id] | |
| if chunk_id < 0: | |
| raise HTTPException(status_code=400, detail="chunk_id must be non-negative") | |
| # Step 2: 音声読み込み | |
| audio_bytes = await audio.read() | |
| chunk_np, chunk_sr = load_wav_to_numpy(audio_bytes, target_sr=state.sample_rate) | |
| # Step 3: サンプルレートチェック | |
| if chunk_sr != state.sample_rate: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Sample rate mismatch: expected {state.sample_rate}, got {chunk_sr}" | |
| ) | |
| # Step 4: Seed-VCで変換 | |
| converted = seed_vc_infer(chunk_np, chunk_sr, state) | |
| # Step 5: クロスフェード | |
| fade_len = int(model_sr * state.overlap_ms / 1000) | |
| output = crossfade(state.last_output_tail, converted, fade_len) | |
| # Step 6: tail更新 | |
| if len(output) >= fade_len: | |
| state.last_output_tail = output[-fade_len:].copy() | |
| else: | |
| state.last_output_tail = output.copy() | |
| state.last_access_ts = time.time() | |
| # Step 7: WAVエンコード | |
| wav_bytes = numpy_to_wav_bytes(output, model_sr) | |
| return Response( | |
| content=wav_bytes, | |
| media_type="audio/wav", | |
| headers={"X-Chunk-Id": str(chunk_id)} | |
| ) | |
| async def end_session(body: SessionEndRequest): | |
| """3.2.4 POST /end""" | |
| SESSIONS.pop(body.session_id, None) | |
| return {"status": "ended"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |