maitreyi2906 commited on
Commit
f59143c
·
verified ·
1 Parent(s): db6adf2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +326 -0
app.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Healing Words E2E Translation Demo
4
+
5
+ Interactive demo for biomedical translation models covering:
6
+ - Amharic ↔ English
7
+ - Hausa ↔ English
8
+ - Hindi ↔ English
9
+
10
+ Usage: python demo.py
11
+ """
12
+
13
+ import gradio as gr
14
+ import torch
15
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
16
+ from peft import PeftModel
17
+ import yaml
18
+ import os
19
+ from pathlib import Path
20
+
21
+ class TranslationDemo:
22
+ def __init__(self):
23
+ self.models = {}
24
+ self.tokenizers = {}
25
+ self.configs = {}
26
+
27
+ # Language pair information
28
+ self.language_pairs = {
29
+ "amh_en": {
30
+ "name": "Amharic ↔ English",
31
+ "src_lang": "Amharic (አማርኛ)",
32
+ "tgt_lang": "English",
33
+ "src_code": "amh_Ethi",
34
+ "tgt_code": "eng_Latn",
35
+ "base_model": "facebook/nllb-200-distilled-600M"
36
+ },
37
+ "ha_en": {
38
+ "name": "Hausa ↔ English",
39
+ "src_lang": "Hausa",
40
+ "tgt_lang": "English",
41
+ "src_code": "hau_Latn",
42
+ "tgt_code": "eng_Latn",
43
+ "base_model": "facebook/nllb-200-distilled-600M"
44
+ },
45
+ "hi_en": {
46
+ "name": "Hindi ↔ English",
47
+ "src_lang": "Hindi (हिन्दी)",
48
+ "tgt_lang": "English",
49
+ "src_code": "hin_Deva",
50
+ "tgt_code": "eng_Latn",
51
+ "base_model": "facebook/nllb-200-distilled-600M"
52
+ }
53
+ }
54
+
55
+ self.load_models()
56
+
57
+ def load_models(self):
58
+ """Load all available trained models"""
59
+ base_dir = Path("kit/outputs")
60
+
61
+ for pair_id, pair_info in self.language_pairs.items():
62
+ checkpoint_dir = base_dir / pair_id / "checkpoint-best"
63
+ config_path = Path(f"kit/configs/{pair_id}.yaml")
64
+
65
+ if checkpoint_dir.exists() and config_path.exists():
66
+ try:
67
+ print(f"Loading {pair_info['name']} model...")
68
+
69
+ # Load config
70
+ with open(config_path) as f:
71
+ config = yaml.safe_load(f)
72
+ self.configs[pair_id] = config
73
+
74
+ # Load base model and tokenizer
75
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(pair_info["base_model"])
76
+ tokenizer = AutoTokenizer.from_pretrained(pair_info["base_model"])
77
+
78
+ # Set language codes
79
+ tokenizer.src_lang = pair_info["src_code"]
80
+ tokenizer.tgt_lang = pair_info["tgt_code"]
81
+
82
+ # Load LoRA adapter
83
+ model = PeftModel.from_pretrained(base_model, checkpoint_dir)
84
+ model.eval()
85
+
86
+ self.models[pair_id] = model
87
+ self.tokenizers[pair_id] = tokenizer
88
+
89
+ print(f"✓ Loaded {pair_info['name']} model")
90
+
91
+ except Exception as e:
92
+ print(f"✗ Failed to load {pair_info['name']} model: {e}")
93
+ else:
94
+ print(f"✗ No trained model found for {pair_info['name']}")
95
+
96
+ def translate(self, text, language_pair, direction, domain="biomedical"):
97
+ """Translate text using the specified model"""
98
+ if not text.strip():
99
+ return "Please enter some text to translate."
100
+
101
+ if language_pair not in self.models:
102
+ return f"Model for {self.language_pairs[language_pair]['name']} is not available."
103
+
104
+ try:
105
+ model = self.models[language_pair]
106
+ tokenizer = self.tokenizers[language_pair]
107
+ config = self.configs[language_pair]
108
+
109
+ # Add domain tag if configured
110
+ if config.get("use_domain_tags", True) and domain:
111
+ text = f"[{domain}] {text}"
112
+
113
+ # Set tokenizer language codes based on direction
114
+ if direction == "to_english":
115
+ tokenizer.src_lang = self.language_pairs[language_pair]["src_code"]
116
+ tokenizer.tgt_lang = self.language_pairs[language_pair]["tgt_code"]
117
+ else: # from_english
118
+ tokenizer.src_lang = self.language_pairs[language_pair]["tgt_code"]
119
+ tokenizer.tgt_lang = self.language_pairs[language_pair]["src_code"]
120
+
121
+ # Tokenize input
122
+ inputs = tokenizer(text, return_tensors="pt", max_length=256, truncation=True)
123
+
124
+ # Generate translation
125
+ with torch.no_grad():
126
+ outputs = model.generate(
127
+ **inputs,
128
+ max_length=256,
129
+ num_beams=4,
130
+ early_stopping=True,
131
+ no_repeat_ngram_size=2
132
+ )
133
+
134
+ # Decode output
135
+ translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
136
+ return translation
137
+
138
+ except Exception as e:
139
+ return f"Translation error: {str(e)}"
140
+
141
+ def get_example_texts(self, language_pair, direction):
142
+ """Get example texts for the selected language pair and direction"""
143
+ examples = {
144
+ "amh_en": {
145
+ "to_english": [
146
+ "የጤና ሰራተኞች ኮቪድ-19 ን ለመከላከል ጭንብል መጠቀም አለባቸው።",
147
+ "ህመምተኛው ከፍተኛ ትኩሳት እና ሳል አለው።",
148
+ "መድሃኒቱን ምግብ በመብላት ይውሰዱት።"
149
+ ],
150
+ "from_english": [
151
+ "The patient has a high fever and persistent cough.",
152
+ "Healthcare workers should wear masks to prevent COVID-19.",
153
+ "Take this medication with food for better absorption."
154
+ ]
155
+ },
156
+ "ha_en": {
157
+ "to_english": [
158
+ "Ma'aikatan lafiya ya kamata su sa abin rufe fuska don karewa daga COVID-19.",
159
+ "Majiyyaci yana da zazzabi mai yawa da tari.",
160
+ "Ka sha wannan magani tare da abinci."
161
+ ],
162
+ "from_english": [
163
+ "The patient needs immediate medical attention.",
164
+ "Wash your hands frequently to prevent infection.",
165
+ "The medication should be taken twice daily."
166
+ ]
167
+ },
168
+ "hi_en": {
169
+ "to_english": [
170
+ "स्वास्थ्य कर्मचारियों को COVID-19 से बचने के लिए मास्क पहनना चाहिए।",
171
+ "मरीज़ को तेज़ बुखार और खांसी है।",
172
+ "इस दवा को भोजन के साथ लें।"
173
+ ],
174
+ "from_english": [
175
+ "The patient requires urgent medical intervention.",
176
+ "Monitor vital signs every two hours.",
177
+ "Administer the injection intramuscularly."
178
+ ]
179
+ }
180
+ }
181
+ return examples.get(language_pair, {}).get(direction, [])
182
+
183
+ # Initialize the demo
184
+ demo_instance = TranslationDemo()
185
+
186
+ def translate_wrapper(text, language_pair, direction, domain):
187
+ """Wrapper function for Gradio interface"""
188
+ return demo_instance.translate(text, language_pair, direction, domain)
189
+
190
+ def update_examples(language_pair, direction):
191
+ """Update example texts based on selection"""
192
+ examples = demo_instance.get_example_texts(language_pair, direction)
193
+ return gr.update(choices=examples, value=examples[0] if examples else "")
194
+
195
+ def load_example(example_text):
196
+ """Load selected example into text input"""
197
+ return example_text
198
+
199
+ # Create Gradio interface
200
+ with gr.Blocks(title="Healing Words E2E Translation Demo", theme=gr.themes.Soft()) as interface:
201
+
202
+ gr.HTML("""
203
+ <div style="text-align: center; padding: 20px;">
204
+ <h1>🌍 Healing Words E2E Translation Demo</h1>
205
+ <p>Interactive biomedical translation for low-resource languages</p>
206
+ <p><em>Amharic • Hausa • Hindi ↔ English</em></p>
207
+ </div>
208
+ """)
209
+
210
+ with gr.Row():
211
+ with gr.Column(scale=1):
212
+ language_pair = gr.Dropdown(
213
+ choices=[
214
+ ("Amharic ↔ English", "amh_en"),
215
+ ("Hausa ↔ English", "ha_en"),
216
+ ("Hindi ↔ English", "hi_en")
217
+ ],
218
+ value="amh_en",
219
+ label="Language Pair"
220
+ )
221
+
222
+ direction = gr.Radio(
223
+ choices=[
224
+ ("To English", "to_english"),
225
+ ("From English", "from_english")
226
+ ],
227
+ value="to_english",
228
+ label="Translation Direction"
229
+ )
230
+
231
+ domain = gr.Dropdown(
232
+ choices=["biomedical", "general"],
233
+ value="biomedical",
234
+ label="Domain"
235
+ )
236
+
237
+ with gr.Row():
238
+ with gr.Column():
239
+ text_input = gr.Textbox(
240
+ label="Input Text",
241
+ placeholder="Enter text to translate...",
242
+ lines=4
243
+ )
244
+
245
+ examples = gr.Dropdown(
246
+ label="Example Texts",
247
+ choices=[],
248
+ interactive=True
249
+ )
250
+
251
+ with gr.Row():
252
+ translate_btn = gr.Button("🔄 Translate", variant="primary")
253
+ clear_btn = gr.Button("🗑️ Clear")
254
+
255
+ with gr.Column():
256
+ translation_output = gr.Textbox(
257
+ label="Translation",
258
+ lines=6,
259
+ interactive=False
260
+ )
261
+
262
+ # Model status information
263
+ with gr.Accordion("📊 Model Information", open=False):
264
+ model_status = []
265
+ for pair_id, pair_info in demo_instance.language_pairs.items():
266
+ status = "✅ Available" if pair_id in demo_instance.models else "❌ Not loaded"
267
+ model_status.append(f"**{pair_info['name']}**: {status}")
268
+
269
+ gr.Markdown("\n".join(model_status))
270
+
271
+ gr.Markdown("""
272
+ ### About the Models
273
+ - **Base Model**: NLLB-200 Distilled 600M
274
+ - **Fine-tuning**: LoRA (Low-Rank Adaptation)
275
+ - **Domain**: Biomedical + General
276
+ - **Training Data**: Synthetic templates + Real biomedical text
277
+ """)
278
+
279
+ # Event handlers
280
+ language_pair.change(
281
+ fn=update_examples,
282
+ inputs=[language_pair, direction],
283
+ outputs=[examples]
284
+ )
285
+
286
+ direction.change(
287
+ fn=update_examples,
288
+ inputs=[language_pair, direction],
289
+ outputs=[examples]
290
+ )
291
+
292
+ examples.change(
293
+ fn=load_example,
294
+ inputs=[examples],
295
+ outputs=[text_input]
296
+ )
297
+
298
+ translate_btn.click(
299
+ fn=translate_wrapper,
300
+ inputs=[text_input, language_pair, direction, domain],
301
+ outputs=[translation_output]
302
+ )
303
+
304
+ clear_btn.click(
305
+ fn=lambda: ("", ""),
306
+ outputs=[text_input, translation_output]
307
+ )
308
+
309
+ # Initialize examples on load
310
+ interface.load(
311
+ fn=update_examples,
312
+ inputs=[language_pair, direction],
313
+ outputs=[examples]
314
+ )
315
+
316
+ if __name__ == "__main__":
317
+ print("Starting Healing Words E2E Translation Demo...")
318
+ print(f"Available models: {list(demo_instance.models.keys())}")
319
+
320
+ # Launch the interface
321
+ interface.launch(
322
+ server_name="0.0.0.0",
323
+ server_port=7860,
324
+ share=False,
325
+ debug=True
326
+ )