File size: 3,651 Bytes
8f051ed
a276191
 
 
 
 
8f051ed
 
a276191
 
0f8a47d
 
a276191
 
 
 
 
 
 
 
 
8f051ed
a276191
 
 
 
 
 
 
f8bb1c0
 
a276191
 
 
 
 
 
 
 
 
 
 
f8bb1c0
a276191
f8bb1c0
8f051ed
 
 
f8bb1c0
 
 
 
a276191
 
 
f8bb1c0
8f051ed
 
 
f8bb1c0
 
 
a276191
 
f8bb1c0
a276191
 
0f8a47d
8f051ed
 
a276191
 
 
 
 
 
 
 
 
 
0f8a47d
a276191
f8bb1c0
 
 
 
 
 
a276191
 
 
 
 
 
 
 
 
 
 
 
f8bb1c0
a276191
 
 
 
0f8a47d
 
 
 
 
a276191
0f8a47d
8f051ed
a276191
 
 
 
 
0f8a47d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import gradio as gr
from transformers import pipeline
import plotly.express as px
import pandas as pd

HF_TOKEN = os.getenv("HF_TOKEN")

MODEL_LIST = [
    "thethinkmachine/immune-resin",
    "thethinkmachine/bright-avocado",
    "thethinkmachine/kickass-supercluster",
]

pipeline_cache = {}

def get_pipeline(model_name):
    if model_name not in pipeline_cache:
        pipeline_cache[model_name] = pipeline(
            "text-classification",
            model=model_name,
            token=HF_TOKEN,
            return_all_scores=True
        )
    return pipeline_cache[model_name]


def classify(text, model_name, chart_type, threshold):
    if not text.strip():
        # return empty outputs for the table and plot
        return None, None

    clf = get_pipeline(model_name)
    results = clf(text)[0]
    df = pd.DataFrame(results)

    # Sort labels by probability
    df = df.sort_values(by="score", ascending=False).reset_index(drop=True)

    # Highlight labels above threshold
    df["highlight"] = df["score"].apply(lambda x: "High" if x >= threshold else "Low")

    # Chart selection
    if chart_type == "Radar Chart":
        fig = px.line_polar(
            df,
            r="score",
            theta="label",
            line_close=True,
            color="highlight",
            title=f"Label Probabilities - {model_name}"
        )
        fig.update_traces(fill='toself')
        fig.update_layout(polar=dict(radialaxis=dict(range=[0, 1])))
    else:
        fig = px.bar(
            df,
            x="label",
            y="score",
            color="highlight",
            title=f"Label Probabilities - {model_name}"
        )
        fig.update_layout(yaxis=dict(range=[0, 1]))

    return df, fig


with gr.Blocks(theme=gr.themes.Ocean()) as demo:
    gr.Markdown("# Ekman Emotions Playground πŸ€–πŸ‘ΎπŸ¦œ")
    gr.Markdown("### Why let humans play with your emotions when a robot can do it for you for free? ...Cheaper than therapy, 100% less effective!")

    with gr.Row():
        model_dropdown = gr.Dropdown(
            choices=MODEL_LIST,
            label="Select Model",
            value=MODEL_LIST[0]
        )
        chart_dropdown = gr.Dropdown(
            choices=["Radar Chart", "Bar Chart"],
            label="Chart Type",
            value="Bar Chart"
        )
        threshold_slider = gr.Slider(
            0, 1,
            value=0.5,
            step=0.01,
            label="Highlight Threshold"
        )

    text_input = gr.Textbox(label="Input Text")

    with gr.Row():
        output_table = gr.DataFrame(label="Scores Table")
        output_plot = gr.Plot(label="Probability Chart")

    classify_btn = gr.Button("Run Classification")

    classify_btn.click(
        classify,
        inputs=[text_input, model_dropdown, chart_dropdown, threshold_slider],
        outputs=[output_table, output_plot]
    )

    gr.Examples(
        examples=[
            "I only saw her once and I'm head over heels!",
            "Pay you for what, just standing there?!",
            "Dumbass Broncos fans circa December 2015.",
            "It's great that you're a recovering addict, that's cool. Have you ever tried DMT?",
            "I'm scared to even ask my mom ,I might get yelled at 😟",
            "hurr durr I like using reddit, and anyone who doesn't agree with me is a retard",
            "They're really pushing me into this... once I go, there's no coming back you know?",
            "Considering I haven't eaten or drunk anything in about twenty hours, my head hurts.",
        ],
        inputs=[text_input]
    )

if __name__ == "__main__":
    demo.launch()