frankai98 commited on
Commit
82ade9a
Β·
verified Β·
1 Parent(s): bf47678

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -195
app.py CHANGED
@@ -78,163 +78,165 @@ def clear_gpu_memory():
78
  torch.cuda.empty_cache()
79
  torch.cuda.ipc_collect()
80
 
81
- # Let the user specify the column name for tweets text (defaulting to "content")
82
- tweets_column = st.text_input("Enter the column name for Tweets:", value="content")
83
-
84
- # Input: Query question for scoring and CSV file upload for candidate tweets
85
- query_input = st.text_area("Enter your query question for analysis (this does not need to be part of the CSV):")
86
- uploaded_file = st.file_uploader(f"Upload Tweets CSV File (must contain a '{tweets_column}' column)", type=["csv"])
87
-
88
- candidate_docs = []
89
- if uploaded_file is not None:
90
- try:
91
- df = pd.read_csv(uploaded_file)
92
- if tweets_column not in df.columns:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  st.error(f"CSV must contain a '{tweets_column}' column.")
 
 
94
  else:
95
- candidate_docs = df[tweets_column].dropna().astype(str).tolist()
96
- except Exception as e:
97
- st.error(f"Error reading CSV file: {e}")
98
-
99
- if st.button("Generate Report"):
100
- # Reset timer state so that the timer always shows up
101
- st.session_state.timer_started = False
102
- st.session_state.timer_frozen = False
103
- if uploaded_file is None:
104
- st.error("Please upload a CSV file.")
105
- elif not tweets_column.strip():
106
- st.error("Please enter your column name")
107
- elif not candidate_docs:
108
- st.error(f"CSV must contain a '{tweets_column}' column.")
109
- elif not query_input.strip():
110
- st.error("Please enter a query question!")
111
- else:
112
- if not st.session_state.timer_started and not st.session_state.timer_frozen:
113
- st.session_state.timer_started = True
114
- html(timer(), height=50)
115
- status_text = st.empty()
116
- progress_bar = st.progress(0)
117
-
118
-
119
- processed_docs = []
120
- scored_results = []
121
-
122
- # First, check which documents need summarization
123
- docs_to_summarize = []
124
- docs_indices = []
125
-
126
- for i, doc in enumerate(candidate_docs):
127
- if len(doc) > 280:
128
- docs_to_summarize.append(doc)
129
- docs_indices.append(i)
130
-
131
- # If we have documents to summarize, load Llama model first
132
- if docs_to_summarize:
133
- status_text.markdown("**πŸ“ Loading summarization model...**")
134
- t5_pipe = get_summary_model()
135
 
136
- status_text.markdown("**πŸ“ Summarizing long tweets...**")
137
 
138
- # Process documents that need summarization
139
- for idx, (i, doc) in enumerate(zip(docs_indices, docs_to_summarize)):
140
- progress = int((idx / len(docs_to_summarize)) * 25) # First quarter of progress
141
- progress_bar.progress(progress)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
- input_text = "summarize: " + doc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  try:
146
- summary_result = t5_pipe(
147
- input_text,
148
- max_length=128,
149
- min_length=10,
150
- no_repeat_ngram_size=2,
151
- num_beams=4,
152
- early_stopping=True,
153
- truncation=True
154
- )
155
 
156
- # Store the summary in place of the original text
157
- candidate_docs[i] = summary_result[0]['generated_text']
 
 
 
 
158
 
159
  except Exception as e:
160
- st.warning(f"Error summarizing document {i}: {str(e)}")
 
 
 
 
 
 
161
 
162
- # Clear Llama model from memory
163
- del t5_pipe
 
 
 
164
  import gc
165
  gc.collect()
166
  torch.cuda.empty_cache()
167
-
168
- # Now load sentiment model
169
- status_text.markdown("**πŸ” Loading sentiment analysis model...**")
170
- progress_bar.progress(25)
171
- score_pipe = get_sentiment_model()
172
-
173
- status_text.markdown("**πŸ” Scoring documents...**")
174
-
175
- # Process each document with sentiment analysis
176
- for i, doc in enumerate(candidate_docs):
177
- progress_offset = 25 if docs_to_summarize else 0
178
- progress = progress_offset + int((i / len(candidate_docs)) * (50 - progress_offset))
179
- progress_bar.progress(progress)
180
 
