File size: 11,959 Bytes
f59143c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
#!/usr/bin/env python3
"""
Healing Words E2E Translation Demo

Interactive demo for biomedical translation models covering:
- Amharic ↔ English
- Hausa ↔ English  
- Hindi ↔ English

Usage: python demo.py
"""

import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
import yaml
import os
from pathlib import Path

class TranslationDemo:
    def __init__(self):
        self.models = {}
        self.tokenizers = {}
        self.configs = {}
        
        # Language pair information
        self.language_pairs = {
            "amh_en": {
                "name": "Amharic ↔ English",
                "src_lang": "Amharic (አማርኛ)",
                "tgt_lang": "English",
                "src_code": "amh_Ethi",
                "tgt_code": "eng_Latn",
                "base_model": "facebook/nllb-200-distilled-600M"
            },
            "ha_en": {
                "name": "Hausa ↔ English", 
                "src_lang": "Hausa",
                "tgt_lang": "English",
                "src_code": "hau_Latn",
                "tgt_code": "eng_Latn",
                "base_model": "facebook/nllb-200-distilled-600M"
            },
            "hi_en": {
                "name": "Hindi ↔ English",
                "src_lang": "Hindi (हिन्दी)",
                "tgt_lang": "English", 
                "src_code": "hin_Deva",
                "tgt_code": "eng_Latn",
                "base_model": "facebook/nllb-200-distilled-600M"
            }
        }
        
        self.load_models()
    
    def load_models(self):
        """Load all available trained models"""
        base_dir = Path("kit/outputs")
        
        for pair_id, pair_info in self.language_pairs.items():
            checkpoint_dir = base_dir / pair_id / "checkpoint-best"
            config_path = Path(f"kit/configs/{pair_id}.yaml")
            
            if checkpoint_dir.exists() and config_path.exists():
                try:
                    print(f"Loading {pair_info['name']} model...")
                    
                    # Load config
                    with open(config_path) as f:
                        config = yaml.safe_load(f)
                    self.configs[pair_id] = config
                    
                    # Load base model and tokenizer
                    base_model = AutoModelForSeq2SeqLM.from_pretrained(pair_info["base_model"])
                    tokenizer = AutoTokenizer.from_pretrained(pair_info["base_model"])
                    
                    # Set language codes
                    tokenizer.src_lang = pair_info["src_code"]
                    tokenizer.tgt_lang = pair_info["tgt_code"]
                    
                    # Load LoRA adapter
                    model = PeftModel.from_pretrained(base_model, checkpoint_dir)
                    model.eval()
                    
                    self.models[pair_id] = model
                    self.tokenizers[pair_id] = tokenizer
                    
                    print(f"✓ Loaded {pair_info['name']} model")
                    
                except Exception as e:
                    print(f"✗ Failed to load {pair_info['name']} model: {e}")
            else:
                print(f"✗ No trained model found for {pair_info['name']}")
    
    def translate(self, text, language_pair, direction, domain="biomedical"):
        """Translate text using the specified model"""
        if not text.strip():
            return "Please enter some text to translate."
            
        if language_pair not in self.models:
            return f"Model for {self.language_pairs[language_pair]['name']} is not available."
        
        try:
            model = self.models[language_pair]
            tokenizer = self.tokenizers[language_pair]
            config = self.configs[language_pair]
            
            # Add domain tag if configured
            if config.get("use_domain_tags", True) and domain:
                text = f"[{domain}] {text}"
            
            # Set tokenizer language codes based on direction
            if direction == "to_english":
                tokenizer.src_lang = self.language_pairs[language_pair]["src_code"]
                tokenizer.tgt_lang = self.language_pairs[language_pair]["tgt_code"]
            else:  # from_english
                tokenizer.src_lang = self.language_pairs[language_pair]["tgt_code"] 
                tokenizer.tgt_lang = self.language_pairs[language_pair]["src_code"]
            
            # Tokenize input
            inputs = tokenizer(text, return_tensors="pt", max_length=256, truncation=True)
            
            # Generate translation
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_length=256,
                    num_beams=4,
                    early_stopping=True,
                    no_repeat_ngram_size=2
                )
            
            # Decode output
            translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
            return translation
            
        except Exception as e:
            return f"Translation error: {str(e)}"
    
    def get_example_texts(self, language_pair, direction):
        """Get example texts for the selected language pair and direction"""
        examples = {
            "amh_en": {
                "to_english": [
                    "የጤና ሰራተኞች ኮቪድ-19 ን ለመከላከል ጭንብል መጠቀም አለባቸው።",
                    "ህመምተኛው ከፍተኛ ትኩሳት እና ሳል አለው።",
                    "መድሃኒቱን ምግብ በመብላት ይውሰዱት።"
                ],
                "from_english": [
                    "The patient has a high fever and persistent cough.",
                    "Healthcare workers should wear masks to prevent COVID-19.",
                    "Take this medication with food for better absorption."
                ]
            },
            "ha_en": {
                "to_english": [
                    "Ma'aikatan lafiya ya kamata su sa abin rufe fuska don karewa daga COVID-19.",
                    "Majiyyaci yana da zazzabi mai yawa da tari.",
                    "Ka sha wannan magani tare da abinci."
                ],
                "from_english": [
                    "The patient needs immediate medical attention.",
                    "Wash your hands frequently to prevent infection.", 
                    "The medication should be taken twice daily."
                ]
            },
            "hi_en": {
                "to_english": [
                    "स्वास्थ्य कर्मचारियों को COVID-19 से बचने के लिए मास्क पहनना चाहिए।",
                    "मरीज़ को तेज़ बुखार और खांसी है।",
                    "इस दवा को भोजन के साथ लें।"
                ],
                "from_english": [
                    "The patient requires urgent medical intervention.",
                    "Monitor vital signs every two hours.",
                    "Administer the injection intramuscularly."
                ]
            }
        }
        return examples.get(language_pair, {}).get(direction, [])

