thethinkmachine's picture
Update app.py
0f8a47d verified
import os
import gradio as gr
from transformers import pipeline
import plotly.express as px
import pandas as pd
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_LIST = [
"thethinkmachine/immune-resin",
"thethinkmachine/bright-avocado",
"thethinkmachine/kickass-supercluster",
]
pipeline_cache = {}
def get_pipeline(model_name):
if model_name not in pipeline_cache:
pipeline_cache[model_name] = pipeline(
"text-classification",
model=model_name,
token=HF_TOKEN,
return_all_scores=True
)
return pipeline_cache[model_name]
def classify(text, model_name, chart_type, threshold):
if not text.strip():
# return empty outputs for the table and plot
return None, None
clf = get_pipeline(model_name)
results = clf(text)[0]
df = pd.DataFrame(results)
# Sort labels by probability
df = df.sort_values(by="score", ascending=False).reset_index(drop=True)
# Highlight labels above threshold
df["highlight"] = df["score"].apply(lambda x: "High" if x >= threshold else "Low")
# Chart selection
if chart_type == "Radar Chart":
fig = px.line_polar(
df,
r="score",
theta="label",
line_close=True,
color="highlight",
title=f"Label Probabilities - {model_name}"
)
fig.update_traces(fill='toself')
fig.update_layout(polar=dict(radialaxis=dict(range=[0, 1])))
else:
fig = px.bar(
df,
x="label",
y="score",
color="highlight",
title=f"Label Probabilities - {model_name}"
)
fig.update_layout(yaxis=dict(range=[0, 1]))
return df, fig
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
gr.Markdown("# Ekman Emotions Playground πŸ€–πŸ‘ΎπŸ¦œ")
gr.Markdown("### Why let humans play with your emotions when a robot can do it for you for free? ...Cheaper than therapy, 100% less effective!")
with gr.Row():
model_dropdown = gr.Dropdown(
choices=MODEL_LIST,
label="Select Model",
value=MODEL_LIST[0]
)
chart_dropdown = gr.Dropdown(
choices=["Radar Chart", "Bar Chart"],
label="Chart Type",
value="Bar Chart"
)
threshold_slider = gr.Slider(
0, 1,
value=0.5,
step=0.01,
label="Highlight Threshold"
)
text_input = gr.Textbox(label="Input Text")
with gr.Row():
output_table = gr.DataFrame(label="Scores Table")
output_plot = gr.Plot(label="Probability Chart")
classify_btn = gr.Button("Run Classification")
classify_btn.click(
classify,
inputs=[text_input, model_dropdown, chart_dropdown, threshold_slider],
outputs=[output_table, output_plot]
)
gr.Examples(
examples=[
"I only saw her once and I'm head over heels!",
"Pay you for what, just standing there?!",
"Dumbass Broncos fans circa December 2015.",
"It's great that you're a recovering addict, that's cool. Have you ever tried DMT?",
"I'm scared to even ask my mom ,I might get yelled at 😟",
"hurr durr I like using reddit, and anyone who doesn't agree with me is a retard",
"They're really pushing me into this... once I go, there's no coming back you know?",
"Considering I haven't eaten or drunk anything in about twenty hours, my head hurts.",
],
inputs=[text_input]
)
if __name__ == "__main__":
demo.launch()