frankai98 commited on
Commit
076376a
·
verified ·
1 Parent(s): daa50b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -22
app.py CHANGED
@@ -11,13 +11,13 @@ import random
11
  import time
12
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
13
 
14
- # Retrieve the token from environment variables
15
  hf_token = os.environ.get("HF_TOKEN")
16
  if not hf_token:
17
  st.error("Hugging Face token not found. Please set the HF_TOKEN environment variable.")
18
  st.stop()
19
 
20
- # Login with the token
21
  login(token=hf_token)
22
 
23
  # Timer component using HTML and JavaScript
@@ -52,6 +52,7 @@ st.header("𝕏/Twitter Tweets Sentiment Report Generator")
52
  # Concise introduction
53
  st.write("This model🎰 will score your tweets in your CSV file🗄️ based on their sentiment😀 and generate a report🗟 answering your query question❔ based on those results.")
54
 
 
55
  def print_gpu_status(label):
56
  if torch.cuda.is_available():
57
  allocated = torch.cuda.memory_allocated() / 1024**3
@@ -78,6 +79,58 @@ def clear_gpu_memory():
78
  if torch.cuda.is_available():
79
  torch.cuda.empty_cache()
80
  torch.cuda.ipc_collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  # Main Function Part:
83
  def main():
@@ -137,7 +190,7 @@ def main():
137
  docs_to_summarize.append(doc)
138
  docs_indices.append(i)
139
 
140
- # If we have documents to summarize, load Llama model first
141
  if docs_to_summarize:
142
  status_text.markdown("**📝 Loading summarization model...**")
143
  t5_pipe = get_summary_model()
@@ -168,7 +221,7 @@ def main():
168
  except Exception as e:
169
  st.warning(f"Error summarizing document {i}: {str(e)}")
170
 
171
- # Clear Llama model from memory
172
  del t5_pipe
173
  import gc
174
  gc.collect()
@@ -208,7 +261,9 @@ def main():
208
  status_text.markdown(f"**🔍 Scoring documents... ({i}/{len(candidate_docs)})**")
209
 
210
  # Pair documents with scores
211
- scored_docs = list(zip(processed_docs, [result.get("score", 0.5) for result in scored_results]))
 
 
212
 
213
  # Clear sentiment model from memory
214
  del score_pipe
@@ -242,21 +297,7 @@ def main():
242
  sampled_docs = scored_docs
243
 
244
  # Build prompt
245
- messages = [
246
- {"role": "user", "content": f"""
247
- Generate a well-structured business report based on tweets from twitter/X with sentiment score (0: negative, 1: neutral, 2: positive) that answers Query Question and meets following Requirements.
248
- **Requirements:**
249
- - Include an introduction, key insights, and a conclusion.
250
- - Ensure the analysis is concise and does not cut off abruptly.
251
- - Summarize major findings without repeating verbatim.
252
- - Cover both positive and negative aspects, highlighting trends in user sentiment.
253
- **Query Question:**
254
- "{query_input}"
255
- **Tweets with sentiment score:**
256
- {sampled_docs}
257
- Please ensure the report is complete and reaches approximately 800 words.
258
- """}
259
- ]
260
 
261
  # Create a process function to avoid the Triton registration issue
262
  def process_with_gemma(messages):
@@ -296,8 +337,7 @@ Please ensure the report is complete and reaches approximately 800 words.
296
  report = "Error generating report. Please try again with fewer tweets."
297
  else:
298
  # Extract content from successful Gemma result
299
- report = raw_result[0]['generated_text']
300
- #extract_assistant_content(raw_result)
301
 
302
  progress_bar.progress(100)
303
  status_text.success("**✅ Generation complete!**")
 
11
  import time
12
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
13
 
14
+ # Retrieve the token from environment variables for huggingface login
15
  hf_token = os.environ.get("HF_TOKEN")
16
  if not hf_token:
17
  st.error("Hugging Face token not found. Please set the HF_TOKEN environment variable.")
18
  st.stop()
19
 
20
+ # Huggingface login with the token
21
  login(token=hf_token)
22
 
23
  # Timer component using HTML and JavaScript
 
52
  # Concise introduction
