|
|
import os |
|
|
import gradio as gr |
|
|
from transformers import pipeline |
|
|
import plotly.express as px |
|
|
import pandas as pd |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
MODEL_LIST = [ |
|
|
"thethinkmachine/immune-resin", |
|
|
"thethinkmachine/bright-avocado", |
|
|
"thethinkmachine/kickass-supercluster", |
|
|
] |
|
|
|
|
|
pipeline_cache = {} |
|
|
|
|
|
def get_pipeline(model_name): |
|
|
if model_name not in pipeline_cache: |
|
|
pipeline_cache[model_name] = pipeline( |
|
|
"text-classification", |
|
|
model=model_name, |
|
|
token=HF_TOKEN, |
|
|
return_all_scores=True |
|
|
) |
|
|
return pipeline_cache[model_name] |
|
|
|
|
|
|
|
|
def classify(text, model_name, chart_type, threshold): |
|
|
if not text.strip(): |
|
|
|
|
|
return None, None |
|
|
|
|
|
clf = get_pipeline(model_name) |
|
|
results = clf(text)[0] |
|
|
df = pd.DataFrame(results) |
|
|
|
|
|
|
|
|
df = df.sort_values(by="score", ascending=False).reset_index(drop=True) |
|
|
|
|
|
|
|
|
df["highlight"] = df["score"].apply(lambda x: "High" if x >= threshold else "Low") |
|
|
|
|
|
|
|
|
if chart_type == "Radar Chart": |
|
|
fig = px.line_polar( |
|
|
df, |
|
|
r="score", |
|
|
theta="label", |
|
|
line_close=True, |
|
|
color="highlight", |
|
|
title=f"Label Probabilities - {model_name}" |
|
|
) |
|
|
fig.update_traces(fill='toself') |
|
|
fig.update_layout(polar=dict(radialaxis=dict(range=[0, 1]))) |
|
|
else: |
|
|
fig = px.bar( |
|
|
df, |
|
|
x="label", |
|
|
y="score", |
|
|
color="highlight", |
|
|
title=f"Label Probabilities - {model_name}" |
|
|
) |
|
|
fig.update_layout(yaxis=dict(range=[0, 1])) |
|
|
|
|
|
return df, fig |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Ocean()) as demo: |
|
|
gr.Markdown("# Ekman Emotions Playground π€πΎπ¦") |
|
|
gr.Markdown("### Why let humans play with your emotions when a robot can do it for you for free? ...Cheaper than therapy, 100% less effective!") |
|
|
|
|
|
with gr.Row(): |
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=MODEL_LIST, |
|
|
label="Select Model", |
|
|
value=MODEL_LIST[0] |
|
|
) |
|
|
chart_dropdown = gr.Dropdown( |
|
|
choices=["Radar Chart", "Bar Chart"], |
|
|
label="Chart Type", |
|
|
value="Bar Chart" |
|
|
) |
|
|
threshold_slider = gr.Slider( |
|
|
0, 1, |
|
|
value=0.5, |
|
|
step=0.01, |
|
|
label="Highlight Threshold" |
|
|
) |
|
|
|
|
|
text_input = gr.Textbox(label="Input Text") |
|
|
|
|
|
with gr.Row(): |
|
|
output_table = gr.DataFrame(label="Scores Table") |
|
|
output_plot = gr.Plot(label="Probability Chart") |
|
|
|
|
|
classify_btn = gr.Button("Run Classification") |
|
|
|
|
|
classify_btn.click( |
|
|
classify, |
|
|
inputs=[text_input, model_dropdown, chart_dropdown, threshold_slider], |
|
|
outputs=[output_table, output_plot] |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
"I only saw her once and I'm head over heels!", |
|
|
"Pay you for what, just standing there?!", |
|
|
"Dumbass Broncos fans circa December 2015.", |
|
|
"It's great that you're a recovering addict, that's cool. Have you ever tried DMT?", |
|
|
"I'm scared to even ask my mom ,I might get yelled at π", |
|
|
"hurr durr I like using reddit, and anyone who doesn't agree with me is a retard", |
|
|
"They're really pushing me into this... once I go, there's no coming back you know?", |
|
|
"Considering I haven't eaten or drunk anything in about twenty hours, my head hurts.", |
|
|
], |
|
|
inputs=[text_input] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |