sachin7777777 commited on
Commit
b1b1487
Β·
verified Β·
1 Parent(s): 3b2fd42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -29
app.py CHANGED
@@ -4,7 +4,7 @@ import pandas as pd
4
  import plotly.express as px
5
 
6
  # ------------------------------
7
- # Load pretrained text model
8
  # ------------------------------
9
  text_classifier = pipeline(
10
  "text-classification",
@@ -12,6 +12,11 @@ text_classifier = pipeline(
12
  top_k=None # returns all scores
13
  )
14
 
 
 
 
 
 
15
  # ------------------------------
16
  # Map emotion to emoji
17
  # ------------------------------
@@ -22,9 +27,40 @@ EMOJI_MAP = {
22
  "joy": "πŸ˜„",
23
  "neutral": "😐",
24
  "sadness": "😒",
25
- "surprise": "😲"
 
 
 
 
26
  }
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # ------------------------------
29
  # Create bar chart
30
  # ------------------------------
@@ -43,45 +79,46 @@ def make_bar_chart(scores_dict, title="Emotion Scores"):
43
  # ------------------------------
44
  # Prediction function
45
  # ------------------------------
46
- def predict(text, w_text=1.0):
47
- if not text:
48
- return "Please enter text.", None
49
- preds = text_classifier(text)[0] # get all scores
50
- scores = {p['label']: p['score'] for p in preds}
51
- best_label = max(scores, key=scores.get)
52
- emoji = EMOJI_MAP.get(best_label, "")
53
-
54
- # Animate emoji with simple bouncing
55
- final_emotion_html = f"""
56
- <div style="font-size:80px; text-align:center; animation: bounce 1s infinite;">
57
- {emoji}
58
- </div>
59
- <h3 style="text-align:center;">{best_label.upper()} (score: {scores[best_label]:.2f})</h3>
60
- <style>
61
- @keyframes bounce {{
62
- 0%, 20%, 50%, 80%, 100% {{transform: translateY(0);}}
63
- 40% {{transform: translateY(-20px);}}
64
- 60% {{transform: translateY(-10px);}}
65
- }}
66
- </style>
67
- """
68
- chart = make_bar_chart(scores, "Text Emotion Scores")
69
- return final_emotion_html, chart
70
 
71
  # ------------------------------
72
  # Build Gradio interface
73
  # ------------------------------
74
  with gr.Blocks() as demo:
75
- gr.Markdown("## 🎭 Text Emotion Classification with Emoji Animation")
76
 
77
  with gr.Row():
78
  with gr.Column():
79
  txt = gr.Textbox(label="Text input", placeholder="Type something emotional...")
 
 
 
80
  btn = gr.Button("Predict")
81
  with gr.Column():
82
- final_label = gr.HTML(label="Predicted Emotion")
83
  chart_output = gr.Plot(label="Emotion Scores")
84
 
85
- btn.click(fn=predict, inputs=[txt], outputs=[final_label, chart_output])
86
 
87
  demo.launch()
 
4
  import plotly.express as px
5
 
6
  # ------------------------------
7
+ # Load pretrained models
8
  # ------------------------------
9
  text_classifier = pipeline(
10
  "text-classification",
 
12
  top_k=None # returns all scores
13
  )
14
 
15
+ audio_classifier = pipeline(
16
+ "audio-classification",
17
+ model="Dpngtm/wav2vec2-emotion-recognition"
18
+ )
19
+
20
  # ------------------------------
21
  # Map emotion to emoji
22
  # ------------------------------
 
27
  "joy": "πŸ˜„",
28
  "neutral": "😐",
29
  "sadness": "😒",
30
+ "surprise": "😲",
31
+ "hap": "πŸ˜„", # for audio model
32
+ "neu": "😐",
33
+ "sad": "😒",
34
+ "ang": "😑"
35
  }
36
 
37
+ # ------------------------------
38
+ # Fusion function
39
+ # ------------------------------
40
+ def fuse_predictions(text_preds=None, audio_preds=None, w_text=0.5, w_audio=0.5):
41
+ labels = set()
42
+ if text_preds:
43
+ labels |= {p['label'] for p in text_preds}
44
+ if audio_preds:
45
+ labels |= {p['label'] for p in audio_preds}
46
+ scores = {l: 0.0 for l in labels}
47
+
48
+ def normalize(preds):
49
+ s = sum(p['score'] for p in preds)
50
+ return {p['label']: p['score']/s for p in preds}
51
+
52
+ if text_preds:
53
+ t_norm = normalize(text_preds)
54
+ for l in labels:
55
+ scores[l] += w_text * t_norm.get(l, 0)
56
+ if audio_preds:
57
+ a_norm = normalize(audio_preds)
58
+ for l in labels:
59
+ scores[l] += w_audio * a_norm.get(l, 0)
60
+
61
+ best = max(scores.items(), key=lambda x: x[1]) if scores else ("none", 0)
62
+ return {"fused_label": best[0], "fused_score": round(best[1], 3), "all_scores": scores}
63
+
64
  # ------------------------------
65
  # Create bar chart
66
  # ------------------------------
 
79
  # ------------------------------
80
  # Prediction function
81
  # ------------------------------
82
+ def predict(text, audio, w_text, w_audio):
83
+ text_preds, audio_preds = None, None
84
+ if text:
85
+ text_preds = text_classifier(text)[0]
86
+ if audio:
87
+ audio_preds = audio_classifier(audio)
88
+ fused = fuse_predictions(text_preds, audio_preds, w_text, w_audio)
89
+
90
+ # Display final predicted emotion with emoji
91
+ label = fused['fused_label']
92
+ emoji = EMOJI_MAP.get(label, "")
93
+ final_emotion = f"### Final Predicted Emotion: {label.upper()} {emoji} (score: {fused['fused_score']})"
94
+
95
+ # Bar charts
96
+ charts = []
97
+ if text_preds:
98
+ charts.append(make_bar_chart({p['label']: p['score'] for p in text_preds}, "Text Emotion Scores"))
99
+ if audio_preds:
100
+ charts.append(make_bar_chart({p['label']: p['score'] for p in audio_preds}, "Audio Emotion Scores"))
101
+ charts.append(make_bar_chart(fused['all_scores'], "Fused Emotion Scores"))
102
+
103
+ return final_emotion, charts
 
 
104
 
105
  # ------------------------------
106
  # Build Gradio interface
107
  # ------------------------------
108
  with gr.Blocks() as demo:
109
+ gr.Markdown("## 🎭 Multimodal Emotion Classification (Text + Speech)")
110
 
111
  with gr.Row():
112
  with gr.Column():
113
  txt = gr.Textbox(label="Text input", placeholder="Type something emotional...")
114
+ aud = gr.Audio(type="filepath", label="Upload speech (wav/mp3)")
115
+ w1 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Text weight (w_text)")
116
+ w2 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Audio weight (w_audio)")
117
  btn = gr.Button("Predict")
118
  with gr.Column():
119
+ final_label = gr.Markdown(label="Predicted Emotion")
120
  chart_output = gr.Plot(label="Emotion Scores")
121
 
122
+ btn.click(fn=predict, inputs=[txt, aud, w1, w2], outputs=[final_label, chart_output])
123
 
124
  demo.launch()