File size: 3,651 Bytes
8f051ed a276191 8f051ed a276191 0f8a47d a276191 8f051ed a276191 f8bb1c0 a276191 f8bb1c0 a276191 f8bb1c0 8f051ed f8bb1c0 a276191 f8bb1c0 8f051ed f8bb1c0 a276191 f8bb1c0 a276191 0f8a47d 8f051ed a276191 0f8a47d a276191 f8bb1c0 a276191 f8bb1c0 a276191 0f8a47d a276191 0f8a47d 8f051ed a276191 0f8a47d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import os
import gradio as gr
from transformers import pipeline
import plotly.express as px
import pandas as pd
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_LIST = [
"thethinkmachine/immune-resin",
"thethinkmachine/bright-avocado",
"thethinkmachine/kickass-supercluster",
]
pipeline_cache = {}
def get_pipeline(model_name):
if model_name not in pipeline_cache:
pipeline_cache[model_name] = pipeline(
"text-classification",
model=model_name,
token=HF_TOKEN,
return_all_scores=True
)
return pipeline_cache[model_name]
def classify(text, model_name, chart_type, threshold):
if not text.strip():
# return empty outputs for the table and plot
return None, None
clf = get_pipeline(model_name)
results = clf(text)[0]
df = pd.DataFrame(results)
# Sort labels by probability
df = df.sort_values(by="score", ascending=False).reset_index(drop=True)
# Highlight labels above threshold
df["highlight"] = df["score"].apply(lambda x: "High" if x >= threshold else "Low")
# Chart selection
if chart_type == "Radar Chart":
fig = px.line_polar(
df,
r="score",
theta="label",
line_close=True,
color="highlight",
title=f"Label Probabilities - {model_name}"
)
fig.update_traces(fill='toself')
fig.update_layout(polar=dict(radialaxis=dict(range=[0, 1])))
else:
fig = px.bar(
df,
x="label",
y="score",
color="highlight",
title=f"Label Probabilities - {model_name}"
)
fig.update_layout(yaxis=dict(range=[0, 1]))
return df, fig
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
gr.Markdown("# Ekman Emotions Playground π€πΎπ¦")
gr.Markdown("### Why let humans play with your emotions when a robot can do it for you for free? ...Cheaper than therapy, 100% less effective!")
with gr.Row():
model_dropdown = gr.Dropdown(
choices=MODEL_LIST,
label="Select Model",
value=MODEL_LIST[0]
)
chart_dropdown = gr.Dropdown(
choices=["Radar Chart", "Bar Chart"],
label="Chart Type",
value="Bar Chart"
)
threshold_slider = gr.Slider(
0, 1,
value=0.5,
step=0.01,
label="Highlight Threshold"
)
text_input = gr.Textbox(label="Input Text")
with gr.Row():
output_table = gr.DataFrame(label="Scores Table")
output_plot = gr.Plot(label="Probability Chart")
classify_btn = gr.Button("Run Classification")
classify_btn.click(
classify,
inputs=[text_input, model_dropdown, chart_dropdown, threshold_slider],
outputs=[output_table, output_plot]
)
gr.Examples(
examples=[
"I only saw her once and I'm head over heels!",
"Pay you for what, just standing there?!",
"Dumbass Broncos fans circa December 2015.",
"It's great that you're a recovering addict, that's cool. Have you ever tried DMT?",
"I'm scared to even ask my mom ,I might get yelled at π",
"hurr durr I like using reddit, and anyone who doesn't agree with me is a retard",
"They're really pushing me into this... once I go, there's no coming back you know?",
"Considering I haven't eaten or drunk anything in about twenty hours, my head hurts.",
],
inputs=[text_input]
)
if __name__ == "__main__":
demo.launch() |