sachin7777777 commited on
Commit
419d020
Β·
verified Β·
1 Parent(s): 43e9a81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -53
app.py CHANGED
@@ -1,36 +1,34 @@
1
  import gradio as gr
2
- from transformers import pipeline, Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
3
- import torch
4
  import pandas as pd
5
  import plotly.express as px
6
- import soundfile as sf
7
 
8
  # ------------------------------
9
  # Load pretrained models
10
  # ------------------------------
11
- # Text classifier
12
  text_classifier = pipeline(
13
  "text-classification",
14
  model="j-hartmann/emotion-english-distilroberta-base",
15
- top_k=None # returns all scores
16
  )
17
 
18
- # Audio classifier (Wav2Vec2)
19
- audio_model_name = "Dpngtm/wav2vec2-emotion-recognition"
20
- audio_processor = Wav2Vec2Processor.from_pretrained(audio_model_name)
21
- audio_model = Wav2Vec2ForSequenceClassification.from_pretrained(audio_model_name)
22
 
23
  # ------------------------------
24
- # Map emotion to emoji
25
  # ------------------------------
26
  EMOJI_MAP = {
27
- "anger": "😑",
28
- "disgust": "🀒",
29
- "fear": "😨",
30
- "joy": "πŸ˜„",
31
- "neutral": "😐",
32
  "sadness": "😒",
33
- "surprise": "😲"
 
 
 
 
 
34
  }
35
 
36
  # ------------------------------
@@ -61,11 +59,11 @@ def fuse_predictions(text_preds=None, audio_preds=None, w_text=0.5, w_audio=0.5)
61
  return {"fused_label": best[0], "fused_score": round(best[1], 3), "all_scores": scores}
62
 
63
  # ------------------------------
64
- # Bar chart function
65
  # ------------------------------
66
  def make_bar_chart(scores_dict, title="Emotion Scores"):
67
  df = pd.DataFrame({
68
- "Emotion": list(scores_dict.keys()),
69
  "Score": list(scores_dict.values())
70
  })
