Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -81,10 +81,10 @@ def clear_gpu_memory():
|
|
| 81 |
torch.cuda.ipc_collect()
|
| 82 |
|
| 83 |
# Function to build appropriate prompt for text generation model
|
| 84 |
-
def
|
| 85 |
docs_text = ""
|
| 86 |
-
for idx,
|
| 87 |
-
docs_text += f"Tweet {idx+1} (Sentiment: {sentiment}): {comment}\n"
|
| 88 |
|
| 89 |
system_message = (
|
| 90 |
"You are an intelligent assistant. Your task is to generate a comprehensive business report "
|
|
@@ -100,12 +100,9 @@ def build_messages(query_input, sampled_docs):
|
|
| 100 |
"Now, produce only the final report as instructed, without any extra commentary."
|
| 101 |
)
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
]
|
| 107 |
-
|
| 108 |
-
return messages
|
| 109 |
|
| 110 |
# A helper to extract the assistant's response
|
| 111 |
def extract_assistant_response(output):
|
|
@@ -294,18 +291,13 @@ def main():
|
|
| 294 |
sampled_docs = scored_docs
|
| 295 |
|
| 296 |
# Build prompt
|
| 297 |
-
|
| 298 |
|
| 299 |
# Create a process function to avoid the Triton registration issue
|
| 300 |
-
def process_with_gemma(
|
| 301 |
-
# We'll define the pipeline here rather than using the cached version
|
| 302 |
-
# This ensures a clean library registration context
|
| 303 |
from transformers import pipeline, AutoTokenizer
|
| 304 |
import torch
|
| 305 |
|
| 306 |
-
# Set dtype explicitly
|
| 307 |
-
# torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
| 308 |
-
|
| 309 |
try:
|
| 310 |
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-1b-it")
|
| 311 |
pipe = pipeline(
|
|
@@ -313,21 +305,19 @@ def main():
|
|
| 313 |
model="unsloth/gemma-3-1b-it",
|
| 314 |
tokenizer=tokenizer,
|
| 315 |
device=0 if torch.cuda.is_available() else -1,
|
| 316 |
-
# torch_dtype=torch_dtype
|
| 317 |
-
|
| 318 |
)
|
| 319 |
|
| 320 |
-
result = pipe(
|
| 321 |
return result, None
|
| 322 |
|
| 323 |
except Exception as e:
|
| 324 |
return None, str(e)
|
| 325 |
-
|
| 326 |
# Try to process with Gemma
|
| 327 |
status_text.markdown("**📝 Generating report with Gemma...**")
|
| 328 |
progress_bar.progress(80)
|
| 329 |
|
| 330 |
-
raw_result, error = process_with_gemma(
|
| 331 |
|
| 332 |
if error:
|
| 333 |
st.error(f"Gemma processing failed: {str(error)}")
|
|
|
|
| 81 |
torch.cuda.ipc_collect()
|
| 82 |
|
| 83 |
# Function to build appropriate prompt for text generation model
|
| 84 |
+
def build_prompt(query_input, sampled_docs):
|
| 85 |
docs_text = ""
|
| 86 |
+
for idx, doc in enumerate(sampled_docs):
|
| 87 |
+
docs_text += f"Tweet {idx+1} (Sentiment: {doc['sentiment']}): {doc['comment']}\n"
|
| 88 |
|
| 89 |
system_message = (
|
| 90 |
"You are an intelligent assistant. Your task is to generate a comprehensive business report "
|
|
|
|
| 100 |
"Now, produce only the final report as instructed, without any extra commentary."
|
| 101 |
)
|
| 102 |
|
| 103 |
+
prompt = system_message + "\n\n" + user_content
|
| 104 |
+
return prompt
|
| 105 |
+
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
# A helper to extract the assistant's response
|
| 108 |
def extract_assistant_response(output):
|
|
|
|
| 291 |
sampled_docs = scored_docs
|
| 292 |
|
| 293 |
# Build prompt
|
| 294 |
+
prompt = build_prompt(query_input, sampled_docs)
|
| 295 |
|
| 296 |
# Create a process function to avoid the Triton registration issue
|
| 297 |
+
def process_with_gemma(prompt):
|
|
|
|
|
|
|
| 298 |
from transformers import pipeline, AutoTokenizer
|
| 299 |
import torch
|
| 300 |
|
|
|
|
|
|
|
|
|
|
| 301 |
try:
|
| 302 |
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-1b-it")
|
| 303 |
pipe = pipeline(
|
|
|
|
| 305 |
model="unsloth/gemma-3-1b-it",
|
| 306 |
tokenizer=tokenizer,
|
| 307 |
device=0 if torch.cuda.is_available() else -1,
|
|
|
|
|
|
|
| 308 |
)
|
| 309 |
|
| 310 |
+
result = pipe(prompt, max_new_tokens=1024, repetition_penalty=1.2, do_sample=True, temperature=0.7, return_full_text=False)
|
| 311 |
return result, None
|
| 312 |
|
| 313 |
except Exception as e:
|
| 314 |
return None, str(e)
|
| 315 |
+
|
| 316 |
# Try to process with Gemma
|
| 317 |
status_text.markdown("**📝 Generating report with Gemma...**")
|
| 318 |
progress_bar.progress(80)
|
| 319 |
|
| 320 |
+
raw_result, error = process_with_gemma(prompt)
|
| 321 |
|
| 322 |
if error:
|
| 323 |
st.error(f"Gemma processing failed: {str(error)}")
|