Bhavibond's picture
use whisper-tiny and check processing speeds for low tiers
b623c58 verified
import gradio as gr
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, AutoModelForSeq2SeqLM, AutoTokenizer
import torchaudio
import torch
from datasets import load_dataset
import os
# Load lightweight models
ASR_MODEL = "openai/whisper-tiny" # Faster ASR model
TRANSLATION_MODEL = "Helsinki-NLP/opus-mt-en-mul" # Lightweight translation model
# Load ASR model
from transformers import pipeline
asr = pipeline("automatic-speech-recognition", model=ASR_MODEL, device=0 if torch.cuda.is_available() else -1)
# Load translation model and tokenizer
translator_model = AutoModelForSeq2SeqLM.from_pretrained(TRANSLATION_MODEL)
translator_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL)
# Load TTS processor and model (use float16 for better speed)
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
tts = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(torch.float16)
# Cache speaker embeddings to avoid reloading every time
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(torch.float16)
# Ensure output directory exists
os.makedirs("output", exist_ok=True)
# Processing function
def process_audio(audio, target_language):
if not audio:
return "Error: No audio file provided.", None, None
try:
# Step 1: Transcribe the audio
result = asr(audio)["text"]
if not result:
return "Error: Failed to transcribe audio.", None, None
# Step 2: Translate the text
inputs = translator_tokenizer(result, return_tensors="pt", padding=True)
outputs = translator_model.generate(**inputs)
translated_text = translator_tokenizer.decode(outputs[0], skip_special_tokens=True)
if not translated_text:
return "Error: Translation failed.", None, None
# Step 3: Generate speech from translated text
inputs = processor(text=translated_text, return_tensors="pt")
input_features = inputs.input_features.to(torch.float16)
with torch.no_grad():
speech = tts.generate_speech(input_features, speaker_embeddings)
# Save generated speech
output_audio_path = "output/generated_speech.wav"
torchaudio.save(output_audio_path, speech.cpu(), 24000)
# Step 4: Create Braille-compatible file
braille_output_path = "output/braille.txt"
with open(braille_output_path, "w", encoding="utf-8") as f:
f.write(translated_text)
return translated_text, output_audio_path, braille_output_path
except Exception as e:
return f"Error: {str(e)}", None, None
# Define Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Multi-Language Voice Translator")
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Upload Audio")
target_language = gr.Dropdown(
choices=["en", "hi", "kn", "ta", "te", "es", "de", "fr", "hu"],
value="en",
label="Target Language"
)
with gr.Row():
submit_button = gr.Button("Translate & Synthesize")
clear_button = gr.Button("Clear")
with gr.Row():
translated_text = gr.Textbox(label="Translated Text")
generated_speech = gr.Audio(label="Generated Speech", interactive=False)
braille_file = gr.File(label="Download Braille File")
# Link functions to buttons
submit_button.click(
fn=process_audio,
inputs=[audio_input, target_language],
outputs=[translated_text, generated_speech, braille_file],
)
clear_button.click(
fn=lambda: ("", None, None),
inputs=[],
outputs=[translated_text, generated_speech, braille_file],
)
# Launch the app
demo.launch()