Spaces:
Running
Running
| """ | |
| Beautiful Medical NER Demo using OpenMed Models | |
| A comprehensive Named Entity Recognition demo for medical professionals | |
| featuring multiple specialized medical models with beautiful entity visualization. | |
| """ | |
| import gradio as gr | |
| import spacy | |
| from spacy import displacy | |
| from transformers import pipeline | |
| import warnings | |
| import logging | |
| from typing import Dict, List, Tuple | |
| import random # Added for random color generation | |
| # Suppress warnings for cleaner output | |
| warnings.filterwarnings("ignore") | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| # Model configurations | |
| MODELS = { | |
| "Oncology Detection": { | |
| "model_id": "OpenMed/OpenMed-NER-OncologyDetect-SuperMedical-355M", | |
| "description": "Specialized in cancer, genetics, and oncology entities", | |
| }, | |
| "Pharmaceutical Detection": { | |
| "model_id": "OpenMed/OpenMed-NER-PharmaDetect-SuperClinical-434M", | |
| "description": "Detects drugs, chemicals, and pharmaceutical entities", | |
| }, | |
| "Disease Detection": { | |
| "model_id": "OpenMed/OpenMed-NER-DiseaseDetect-SuperClinical-434M", | |
| "description": "Identifies diseases, conditions, and pathologies", | |
| }, | |
| "Genome Detection": { | |
| "model_id": "OpenMed/OpenMed-NER-GenomeDetect-ModernClinical-395M", | |
| "description": "Recognizes genes, proteins, and genomic entities", | |
| }, | |
| } | |
| # Medical text examples for each model | |
| EXAMPLES = { | |
| "Oncology Detection": [ | |
| "The patient presented with metastatic adenocarcinoma of the lung with mutations in EGFR and KRAS genes. Treatment with erlotinib was initiated, targeting the epidermal growth factor receptor pathway.", | |
| "Histological examination revealed invasive ductal carcinoma with high-grade nuclear features. The tumor showed positive estrogen receptor and HER2 amplification, indicating potential for targeted therapy.", | |
| "The oncologist recommended adjuvant chemotherapy with doxorubicin and cyclophosphamide, followed by paclitaxel, to target rapidly dividing cancer cells in the breast tissue.", | |
| ], | |
| "Pharmaceutical Detection": [ | |
| "The patient was prescribed metformin 500mg twice daily for diabetes management, along with lisinopril 10mg for hypertension control and atorvastatin 20mg for cholesterol reduction.", | |
| "Administration of morphine sulfate provided effective pain relief, while ondansetron prevented chemotherapy-induced nausea. The patient also received dexamethasone as an anti-inflammatory agent.", | |
| "The pharmacokinetic study evaluated the absorption of ibuprofen and its interaction with warfarin, monitoring plasma concentrations and potential bleeding risks.", | |
| ], | |
| "Disease Detection": [ | |
| "The patient was diagnosed with type 2 diabetes mellitus, hypertension, and coronary artery disease. Additional findings included diabetic nephropathy and peripheral neuropathy.", | |
| "Clinical presentation was consistent with acute myocardial infarction complicated by cardiogenic shock. The patient also had a history of chronic obstructive pulmonary disease and atrial fibrillation.", | |
| "Laboratory results confirmed the diagnosis of rheumatoid arthritis with elevated inflammatory markers. The patient also exhibited symptoms of Sjögren's syndrome and osteoporosis.", | |
| ], | |
| "Genome Detection": [ | |
| "Genetic analysis revealed mutations in the BRCA1 and BRCA2 genes, significantly increasing the risk of hereditary breast and ovarian cancer. The p53 tumor suppressor gene also showed alterations.", | |
| "Expression profiling identified upregulation of MYC oncogene and downregulation of PTEN tumor suppressor. The mTOR signaling pathway showed significant activation in the tumor samples.", | |
| "Whole genome sequencing detected variants in CFTR gene associated with cystic fibrosis, along with polymorphisms in CYP2D6 affecting drug metabolism and APOE influencing Alzheimer's risk.", | |
| ], | |
| } | |
| class MedicalNERApp: | |
| def __init__(self): | |
| self.pipelines = {} | |
| self.nlp = spacy.blank("en") # SpaCy model for visualization | |
| self.load_models() | |
| def load_models(self): | |
| """Load and cache all models for better performance""" | |
| print("🏥 Loading Medical NER Models...") | |
| for model_name, config in MODELS.items(): | |
| print(f"Loading {model_name}...") | |
| try: | |
| # Set aggregation_strategy to None to get raw BIO tokens for manual grouping | |
| ner_pipeline = pipeline( | |
| "ner", model=config["model_id"], aggregation_strategy=None | |
| ) | |
| self.pipelines[model_name] = ner_pipeline | |
| print(f"✅ {model_name} loaded successfully") | |
| except Exception as e: | |
| print(f"❌ Error loading {model_name}: {str(e)}") | |
| self.pipelines[model_name] = None | |
| print("🎉 All models loaded and cached!") | |
| def group_entities(self, ner_results: List[Dict], text: str) -> List[Dict]: | |
| """ | |
| Groups raw BIO-tagged tokens into final entities. | |
| """ | |
| print(f"\nDEBUG: Raw model output:") | |
| for token in ner_results: | |
| print(f"Token: {token['word']:20} | Label: {token['entity']:20} | Score: {token['score']:.3f}") | |
| final_entities = [] | |
| current_entity = None | |
| for i, token in enumerate(ner_results): | |
| # Skip special tokens and whitespace-only tokens | |
| if not token['word'].strip(): | |
| continue | |
| label = token['entity'] | |
| score = token['score'] | |
| # Skip O tags | |
| if label == 'O': | |
| if current_entity: | |
| print(f"DEBUG: Finalizing entity on O tag: {current_entity}") | |
| final_entities.append(current_entity) | |
| current_entity = None | |
| continue | |
| # Clean the label | |
| clean_label = label.replace('B-', '').replace('I-', '') | |
| # Start of new entity | |
| if label.startswith('B-'): | |
| # Check if this should be merged with the previous entity | |
| # This handles cases where the model outputs consecutive B- tags for the same entity | |
| if (current_entity and | |
| clean_label == current_entity['label'] and | |
| token['start'] <= current_entity['end'] + 2): # Allow small gaps | |
| # Merge with current entity | |
| current_entity['end'] = token['end'] | |
| current_entity['text'] = text[current_entity['start']:token['end']] | |
| current_entity['tokens'].append(token['word']) | |
| current_entity['score'] = (current_entity['score'] + score) / 2 | |
| print(f"DEBUG: Merged consecutive B- tag: {current_entity}") | |
| else: | |
| # Finalize previous and start new | |
| if current_entity: | |
| print(f"DEBUG: Finalizing entity on B- tag: {current_entity}") | |
| final_entities.append(current_entity) | |
| current_entity = { | |
| 'label': clean_label, | |
| 'start': token['start'], | |
| 'end': token['end'], | |
| 'text': text[token['start']:token['end']], | |
| 'tokens': [token['word']], | |
| 'score': score | |
| } | |
| print(f"DEBUG: Started new entity: {current_entity}") | |
| # Inside of entity | |
| elif label.startswith('I-'): | |
| # If we have a current entity and labels match | |
| if current_entity and clean_label == current_entity['label']: | |
| current_entity['end'] = token['end'] | |
| current_entity['text'] = text[current_entity['start']:token['end']] | |
| current_entity['tokens'].append(token['word']) | |
| current_entity['score'] = (current_entity['score'] + score) / 2 | |
| print(f"DEBUG: Extended entity: {current_entity}") | |
| else: | |
| # Orphan I- tag, treat as B- | |
| if current_entity: | |
| print(f"DEBUG: Finalizing entity on orphan I- tag: {current_entity}") | |
| final_entities.append(current_entity) | |
| current_entity = { | |
| 'label': clean_label, | |
| 'start': token['start'], | |
| 'end': token['end'], | |
| 'text': text[token['start']:token['end']], | |
| 'tokens': [token['word']], | |
| 'score': score | |
| } | |
| print(f"DEBUG: Started new entity from orphan I- tag: {current_entity}") | |
| # Add final entity if exists | |
| if current_entity: | |
| print(f"DEBUG: Finalizing last entity: {current_entity}") | |
| final_entities.append(current_entity) | |
| # Post-process: merge adjacent entities of the same type that are very close | |
| merged_entities = [] | |
| for entity in final_entities: | |
| if (merged_entities and | |
| merged_entities[-1]['label'] == entity['label'] and | |
| entity['start'] <= merged_entities[-1]['end'] + 3): # Allow small gaps | |
| # Merge with last entity | |
| last_entity = merged_entities[-1] | |
| merged_entity = { | |
| 'label': entity['label'], | |
| 'start': last_entity['start'], | |
| 'end': entity['end'], | |
| 'text': text[last_entity['start']:entity['end']], | |
| 'tokens': last_entity['tokens'] + entity['tokens'], | |
| 'score': (last_entity['score'] + entity['score']) / 2 | |
| } | |
| merged_entities[-1] = merged_entity | |
| print(f"DEBUG: Post-merged entities: {merged_entity}") | |
| else: | |
| merged_entities.append(entity) | |
| print(f"\nDEBUG: Final grouped entities:") | |
| for entity in merged_entities: | |
| print(f"Entity: {entity['text']:30} | Label: {entity['label']:20} | Score: {entity['score']:.3f}") | |
| return merged_entities | |
| def _finalize_entity(self, tokens: List[Dict], text: str) -> Dict: | |
| """Helper to construct a final entity from its constituent tokens.""" | |
| label = tokens[0]['entity'].replace('B-', '').replace('I-', '') | |
| start_char = tokens[0]['start'] | |
| end_char = tokens[-1]['end'] | |
| return { | |
| "label": label, | |
| "start": start_char, | |
| "end": end_char, | |
| "text": text[start_char:end_char], | |
| "confidence": sum(t['score'] for t in tokens) / len(tokens), | |
| } | |
| def create_spacy_visualization(self, text: str, entities: List[Dict], model_name: str) -> str: | |
| """Create spaCy displaCy visualization with dynamic colors.""" | |
| print("\nDEBUG: Creating spaCy visualization") | |
| print(f"Input text: {text}") | |
| print("Entities to visualize:") | |
| for ent in entities: | |
| print(f" {ent['text']} ({ent['label']}) [{ent['start']}:{ent['end']}]") | |
| doc = self.nlp(text) | |
| spacy_ents = [] | |
| for entity in entities: | |
| try: | |
| # Clean up the entity text (remove leading/trailing spaces) | |
| start = entity['start'] | |
| end = entity['end'] | |
| # Strip leading spaces | |
| while start < end and text[start].isspace(): | |
| start += 1 | |
| # Strip trailing spaces | |
| while end > start and text[end-1].isspace(): | |
| end -= 1 | |
| # Try to create span with cleaned boundaries | |
| span = doc.char_span(start, end, label=entity['label']) | |
| if span is not None: | |
| spacy_ents.append(span) | |
| print(f"✓ Created span: '{span.text}' -> {entity['label']}") | |
| else: | |
| print(f"✗ Failed to create span for: '{text[start:end]}' -> {entity['label']}") | |
| # Try original boundaries as fallback | |
| span = doc.char_span(entity['start'], entity['end'], label=entity['label']) | |
| if span is not None: | |
| spacy_ents.append(span) | |
| print(f"✓ Created span with original boundaries: '{span.text}' -> {entity['label']}") | |
| else: | |
| print(f"✗ Failed with original boundaries too: '{entity['text']}' -> {entity['label']}") | |
| except Exception as e: | |
| print(f"Error creating span for entity {entity}: {str(e)}") | |
| # Filter out overlapping entities | |
| spacy_ents = spacy.util.filter_spans(spacy_ents) | |
| doc.ents = spacy_ents | |
| print(f"\nDEBUG: Final spaCy entities:") | |
| for ent in doc.ents: | |
| print(f" {ent.text} ({ent.label_}) [{ent.start_char}:{ent.end_char}]") | |
| # Define a bright, engaging color palette | |
| color_palette = { | |
| "DISEASE": "#FF5733", # Bright red-orange | |
| "CHEM": "#33FF57", # Bright green | |
| "GENE/PROTEIN": "#3357FF", # Bright blue | |
| "Cancer": "#FF33F6", # Bright pink | |
| "Cell": "#33FFF6", # Bright cyan | |
| "Organ": "#F6FF33", # Bright yellow | |
| "Tissue": "#FF8333", # Bright orange | |
| "Simple_chemical": "#8333FF", # Bright purple | |
| "Gene_or_gene_product": "#33FF83", # Bright mint | |
| } | |
| # Get unique entity types and assign colors | |
| unique_labels = sorted(list(set(ent.label_ for ent in doc.ents))) | |
| colors = {} | |
| for label in unique_labels: | |
| colors[label] = color_palette.get(label, "#" + ''.join([hex(x)[2:].zfill(2) for x in (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))])) | |
| options = { | |
| "ents": unique_labels, | |
| "colors": colors, | |
| "style": "max-width: 100%; line-height: 2.5; direction: ltr;" | |
| } | |
| print(f"\nDEBUG: Visualization options:") | |
| print(f"Entity types: {unique_labels}") | |
| print(f"Color mapping: {colors}") | |
| return displacy.render(doc, style="ent", options=options, page=False) | |
| def predict_entities(self, text: str, model_name: str) -> Tuple[str, str]: | |
| """ | |
| Predict entities using a robust aggregation strategy. | |
| """ | |
| if not text.strip(): | |
| return "<p>Please enter medical text to analyze.</p>", "No text provided" | |
| if model_name not in self.pipelines or self.pipelines[model_name] is None: | |
| return f"<p>❌ Model {model_name} is not available.</p>", "Model not available" | |
| try: | |
| print(f"\nDEBUG: Processing text with {model_name}") | |
| print(f"Text: {text}") | |
| # Get raw token predictions | |
| raw_tokens = self.pipelines[model_name](text) | |
| print(f"Got {len(raw_tokens)} raw tokens from model") | |
| if not raw_tokens: | |
| print("No tokens returned from model") | |
| return "<p>No entities detected.</p>", "No entities found" | |
| # Group raw tokens into complete entities | |
| final_entities = self.group_entities(raw_tokens, text) | |
| print(f"Grouped into {len(final_entities)} final entities") | |
| if not final_entities: | |
| print("No entities after grouping") | |
| return "<p>No entities detected.</p>", "No entities found" | |
| # Create visualization and summary | |
| html_output = self.create_spacy_visualization(text, final_entities, model_name) | |
| print(f"Generated visualization HTML ({len(html_output)} chars)") | |
| wrapped_html = self.wrap_displacy_output(html_output, model_name, len(final_entities)) | |
| print(f"Wrapped visualization HTML ({len(wrapped_html)} chars)") | |
| summary = self.create_summary(final_entities, model_name) | |
| print(f"Generated summary ({len(summary)} chars)") | |
| return wrapped_html, summary | |
| except Exception as e: | |
| import traceback | |
| print(f"ERROR in predict_entities: {str(e)}") | |
| traceback.print_exc() | |
| error_msg = f"Error during prediction: {str(e)}" | |
| return f"<p>❌ {error_msg}</p>", error_msg | |
| def wrap_displacy_output(self, displacy_html: str, model_name: str, entity_count: int) -> str: | |
| """Wrap displaCy output in a beautiful container.""" | |
| return f""" | |
| <div style="font-family: 'Segoe UI', Arial, sans-serif; | |
| border-radius: 10px; | |
| box-shadow: 0 4px 6px rgba(0,0,0,0.1); | |
| overflow: hidden;"> | |
| <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; padding: 15px; text-align: center;"> | |
| <h3 style="margin: 0; font-size: 18px;">{model_name}</h3> | |
| <p style="margin: 5px 0 0 0; opacity: 0.9; font-size: 14px;"> | |
| Found {entity_count} medical entities | |
| </p> | |
| </div> | |
| <div style="padding: 20px; margin: 0; line-height: 2.5;"> | |
| {displacy_html} | |
| </div> | |
| </div> | |
| """ | |
| def create_summary(self, entities: List[Dict], model_name: str) -> str: | |
| """Create a summary of detected entities.""" | |
| if not entities: | |
| return "No entities detected." | |
| entity_counts = {} | |
| for entity in entities: | |
| label = entity["label"] | |
| if label not in entity_counts: | |
| entity_counts[label] = [] | |
| entity_counts[label].append(entity) | |
| summary_parts = [f"📊 **{model_name} Summary**\n"] | |
| summary_parts.append(f"Total entities detected: **{len(entities)}**\n") | |
| for label, ents in sorted(entity_counts.items()): | |
| avg_confidence = sum(e["score"] for e in ents) / len(ents) | |
| unique_texts = sorted(list(set(e["text"] for e in ents))) | |
| summary_parts.append( | |
| f"• **{label}**: {len(ents)} instances " | |
| f"(avg confidence: {avg_confidence:.2f})\n" | |
| f" Examples: {', '.join(unique_texts[:3])}" | |
| f"{'...' if len(unique_texts) > 3 else ''}\n" | |
| ) | |
| # Add BIO tags information | |
| summary_parts.append("\n🏷️ **BIO Tagging Info**\n") | |
| summary_parts.append("The model uses BIO (Beginning-Inside-Outside) tagging scheme:\n") | |
| summary_parts.append("• `B-LABEL`: Beginning of an entity\n") | |
| summary_parts.append("• `I-LABEL`: Inside/continuation of an entity\n") | |
| summary_parts.append("• `O`: Outside any entity (not shown in results)\n") | |
| # Show example BIO tags for detected entity types | |
| if entity_counts: | |
| summary_parts.append("\nDetected entity types with their BIO tags:\n") | |
| for label in sorted(entity_counts.keys()): | |
| summary_parts.append(f"• `B-{label}`, `I-{label}`: {label} entities\n") | |
| return "\n".join(summary_parts) | |
| # Initialize the app | |
| print("🚀 Initializing Medical NER Application...") | |
| ner_app = MedicalNERApp() | |
| # Run a short warmup for each model here so it's not the first time | |
| print("🔥 Warming up models...") | |
| warmup_text = "The patient has diabetes and takes metformin." | |
| for model_name in MODELS.keys(): | |
| if ner_app.pipelines[model_name] is not None: | |
| try: | |
| print(f"Warming up {model_name}...") | |
| _ = ner_app.predict_entities(warmup_text, model_name) | |
| print(f"✅ {model_name} warmed up successfully") | |
| except Exception as e: | |
| print(f"⚠️ Warmup failed for {model_name}: {str(e)}") | |
| print("🎉 Model warmup complete!") | |
| def predict_wrapper(text: str, model_name: str): | |
| """Wrapper function for Gradio interface""" | |
| html_output, summary = ner_app.predict_entities(text, model_name) | |
| return html_output, summary | |
| def load_example(model_name: str, example_idx: int): | |
| """Load example text for the selected model""" | |
| if model_name in EXAMPLES and 0 <= example_idx < len(EXAMPLES[model_name]): | |
| return EXAMPLES[model_name][example_idx] | |
| return "" | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| title="🏥 Medical NER Expert", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| } | |
| .main-header { | |
| text-align: center; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| padding: 2rem; | |
| border-radius: 15px; | |
| margin-bottom: 2rem; | |
| box-shadow: 0 8px 32px rgba(0,0,0,0.1); | |
| } | |
| .model-info { | |
| padding: 1rem; | |
| border-radius: 10px; | |
| border-left: 4px solid #667eea; | |
| margin: 1rem 0; | |
| } | |
| """, | |
| ) as demo: | |
| # Header | |
| gr.HTML( | |
| """ | |
| <div class="main-header"> | |
| <h1>🏥 Medical NER Expert</h1> | |
| <p>Advanced Named Entity Recognition for Medical Professionals</p> | |
| <p>Powered by OpenMed's specialized medical AI models with spaCy displaCy visualization</p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Model selection | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value="Oncology Detection", | |
| label="🔬 Select Medical NER Model", | |
| info="Choose the specialized model for your analysis", | |
| ) | |
| # Model info display | |
| model_info = gr.HTML( | |
| value=f""" | |
| <div class="model-info"> | |
| <strong>Oncology Detection</strong><br> | |
| {MODELS["Oncology Detection"]["description"]} | |
| </div> | |
| """ | |
| ) | |
| # Text input | |
| text_input = gr.Textbox( | |
| lines=8, | |
| placeholder="Enter medical text here for entity recognition...", | |
| label="📝 Medical Text Input", | |
| value=EXAMPLES["Oncology Detection"][0], | |
| ) | |
| # Example buttons | |
| with gr.Row(): | |
| example_buttons = [] | |
| for i in range(3): | |
| btn = gr.Button(f"Example {i+1}", size="sm", variant="secondary") | |
| example_buttons.append(btn) | |
| # Analyze button | |
| analyze_btn = gr.Button("🔍 Analyze Text", variant="primary", size="lg") | |
| with gr.Column(scale=3): | |
| # Results | |
| results_html = gr.HTML( | |
| label="🎯 Entity Recognition Results", | |
| value="<p>Select a model and enter text to see entity recognition results.</p>", | |
| ) | |
| # Summary | |
| summary_output = gr.Markdown( | |
| value="Analysis summary will appear here...", | |
| label="📊 Analysis Summary", | |
| ) | |
| # Update model info when model changes | |
| def update_model_info(model_name): | |
| if model_name in MODELS: | |
| return f""" | |
| <div class="model-info"> | |
| <strong>{model_name}</strong><br> | |
| {MODELS[model_name]["description"]}<br> | |
| <small>Model: {MODELS[model_name]["model_id"]}</small> | |
| </div> | |
| """ | |
| return "" | |
| model_dropdown.change( | |
| update_model_info, inputs=[model_dropdown], outputs=[model_info] | |
| ) | |
| # Example button handlers | |
| for i, btn in enumerate(example_buttons): | |
| btn.click( | |
| lambda model_name, idx=i: load_example(model_name, idx), | |
| inputs=[model_dropdown], | |
| outputs=[text_input], | |
| ) | |
| # Main analysis function | |
| analyze_btn.click( | |
| predict_wrapper, | |
| inputs=[text_input, model_dropdown], | |
| outputs=[results_html, summary_output], | |
| ) | |
| # Auto-update when model changes (load first example) | |
| model_dropdown.change( | |
| lambda model_name: load_example(model_name, 0), | |
| inputs=[model_dropdown], | |
| outputs=[text_input], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=False, # Not needed on Spaces | |
| show_error=True, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ) | |