demo / app.py
maitreyi2906's picture
Create app.py
f59143c verified
raw
history blame
12 kB
#!/usr/bin/env python3
"""
Healing Words E2E Translation Demo
Interactive demo for biomedical translation models covering:
- Amharic ↔ English
- Hausa ↔ English
- Hindi ↔ English
Usage: python demo.py
"""
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
import yaml
import os
from pathlib import Path
class TranslationDemo:
def __init__(self):
self.models = {}
self.tokenizers = {}
self.configs = {}
# Language pair information
self.language_pairs = {
"amh_en": {
"name": "Amharic ↔ English",
"src_lang": "Amharic (አማርኛ)",
"tgt_lang": "English",
"src_code": "amh_Ethi",
"tgt_code": "eng_Latn",
"base_model": "facebook/nllb-200-distilled-600M"
},
"ha_en": {
"name": "Hausa ↔ English",
"src_lang": "Hausa",
"tgt_lang": "English",
"src_code": "hau_Latn",
"tgt_code": "eng_Latn",
"base_model": "facebook/nllb-200-distilled-600M"
},
"hi_en": {
"name": "Hindi ↔ English",
"src_lang": "Hindi (हिन्दी)",
"tgt_lang": "English",
"src_code": "hin_Deva",
"tgt_code": "eng_Latn",
"base_model": "facebook/nllb-200-distilled-600M"
}
}
self.load_models()
def load_models(self):
"""Load all available trained models"""
base_dir = Path("kit/outputs")
for pair_id, pair_info in self.language_pairs.items():
checkpoint_dir = base_dir / pair_id / "checkpoint-best"
config_path = Path(f"kit/configs/{pair_id}.yaml")
if checkpoint_dir.exists() and config_path.exists():
try:
print(f"Loading {pair_info['name']} model...")
# Load config
with open(config_path) as f:
config = yaml.safe_load(f)
self.configs[pair_id] = config
# Load base model and tokenizer
base_model = AutoModelForSeq2SeqLM.from_pretrained(pair_info["base_model"])
tokenizer = AutoTokenizer.from_pretrained(pair_info["base_model"])
# Set language codes
tokenizer.src_lang = pair_info["src_code"]
tokenizer.tgt_lang = pair_info["tgt_code"]
# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, checkpoint_dir)
model.eval()
self.models[pair_id] = model
self.tokenizers[pair_id] = tokenizer
print(f"✓ Loaded {pair_info['name']} model")
except Exception as e:
print(f"✗ Failed to load {pair_info['name']} model: {e}")
else:
print(f"✗ No trained model found for {pair_info['name']}")
def translate(self, text, language_pair, direction, domain="biomedical"):
"""Translate text using the specified model"""
if not text.strip():
return "Please enter some text to translate."
if language_pair not in self.models:
return f"Model for {self.language_pairs[language_pair]['name']} is not available."
try:
model = self.models[language_pair]
tokenizer = self.tokenizers[language_pair]
config = self.configs[language_pair]
# Add domain tag if configured
if config.get("use_domain_tags", True) and domain:
text = f"[{domain}] {text}"
# Set tokenizer language codes based on direction
if direction == "to_english":
tokenizer.src_lang = self.language_pairs[language_pair]["src_code"]
tokenizer.tgt_lang = self.language_pairs[language_pair]["tgt_code"]
else: # from_english
tokenizer.src_lang = self.language_pairs[language_pair]["tgt_code"]
tokenizer.tgt_lang = self.language_pairs[language_pair]["src_code"]
# Tokenize input
inputs = tokenizer(text, return_tensors="pt", max_length=256, truncation=True)
# Generate translation
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=256,
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=2
)
# Decode output
translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
return translation
except Exception as e:
return f"Translation error: {str(e)}"
def get_example_texts(self, language_pair, direction):
"""Get example texts for the selected language pair and direction"""
examples = {
"amh_en": {
"to_english": [
"የጤና ሰራተኞች ኮቪድ-19 ን ለመከላከል ጭንብል መጠቀም አለባቸው።",
"ህመምተኛው ከፍተኛ ትኩሳት እና ሳል አለው።",
"መድሃኒቱን ምግብ በመብላት ይውሰዱት።"
],
"from_english": [
"The patient has a high fever and persistent cough.",
"Healthcare workers should wear masks to prevent COVID-19.",
"Take this medication with food for better absorption."
]
},
"ha_en": {
"to_english": [
"Ma'aikatan lafiya ya kamata su sa abin rufe fuska don karewa daga COVID-19.",
"Majiyyaci yana da zazzabi mai yawa da tari.",
"Ka sha wannan magani tare da abinci."
],
"from_english": [
"The patient needs immediate medical attention.",
"Wash your hands frequently to prevent infection.",
"The medication should be taken twice daily."
]
},
"hi_en": {
"to_english": [
"स्वास्थ्य कर्मचारियों को COVID-19 से बचने के लिए मास्क पहनना चाहिए।",
"मरीज़ को तेज़ बुखार और खांसी है।",
"इस दवा को भोजन के साथ लें।"
],
"from_english": [
"The patient requires urgent medical intervention.",
"Monitor vital signs every two hours.",
"Administer the injection intramuscularly."
]
}
}
return examples.get(language_pair, {}).get(direction, [])
# Initialize the demo
demo_instance = TranslationDemo()
def translate_wrapper(text, language_pair, direction, domain):
"""Wrapper function for Gradio interface"""
return demo_instance.translate(text, language_pair, direction, domain)
def update_examples(language_pair, direction):
"""Update example texts based on selection"""
examples = demo_instance.get_example_texts(language_pair, direction)
return gr.update(choices=examples, value=examples[0] if examples else "")
def load_example(example_text):
"""Load selected example into text input"""
return example_text
# Create Gradio interface
with gr.Blocks(title="Healing Words E2E Translation Demo", theme=gr.themes.Soft()) as interface:
gr.HTML("""
<div style="text-align: center; padding: 20px;">
<h1>🌍 Healing Words E2E Translation Demo</h1>
<p>Interactive biomedical translation for low-resource languages</p>
<p><em>Amharic • Hausa • Hindi ↔ English</em></p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
language_pair = gr.Dropdown(
choices=[
("Amharic ↔ English", "amh_en"),
("Hausa ↔ English", "ha_en"),
("Hindi ↔ English", "hi_en")
],
value="amh_en",
label="Language Pair"
)
direction = gr.Radio(
choices=[
("To English", "to_english"),
("From English", "from_english")
],
value="to_english",
label="Translation Direction"
)
domain = gr.Dropdown(
choices=["biomedical", "general"],
value="biomedical",
label="Domain"
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Input Text",
placeholder="Enter text to translate...",
lines=4
)
examples = gr.Dropdown(
label="Example Texts",
choices=[],
interactive=True
)
with gr.Row():
translate_btn = gr.Button("🔄 Translate", variant="primary")
clear_btn = gr.Button("🗑️ Clear")
with gr.Column():
translation_output = gr.Textbox(
label="Translation",
lines=6,
interactive=False
)
# Model status information
with gr.Accordion("📊 Model Information", open=False):
model_status = []
for pair_id, pair_info in demo_instance.language_pairs.items():
status = "✅ Available" if pair_id in demo_instance.models else "❌ Not loaded"
model_status.append(f"**{pair_info['name']}**: {status}")
gr.Markdown("\n".join(model_status))
gr.Markdown("""
### About the Models
- **Base Model**: NLLB-200 Distilled 600M
- **Fine-tuning**: LoRA (Low-Rank Adaptation)
- **Domain**: Biomedical + General
- **Training Data**: Synthetic templates + Real biomedical text
""")
# Event handlers
language_pair.change(
fn=update_examples,
inputs=[language_pair, direction],
outputs=[examples]
)
direction.change(
fn=update_examples,
inputs=[language_pair, direction],
outputs=[examples]
)
examples.change(
fn=load_example,
inputs=[examples],
outputs=[text_input]
)
translate_btn.click(
fn=translate_wrapper,
inputs=[text_input, language_pair, direction, domain],
outputs=[translation_output]
)
clear_btn.click(
fn=lambda: ("", ""),
outputs=[text_input, translation_output]
)
# Initialize examples on load
interface.load(
fn=update_examples,
inputs=[language_pair, direction],
outputs=[examples]
)
if __name__ == "__main__":
print("Starting Healing Words E2E Translation Demo...")
print(f"Available models: {list(demo_instance.models.keys())}")
# Launch the interface
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=True
)