import gradio as gr from transformers import pipeline import pandas as pd import plotly.express as px # ------------------------------ # Load pretrained models # ------------------------------ text_classifier = pipeline( "text-classification", model="j-hartmann/emotion-english-distilroberta-base", return_all_scores=True ) audio_classifier = pipeline( "audio-classification", model="superb/wav2vec2-base-superb-er" ) # ------------------------------ # Map emotion to emoji # ------------------------------ EMOJI_MAP = { "anger": "😡", "disgust": "🤢", "fear": "😨", "joy": "😄", "neutral": "😐", "sadness": "😢", "surprise": "😲", "hap": "😄", # for audio model "neu": "😐", "sad": "😢", "ang": "😡" } # ------------------------------ # Fusion function # ------------------------------ def fuse_predictions(text_preds=None, audio_preds=None, w_text=0.5, w_audio=0.5): labels = set() if text_preds: labels |= {p['label'] for p in text_preds} if audio_preds: labels |= {p['label'] for p in audio_preds} scores = {l: 0.0 for l in labels} def normalize(preds): s = sum(p['score'] for p in preds) return {p['label']: p['score']/s for p in preds} if text_preds: t_norm = normalize(text_preds) for l in labels: scores[l] += w_text * t_norm.get(l, 0) if audio_preds: a_norm = normalize(audio_preds) for l in labels: scores[l] += w_audio * a_norm.get(l, 0) best = max(scores.items(), key=lambda x: x[1]) if scores else ("none", 0) return {"fused_label": best[0], "fused_score": round(best[1], 3), "all_scores": scores} # ------------------------------ # Create bar chart # ------------------------------ def make_bar_chart(scores_dict, title="Emotion Scores"): df = pd.DataFrame({ "Emotion": list(scores_dict.keys()), "Score": list(scores_dict.values()) }) fig = px.bar(df, x="Emotion", y="Score", text="Score", title=title, range_y=[0,1], color="Emotion", color_discrete_sequence=px.colors.qualitative.Bold) fig.update_traces(texttemplate='%{text:.2f}', textposition='outside') fig.update_layout(yaxis_title="Probability", xaxis_title="Emotion", showlegend=False) return fig # ------------------------------ # Prediction function # ------------------------------ def predict(text, audio, w_text, w_audio): text_preds, audio_preds = None, None if text: text_preds = text_classifier(text)[0] if audio: audio_preds = audio_classifier(audio) fused = fuse_predictions(text_preds, audio_preds, w_text, w_audio) # Display final predicted emotion with emoji label = fused['fused_label'] emoji = EMOJI_MAP.get(label, "") final_emotion = f"### Final Predicted Emotion: {label.upper()} {emoji} (score: {fused['fused_score']})" # Bar charts charts = [] if text_preds: charts.append(make_bar_chart({p['label']: p['score'] for p in text_preds}, "Text Emotion Scores")) if audio_preds: charts.append(make_bar_chart({p['label']: p['score'] for p in audio_preds}, "Audio Emotion Scores")) charts.append(make_bar_chart(fused['all_scores'], "Fused Emotion Scores")) return final_emotion, charts # ------------------------------ # Build Gradio interface # ------------------------------ with gr.Blocks() as demo: gr.Markdown("## 🎭 Multimodal Emotion Classification (Text + Speech)") with gr.Row(): with gr.Column(): txt = gr.Textbox(label="Text input", placeholder="Type something emotional...") aud = gr.Audio(type="filepath", label="Upload speech (wav/mp3)") w1 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Text weight (w_text)") w2 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Audio weight (w_audio)") btn = gr.Button("Predict") with gr.Column(): final_label = gr.Markdown(label="Predicted Emotion") chart_output = gr.Plot(label="Emotion Scores") # Button click triggers prediction btn.click(fn=predict, inputs=[txt, aud, w1, w2], outputs=[final_label, chart_output]) demo.launch()