tecuts commited on
Commit
13046df
·
verified ·
1 Parent(s): db547a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -159
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import json
3
  import asyncio
4
  import requests
 
5
  from datetime import datetime
6
  from typing import List, Dict, Optional
7
  from fastapi import FastAPI, Request, HTTPException, Depends
@@ -48,22 +49,27 @@ GOOGLE_CX = os.getenv("GOOGLE_CX")
48
  LLM_API_KEY = os.getenv("LLM_API_KEY")
49
  LLM_BASE_URL = os.getenv("LLM_BASE_URL", "https://api-15i2e8ze256bvfn6.aistudio-app.com/v1")
50
 
51
- # --- Simplified System Prompts ---
52
  SYSTEM_PROMPT_WITH_SEARCH = """You are an intelligent AI assistant with access to real-time web search capabilities.
53
 
54
- When search tools are available, use them for queries that need current, recent, or specific factual information.
 
 
 
 
 
55
 
56
  **Response Guidelines:**
57
- 1. Use search tools when available and relevant
58
- 2. Synthesize information from multiple sources
59
- 3. Clearly indicate when information comes from search results
60
- 4. Provide comprehensive, well-structured answers
61
- 5. Cite sources appropriately
 
62
 
63
  Current date: {current_date}"""
64
 
65
  SYSTEM_PROMPT_NO_SEARCH = """You are an intelligent AI assistant. Provide helpful, accurate, and comprehensive responses based on your training data.
66
-
67
  Current date: {current_date}"""
68
 
69
  # --- Optimized Web Search Tool ---
@@ -132,6 +138,40 @@ def format_search_results_compact(search_results: List[Dict]) -> str:
132
 
133
  return "\n".join(formatted)
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  # --- FastAPI Application Setup ---
136
  app = FastAPI(title="Streaming AI Chatbot", version="2.1.0")
137
 
@@ -156,7 +196,7 @@ else:
156
  client = OpenAI(api_key=LLM_API_KEY, base_url=LLM_BASE_URL)
157
  logger.info("OpenAI client initialized successfully")
158
 