53
  st.write("This model🎰 will score your tweets in your CSV file🗄️ based on their sentiment😀 and generate a report🗟 answering your query question❔ based on those results.")
54
 
55
+ # Display VRAM status for debug
56
  def print_gpu_status(label):
57
  if torch.cuda.is_available():
58
  allocated = torch.cuda.memory_allocated() / 1024**3
 
79
  if torch.cuda.is_available():
80
  torch.cuda.empty_cache()
81
  torch.cuda.ipc_collect()
82
+
83
+ # Function to build appropriate prompt for text generation model
84
+ def build_messages(query_input, sampled_docs):
85
+ docs_text = ""
86
+ for idx, doc in enumerate(sampled_docs):
87
+ docs_text += f"Tweet {idx+1} (Sentiment: {doc['sentiment']}): {doc['comment']}\n"
88
+
89
+ system_message = (
90
+ "You are an intelligent assistant. Your task is to generate a comprehensive business report "
91
+ "analyzing the provided tweets with sentiment scores. The report must include an introduction, "
92
+ "key insights, and a conclusion, and should be approximately 800 words long. "
93
+ "IMPORTANT: Do not include any introductory greetings, summary statements, or closing questions. "
94
+ "Output only the final report content."
95
+ )
96
+
97
+ user_content = (
98
+ f"**Tweets**:\n{docs_text}\n\n"
99
+ f"**Query Question**: \"{query_input}\"\n\n"
100
+ "Now, produce only the final report as instructed, without any extra commentary."
101
+ )
102
+
103
+ messages = [
104
+ {"role": "system", "content": system_message},
105
+ {"role": "user", "content": user_content}
106
+ ]
107
+
108
+ return messages
109
+
110
+ # A helper to extract the assistant's response
111
+ def extract_assistant_response(output):
112
+ """
113
+ Extract only the content from the assistant's response.
114
+ Handles nested structure from the pipeline output.
115
+ """
116
+ try:
117
+ # The output is expected to be a list containing a dict with 'generated_text'
118
+ if isinstance(output, list) and len(output) > 0 and 'generated_text' in output[0]:
119
+ messages = output[0]['generated_text']
120
+ if isinstance(messages, list):
121
+ for message in messages:
122
+ if isinstance(message, dict) and message.get('role') == 'assistant':
123
+ return message.get('content', '')
124
+ # Fallback: try to directly find 'assistant' role in output
125
+ if isinstance(output, list):
126
+ for item in output:
127
+ if isinstance(item, dict) and item.get('role') == 'assistant':
128
+ return item.get('content', '')
129
+ print(f"DEBUG: Could not find assistant response in: {str(output)[:200]}...")
130
+ return ''
131
+ except Exception as e:
132
+ print(f"Error extracting assistant response: {e}")
133
+ return ''
134
 
135
  # Main Function Part:
136
  def main():
 
190
  docs_to_summarize.append(doc)
191
  docs_indices.append(i)
192
 
193
+ # If we have documents to summarize, load finetuned summarization model first
194
  if docs_to_summarize:
195
  status_text.markdown("**📝 Loading summarization model...**")
196
  t5_pipe = get_summary_model()
 
221
  except Exception as e:
222
  st.warning(f"Error summarizing document {i}: {str(e)}")
223
 
224
+ # Clear summarization model from memory
225
  del t5_pipe
226
  import gc
227
  gc.collect()
 
261
  status_text.markdown(f"**🔍 Scoring documents... ({i}/{len(candidate_docs)})**")
262
 
263
  # Pair documents with scores
264
+ scored_docs = [{"comment": doc, "sentiment": result.get("score", 1)}
265
+ for doc, result in zip(processed_docs, scored_results)]
266
+
267
 
268
  # Clear sentiment model from memory
269
  del score_pipe
 
297
  sampled_docs = scored_docs
298
 
299
  # Build prompt
300
+ messages = build_messages(query_input, sampled_docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
  # Create a process function to avoid the Triton registration issue
303
  def process_with_gemma(messages):
 
337
  report = "Error generating report. Please try again with fewer tweets."
338
  else:
339
  # Extract content from successful Gemma result
340
+ report = extract_assistant_response(raw_result)
 
341
 
342
  progress_bar.progress(100)
343
  status_text.success("**✅ Generation complete!**")