Spaces:
Sleeping
Sleeping
| import json | |
| import torch | |
| from transformers import BertTokenizerFast, BertForTokenClassification | |
| import gradio as gr | |
| # Initialize tokenizer and model | |
| tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') | |
| model = BertForTokenClassification.from_pretrained('ethical-spectacle/social-bias-ner') | |
| model.eval() | |
| model.to('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Mapping IDs to labels | |
| id2label = { | |
| 0: 'O', | |
| 1: 'B-STEREO', | |
| 2: 'I-STEREO', | |
| 3: 'B-GEN', | |
| 4: 'I-GEN', | |
| 5: 'B-UNFAIR', | |
| 6: 'I-UNFAIR' | |
| } | |
| # Entity colors for highlights | |
| label_colors = { | |
| "STEREO": "rgba(255, 0, 0, 0.2)", # Light Red | |
| "GEN": "rgba(0, 0, 255, 0.2)", # Light Blue | |
| "UNFAIR": "rgba(0, 255, 0, 0.2)" # Light Green | |
| } | |
| # Post-process entity tags | |
| def post_process_entities(result): | |
| prev_entity_type = None | |
| for token_data in result: | |
| labels = token_data["labels"] | |
| # Handle sequence rules | |
| new_labels = [] | |
| for label_data in labels: | |
| label = label_data['label'] | |
| if label.startswith("B-") and prev_entity_type == label[2:]: | |
| new_labels.append({"label": f"I-{label[2:]}", "confidence": label_data["confidence"]}) | |
| elif label.startswith("I-") and prev_entity_type != label[2:]: | |
| new_labels.append({"label": f"B-{label[2:]}", "confidence": label_data["confidence"]}) | |
| else: | |
| new_labels.append(label_data) | |
| prev_entity_type = label[2:] | |
| token_data["labels"] = new_labels | |
| return result | |
| # Generate HTML matrix and JSON results with probabilities | |
| def predict_ner_tags_with_json(sentence): | |
| inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
| input_ids = inputs['input_ids'].to(model.device) | |
| attention_mask = inputs['attention_mask'].to(model.device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
| logits = outputs.logits | |
| probabilities = torch.sigmoid(logits) | |
| tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) | |
| result = [] | |
| for i, token in enumerate(tokens): | |
| if token not in tokenizer.all_special_tokens: | |
| label_indices = (probabilities[0][i] > 0.52).nonzero(as_tuple=False).squeeze(-1) | |
| labels = [ | |
| { | |
| "label": id2label[idx.item()], | |
| "confidence": round(probabilities[0][i][idx].item() * 100, 2) | |
| } | |
| for idx in label_indices | |
| ] | |
| result.append({"token": token.replace("##", ""), "labels": labels}) | |
| result = post_process_entities(result) | |
| # Create table rows | |
| word_row = [] | |
| stereo_row = [] | |
| gen_row = [] | |
| unfair_row = [] | |
| for token_data in result: | |
| token = token_data["token"] | |
| labels = token_data["labels"] | |
| word_row.append(f"<span style='font-weight:bold;'>{token}</span>") | |
| # STEREO | |
| stereo_labels = [ | |
| f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "STEREO" in label_data["label"] | |
| ] | |
| stereo_row.append( | |
| f"<span style='background:{label_colors['STEREO']}; border-radius:6px; padding:2px 5px;'>{', '.join(stereo_labels)}</span>" | |
| if stereo_labels else " " | |
| ) | |
| # GEN | |
| gen_labels = [ | |
| f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "GEN" in label_data["label"] | |
| ] | |
| gen_row.append( | |
| f"<span style='background:{label_colors['GEN']}; border-radius:6px; padding:2px 5px;'>{', '.join(gen_labels)}</span>" | |
| if gen_labels else " " | |
| ) | |
| # UNFAIR | |
| unfair_labels = [ | |
| f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "UNFAIR" in label_data["label"] | |
| ] | |
| unfair_row.append( | |
| f"<span style='background:{label_colors['UNFAIR']}; border-radius:6px; padding:2px 5px;'>{', '.join(unfair_labels)}</span>" | |
| if unfair_labels else " " | |
| ) | |
| matrix_html = f""" | |
| <table style='border-collapse:collapse; width:100%; font-family:monospace; text-align:left;'> | |
| <tr> | |
| <td><strong>Text Sequence</strong></td> | |
| {''.join(f"<td>{word}</td>" for word in word_row)} | |
| </tr> | |
| <tr> | |
| <td><strong>Generalizations</strong></td> | |
| {''.join(f"<td>{cell}</td>" for cell in gen_row)} | |
| </tr> | |
| <tr> | |
| <td><strong>Unfairness</strong></td> | |
| {''.join(f"<td>{cell}</td>" for cell in unfair_row)} | |
| </tr> | |
| <tr> | |
| <td><strong>Stereotypes</strong></td> | |
| {''.join(f"<td>{cell}</td>" for cell in stereo_row)} | |
| </tr> | |
| </table> | |
| """ | |
| # JSON string | |
| json_result = json.dumps(result, indent=4) | |
| return f"{matrix_html}<br><pre>{json_result}</pre>" | |
| # Gradio Interface | |
| iface = gr.Blocks() | |
| with iface: | |
| with gr.Row(): | |
| gr.Markdown( | |
| """ | |
| # GUS-Net 🕵 | |
| [GUS-Net](https://huggingface.co/ethical-spectacle/social-bias-ner) is a `BertForTokenClassification` based model, trained on the [GUS dataset](https://huggingface.co/datasets/ethical-spectacle/gus-dataset-v1). It preforms multi-label named-entity recognition of socially biased entities, intended to reveal the underlying structure of bias rather than a one-size fits all definition. | |
| You can find the full collection of resources introduced in our paper [here](https://huggingface.co/collections/ethical-spectacle/gus-net-66edfe93801ea45d7a26a10f). | |
| This [blog post](https://huggingface.co/blog/maximuspowers/bias-entity-recognition) walks through the training and architecture of the model. | |
| Enter a sentence for named-entity recognition of biased entities: | |
| - **Generalizations (GEN)** | |
| - **Unfairness (UNFAIR)** | |
| - **Stereotypes (STEREO)** | |
| Labels follow the BIO format. Try it out: | |
| """ | |
| ) | |
| with gr.Row(): | |
| input_box = gr.Textbox(label="Input Sentence") | |
| with gr.Row(): | |
| output_box = gr.HTML(label="Entity Matrix and JSON Output") | |
| input_box.change(predict_ner_tags_with_json, inputs=[input_box], outputs=[output_box]) | |
| iface.launch(share=True) | |