181
- try:
182
- # Process with sentiment analysis
183
- result = score_pipe(doc, truncation=True, max_length=512)
184
-
185
- # If it's a list, get the first element
186
- if isinstance(result, list):
187
- result = result[0]
188
-
189
- processed_docs.append(doc)
190
- scored_results.append(result)
191
-
192
- except Exception as e:
193
- st.warning(f"Error scoring document {i}: {str(e)}")
194
- processed_docs.append("Error processing this document")
195
- scored_results.append({"label": "NEUTRAL", "score": 1})
196
 
197
- # Display occasional status updates
198
- if i % max(1, len(candidate_docs) // 10) == 0:
199
- status_text.markdown(f"**πŸ” Scoring documents... ({i}/{len(candidate_docs)})**")
200
-
201
- # Pair documents with scores
202
- scored_docs = list(zip(processed_docs, [result.get("score", 0.5) for result in scored_results]))
203
-
204
- # Clear sentiment model from memory
205
- del score_pipe
206
- import gc
207
- gc.collect()
208
- torch.cuda.empty_cache()
209
-
210
- #print_gpu_status("After sentiment model deletion, VRAM")
211
-
212
- # Load Gemma for final report generation
213
- status_text.markdown("**πŸ“Š Loading report generation model...**")
214
- progress_bar.progress(67)
215
-
216
- # Make sure GPU memory is clear
217
- clear_gpu_memory()
218
- print_gpu_status("Before loading Gemma model, VRAM")
219
-
220
- # Set memory optimization environment variable
221
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
222
-
223
- # Sample or summarize the data for Gemma to avoid memory issues
224
- status_text.markdown("**πŸ“ Preparing data for report generation...**")
225
- progress_bar.progress(75)
226
-
227
- import random
228
- max_tweets = 100
229
- if len(scored_docs) > max_tweets:
230
- sampled_docs = random.sample(scored_docs, max_tweets)
231
- st.info(f"Sampling {max_tweets} out of {len(scored_docs)} tweets for report generation")
232
- else:
233
- sampled_docs = scored_docs
234
-
235
- # Build prompt
236
- messages = [
237
- {"role": "user", "content": f"""
238
  Generate a well-structured business report based on tweets from twitter/X with sentiment score (0: negative, 1: neutral, 2: positive) that answers Query Question and meets following Requirements.
239
  **Requirements:**
240
  - Include an introduction, key insights, and a conclusion.
@@ -246,60 +248,64 @@ Generate a well-structured business report based on tweets from twitter/X with s
246
  **Tweets with sentiment score:**
247
  {sampled_docs}
248
  Please ensure the report is complete and reaches approximately 1000 words.
249
- """}
250
- ]
251
-
252
- # Create a process function to avoid the Triton registration issue
253
- def process_with_gemma(messages):
254
- # We'll define the pipeline here rather than using the cached version
255
- # This ensures a clean library registration context
256
- from transformers import pipeline, AutoTokenizer
257
- import torch
258
-
259
- # Set dtype explicitly
260
- torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
261
-
262
- try:
263
- tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-1b-it")
264
- pipe = pipeline(
265
- "text-generation",
266
- model="unsloth/gemma-3-1b-it",
267
- tokenizer=tokenizer,
268
- device=0 if torch.cuda.is_available() else -1,
269
- torch_dtype=torch_dtype
270
-
271
- )
272
 
273
- result = pipe(messages, max_new_tokens=1500, repetition_penalty=1.2, do_sample=True, temperature=0.7, return_full_text=False)
274
- return result, None
275
 
276
- except Exception as e:
277
- return None, str(e)
278
-
279
- # Try to process with Gemma
280
- status_text.markdown("**πŸ“ Generating report with Gemma...**")
281
- progress_bar.progress(80)
282
-
283
- raw_result, error = process_with_gemma(messages)
284
-
285
- if error:
286
- st.error(f"Gemma processing failed: {str(error)}")
287
- report = "Error generating report. Please try again with fewer tweets."
288
- else:
289
- # Extract content from successful Gemma result
290
- report = raw_result[0]['generated_text']
291
- #extract_assistant_content(raw_result)
292
-
293
- progress_bar.progress(100)
294
- status_text.success("**βœ… Generation complete!**")
295
- html("<script>localStorage.setItem('freezeTimer', 'true');</script>", height=0)
296
- st.session_state.timer_frozen = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
- # First, create the replacement separately
299
- formatted_report = report.replace('\\n', '<br>')
300
-
301
- # Display title separately with standard formatting
302
- st.subheader("Generated Report:")
303
-
304
- # Display the report content with normal styling
305
- st.markdown(f"<div style='font-size: normal; font-weight: normal;'>{formatted_report}</div>", unsafe_allow_html=True)
 
78
  torch.cuda.empty_cache()
79
  torch.cuda.ipc_collect()
80
 
81
+ # Main Function Part:
82
+ def main():
83
+ # Let the user specify the column name for tweets text (defaulting to "content")
84
+ tweets_column = st.text_input("Enter the column name for Tweets:", value="content")
85
+
86
+ # Input: Query question for scoring and CSV file upload for candidate tweets
87
+ query_input = st.text_area("Enter your query question for analysis (this does not need to be part of the CSV):")
88
+ uploaded_file = st.file_uploader(f"Upload Tweets CSV File (must contain a '{tweets_column}' column)", type=["csv"])
89
+
90
+ candidate_docs = []
91
+ if uploaded_file is not None:
92
+ try:
93
+ df = pd.read_csv(uploaded_file)
94
+ if tweets_column not in df.columns:
95
+ st.error(f"CSV must contain a '{tweets_column}' column.")
96
+ else:
97
+ candidate_docs = df[tweets_column].dropna().astype(str).tolist()
98
+ except Exception as e:
99
+ st.error(f"Error reading CSV file: {e}")
100
+
101
+ if st.button("Generate Report"):
102
+ # Reset timer state so that the timer always shows up
103
+ st.session_state.timer_started = False
104
+ st.session_state.timer_frozen = False
105
+ if uploaded_file is None:
106
+ st.error("Please upload a CSV file.")
107
+ elif not tweets_column.strip():
108
+ st.error("Please enter your column name")
109
+ elif not candidate_docs:
110
  st.error(f"CSV must contain a '{tweets_column}' column.")
111
+ elif not query_input.strip():
112
+ st.error("Please enter a query question!")
113
  else:
114
+ if not st.session_state.timer_started and not st.session_state.timer_frozen:
115
+ st.session_state.timer_started = True
116
+ html(timer(), height=50)
117
+ status_text = st.empty()
118
+ progress_bar = st.progress(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
 
120
 
121
+ processed_docs = []
122
+ scored_results = []
123
+
124
+ # First, check which documents need summarization
125
+ docs_to_summarize = []
126
+ docs_indices = []
127
+
128
+ for i, doc in enumerate(candidate_docs):
129
+ if len(doc) > 280:
130
+ docs_to_summarize.append(doc)
131
+ docs_indices.append(i)
132
+
133
+ # If we have documents to summarize, load Llama model first
134
+ if docs_to_summarize:
135
+ status_text.markdown("**πŸ“ Loading summarization model...**")
136
+ t5_pipe = get_summary_model()
137
+
138
+ status_text.markdown("**πŸ“ Summarizing long tweets...**")
139
+
140
+ # Process documents that need summarization
141
+ for idx, (i, doc) in enumerate(zip(docs_indices, docs_to_summarize)):
142
+ progress = int((idx / len(docs_to_summarize)) * 25) # First quarter of progress
143
+ progress_bar.progress(progress)
144
+
145
+ input_text = "summarize: " + doc
146
+
147
+ try:
148
+ summary_result = t5_pipe(
149
+ input_text,
150
+ max_length=128,
151
+ min_length=10,
152
+ no_repeat_ngram_size=2,
153
+ num_beams=4,
154
+ early_stopping=True,
155
+ truncation=True
156
+ )
157
+
158
+ # Store the summary in place of the original text
159
+ candidate_docs[i] = summary_result[0]['generated_text']
160
+
161
+ except Exception as e:
162
+ st.warning(f"Error summarizing document {i}: {str(e)}")
163
 
164
+ # Clear Llama model from memory
165
+ del t5_pipe
166
+ import gc
167
+ gc.collect()
168
+ torch.cuda.empty_cache()
169
+
170
+ # Now load sentiment model
171
+ status_text.markdown("**πŸ” Loading sentiment analysis model...**")
172
+ progress_bar.progress(25)
173
+ score_pipe = get_sentiment_model()
174
+
175
+ status_text.markdown("**πŸ” Scoring documents...**")
176
+
177
+ # Process each document with sentiment analysis
178
+ for i, doc in enumerate(candidate_docs):
179
+ progress_offset = 25 if docs_to_summarize else 0
180
+ progress = progress_offset + int((i / len(candidate_docs)) * (50 - progress_offset))
181
+ progress_bar.progress(progress)
182
 
183
  try:
184
+ # Process with sentiment analysis
185
+ result = score_pipe(doc, truncation=True, max_length=512)
 
 
 
 
 
 
 
186
 
187
+ # If it's a list, get the first element
188
+ if isinstance(result, list):
189
+ result = result[0]
190
+
191
+ processed_docs.append(doc)
192
+ scored_results.append(result)
193
 
194
  except Exception as e:
195
+ st.warning(f"Error scoring document {i}: {str(e)}")
196
+ processed_docs.append("Error processing this document")
197
+ scored_results.append({"label": "NEUTRAL", "score": 1})
198
+
199
+ # Display occasional status updates
200
+ if i % max(1, len(candidate_docs) // 10) == 0:
201
+ status_text.markdown(f"**πŸ” Scoring documents... ({i}/{len(candidate_docs)})**")
202
 
203
+ # Pair documents with scores
204
+ scored_docs = list(zip(processed_docs, [result.get("score", 0.5) for result in scored_results]))
205
+
206
+ # Clear sentiment model from memory
207
+ del score_pipe
208
  import gc
209
  gc.collect()
210
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
+ #print_gpu_status("After sentiment model deletion, VRAM")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
+ # Load Gemma for final report generation
215
+ status_text.markdown("**πŸ“Š Loading report generation model...**")
216
+ progress_bar.progress(67)
217
+
218
+ # Make sure GPU memory is clear
219
+ clear_gpu_memory()
220
+ print_gpu_status("Before loading Gemma model, VRAM")
221
+
222
+ # Set memory optimization environment variable
223
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
224
+
225
+ # Sample or summarize the data for Gemma to avoid memory issues
226
+ status_text.markdown("**πŸ“ Preparing data for report generation...**")
227
+ progress_bar.progress(75)
228
+
229
+ import random
230
+ max_tweets = 1000
231
+ if len(scored_docs) > max_tweets:
232
+ sampled_docs = random.sample(scored_docs, max_tweets)
233
+ st.info(f"Sampling {max_tweets} out of {len(scored_docs)} tweets for report generation")
234
+ else:
235
+ sampled_docs = scored_docs
236
+
237
+ # Build prompt
238
+ messages = [
239
+ {"role": "user", "content": f"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  Generate a well-structured business report based on tweets from twitter/X with sentiment score (0: negative, 1: neutral, 2: positive) that answers Query Question and meets following Requirements.
241
  **Requirements:**
242
  - Include an introduction, key insights, and a conclusion.
 
248
  **Tweets with sentiment score:**
249
  {sampled_docs}
250
  Please ensure the report is complete and reaches approximately 1000 words.
251
+ """}
252
+ ]
253
+
254
+ # Create a process function to avoid the Triton registration issue
255
+ def process_with_gemma(messages):
256
+ # We'll define the pipeline here rather than using the cached version
257
+ # This ensures a clean library registration context
258
+ from transformers import pipeline, AutoTokenizer
259
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
+ # Set dtype explicitly
262
+ torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
263
 
264
+ try:
265
+ tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-1b-it")
266
+ pipe = pipeline(
267
+ "text-generation",
268
+ model="unsloth/gemma-3-1b-it",
269
+ tokenizer=tokenizer,
270
+ device=0 if torch.cuda.is_available() else -1,
271
+ torch_dtype=torch_dtype
272
+
273
+ )
274
+
275
+ result = pipe(messages, max_new_tokens=1500, repetition_penalty=1.2, do_sample=True, temperature=0.7, return_full_text=False)
276
+ return result, None
277
+
278
+ except Exception as e:
279
+ return None, str(e)
280
+
281
+ # Try to process with Gemma
282
+ status_text.markdown("**πŸ“ Generating report with Gemma...**")
283
+ progress_bar.progress(80)
284
+
285
+ raw_result, error = process_with_gemma(messages)
286
+
287
+ if error:
288
+ st.error(f"Gemma processing failed: {str(error)}")
289
+ report = "Error generating report. Please try again with fewer tweets."
290
+ else:
291
+ # Extract content from successful Gemma result
292
+ report = raw_result[0]['generated_text']
293
+ #extract_assistant_content(raw_result)
294
+
295
+ progress_bar.progress(100)
296
+ status_text.success("**βœ… Generation complete!**")
297
+ html("<script>localStorage.setItem('freezeTimer', 'true');</script>", height=0)
298
+ st.session_state.timer_frozen = True
299
+
300
+ # First, create the replacement separately
301
+ formatted_report = report.replace('\\n', '<br>')
302
+
303
+ # Display title separately with standard formatting
304
+ st.subheader("Generated Report:")
305
+
306
+ # Display the report content with normal styling
307
+ st.markdown(f"<div style='font-size: normal; font-weight: normal;'>{formatted_report}</div>", unsafe_allow_html=True)
308
 
309
+ # Run the Main Function
310
+ if __name__ == '__main__':
311
+ main()