159
- # --- Tool Definition ---
160
  available_tools = [
161
  {
162
  "type": "function",
@@ -177,11 +217,40 @@ available_tools = [
177
  }
178
  ]
179
 
180
- # --- Streaming Response Generator ---
181
- async def generate_streaming_response(messages: List[Dict], use_search: bool, temperature: float):
182
- """Generate streaming response with optional search"""
183
 
184
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  # Initial LLM call with streaming
186
  llm_kwargs = {
187
  "model": "unsloth/Qwen3-30B-A3B-GGUF",
@@ -191,15 +260,17 @@ async def generate_streaming_response(messages: List[Dict], use_search: bool, te
191
  "stream": True
192
  }
193
 
194
- if use_search:
 
195
  llm_kwargs["tools"] = available_tools
196
  llm_kwargs["tool_choice"] = "auto"
197
 
198
- source_links = []
199
  response_content = ""
200
  tool_calls_data = []
201
 
202
- # First streaming call
 
 
203
  stream = client.chat.completions.create(**llm_kwargs)
204
 
205
  for chunk in stream:
@@ -209,9 +280,21 @@ async def generate_streaming_response(messages: List[Dict], use_search: bool, te
209
  if delta.content:
210
  content_chunk = delta.content
211
  response_content += content_chunk
 
 
 
 
 
 
 
 
 
 
 
 
212
  yield f"data: {json.dumps({'type': 'content', 'data': content_chunk})}\n\n"
213
 
214
- # Handle tool calls
215
  if delta.tool_calls:
216
  for tool_call in delta.tool_calls:
217
  if len(tool_calls_data) <= tool_call.index:
@@ -225,11 +308,56 @@ async def generate_streaming_response(messages: List[Dict], use_search: bool, te
225
  if tool_call.function.arguments:
226
  tool_calls_data[tool_call.index]["function"]["arguments"] += tool_call.function.arguments
227
 
228
- # Process tool calls if any
229
- if tool_calls_data and any(tc["function"]["name"] for tc in tool_calls_data):
230
- yield f"data: {json.dumps({'type': 'status', 'data': 'Searching...'})}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
- # Execute searches concurrently for speed
233
  search_tasks = []
234
  for tool_call in tool_calls_data:
235
  if tool_call["function"]["name"] == "google_search":
@@ -238,14 +366,13 @@ async def generate_streaming_response(messages: List[Dict], use_search: bool, te
238
  query = args.get("query", "").strip()
239
  if query:
240
  search_tasks.append(google_search_tool_async(query))
 
241
  except json.JSONDecodeError:
242
  continue
243
 
244
- # Run searches concurrently
245
  if search_tasks:
246
  search_results_list = await asyncio.gather(*search_tasks, return_exceptions=True)
247
 
248
- # Combine all search results
249
  all_results = []
250
  for results in search_results_list:
251
  if isinstance(results, list):
@@ -257,19 +384,14 @@ async def generate_streaming_response(messages: List[Dict], use_search: bool, te
257
  "domain": result["domain"]
258
  })
259
 
260
- # Format search results
261
  if all_results:
262
  search_context = format_search_results_compact(all_results)
263
 
264
- # Create new message with search context
265
  search_messages = messages + [{
266
  "role": "system",
267
  "content": f"{search_context}\n\nPlease provide a comprehensive response based on the search results above."
268
  }]
269
 
270
- yield f"data: {json.dumps({'type': 'status', 'data': 'Generating response...'})}\n\n"
271
-
272
- # Generate final response with search context
273
  final_stream = client.chat.completions.create(
274
  model="unsloth/Qwen3-30B-A3B-GGUF",
275
  temperature=temperature,
@@ -282,12 +404,14 @@ async def generate_streaming_response(messages: List[Dict], use_search: bool, te
282
  if chunk.choices[0].delta.content:
283
  content = chunk.choices[0].delta.content
284
  yield f"data: {json.dumps({'type': 'content', 'data': content})}\n\n"
 
 
285
 
286
  # Send sources and completion
287
  if source_links:
288
  yield f"data: {json.dumps({'type': 'sources', 'data': source_links})}\n\n"
289
 
290
- yield f"data: {json.dumps({'type': 'done', 'data': {'search_used': bool(source_links)}})}\n\n"
291
 
292
  except Exception as e:
293
  logger.error(f"Streaming error: {e}")
@@ -302,8 +426,8 @@ async def chat_stream_endpoint(request: Request, _: None = Depends(verify_origin
302
  try:
303
  data = await request.json()
304
  user_message = data.get("message", "").strip()
305
- use_search = data.get("use_search", False) # Default: False
306
- temperature = max(0, min(2, data.get("temperature", 0.7))) # Clamp to valid range
307
  conversation_history = data.get("history", [])
308
 
309
  if not user_message:
@@ -314,15 +438,15 @@ async def chat_stream_endpoint(request: Request, _: None = Depends(verify_origin
314
  system_content = (SYSTEM_PROMPT_WITH_SEARCH if use_search else SYSTEM_PROMPT_NO_SEARCH).format(current_date=current_date)
315
  messages = [{"role": "system", "content": system_content}] + conversation_history + [{"role": "user", "content": user_message}]
316
 
317
- logger.info(f"Stream request - search: {use_search}, temp: {temperature}")
318
 
319
  return StreamingResponse(
320
- generate_streaming_response(messages, use_search, temperature),
321
  media_type="text/plain",
322
  headers={
323
  "Cache-Control": "no-cache",
324
  "Connection": "keep-alive",
325
- "X-Accel-Buffering": "no" # Disable nginx buffering
326
  }
327
  )
328
 
@@ -330,129 +454,4 @@ async def chat_stream_endpoint(request: Request, _: None = Depends(verify_origin
330
  raise HTTPException(status_code=400, detail="Invalid JSON")
331
  except Exception as e:
332
  logger.error(f"Stream endpoint error: {e}")
333
- raise HTTPException(status_code=500, detail=str(e))
334
-
335
- # --- Regular Chat Endpoint (for backward compatibility) ---
336
- @app.post("/chat")
337
- async def chat_endpoint(request: Request, _: None = Depends(verify_origin)):
338
- if not client:
339
- raise HTTPException(status_code=500, detail="LLM client not configured")
340
-
341
- try:
342
- data = await request.json()
343
- user_message = data.get("message", "").strip()
344
- use_search = data.get("use_search", False) # Default: False
345
- temperature = max(0, min(2, data.get("temperature", 0.7)))
346
- conversation_history = data.get("history", [])
347
-
348
- if not user_message:
349
- raise HTTPException(status_code=400, detail="No message provided")
350
-
351
- # Prepare messages
352
- current_date = datetime.now().strftime("%Y-%m-%d")
353
- system_content = (SYSTEM_PROMPT_WITH_SEARCH if use_search else SYSTEM_PROMPT_NO_SEARCH).format(current_date=current_date)
354
- messages = [{"role": "system", "content": system_content}] + conversation_history + [{"role": "user", "content": user_message}]
355
-
356
- source_links = []
357
-
358
- if use_search:
359
- # Search-enabled flow (non-streaming for compatibility)
360
- llm_response = client.chat.completions.create(
361
- model="unsloth/Qwen3-30B-A3B-GGUF",
362
- temperature=temperature,
363
- messages=messages,
364
- tools=available_tools,
365
- tool_choice="auto",
366
- max_tokens=2000
367
- )
368
-
369
- tool_calls = llm_response.choices[0].message.tool_calls
370
-
371
- if tool_calls:
372
- # Execute searches
373
- search_tasks = []
374
- for tool_call in tool_calls:
375
- if tool_call.function.name == "google_search":
376
- try:
377
- args = json.loads(tool_call.function.arguments)
378
- query = args.get("query", "").strip()
379
- if query:
380
- search_tasks.append(google_search_tool_async(query))
381
- except json.JSONDecodeError:
382
- continue
383
-
384
- if search_tasks:
385
- search_results_list = await asyncio.gather(*search_tasks, return_exceptions=True)
386
- all_results = []
387
- for results in search_results_list:
388
- if isinstance(results, list):
389
- all_results.extend(results)
390
- for result in results:
391
- source_links.append({
392
- "title": result["source_title"],
393
- "url": result["url"],
394
- "domain": result["domain"]
395
- })
396
-
397
- if all_results:
398
- search_context = format_search_results_compact(all_results)
399
- search_messages = messages + [{
400
- "role": "system",
401
- "content": f"{search_context}\n\nPlease provide a comprehensive response based on the search results above."
402
- }]
403
-
404
- final_response = client.chat.completions.create(
405
- model="unsloth/Qwen3-30B-A3B-GGUF",
406
- temperature=temperature,
407
- messages=search_messages,
408
- max_tokens=2000
409
- )
410
- final_content = final_response.choices[0].message.content
411
- else:
412
- final_content = llm_response.choices[0].message.content
413
- else:
414
- final_content = llm_response.choices[0].message.content
415
- else:
416
- final_content = llm_response.choices[0].message.content
417
- else:
418
- # No search - direct response
419
- llm_response = client.chat.completions.create(
420
- model="unsloth/Qwen3-30B-A3B-GGUF",
421
- temperature=temperature,
422
- messages=messages,
423
- max_tokens=2000
424
- )
425
- final_content = llm_response.choices[0].message.content
426
-
427
- return {
428
- "response": final_content,
429
- "sources": source_links,
430
- "search_used": bool(source_links),
431
- "temperature": temperature,
432
- "timestamp": datetime.now().isoformat()
433
- }
434
-
435
- except Exception as e:
436
- logger.error(f"Chat endpoint error: {e}")
437
- raise HTTPException(status_code=500, detail=str(e))
438
-
439
- # --- Health Check Endpoints ---
440
- @app.get("/")
441
- async def root():
442
- return {
443
- "message": "Streaming AI Chatbot API",
444
- "version": "2.1.0",
445
- "endpoints": ["/chat", "/chat/stream"],
446
- "timestamp": datetime.now().isoformat()
447
- }
448
-
449
- @app.get("/health")
450
- async def health_check():
451
- return {
452
- "status": "healthy",
453
- "timestamp": datetime.now().isoformat(),
454
- "services": {
455
- "llm_client": client is not None,
456
- "google_search": bool(GOOGLE_API_KEY and GOOGLE_CX)
457
- }
458
- }
 
2
  import json
3
  import asyncio
4
  import requests
5
+ import re
6
  from datetime import datetime
7
  from typing import List, Dict, Optional
8
  from fastapi import FastAPI, Request, HTTPException, Depends
 
49
  LLM_API_KEY = os.getenv("LLM_API_KEY")
50
  LLM_BASE_URL = os.getenv("LLM_BASE_URL", "https://api-15i2e8ze256bvfn6.aistudio-app.com/v1")
51
 
52
+ # --- Enhanced System Prompts ---
53
  SYSTEM_PROMPT_WITH_SEARCH = """You are an intelligent AI assistant with access to real-time web search capabilities.
54
 
55
+ IMPORTANT: When you need current information, recent events, or specific facts that might be outdated, you should explicitly request a search by including the phrase "SEARCH_NEEDED:" followed by your search query in your response.
56
+
57
+ For example:
58
+ - If asked about recent news: "SEARCH_NEEDED: latest news about [topic]"
59
+ - If asked about current events: "SEARCH_NEEDED: current status of [event]"
60
+ - If asked about recent developments: "SEARCH_NEEDED: recent developments in [field]"
61
 
62
  **Response Guidelines:**
63
+ 1. Use search for queries that need current, recent, or specific factual information
64
+ 2. Be proactive in identifying when search is needed
65
+ 3. Synthesize information from multiple sources when search results are provided
66
+ 4. Clearly indicate when information comes from search results
67
+ 5. Provide comprehensive, well-structured answers
68
+ 6. Cite sources appropriately
69
 
70
  Current date: {current_date}"""
71
 
72
  SYSTEM_PROMPT_NO_SEARCH = """You are an intelligent AI assistant. Provide helpful, accurate, and comprehensive responses based on your training data.
 
73
  Current date: {current_date}"""
74
 
75
  # --- Optimized Web Search Tool ---
 
138
 
139
  return "\n".join(formatted)
140
 
141
+ # --- Check if query needs search ---
142
+ def should_search(query: str, use_search: bool) -> Optional[str]:
143
+ """Determine if a query needs search and extract search terms"""
144
+ if not use_search:
145
+ return None
146
+
147
+ # Keywords that typically require current information
148
+ current_keywords = [
149
+ 'today', 'recent', 'latest', 'current', 'now', 'this year', '2024', '2025',
150
+ 'news', 'happening', 'update', 'development', 'status', 'price', 'stock',
151
+ 'weather', 'score', 'result', 'election', 'covid', 'pandemic'
152
+ ]
153
+
154
+ query_lower = query.lower()
155
+
156
+ # Check for current-info keywords
157
+ if any(keyword in query_lower for keyword in current_keywords):
158
+ return query
159
+
160
+ # Check for questions about specific companies, products, or events
161
+ question_patterns = [
162
+ r'what.*happened.*',
163
+ r'when.*did.*',
164
+ r'how.*is.*doing',
165
+ r'what.*the.*status',
166
+ r'is.*still.*',
167
+ r'has.*been.*',
168
+ ]
169
+
170
+ if any(re.search(pattern, query_lower) for pattern in question_patterns):
171
+ return query
172
+
173
+ return None
174
+
175
  # --- FastAPI Application Setup ---
176
  app = FastAPI(title="Streaming AI Chatbot", version="2.1.0")
177
 
 
196
  client = OpenAI(api_key=LLM_API_KEY, base_url=LLM_BASE_URL)
197
  logger.info("OpenAI client initialized successfully")
198
 
199
+ # --- Tool Definition (keeping for potential future use) ---
200
  available_tools = [
201
  {
202
  "type": "function",
 
217
  }
218
  ]
219
 
220
+ # --- Enhanced Streaming Response Generator ---
221
+ async def generate_streaming_response(messages: List[Dict], use_search: bool, temperature: float, original_query: str):
222
+ """Generate streaming response with intelligent search triggering"""
223
 
224
  try:
225
+ source_links = []
226
+ search_performed = False
227
+
228
+ # Check if we should proactively search
229
+ proactive_search_query = should_search(original_query, use_search)
230
+ if proactive_search_query:
231
+ logger.info(f"Proactive search triggered for: {proactive_search_query}")
232
+ yield f"data: {json.dumps({'type': 'status', 'data': 'Searching for current information...'})}\n\n"
233
+
234
+ search_results = await google_search_tool_async(proactive_search_query, 4)
235
+ if search_results:
236
+ search_context = format_search_results_compact(search_results)
237
+
238
+ # Add search context to messages
239
+ enhanced_messages = messages + [{
240
+ "role": "system",
241
+ "content": f"Recent search results for your reference:\n\n{search_context}\n\nPlease use this information to provide a comprehensive and up-to-date response."
242
+ }]
243
+
244
+ for result in search_results:
245
+ source_links.append({
246
+ "title": result["source_title"],
247
+ "url": result["url"],
248
+ "domain": result["domain"]
249
+ })
250
+
251
+ search_performed = True
252
+ messages = enhanced_messages
253
+
254
  # Initial LLM call with streaming
255
  llm_kwargs = {
256
  "model": "unsloth/Qwen3-30B-A3B-GGUF",
 
260
  "stream": True
261
  }
262
 
263
+ # Try function calling as backup (in case model supports it)
264
+ if use_search and not search_performed:
265
  llm_kwargs["tools"] = available_tools
266
  llm_kwargs["tool_choice"] = "auto"
267
 
 
268
  response_content = ""
269
  tool_calls_data = []
270
 
271
+ yield f"data: {json.dumps({'type': 'status', 'data': 'Generating response...'})}\n\n"
272
+
273
+ # Stream the response
274
  stream = client.chat.completions.create(**llm_kwargs)
275
 
276
  for chunk in stream:
 
280
  if delta.content:
281
  content_chunk = delta.content
282
  response_content += content_chunk
283
+
284
+ # Check for search requests in the content
285
+ if use_search and not search_performed and "SEARCH_NEEDED:" in content_chunk:
286
+ # Extract search query from the content
287
+ search_match = re.search(r'SEARCH_NEEDED:\s*(.+?)(?:\n|$)', content_chunk)
288
+ if search_match:
289
+ search_query = search_match.group(1).strip()
290
+ logger.info(f"Search requested by model: {search_query}")
291
+
292
+ # Don't yield this chunk yet, we'll search first
293
+ continue
294
+
295
  yield f"data: {json.dumps({'type': 'content', 'data': content_chunk})}\n\n"
296
 
297
+ # Handle tool calls (backup method)
298
  if delta.tool_calls:
299
  for tool_call in delta.tool_calls:
300
  if len(tool_calls_data) <= tool_call.index:
 
308
  if tool_call.function.arguments:
309
  tool_calls_data[tool_call.index]["function"]["arguments"] += tool_call.function.arguments
310
 
311
+ # Handle model-requested search
312
+ if use_search and not search_performed and "SEARCH_NEEDED:" in response_content:
313
+ search_matches = re.findall(r'SEARCH_NEEDED:\s*(.+?)(?:\n|$)', response_content)
314
+ if search_matches:
315
+ yield f"data: {json.dumps({'type': 'status', 'data': 'Performing requested search...'})}\n\n"
316
+
317
+ # Execute all requested searches
318
+ search_tasks = [google_search_tool_async(query.strip()) for query in search_matches]
319
+ search_results_list = await asyncio.gather(*search_tasks, return_exceptions=True)
320
+
321
+ all_results = []
322
+ for results in search_results_list:
323
+ if isinstance(results, list):
324
+ all_results.extend(results)
325
+
326
+ if all_results:
327
+ search_context = format_search_results_compact(all_results)
328
+
329
+ for result in all_results:
330
+ source_links.append({
331
+ "title": result["source_title"],
332
+ "url": result["url"],
333
+ "domain": result["domain"]
334
+ })
335
+
336
+ # Generate new response with search results
337
+ search_messages = messages + [{
338
+ "role": "system",
339
+ "content": f"Search Results:\n\n{search_context}\n\nPlease provide a comprehensive response based on these search results."
340
+ }]
341
+
342
+ final_stream = client.chat.completions.create(
343
+ model="unsloth/Qwen3-30B-A3B-GGUF",
344
+ temperature=temperature,
345
+ messages=search_messages,
346
+ max_tokens=2000,
347
+ stream=True
348
+ )
349
+
350
+ for chunk in final_stream:
351
+ if chunk.choices[0].delta.content:
352
+ content = chunk.choices[0].delta.content
353
+ yield f"data: {json.dumps({'type': 'content', 'data': content})}\n\n"
354
+
355
+ search_performed = True
356
+
357
+ # Process function-based tool calls (backup method)
358
+ elif tool_calls_data and any(tc["function"]["name"] for tc in tool_calls_data):
359
+ yield f"data: {json.dumps({'type': 'status', 'data': 'Executing search tools...'})}\n\n"
360
 
 
361
  search_tasks = []
362
  for tool_call in tool_calls_data:
363
  if tool_call["function"]["name"] == "google_search":
 
366
  query = args.get("query", "").strip()
367
  if query:
368
  search_tasks.append(google_search_tool_async(query))
369
+ logger.info(f"Function call search: {query}")
370
  except json.JSONDecodeError:
371
  continue
372
 
 
373
  if search_tasks:
374
  search_results_list = await asyncio.gather(*search_tasks, return_exceptions=True)
375
 
 
376
  all_results = []
377
  for results in search_results_list:
378
  if isinstance(results, list):
 
384
  "domain": result["domain"]
385
  })
386
 
 
387
  if all_results:
388
  search_context = format_search_results_compact(all_results)
389
 
 
390
  search_messages = messages + [{
391
  "role": "system",
392
  "content": f"{search_context}\n\nPlease provide a comprehensive response based on the search results above."
393
  }]
394
 
 
 
 
395
  final_stream = client.chat.completions.create(
396
  model="unsloth/Qwen3-30B-A3B-GGUF",
397
  temperature=temperature,
 
404
  if chunk.choices[0].delta.content:
405
  content = chunk.choices[0].delta.content
406
  yield f"data: {json.dumps({'type': 'content', 'data': content})}\n\n"
407
+
408
+ search_performed = True
409
 
410
  # Send sources and completion
411
  if source_links:
412
  yield f"data: {json.dumps({'type': 'sources', 'data': source_links})}\n\n"
413
 
414
+ yield f"data: {json.dumps({'type': 'done', 'data': {'search_used': search_performed}})}\n\n"
415
 
416
  except Exception as e:
417
  logger.error(f"Streaming error: {e}")
 
426
  try:
427
  data = await request.json()
428
  user_message = data.get("message", "").strip()
429
+ use_search = data.get("use_search", False)
430
+ temperature = max(0, min(2, data.get("temperature", 0.7)))
431
  conversation_history = data.get("history", [])
432
 
433
  if not user_message:
 
438
  system_content = (SYSTEM_PROMPT_WITH_SEARCH if use_search else SYSTEM_PROMPT_NO_SEARCH).format(current_date=current_date)
439
  messages = [{"role": "system", "content": system_content}] + conversation_history + [{"role": "user", "content": user_message}]
440
 
441
+ logger.info(f"Stream request - search: {use_search}, temp: {temperature}, query: {user_message[:50]}...")
442
 
443
  return StreamingResponse(
444
+ generate_streaming_response(messages, use_search, temperature, user_message),
445
  media_type="text/plain",
446
  headers={
447
  "Cache-Control": "no-cache",
448
  "Connection": "keep-alive",
449
+ "X-Accel-Buffering": "no"
450
  }
451
  )
452
 
 
454
  raise HTTPException(status_code=400, detail="Invalid JSON")
455
  except Exception as e:
456
  logger.error(f"Stream endpoint error: {e}")
457
+ raise HTTPException(status_code=500, detail=str(e))