gabrielchua commited on
Commit
0989743
Β·
1 Parent(s): d78aca9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -39
app.py CHANGED
@@ -5,19 +5,19 @@ import torch
5
  import sys
6
  import uuid
7
  from datetime import datetime
8
-
9
  import json
10
- import gspread
11
- from google.oauth2 import service_account
12
 
13
  from safetensors.torch import load_file
14
  from lionguard2 import LionGuard2, CATEGORIES
15
  from utils import get_embeddings
16
 
17
- # -- OpenAI Setup --
 
 
 
18
  client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
19
 
20
- # -- Model Loading --
21
  def load_lionguard2():
22
  model = LionGuard2()
23
  model.eval()
@@ -27,22 +27,26 @@ def load_lionguard2():
27
 
28
  model = load_lionguard2()
29
 
30
- # -- Google Sheets Config --
31
  GOOGLE_SHEET_URL = os.environ.get("GOOGLE_SHEET_URL")
32
  GOOGLE_CREDENTIALS = os.environ.get("GCP_SERVICE_ACCOUNT")
33
  RESULTS_SHEET_NAME = "results"
34
  VOTES_SHEET_NAME = "votes"
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def save_results_data(row):
37
  try:
38
- credentials = service_account.Credentials.from_service_account_info(
39
- json.loads(GOOGLE_CREDENTIALS),
40
- scopes=[
41
- "https://www.googleapis.com/auth/spreadsheets",
42
- "https://www.googleapis.com/auth/drive",
43
- ],
44
- )
45
- gc = gspread.authorize(credentials)
46
  sheet = gc.open_by_url(GOOGLE_SHEET_URL)
47
  ws = sheet.worksheet(RESULTS_SHEET_NAME)
48
  ws.append_row(list(row.values()))
@@ -51,14 +55,7 @@ def save_results_data(row):
51
 
52
  def save_vote_data(text_id, agree):
53
  try:
54
- credentials = service_account.Credentials.from_service_account_info(
55
- json.loads(GOOGLE_CREDENTIALS),
56
- scopes=[
57
- "https://www.googleapis.com/auth/spreadsheets",
58
- "https://www.googleapis.com/auth/drive",
59
- ],
60
- )
61
- gc = gspread.authorize(credentials)
62
  sheet = gc.open_by_url(GOOGLE_SHEET_URL)
63
  ws = sheet.worksheet(VOTES_SHEET_NAME)
