thethinkmachine commited on
Commit
a276191
·
verified ·
1 Parent(s): 2aa4103

Initial commit

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import plotly.express as px
4
+ import pandas as pd
5
+
6
+ MODEL_LIST = [
7
+ "thethinkmachine/immune-resin",
8
+ ]
9
+
10
+ pipeline_cache = {}
11
+
12
+ def get_pipeline(model_name):
13
+ if model_name not in pipeline_cache:
14
+ pipeline_cache[model_name] = pipeline(
15
+ "text-classification",
16
+ model=model_name,
17
+ return_all_scores=True
18
+ )
19
+ return pipeline_cache[model_name]
20
+
21
+
22
+ def classify(text, model_name, chart_type, threshold):
23
+ if not text.strip():
24
+ return "", None, None
25
+
26
+ clf = get_pipeline(model_name)
27
+ results = clf(text)[0]
28
+ df = pd.DataFrame(results)
29
+
30
+ # Sort labels by probability
31
+ df = df.sort_values(by="score", ascending=False).reset_index(drop=True)
32
+
33
+ # Highlight labels above threshold
34
+ df["highlight"] = df["score"].apply(lambda x: "High" if x >= threshold else "Low")
35
+
36
+ if chart_type == "Radar Chart":
37
+ fig = px.line_polar(df, r="score", theta="label", line_close=True,
38
+ color="highlight",
39
+ title=f"Label Probabilities - {model_name}")
40
+ fig.update_traces(fill='toself')
41
+ fig.update_layout(polar=dict(radialaxis=dict(range=[0, 1])))
42
+ else:
43
+ fig = px.bar(df, x="label", y="score", color="highlight",
44
+ title=f"Label Probabilities - {model_name}")
45
+ fig.update_layout(yaxis=dict(range=[0, 1]))
46
+
47
+ return df, fig, None
48
+
49
+
50
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
51
+ gr.Markdown("# LetMeGuessYourMood-5000")
52
+
53
+ with gr.Row():
54
+ model_dropdown = gr.Dropdown(
55
+ choices=MODEL_LIST,
56
+ label="Select Model",
57
+ value=MODEL_LIST[0]
58
+ )
59
+ chart_dropdown = gr.Dropdown(
60
+ choices=["Radar Chart", "Bar Chart"],
61
+ label="Chart Type",
62
+ value="Radar Chart"
63
+ )
64
+ threshold_slider = gr.Slider(0, 1, value=0.5, step=0.01, label="Highlight Threshold")
65
+
66
+ text_input = gr.Textbox(label="Input Text")
67
+
68
+ with gr.Row():
69
+ output_table = gr.DataFrame(label="Scores Table")
70
+ output_plot = gr.Plot(label="Probability Chart")
71
+
72
+ classify_btn = gr.Button("Run Classification")
73
+
74
+ classify_btn.click(
75
+ classify,
76
+ inputs=[text_input, model_dropdown, chart_dropdown, threshold_slider],
77
+ outputs=[output_table, output_plot, None]
78
+ )
79
+
80
+ gr.Examples(
81
+ examples=[
82
+ "Sorry for not uninstalling you faster",
83
+ "hurr durr I like using reddit, and anyone who doesn't agree with me is a retard",
84
+ "Considering I haven't eaten or drunk anything in about twenty hours, plus I only slept for about four of them before getting back to beating it, I'm gonna say too much. My head hurts.",
85
+ ],
86
+ inputs=[text_input]
87
+ )
88
+
89
+ if __name__ == "__main__":
90
+ demo.launch()