frankai98 commited on
Commit
f62daa9
·
verified ·
1 Parent(s): e03ee92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -21
app.py CHANGED
@@ -81,10 +81,10 @@ def clear_gpu_memory():
81
  torch.cuda.ipc_collect()
82
 
83
  # Function to build appropriate prompt for text generation model
84
- def build_messages(query_input, sampled_docs):
85
  docs_text = ""
86
- for idx, (comment, sentiment) in enumerate(sampled_docs):
87
- docs_text += f"Tweet {idx+1} (Sentiment: {sentiment}): {comment}\n"
88
 
89
  system_message = (
90
  "You are an intelligent assistant. Your task is to generate a comprehensive business report "
@@ -100,12 +100,9 @@ def build_messages(query_input, sampled_docs):
100
  "Now, produce only the final report as instructed, without any extra commentary."
101
  )
102
 
103
- messages = [
104
- {"role": "system", "content": system_message},
105
- {"role": "user", "content": user_content}
106
- ]
107
-
108
- return messages
109
 
110
  # A helper to extract the assistant's response
111
  def extract_assistant_response(output):
@@ -294,18 +291,13 @@ def main():
294
  sampled_docs = scored_docs
295
 
296
  # Build prompt
297
- messages = build_messages(query_input, sampled_docs)
298
 
299
  # Create a process function to avoid the Triton registration issue
300
- def process_with_gemma(messages):
301
- # We'll define the pipeline here rather than using the cached version
302
- # This ensures a clean library registration context
303
  from transformers import pipeline, AutoTokenizer
304
  import torch
305
 
306
- # Set dtype explicitly
307
- # torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
308
-
309
  try:
310
  tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-1b-it")
311
  pipe = pipeline(
@@ -313,21 +305,19 @@ def main():
313
  model="unsloth/gemma-3-1b-it",
314
  tokenizer=tokenizer,
315
  device=0 if torch.cuda.is_available() else -1,
316
- # torch_dtype=torch_dtype
317
-
318
  )
319
 
320
- result = pipe(messages, max_new_tokens=1024, repetition_penalty=1.2, do_sample=True, temperature=0.7, return_full_text=False)
321
  return result, None
322
 
323
  except Exception as e:
324
  return None, str(e)
325
-
326
  # Try to process with Gemma
327
  status_text.markdown("**📝 Generating report with Gemma...**")
328
  progress_bar.progress(80)
329
 
330
- raw_result, error = process_with_gemma(messages)
331
 
332
  if error:
333
  st.error(f"Gemma processing failed: {str(error)}")
 
81
  torch.cuda.ipc_collect()
82
 
83
  # Function to build appropriate prompt for text generation model
84
+ def build_prompt(query_input, sampled_docs):
85
  docs_text = ""
86
+ for idx, doc in enumerate(sampled_docs):
87
+ docs_text += f"Tweet {idx+1} (Sentiment: {doc['sentiment']}): {doc['comment']}\n"
88
 
89
  system_message = (
90
  "You are an intelligent assistant. Your task is to generate a comprehensive business report "
 
100
  "Now, produce only the final report as instructed, without any extra commentary."
101
  )
102
 
103
+ prompt = system_message + "\n\n" + user_content
104
+ return prompt
105
+
 
 
 
106
 
107
  # A helper to extract the assistant's response
108
  def extract_assistant_response(output):
 
291
  sampled_docs = scored_docs
292
 
293
  # Build prompt
294
+ prompt = build_prompt(query_input, sampled_docs)
295
 
296
  # Create a process function to avoid the Triton registration issue
297
+ def process_with_gemma(prompt):
 
 
298
  from transformers import pipeline, AutoTokenizer
299
  import torch
300
 
 
 
 
301
  try:
302
  tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-1b-it")
303
  pipe = pipeline(
 
305
  model="unsloth/gemma-3-1b-it",
306
  tokenizer=tokenizer,
307
  device=0 if torch.cuda.is_available() else -1,
 
 
308
  )
309
 
310
+ result = pipe(prompt, max_new_tokens=1024, repetition_penalty=1.2, do_sample=True, temperature=0.7, return_full_text=False)
311
  return result, None
312
 
313
  except Exception as e:
314
  return None, str(e)
315
+
316
  # Try to process with Gemma
317
  status_text.markdown("**📝 Generating report with Gemma...**")
318
  progress_bar.progress(80)
319
 
320
+ raw_result, error = process_with_gemma(prompt)
321
 
322
  if error:
323
  st.error(f"Gemma processing failed: {str(error)}")