import os import gradio as gr from transformers import pipeline import plotly.express as px import pandas as pd HF_TOKEN = os.getenv("HF_TOKEN") MODEL_LIST = [ "thethinkmachine/immune-resin", "thethinkmachine/bright-avocado", "thethinkmachine/kickass-supercluster", ] pipeline_cache = {} def get_pipeline(model_name): if model_name not in pipeline_cache: pipeline_cache[model_name] = pipeline( "text-classification", model=model_name, token=HF_TOKEN, return_all_scores=True ) return pipeline_cache[model_name] def classify(text, model_name, chart_type, threshold): if not text.strip(): # return empty outputs for the table and plot return None, None clf = get_pipeline(model_name) results = clf(text)[0] df = pd.DataFrame(results) # Sort labels by probability df = df.sort_values(by="score", ascending=False).reset_index(drop=True) # Highlight labels above threshold df["highlight"] = df["score"].apply(lambda x: "High" if x >= threshold else "Low") # Chart selection if chart_type == "Radar Chart": fig = px.line_polar( df, r="score", theta="label", line_close=True, color="highlight", title=f"Label Probabilities - {model_name}" ) fig.update_traces(fill='toself') fig.update_layout(polar=dict(radialaxis=dict(range=[0, 1]))) else: fig = px.bar( df, x="label", y="score", color="highlight", title=f"Label Probabilities - {model_name}" ) fig.update_layout(yaxis=dict(range=[0, 1])) return df, fig with gr.Blocks(theme=gr.themes.Ocean()) as demo: gr.Markdown("# Ekman Emotions Playground πŸ€–πŸ‘ΎπŸ¦œ") gr.Markdown("### Why let humans play with your emotions when a robot can do it for you for free? ...Cheaper than therapy, 100% less effective!") with gr.Row(): model_dropdown = gr.Dropdown( choices=MODEL_LIST, label="Select Model", value=MODEL_LIST[0] ) chart_dropdown = gr.Dropdown( choices=["Radar Chart", "Bar Chart"], label="Chart Type", value="Bar Chart" ) threshold_slider = gr.Slider( 0, 1, value=0.5, step=0.01, label="Highlight Threshold" ) text_input = gr.Textbox(label="Input Text") with gr.Row(): output_table = gr.DataFrame(label="Scores Table") output_plot = gr.Plot(label="Probability Chart") classify_btn = gr.Button("Run Classification") classify_btn.click( classify, inputs=[text_input, model_dropdown, chart_dropdown, threshold_slider], outputs=[output_table, output_plot] ) gr.Examples( examples=[ "I only saw her once and I'm head over heels!", "Pay you for what, just standing there?!", "Dumbass Broncos fans circa December 2015.", "It's great that you're a recovering addict, that's cool. Have you ever tried DMT?", "I'm scared to even ask my mom ,I might get yelled at 😟", "hurr durr I like using reddit, and anyone who doesn't agree with me is a retard", "They're really pushing me into this... once I go, there's no coming back you know?", "Considering I haven't eaten or drunk anything in about twenty hours, my head hurts.", ], inputs=[text_input] ) if __name__ == "__main__": demo.launch()