Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Healing Words E2E Translation Demo | |
| Interactive demo for biomedical translation models covering: | |
| - Amharic ↔ English | |
| - Hausa ↔ English | |
| - Hindi ↔ English | |
| Usage: python demo.py | |
| """ | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from peft import PeftModel | |
| import yaml | |
| import os | |
| from pathlib import Path | |
| class TranslationDemo: | |
| def __init__(self): | |
| self.models = {} | |
| self.tokenizers = {} | |
| self.configs = {} | |
| # Language pair information | |
| self.language_pairs = { | |
| "amh_en": { | |
| "name": "Amharic ↔ English", | |
| "src_lang": "Amharic (አማርኛ)", | |
| "tgt_lang": "English", | |
| "src_code": "amh_Ethi", | |
| "tgt_code": "eng_Latn", | |
| "base_model": "facebook/nllb-200-distilled-600M" | |
| }, | |
| "ha_en": { | |
| "name": "Hausa ↔ English", | |
| "src_lang": "Hausa", | |
| "tgt_lang": "English", | |
| "src_code": "hau_Latn", | |
| "tgt_code": "eng_Latn", | |
| "base_model": "facebook/nllb-200-distilled-600M" | |
| }, | |
| "hi_en": { | |
| "name": "Hindi ↔ English", | |
| "src_lang": "Hindi (हिन्दी)", | |
| "tgt_lang": "English", | |
| "src_code": "hin_Deva", | |
| "tgt_code": "eng_Latn", | |
| "base_model": "facebook/nllb-200-distilled-600M" | |
| } | |
| } | |
| self.load_models() | |
| def load_models(self): | |
| """Load all available trained models""" | |
| base_dir = Path("kit/outputs") | |
| for pair_id, pair_info in self.language_pairs.items(): | |
| checkpoint_dir = base_dir / pair_id / "checkpoint-best" | |
| config_path = Path(f"kit/configs/{pair_id}.yaml") | |
| if checkpoint_dir.exists() and config_path.exists(): | |
| try: | |
| print(f"Loading {pair_info['name']} model...") | |
| # Load config | |
| with open(config_path) as f: | |
| config = yaml.safe_load(f) | |
| self.configs[pair_id] = config | |
| # Load base model and tokenizer | |
| base_model = AutoModelForSeq2SeqLM.from_pretrained(pair_info["base_model"]) | |
| tokenizer = AutoTokenizer.from_pretrained(pair_info["base_model"]) | |
| # Set language codes | |
| tokenizer.src_lang = pair_info["src_code"] | |
| tokenizer.tgt_lang = pair_info["tgt_code"] | |
| # Load LoRA adapter | |
| model = PeftModel.from_pretrained(base_model, checkpoint_dir) | |
| model.eval() | |
| self.models[pair_id] = model | |
| self.tokenizers[pair_id] = tokenizer | |
| print(f"✓ Loaded {pair_info['name']} model") | |
| except Exception as e: | |
| print(f"✗ Failed to load {pair_info['name']} model: {e}") | |
| else: | |
| print(f"✗ No trained model found for {pair_info['name']}") | |
| def translate(self, text, language_pair, direction, domain="biomedical"): | |
| """Translate text using the specified model""" | |
| if not text.strip(): | |
| return "Please enter some text to translate." | |
| if language_pair not in self.models: | |
| return f"Model for {self.language_pairs[language_pair]['name']} is not available." | |
| try: | |
| model = self.models[language_pair] | |
| tokenizer = self.tokenizers[language_pair] | |
| config = self.configs[language_pair] | |
| # Add domain tag if configured | |
| if config.get("use_domain_tags", True) and domain: | |
| text = f"[{domain}] {text}" | |
| # Set tokenizer language codes based on direction | |
| if direction == "to_english": | |
| tokenizer.src_lang = self.language_pairs[language_pair]["src_code"] | |
| tokenizer.tgt_lang = self.language_pairs[language_pair]["tgt_code"] | |
| else: # from_english | |
| tokenizer.src_lang = self.language_pairs[language_pair]["tgt_code"] | |
| tokenizer.tgt_lang = self.language_pairs[language_pair]["src_code"] | |
| # Tokenize input | |
| inputs = tokenizer(text, return_tensors="pt", max_length=256, truncation=True) | |
| # Generate translation | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=256, | |
| num_beams=4, | |
| early_stopping=True, | |
| no_repeat_ngram_size=2 | |
| ) | |
| # Decode output | |
| translation = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return translation | |
| except Exception as e: | |
| return f"Translation error: {str(e)}" | |
| def get_example_texts(self, language_pair, direction): | |
| """Get example texts for the selected language pair and direction""" | |
| examples = { | |
| "amh_en": { | |
| "to_english": [ | |
| "የጤና ሰራተኞች ኮቪድ-19 ን ለመከላከል ጭንብል መጠቀም አለባቸው።", | |
| "ህመምተኛው ከፍተኛ ትኩሳት እና ሳል አለው።", | |
| "መድሃኒቱን ምግብ በመብላት ይውሰዱት።" | |
| ], | |
| "from_english": [ | |
| "The patient has a high fever and persistent cough.", | |
| "Healthcare workers should wear masks to prevent COVID-19.", | |
| "Take this medication with food for better absorption." | |
| ] | |
| }, | |
| "ha_en": { | |
| "to_english": [ | |
| "Ma'aikatan lafiya ya kamata su sa abin rufe fuska don karewa daga COVID-19.", | |
| "Majiyyaci yana da zazzabi mai yawa da tari.", | |
| "Ka sha wannan magani tare da abinci." | |
| ], | |
| "from_english": [ | |
| "The patient needs immediate medical attention.", | |
| "Wash your hands frequently to prevent infection.", | |
| "The medication should be taken twice daily." | |
| ] | |
| }, | |
| "hi_en": { | |
| "to_english": [ | |
| "स्वास्थ्य कर्मचारियों को COVID-19 से बचने के लिए मास्क पहनना चाहिए।", | |
| "मरीज़ को तेज़ बुखार और खांसी है।", | |
| "इस दवा को भोजन के साथ लें।" | |
| ], | |
| "from_english": [ | |
| "The patient requires urgent medical intervention.", | |
| "Monitor vital signs every two hours.", | |
| "Administer the injection intramuscularly." | |
| ] | |
| } | |
| } | |
| return examples.get(language_pair, {}).get(direction, []) | |
| # Initialize the demo | |
| demo_instance = TranslationDemo() | |
| def translate_wrapper(text, language_pair, direction, domain): | |
| """Wrapper function for Gradio interface""" | |
| return demo_instance.translate(text, language_pair, direction, domain) | |
| def update_examples(language_pair, direction): | |
| """Update example texts based on selection""" | |
| examples = demo_instance.get_example_texts(language_pair, direction) | |
| return gr.update(choices=examples, value=examples[0] if examples else "") | |
| def load_example(example_text): | |
| """Load selected example into text input""" | |
| return example_text | |
| # Create Gradio interface | |
| with gr.Blocks(title="Healing Words E2E Translation Demo", theme=gr.themes.Soft()) as interface: | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 20px;"> | |
| <h1>🌍 Healing Words E2E Translation Demo</h1> | |
| <p>Interactive biomedical translation for low-resource languages</p> | |
| <p><em>Amharic • Hausa • Hindi ↔ English</em></p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| language_pair = gr.Dropdown( | |
| choices=[ | |
| ("Amharic ↔ English", "amh_en"), | |
| ("Hausa ↔ English", "ha_en"), | |
| ("Hindi ↔ English", "hi_en") | |
| ], | |
| value="amh_en", | |
| label="Language Pair" | |
| ) | |
| direction = gr.Radio( | |
| choices=[ | |
| ("To English", "to_english"), | |
| ("From English", "from_english") | |
| ], | |
| value="to_english", | |
| label="Translation Direction" | |
| ) | |
| domain = gr.Dropdown( | |
| choices=["biomedical", "general"], | |
| value="biomedical", | |
| label="Domain" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Enter text to translate...", | |
| lines=4 | |
| ) | |
| examples = gr.Dropdown( | |
| label="Example Texts", | |
| choices=[], | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| translate_btn = gr.Button("🔄 Translate", variant="primary") | |
| clear_btn = gr.Button("🗑️ Clear") | |
| with gr.Column(): | |
| translation_output = gr.Textbox( | |
| label="Translation", | |
| lines=6, | |
| interactive=False | |
| ) | |
| # Model status information | |
| with gr.Accordion("📊 Model Information", open=False): | |
| model_status = [] | |
| for pair_id, pair_info in demo_instance.language_pairs.items(): | |
| status = "✅ Available" if pair_id in demo_instance.models else "❌ Not loaded" | |
| model_status.append(f"**{pair_info['name']}**: {status}") | |
| gr.Markdown("\n".join(model_status)) | |
| gr.Markdown(""" | |
| ### About the Models | |
| - **Base Model**: NLLB-200 Distilled 600M | |
| - **Fine-tuning**: LoRA (Low-Rank Adaptation) | |
| - **Domain**: Biomedical + General | |
| - **Training Data**: Synthetic templates + Real biomedical text | |
| """) | |
| # Event handlers | |
| language_pair.change( | |
| fn=update_examples, | |
| inputs=[language_pair, direction], | |
| outputs=[examples] | |
| ) | |
| direction.change( | |
| fn=update_examples, | |
| inputs=[language_pair, direction], | |
| outputs=[examples] | |
| ) | |
| examples.change( | |
| fn=load_example, | |
| inputs=[examples], | |
| outputs=[text_input] | |
| ) | |
| translate_btn.click( | |
| fn=translate_wrapper, | |
| inputs=[text_input, language_pair, direction, domain], | |
| outputs=[translation_output] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", ""), | |
| outputs=[text_input, translation_output] | |
| ) | |
| # Initialize examples on load | |
| interface.load( | |
| fn=update_examples, | |
| inputs=[language_pair, direction], | |
| outputs=[examples] | |
| ) | |
| if __name__ == "__main__": | |
| print("Starting Healing Words E2E Translation Demo...") | |
| print(f"Available models: {list(demo_instance.models.keys())}") | |
| # Launch the interface | |
| interface.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| debug=True | |
| ) |