import gradio as gr import requests import json import os import warnings from huggingface_hub import InferenceClient # 抑制 asyncio 警告 warnings.filterwarnings('ignore', category=DeprecationWarning) os.environ['PYTHONWARNINGS'] = 'ignore' # 如果在 GPU 环境但不需要 GPU,禁用 CUDA if 'CUDA_VISIBLE_DEVICES' not in os.environ: os.environ['CUDA_VISIBLE_DEVICES'] = '' # ========== MCP 工具简化定义(符合MCP协议标准) ========== MCP_TOOLS = [ {"type": "function", "function": {"name": "advanced_search_company", "description": "Search US companies", "parameters": {"type": "object", "properties": {"company_input": {"type": "string"}}, "required": ["company_input"]}}}, {"type": "function", "function": {"name": "get_latest_financial_data", "description": "Get latest financial data", "parameters": {"type": "object", "properties": {"cik": {"type": "string"}}, "required": ["cik"]}}}, {"type": "function", "function": {"name": "extract_financial_metrics", "description": "Get multi-year trends", "parameters": {"type": "object", "properties": {"cik": {"type": "string"}, "years": {"type": "integer"}}, "required": ["cik", "years"]}}}, {"type": "function", "function": {"name": "get_quote", "description": "Get stock quote", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}}, {"type": "function", "function": {"name": "get_market_news", "description": "Get market news", "parameters": {"type": "object", "properties": {"category": {"type": "string"}}, "required": ["category"]}}}, {"type": "function", "function": {"name": "get_company_news", "description": "Get company news", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}, "from_date": {"type": "string"}, "to_date": {"type": "string"}}, "required": ["symbol"]}}} ] # ========== MCP 服务配置 ========== MCP_SERVICES = { "financial": {"url": "http://localhost:7861/mcp", "type": "fastmcp"}, "market": {"url": "https://jc321-marketandstockmcp.hf.space", "type": "gradio"} } TOOL_ROUTING = { "advanced_search_company": MCP_SERVICES["financial"], "get_latest_financial_data": MCP_SERVICES["financial"], "extract_financial_metrics": MCP_SERVICES["financial"], "get_quote": MCP_SERVICES["market"], "get_market_news": MCP_SERVICES["market"], "get_company_news": MCP_SERVICES["market"] } # ========== 初始化 LLM 客户端 ========== hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") client = InferenceClient(api_key=hf_token) if hf_token else InferenceClient() print(f"✅ LLM initialized: Qwen/Qwen3-32B:groq") print(f"📊 MCP Services: {len(MCP_SERVICES)} services, {len(MCP_TOOLS)} tools") # ========== Token 限制配置 ========== # HuggingFace Inference API 实际限制约 8000-16000 tokens # 为了安全,设置更低的限制 MAX_TOTAL_TOKENS = 6000 # 总上下文限制 MAX_TOOL_RESULT_CHARS = 1500 # 工具返回最大字符数 (增加到1500) MAX_HISTORY_CHARS = 500 # 单条历史消息最大字符数 MAX_HISTORY_TURNS = 2 # 最大历史轮数 MAX_TOOL_ITERATIONS = 6 # 最大工具调用轮数 (增加到6,支持多工具调用) MAX_OUTPUT_TOKENS = 2000 # 最大输出 tokens (增加到2000) def estimate_tokens(text): """估算文本 token 数量(粗略:1 token ≈ 2 字符)""" return len(str(text)) // 2 def truncate_text(text, max_chars, suffix="...[truncated]"): """截断文本到指定长度""" text = str(text) if len(text) <= max_chars: return text return text[:max_chars] + suffix def get_system_prompt(): """生成包含当前日期的系统提示词(精简版)""" from datetime import datetime current_date = datetime.now().strftime("%Y-%m-%d") return f"""Financial analyst. Today: {current_date}. Use tools for company data, stock prices, news. Be concise.""" # ============================================================ # MCP 服务调用核心代码区 # 支持 FastMCP (JSON-RPC) 和 Gradio (SSE) 两种协议 # ============================================================ def call_mcp_tool(tool_name, arguments): """调用 MCP 工具""" service_config = TOOL_ROUTING.get(tool_name) if not service_config: return {"error": f"Unknown tool: {tool_name}"} try: if service_config["type"] == "fastmcp": return _call_fastmcp(service_config["url"], tool_name, arguments) elif service_config["type"] == "gradio": return _call_gradio_api(service_config["url"], tool_name, arguments) else: return {"error": "Unknown service type"} except Exception as e: return {"error": str(e)} def _call_fastmcp(service_url, tool_name, arguments): """FastMCP: 标准 MCP JSON-RPC""" response = requests.post( service_url, json={"jsonrpc": "2.0", "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}, "id": 1}, headers={"Content-Type": "application/json"}, timeout=30 ) if response.status_code != 200: return {"error": f"HTTP {response.status_code}"} data = response.json() # 解包 MCP 协议: jsonrpc -> result -> content[0].text -> JSON if isinstance(data, dict) and "result" in data: result = data["result"] if isinstance(result, dict) and "content" in result: content = result["content"] if isinstance(content, list) and len(content) > 0: first_item = content[0] if isinstance(first_item, dict) and "text" in first_item: try: return json.loads(first_item["text"]) except (json.JSONDecodeError, TypeError): return {"text": first_item["text"]} return result return data def _call_gradio_api(service_url, tool_name, arguments): """Gradio: SSE 流式协议""" tool_map = {"get_quote": "test_quote_tool", "get_market_news": "test_market_news_tool", "get_company_news": "test_company_news_tool"} gradio_fn = tool_map.get(tool_name) if not gradio_fn: return {"error": "No mapping"} # 构造参数 if tool_name == "get_quote": params = [arguments.get("symbol", "")] elif tool_name == "get_market_news": params = [arguments.get("category", "general")] elif tool_name == "get_company_news": params = [arguments.get("symbol", ""), arguments.get("from_date", ""), arguments.get("to_date", "")] else: params = [] # 提交请求 call_url = f"{service_url}/call/{gradio_fn}" resp = requests.post(call_url, json={"data": params}, timeout=10) if resp.status_code != 200: return {"error": f"HTTP {resp.status_code}"} event_id = resp.json().get("event_id") if not event_id: return {"error": "No event_id"} # 获取结果 (SSE) result_resp = requests.get(f"{call_url}/{event_id}", stream=True, timeout=20) if result_resp.status_code != 200: return {"error": f"HTTP {result_resp.status_code}"} # 解析 SSE for line in result_resp.iter_lines(): if line and line.decode('utf-8').startswith('data: '): try: result_data = json.loads(line.decode('utf-8')[6:]) if isinstance(result_data, list) and len(result_data) > 0: return {"text": result_data[0]} except json.JSONDecodeError: continue return {"error": "No result"} # ============================================================ # End of MCP 服务调用代码区 # ============================================================ def chatbot_response(message, history): """AI 助手主函数(流式输出,性能优化)""" try: messages = [{"role": "system", "content": get_system_prompt()}] # 添加历史(最近2轮) - 严格限制上下文长度 if history: for item in history[-MAX_HISTORY_TURNS:]: if isinstance(item, (list, tuple)) and len(item) == 2: # 用户消息(不截断) messages.append({"role": "user", "content": item[0]}) # 助手回复(严格截断) assistant_msg = str(item[1]) if len(assistant_msg) > MAX_HISTORY_CHARS: assistant_msg = truncate_text(assistant_msg, MAX_HISTORY_CHARS) messages.append({"role": "assistant", "content": assistant_msg}) messages.append({"role": "user", "content": message}) tool_calls_log = [] # LLM 调用循环(支持多轮工具调用) final_response_content = None for iteration in range(MAX_TOOL_ITERATIONS): response = client.chat.completions.create( model="Qwen/Qwen3-32B:groq", messages=messages, tools=MCP_TOOLS, max_tokens=MAX_OUTPUT_TOKENS, temperature=0.5, tool_choice="auto", stream=False ) choice = response.choices[0] if choice.message.tool_calls: messages.append(choice.message) for tool_call in choice.message.tool_calls: tool_name = tool_call.function.name try: tool_args = json.loads(tool_call.function.arguments) except json.JSONDecodeError: tool_args = {} # 调用 MCP 工具 tool_result = call_mcp_tool(tool_name, tool_args) # 检查错误 if isinstance(tool_result, dict) and "error" in tool_result: # 工具调用失败,记录错误 tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result, "error": True}) result_for_llm = json.dumps({"error": tool_result.get("error", "Unknown error")}, ensure_ascii=False) else: # 限制返回结果大小 result_str = json.dumps(tool_result, ensure_ascii=False) if len(result_str) > MAX_TOOL_RESULT_CHARS: if isinstance(tool_result, dict) and "text" in tool_result: truncated_text = truncate_text(tool_result["text"], MAX_TOOL_RESULT_CHARS - 50) tool_result_truncated = {"text": truncated_text, "_truncated": True} elif isinstance(tool_result, dict): truncated = {} char_count = 0 for k, v in list(tool_result.items())[:8]: # 保留前8个字段 v_str = str(v)[:300] # 每个值最多300字符 truncated[k] = v_str char_count += len(k) + len(v_str) if char_count > MAX_TOOL_RESULT_CHARS: break tool_result_truncated = {**truncated, "_truncated": True} else: tool_result_truncated = {"preview": truncate_text(result_str, MAX_TOOL_RESULT_CHARS), "_truncated": True} result_for_llm = json.dumps(tool_result_truncated, ensure_ascii=False) else: result_for_llm = result_str # 记录成功的工具调用 tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result}) messages.append({ "role": "tool", "name": tool_name, "content": result_for_llm, "tool_call_id": tool_call.id }) continue else: # 没有更多工具调用,保存最终答案 final_response_content = choice.message.content break # 构建响应前缀(简化版) response_prefix = "" # 显示工具调用(使用原生HTML details标签) if tool_calls_log: response_prefix += """