# Initialize the demo
demo_instance = TranslationDemo()

def translate_wrapper(text, language_pair, direction, domain):
    """Wrapper function for Gradio interface"""
    return demo_instance.translate(text, language_pair, direction, domain)

def update_examples(language_pair, direction):
    """Update example texts based on selection"""
    examples = demo_instance.get_example_texts(language_pair, direction)
    return gr.update(choices=examples, value=examples[0] if examples else "")

def load_example(example_text):
    """Load selected example into text input"""
    return example_text

# Create Gradio interface
with gr.Blocks(title="Healing Words E2E Translation Demo", theme=gr.themes.Soft()) as interface:
    
    gr.HTML("""
    <div style="text-align: center; padding: 20px;">
        <h1>🌍 Healing Words E2E Translation Demo</h1>
        <p>Interactive biomedical translation for low-resource languages</p>
        <p><em>Amharic • Hausa • Hindi ↔ English</em></p>
    </div>
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            language_pair = gr.Dropdown(
                choices=[
                    ("Amharic ↔ English", "amh_en"),
                    ("Hausa ↔ English", "ha_en"), 
                    ("Hindi ↔ English", "hi_en")
                ],
                value="amh_en",
                label="Language Pair"
            )
            
            direction = gr.Radio(
                choices=[
                    ("To English", "to_english"),
                    ("From English", "from_english")
                ],
                value="to_english",
                label="Translation Direction"
            )
            
            domain = gr.Dropdown(
                choices=["biomedical", "general"],
                value="biomedical", 
                label="Domain"
            )
    
    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(
                label="Input Text",
                placeholder="Enter text to translate...",
                lines=4
            )
            
            examples = gr.Dropdown(
                label="Example Texts",
                choices=[],
                interactive=True
            )
            
            with gr.Row():
                translate_btn = gr.Button("🔄 Translate", variant="primary")
                clear_btn = gr.Button("🗑️ Clear")
        
        with gr.Column():
            translation_output = gr.Textbox(
                label="Translation",
                lines=6,
                interactive=False
            )
    
    # Model status information
    with gr.Accordion("📊 Model Information", open=False):
        model_status = []
        for pair_id, pair_info in demo_instance.language_pairs.items():
            status = "✅ Available" if pair_id in demo_instance.models else "❌ Not loaded"
            model_status.append(f"**{pair_info['name']}**: {status}")
        
        gr.Markdown("\n".join(model_status))
        
        gr.Markdown("""
        ### About the Models
        - **Base Model**: NLLB-200 Distilled 600M
        - **Fine-tuning**: LoRA (Low-Rank Adaptation)  
        - **Domain**: Biomedical + General
        - **Training Data**: Synthetic templates + Real biomedical text
        """)
    
    # Event handlers
    language_pair.change(
        fn=update_examples,
        inputs=[language_pair, direction],
        outputs=[examples]
    )
    
    direction.change(
        fn=update_examples, 
        inputs=[language_pair, direction],
        outputs=[examples]
    )
    
    examples.change(
        fn=load_example,
        inputs=[examples],
        outputs=[text_input]
    )
    
    translate_btn.click(
        fn=translate_wrapper,
        inputs=[text_input, language_pair, direction, domain],
        outputs=[translation_output]
    )
    
    clear_btn.click(
        fn=lambda: ("", ""),
        outputs=[text_input, translation_output]
    )
    
    # Initialize examples on load
    interface.load(
        fn=update_examples,
        inputs=[language_pair, direction], 
        outputs=[examples]
    )

if __name__ == "__main__":
    print("Starting Healing Words E2E Translation Demo...")
    print(f"Available models: {list(demo_instance.models.keys())}")
    
    # Launch the interface
    interface.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        debug=True
    )