voxtream / app.py
herimor's picture
Add max audio length handling
0d0d952
raw
history blame
4.36 kB
import os
# Disable PyTorch dynamo/inductor globally
os.environ["TORCHDYNAMO_DISABLE"] = "1"
os.environ["TORCHINDUCTOR_DISABLE"] = "1"
import torch._dynamo as dynamo
dynamo.config.suppress_errors = True
import json
from pathlib import Path
import nltk
import torch
import spaces
import gradio as gr
import numpy as np
from voxtream.generator import SpeechGenerator, SpeechGeneratorConfig
with open("configs/generator.json") as f:
config = SpeechGeneratorConfig(**json.load(f))
# Loading speaker encoder
torch.hub.load(
config.spk_enc_repo,
config.spk_enc_model,
model_name=config.spk_enc_model_name,
train_type=config.spk_enc_train_type,
dataset=config.spk_enc_dataset,
trust_repo=True,
verbose=False,
)
# Loading NLTK packages
nltk.download("averaged_perceptron_tagger_eng", quiet=True, raise_on_error=True)
nltk.download("punkt", quiet=True, raise_on_error=True)
# Initialize speech generator
speech_generator = SpeechGenerator(config)
CUSTOM_CSS = """
/* overall width */
.gradio-container {max-width: 1100px !important}
/* stack labels tighter and even heights */
#cols .wrap > .form {gap: 10px}
#left-col, #right-col {gap: 14px}
/* make submit centered + bigger */
#submit {width: 260px; margin: 10px auto 0 auto;}
/* make clear align left and look secondary */
#clear {width: 120px;}
/* give audio a little breathing room */
audio {outline: none;}
"""
@spaces.GPU
def synthesize_fn(prompt_audio_path, prompt_text, target_text):
if next(speech_generator.model.parameters()).device.type == "cpu":
speech_generator.model.to("cuda")
speech_generator.mimi.to("cuda")
speech_generator.spk_enc.to("cuda")
speech_generator.aligner.aligner.to("cuda")
speech_generator.aligner.device = "cuda"
speech_generator.device = "cuda"
if not prompt_audio_path or not target_text:
return None
stream = speech_generator.generate_stream(
prompt_text=prompt_text,
prompt_audio_path=Path(prompt_audio_path),
text=target_text,
)
frames = [frame for frame, _ in stream]
if not frames:
return None
waveform = np.concatenate(frames).astype(np.float32)
# Fade out
fade_len_sec = 0.1
fade_out = np.linspace(1.0, 0.0, int(config.mimi_sr * fade_len_sec))
waveform[-int(config.mimi_sr * fade_len_sec) :] *= fade_out
return (config.mimi_sr, waveform)
def main():
with gr.Blocks(css=CUSTOM_CSS, title="VoXtream") as demo:
gr.Markdown("# VoXtream TTS demo")
with gr.Row(equal_height=True, elem_id="cols"):
with gr.Column(scale=1, elem_id="left-col"):
prompt_audio = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Prompt audio (3-5 sec of target voice. Max 10 sec)",
)
prompt_text = gr.Textbox(
lines=3,
max_length=config.max_prompt_chars,
label=f"Prompt transcript. Max characters: {config.max_prompt_chars} (Required)",
placeholder="Text that matches the prompt audio",
)
with gr.Column(scale=1, elem_id="right-col"):
target_text = gr.Textbox(
lines=3,
max_length=config.max_phone_tokens,
label=f"Target text. Max characters: {config.max_phone_tokens}",
placeholder="What you want the model to say",
)
output_audio = gr.Audio(
type="numpy",
label="Synthesized audio",
interactive=False,
)
with gr.Row():
clear_btn = gr.Button("Clear", elem_id="clear", variant="secondary")
submit_btn = gr.Button("Submit", elem_id="submit", variant="primary")
# wire up actions
submit_btn.click(
fn=synthesize_fn,
inputs=[prompt_audio, prompt_text, target_text],
outputs=output_audio,
)
# reset everything
clear_btn.click(
fn=lambda: (None, "", "", None),
inputs=[],
outputs=[prompt_audio, prompt_text, target_text, output_audio],
)
demo.launch()
if __name__ == "__main__":
main()