71
  fig = px.bar(df, x="Emotion", y="Score", text="Score",
@@ -76,63 +74,42 @@ def make_bar_chart(scores_dict, title="Emotion Scores"):
76
  return fig
77
 
78
  # ------------------------------
79
- # Audio prediction helper
80
- # ------------------------------
81
- def predict_audio(audio_file):
82
- speech, sr = sf.read(audio_file)
83
- inputs = audio_processor(speech, sampling_rate=sr, return_tensors="pt", padding=True)
84
- with torch.no_grad():
85
- logits = audio_model(**inputs).logits
86
- probs = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
87
- labels = [audio_model.config.id2label[i] for i in range(len(probs))]
88
- return [{"label": l, "score": s} for l, s in zip(labels, probs)]
89
-
90
- # ------------------------------
91
- # Gradio prediction function
92
  # ------------------------------
93
  def predict(text, audio, w_text, w_audio):
94
  text_preds, audio_preds = None, None
95
-
96
  if text:
97
- text_preds = text_classifier(text)
98
  if audio:
99
- audio_preds = predict_audio(audio)
100
-
101
  fused = fuse_predictions(text_preds, audio_preds, w_text, w_audio)
102
 
103
- # Final emotion with animated emoji
104
- label = fused['fused_label']
105
- emoji = EMOJI_MAP.get(label, "❓")
106
- final_emotion = f"### {label.upper()} {emoji} \nScore: {fused['fused_score']}"
107
- animation = f"<div style='font-size:80px; animation: bounce 1s infinite;'>{emoji}</div>"
108
-
109
- # Charts
110
  charts = []
111
  if text_preds:
112
  charts.append(make_bar_chart({p['label']: p['score'] for p in text_preds}, "Text Emotion Scores"))
113
  if audio_preds:
114
  charts.append(make_bar_chart({p['label']: p['score'] for p in audio_preds}, "Audio Emotion Scores"))
115
- charts.append(make_bar_chart(fused['all_scores'], "Fused Emotion Scores"))
116
 
117
- return final_emotion + animation, charts
118
 
119
  # ------------------------------
120
- # Build Gradio app
121
  # ------------------------------
122
  with gr.Blocks() as demo:
123
- gr.Markdown("## 🎭 Multimodal Emotion Classification (Text + Speech)")
124
 
125
  with gr.Row():
126
  with gr.Column():
127
- txt = gr.Textbox(label="Text input", placeholder="Type something emotional...")
128
- aud = gr.Audio(type="filepath", label="Upload speech (wav/mp3)")
129
- w1 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Text weight")
130
- w2 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Audio weight")
131
- btn = gr.Button("Predict")
132
  with gr.Column():
133
- final_label = gr.HTML(label="Predicted Emotion")
134
  chart_output = gr.Plot(label="Emotion Scores")
135
 
136
- btn.click(fn=predict, inputs=[txt, aud, w1, w2], outputs=[final_label, chart_output]*3)
137
 
138
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline
 
3
  import pandas as pd
4
  import plotly.express as px
 
5
 
6
  # ------------------------------
7
  # Load pretrained models
8
  # ------------------------------
 
9
  text_classifier = pipeline(
10
  "text-classification",
11
  model="j-hartmann/emotion-english-distilroberta-base",
12
+ return_all_scores=True
13
  )
14
 
15
+ audio_classifier = pipeline(
16
+ "audio-classification",
17
+ model="superb/wav2vec2-base-superb-er"
18
+ )
19
 
20
  # ------------------------------
21
+ # Emotion to Emoji mapping
22
  # ------------------------------
23
  EMOJI_MAP = {
24
+ "joy": "😊",
 
 
 
 
25
  "sadness": "😒",
26
+ "anger": "😠",
27
+ "fear": "😨",
28
+ "love": "❀️",
29
+ "surprise": "😲",
30
+ "disgust": "🀒",
31
+ "neutral": "😐"
32
  }
33
 
34
  # ------------------------------
 
59
  return {"fused_label": best[0], "fused_score": round(best[1], 3), "all_scores": scores}
60
 
61
  # ------------------------------
62
+ # Create bar chart with emojis
63
  # ------------------------------
64
  def make_bar_chart(scores_dict, title="Emotion Scores"):
65
  df = pd.DataFrame({
66
+ "Emotion": [f"{EMOJI_MAP.get(k, '')} {k}" for k in scores_dict.keys()],
67
  "Score": list(scores_dict.values())
68
  })
69
  fig = px.bar(df, x="Emotion", y="Score", text="Score",
 
74
  return fig
75
 
76
  # ------------------------------
77
+ # Prediction function
 
 
 
 
 
 
 
 
 
 
 
 
78
  # ------------------------------
79
  def predict(text, audio, w_text, w_audio):
80
  text_preds, audio_preds = None, None
 
81
  if text:
82
+ text_preds = text_classifier(text)[0]
83
  if audio:
84
+ audio_preds = audio_classifier(audio)
 
85
  fused = fuse_predictions(text_preds, audio_preds, w_text, w_audio)
86
 
87
+ # Bar charts
 
 
 
 
 
 
88
  charts = []
89
  if text_preds:
90
  charts.append(make_bar_chart({p['label']: p['score'] for p in text_preds}, "Text Emotion Scores"))
91
  if audio_preds:
92
  charts.append(make_bar_chart({p['label']: p['score'] for p in audio_preds}, "Audio Emotion Scores"))
93
+ charts.append(make_bar_chart(fused['all_scores'], f"Fused Emotion Scores\nPrediction: {EMOJI_MAP.get(fused['fused_label'], '')} {fused['fused_label']}"))
94
 
95
+ return charts
96
 
97
  # ------------------------------
98
+ # Build Gradio interface with emojis
99
  # ------------------------------
100
  with gr.Blocks() as demo:
101
+ gr.Markdown("## 🎭 Multimodal Emotion Classification (Text + Speech) 😎")
102
 
103
  with gr.Row():
104
  with gr.Column():
105
+ txt = gr.Textbox(label="πŸ“ Text input", placeholder="Type something emotional...")
106
+ aud = gr.Audio(type="filepath", label="🎀 Upload speech (wav/mp3)")
107
+ w1 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="πŸ”Ή Text weight (w_text)")
108
+ w2 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="πŸ”Ή Audio weight (w_audio)")
109
+ btn = gr.Button("✨ Predict")
110
  with gr.Column():
 
111
  chart_output = gr.Plot(label="Emotion Scores")
112
 
113
+ btn.click(fn=predict, inputs=[txt, aud, w1, w2], outputs=[chart_output]*3) # 3 charts: text, audio, fused
114
 
115
  demo.launch()