64
  vote_row = {
@@ -70,6 +67,21 @@ def save_vote_data(text_id, agree):
70
  except Exception as e:
71
  print(f"Error saving vote data: {e}")
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  # --- Classifier logic ---
74
 
75
  def format_score_with_style(score_str):
@@ -90,11 +102,11 @@ def format_score_with_style(score_str):
90
  def format_binary_score(score):
91
  percentage = int(score * 100)
92
  if score < 0.4:
93
- return f'<div style="color: #34d399; font-weight:700;">βœ… Pass ({percentage}/100)</div>'
94
  elif 0.4 <= score < 0.7:
95
- return f'<div style="color: #fbbf24; font-weight:700;">⚠️ Warning ({percentage}/100)</div>'
96
  else:
97
- return f'<div style="color: #fca5a5; font-weight:700;">🚨 Fail ({percentage}/100)</div>'
98
 
99
  def analyze_text(text):
100
  if not text.strip():
@@ -108,6 +120,7 @@ def analyze_text(text):
108
 
109
  main_categories = ['hateful', 'insults', 'sexual', 'physical_violence', 'self_harm', 'all_other_misconduct']
110
  categories_html = []
 
111
  for category in main_categories:
112
  subcategories = CATEGORIES[category]
113
  category_name = category.replace('_', ' ').title()
@@ -122,6 +135,7 @@ def analyze_text(text):
122
  category_display = f"{category_emojis.get(category_name, 'πŸ“')} {category_name}"
123
  level_scores = [results.get(subcategory_key, [0.0])[0] for subcategory_key in subcategories]
124
  max_score = max(level_scores) if level_scores else 0.0
 
125
  categories_html.append(f'''
126
  <tr>
127
  <td>{category_display}</td>
@@ -147,12 +161,12 @@ def analyze_text(text):
147
  "text_id": text_id,
148
  "text": text,
149
  "binary_score": binary_score,
150
- # Add all category scores as before...
151
  }
 
 
152
  save_results_data(results_row)
153
 
154
  voting_html = '<div>Help improve LionGuard2! Rate the analysis below.</div>'
155
-
156
  return format_binary_score(binary_score), html_table, text_id, voting_html
157
 
158
  except Exception as e:
@@ -163,15 +177,15 @@ def vote_thumbs_up(text_id):
163
  if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
164
  save_vote_data(text_id, True)
165
  return '<div style="color: #34d399; font-weight:700;">πŸŽ‰ Thank you!</div>'
166
- return '<div>Voting not available</div>'
167
 
168
  def vote_thumbs_down(text_id):
169
  if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
170
  save_vote_data(text_id, False)
171
  return '<div style="color: #fca5a5; font-weight:700;">πŸ“ Thanks for the feedback!</div>'
172
- return '<div>Voting not available</div>'
173
 
174
- # --- Chatbot guardrail logic ---
175
  def get_openai_response(message, system_prompt="You are a helpful assistant."):
176
  try:
177
  response = client.chat.completions.create(
@@ -201,10 +215,10 @@ def lionguard_2(message, threshold=0.5):
201
  embeddings = get_embeddings([message])
202
  results = model.predict(embeddings)
203
  binary_prob = results['binary'][0]
204
- return binary_prob > threshold
205
  except Exception as e:
206
  print(f"Error in LionGuard 2: {e}")
207
- return False
208
 
209
  def process_message(message, history_no_mod, history_openai, history_lg):
210
  if not message.strip():
@@ -222,7 +236,7 @@ def process_message(message, history_no_mod, history_openai, history_lg):
222
  openai_response = get_openai_response(message)
223
  history_openai.append({"role": "assistant", "content": openai_response})
224
 
225
- lg_flagged = lionguard_2(message)
226
  history_lg.append({"role": "user", "content": message})
227
  if lg_flagged:
228
  lg_response = "🚫 This message has been flagged by LionGuard 2"
@@ -231,6 +245,41 @@ def process_message(message, history_no_mod, history_openai, history_lg):
231
  lg_response = get_openai_response(message)
232
  history_lg.append({"role": "assistant", "content": lg_response})
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  return history_no_mod, history_openai, history_lg, ""
235
 
236
  def clear_all_chats():
@@ -240,7 +289,7 @@ def clear_all_chats():
240
 
241
  DISCLAIMER = """
242
  <div style='background: #fbbf24; color: #1e293b; border-radius: 8px; padding: 14px; margin-bottom: 12px; font-size: 15px; font-weight:500;'>
243
- ⚠️ LionGuard 2 is an experimental ML model and may make mistakes. All entries are logged (anonymised) to improve the model.
244
  </div>
245
  """
246
 
@@ -262,7 +311,7 @@ with gr.Blocks(title="LionGuard 2 Demo", theme=gr.themes.Soft()) as demo:
262
  analyze_btn = gr.Button("Analyze", variant="primary")
263
  with gr.Column(scale=1, min_width=400):
264
  binary_output = gr.HTML(
265
- value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>'
266
  )
267
  category_table = gr.HTML(
268
  value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Category scores will appear here after analysis</div>'
@@ -292,18 +341,18 @@ with gr.Blocks(title="LionGuard 2 Demo", theme=gr.themes.Soft()) as demo:
292
  thumbs_up_btn.click(vote_thumbs_up, inputs=[current_text_id], outputs=[voting_feedback])
293
  thumbs_down_btn.click(vote_thumbs_down, inputs=[current_text_id], outputs=[voting_feedback])
294
 
295
- with gr.Tab("Chatbot Guardrail"):
296
  gr.HTML(DISCLAIMER)
297
  with gr.Row():
298
  with gr.Column(scale=1):
299
  gr.Markdown("#### πŸ”΅ No Moderation")
300
- chatbot_no_mod = gr.Chatbot(height=400, label="No Moderation", show_label=False, bubble_full_width=False, type='messages')
301
  with gr.Column(scale=1):
302
  gr.Markdown("#### 🟠 OpenAI Moderation")
303
- chatbot_openai = gr.Chatbot(height=400, label="OpenAI Moderation", show_label=False, bubble_full_width=False, type='messages')
304
  with gr.Column(scale=1):
305
  gr.Markdown("#### πŸ›‘οΈ LionGuard 2")
306
- chatbot_lg = gr.Chatbot(height=400, label="LionGuard 2", show_label=False, bubble_full_width=False, type='messages')
307
  gr.Markdown("##### πŸ’¬ Send Message to All Models")
308
  with gr.Row():
309
  message_input = gr.Textbox(
 
5
  import sys
6
  import uuid
7
  from datetime import datetime
 
8
  import json
 
 
9
 
10
  from safetensors.torch import load_file
11
  from lionguard2 import LionGuard2, CATEGORIES
12
  from utils import get_embeddings
13
 
14
+ import gspread
15
+ from google.oauth2 import service_account
16
+
17
+ # --- OpenAI Setup ---
18
  client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
19
 
20
+ # --- Model Loading ---
21
  def load_lionguard2():
22
  model = LionGuard2()
23
  model.eval()
 
27
 
28
  model = load_lionguard2()
29
 
30
+ # --- Google Sheets Config ---
31
  GOOGLE_SHEET_URL = os.environ.get("GOOGLE_SHEET_URL")
32
  GOOGLE_CREDENTIALS = os.environ.get("GCP_SERVICE_ACCOUNT")
33
  RESULTS_SHEET_NAME = "results"
34
  VOTES_SHEET_NAME = "votes"
35
+ CHATBOT_SHEET_NAME = "chatbot"
36
+
37
+ def get_gspread_client():
38
+ credentials = service_account.Credentials.from_service_account_info(
39
+ json.loads(GOOGLE_CREDENTIALS),
40
+ scopes=[
41
+ "https://www.googleapis.com/auth/spreadsheets",
42
+ "https://www.googleapis.com/auth/drive",
43
+ ],
44
+ )
45
+ return gspread.authorize(credentials)
46
 
47
  def save_results_data(row):
48
  try:
49
+ gc = get_gspread_client()
 
 
 
 
 
 
 
50
  sheet = gc.open_by_url(GOOGLE_SHEET_URL)
51
  ws = sheet.worksheet(RESULTS_SHEET_NAME)
52
  ws.append_row(list(row.values()))
 
55
 
56
  def save_vote_data(text_id, agree):
57
  try:
58
+ gc = get_gspread_client()
 
 
 
 
 
 
 
59
  sheet = gc.open_by_url(GOOGLE_SHEET_URL)
60
  ws = sheet.worksheet(VOTES_SHEET_NAME)
61
  vote_row = {
 
67
  except Exception as e:
68
  print(f"Error saving vote data: {e}")
69
 
70
+ def log_chatbot_data(row):
71
+ try:
72
+ gc = get_gspread_client()
73
+ sheet = gc.open_by_url(GOOGLE_SHEET_URL)
74
+ ws = sheet.worksheet(CHATBOT_SHEET_NAME)
75
+ ws.append_row([
76
+ row["datetime"], row["text_id"], row["text"], row["binary_score"],
77
+ row["hateful_l1_score"], row["hateful_l2_score"], row["insults_score"],
78
+ row["sexual_l1_score"], row["sexual_l2_score"], row["physical_violence_score"],
79
+ row["self_harm_l1_score"], row["self_harm_l2_score"], row["aom_l1_score"],
80
+ row["aom_l2_score"], row["openai_score"]
81
+ ])
82
+ except Exception as e:
83
+ print(f"Error saving chatbot data: {e}")
84
+
85
  # --- Classifier logic ---
86
 
87
  def format_score_with_style(score_str):
 
102
  def format_binary_score(score):
103
  percentage = int(score * 100)
104
  if score < 0.4:
105
+ return f'<div style="background:linear-gradient(135deg, #065f46 0%, #047857 100%); color:#34d399; padding:48px 0; border-radius:20px; text-align:center; font-weight:900; border:3px solid #10b981; font-size:48px; margin:24px 0; box-shadow:0 4px 24px rgba(0,0,0,0.3);">βœ… Pass ({percentage}/100)</div>'
106
  elif 0.4 <= score < 0.7:
107
+ return f'<div style="background:linear-gradient(135deg, #92400e 0%, #b45309 100%); color:#fbbf24; padding:48px 0; border-radius:20px; text-align:center; font-weight:900; border:3px solid #f59e0b; font-size:48px; margin:24px 0; box-shadow:0 4px 24px rgba(0,0,0,0.3);">⚠️ Warning ({percentage}/100)</div>'
108
  else:
109
+ return f'<div style="background:linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color:#fca5a5; padding:48px 0; border-radius:20px; text-align:center; font-weight:900; border:3px solid #ef4444; font-size:48px; margin:24px 0; box-shadow:0 4px 24px rgba(0,0,0,0.3);">🚨 Fail ({percentage}/100)</div>'
110
 
111
  def analyze_text(text):
112
  if not text.strip():
 
120
 
121
  main_categories = ['hateful', 'insults', 'sexual', 'physical_violence', 'self_harm', 'all_other_misconduct']
122
  categories_html = []
123
+ max_scores = {}
124
  for category in main_categories:
125
  subcategories = CATEGORIES[category]
126
  category_name = category.replace('_', ' ').title()
 
135
  category_display = f"{category_emojis.get(category_name, 'πŸ“')} {category_name}"
136
  level_scores = [results.get(subcategory_key, [0.0])[0] for subcategory_key in subcategories]
137
  max_score = max(level_scores) if level_scores else 0.0
138
+ max_scores[category] = max_score
139
  categories_html.append(f'''
140
  <tr>
141
  <td>{category_display}</td>
 
161
  "text_id": text_id,
162
  "text": text,
163
  "binary_score": binary_score,
 
164
  }
165
+ for category in main_categories:
166
+ results_row[f"{category}_max"] = max_scores[category]
167
  save_results_data(results_row)
168
 
169
  voting_html = '<div>Help improve LionGuard2! Rate the analysis below.</div>'
 
170
  return format_binary_score(binary_score), html_table, text_id, voting_html
171
 
172
  except Exception as e:
 
177
  if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
178
  save_vote_data(text_id, True)
179
  return '<div style="color: #34d399; font-weight:700;">πŸŽ‰ Thank you!</div>'
180
+ return '<div>Voting not available or analysis not yet run.</div>'
181
 
182
  def vote_thumbs_down(text_id):
183
  if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
184
  save_vote_data(text_id, False)
185
  return '<div style="color: #fca5a5; font-weight:700;">πŸ“ Thanks for the feedback!</div>'
186
+ return '<div>Voting not available or analysis not yet run.</div>'
187
 
188
+ # --- Guardrail Comparison logic ---
189
  def get_openai_response(message, system_prompt="You are a helpful assistant."):
190
  try:
191
  response = client.chat.completions.create(
 
215
  embeddings = get_embeddings([message])
216
  results = model.predict(embeddings)
217
  binary_prob = results['binary'][0]
218
+ return binary_prob > threshold, binary_prob
219
  except Exception as e:
220
  print(f"Error in LionGuard 2: {e}")
221
+ return False, 0.0
222
 
223
  def process_message(message, history_no_mod, history_openai, history_lg):
224
  if not message.strip():
 
236
  openai_response = get_openai_response(message)
237
  history_openai.append({"role": "assistant", "content": openai_response})
238
 
239
+ lg_flagged, lg_score = lionguard_2(message)
240
  history_lg.append({"role": "user", "content": message})
241
  if lg_flagged:
242
  lg_response = "🚫 This message has been flagged by LionGuard 2"
 
245
  lg_response = get_openai_response(message)
246
  history_lg.append({"role": "assistant", "content": lg_response})
247
 
248
+ # --- Logging for chatbot worksheet ---
249
+ if GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
250
+ try:
251
+ embeddings = get_embeddings([message])
252
+ results = model.predict(embeddings)
253
+ now = datetime.now().isoformat()
254
+ text_id = str(uuid.uuid4())
255
+ row = {
256
+ "datetime": now,
257
+ "text_id": text_id,
258
+ "text": message,
259
+ "binary_score": results.get("binary", [None])[0],
260
+ "hateful_l1_score": results.get(CATEGORIES['hateful'][0], [None])[0],
261
+ "hateful_l2_score": results.get(CATEGORIES['hateful'][1], [None])[0],
262
+ "insults_score": results.get(CATEGORIES['insults'][0], [None])[0],
263
+ "sexual_l1_score": results.get(CATEGORIES['sexual'][0], [None])[0],
264
+ "sexual_l2_score": results.get(CATEGORIES['sexual'][1], [None])[0],
265
+ "physical_violence_score": results.get(CATEGORIES['physical_violence'][0], [None])[0],
266
+ "self_harm_l1_score": results.get(CATEGORIES['self_harm'][0], [None])[0],
267
+ "self_harm_l2_score": results.get(CATEGORIES['self_harm'][1], [None])[0],
268
+ "aom_l1_score": results.get(CATEGORIES['all_other_misconduct'][0], [None])[0],
269
+ "aom_l2_score": results.get(CATEGORIES['all_other_misconduct'][1], [None])[0],
270
+ "openai_score": None
271
+ }
272
+ try:
273
+ openai_result = client.moderations.create(input=message)
274
+ # Using the "hate" category score as a demonstration. You may customize as needed.
275
+ row["openai_score"] = float(openai_result.results[0].category_scores.get("hate", 0.0))
276
+ except Exception:
277
+ row["openai_score"] = None
278
+
279
+ log_chatbot_data(row)
280
+ except Exception as e:
281
+ print(f"Chatbot logging failed: {e}")
282
+
283
  return history_no_mod, history_openai, history_lg, ""
284
 
285
  def clear_all_chats():
 
289
 
290
  DISCLAIMER = """
291
  <div style='background: #fbbf24; color: #1e293b; border-radius: 8px; padding: 14px; margin-bottom: 12px; font-size: 15px; font-weight:500;'>
292
+ ⚠️ LionGuard 2 may make mistakes. All entries are logged (anonymised) to improve the model.
293
  </div>
294
  """
295
 
 
311
  analyze_btn = gr.Button("Analyze", variant="primary")
312
  with gr.Column(scale=1, min_width=400):
313
  binary_output = gr.HTML(
314
+ value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic; font-size:36px;">Enter text to analyze</div>'
315
  )
316
  category_table = gr.HTML(
317
  value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Category scores will appear here after analysis</div>'
 
341
  thumbs_up_btn.click(vote_thumbs_up, inputs=[current_text_id], outputs=[voting_feedback])
342
  thumbs_down_btn.click(vote_thumbs_down, inputs=[current_text_id], outputs=[voting_feedback])
343
 
344
+ with gr.Tab("Guardrail Comparison"):
345
  gr.HTML(DISCLAIMER)
346
  with gr.Row():
347
  with gr.Column(scale=1):
348
  gr.Markdown("#### πŸ”΅ No Moderation")
349
+ chatbot_no_mod = gr.Chatbot(height=650, label="No Moderation", show_label=False, bubble_full_width=False, type='messages')
350
  with gr.Column(scale=1):
351
  gr.Markdown("#### 🟠 OpenAI Moderation")
352
+ chatbot_openai = gr.Chatbot(height=650, label="OpenAI Moderation", show_label=False, bubble_full_width=False, type='messages')
353
  with gr.Column(scale=1):
354
  gr.Markdown("#### πŸ›‘οΈ LionGuard 2")
355
+ chatbot_lg = gr.Chatbot(height=650, label="LionGuard 2", show_label=False, bubble_full_width=False, type='messages')
356
  gr.Markdown("##### πŸ’¬ Send Message to All Models")
357
  with gr.Row():
358
  message_input = gr.Textbox(