JC321's picture
Test MCP integration
f833e71
import gradio as gr
import requests
import json
import os
import warnings
from huggingface_hub import InferenceClient
# 抑制 asyncio 警告
warnings.filterwarnings('ignore', category=DeprecationWarning)
os.environ['PYTHONWARNINGS'] = 'ignore'
# 如果在 GPU 环境但不需要 GPU,禁用 CUDA
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
os.environ['CUDA_VISIBLE_DEVICES'] = ''
# ========== MCP 工具简化定义(符合MCP协议标准) ==========
MCP_TOOLS = [
{"type": "function", "function": {"name": "advanced_search_company", "description": "Search US companies", "parameters": {"type": "object", "properties": {"company_input": {"type": "string"}}, "required": ["company_input"]}}},
{"type": "function", "function": {"name": "get_latest_financial_data", "description": "Get latest financial data", "parameters": {"type": "object", "properties": {"cik": {"type": "string"}}, "required": ["cik"]}}},
{"type": "function", "function": {"name": "extract_financial_metrics", "description": "Get multi-year trends", "parameters": {"type": "object", "properties": {"cik": {"type": "string"}, "years": {"type": "integer"}}, "required": ["cik", "years"]}}},
{"type": "function", "function": {"name": "get_quote", "description": "Get stock quote", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}},
{"type": "function", "function": {"name": "get_market_news", "description": "Get market news", "parameters": {"type": "object", "properties": {"category": {"type": "string"}}, "required": ["category"]}}},
{"type": "function", "function": {"name": "get_company_news", "description": "Get company news", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}, "from_date": {"type": "string"}, "to_date": {"type": "string"}}, "required": ["symbol"]}}}
]
# ========== MCP 服务配置 ==========
MCP_SERVICES = {
"financial": {"url": "http://localhost:7861/mcp", "type": "fastmcp"},
"market": {"url": "https://jc321-marketandstockmcp.hf.space", "type": "gradio"}
}
TOOL_ROUTING = {
"advanced_search_company": MCP_SERVICES["financial"],
"get_latest_financial_data": MCP_SERVICES["financial"],
"extract_financial_metrics": MCP_SERVICES["financial"],
"get_quote": MCP_SERVICES["market"],
"get_market_news": MCP_SERVICES["market"],
"get_company_news": MCP_SERVICES["market"]
}
# ========== 初始化 LLM 客户端 ==========
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
client = InferenceClient(api_key=hf_token) if hf_token else InferenceClient()
print(f"✅ LLM initialized: Qwen/Qwen3-32B:groq")
print(f"📊 MCP Services: {len(MCP_SERVICES)} services, {len(MCP_TOOLS)} tools")
# ========== Token 限制配置 ==========
# HuggingFace Inference API 实际限制约 8000-16000 tokens
# 为了安全,设置更低的限制
MAX_TOTAL_TOKENS = 6000 # 总上下文限制
MAX_TOOL_RESULT_CHARS = 1500 # 工具返回最大字符数 (增加到1500)
MAX_HISTORY_CHARS = 500 # 单条历史消息最大字符数
MAX_HISTORY_TURNS = 2 # 最大历史轮数
MAX_TOOL_ITERATIONS = 6 # 最大工具调用轮数 (增加到6,支持多工具调用)
MAX_OUTPUT_TOKENS = 2000 # 最大输出 tokens (增加到2000)
def estimate_tokens(text):
"""估算文本 token 数量(粗略:1 token ≈ 2 字符)"""
return len(str(text)) // 2
def truncate_text(text, max_chars, suffix="...[truncated]"):
"""截断文本到指定长度"""
text = str(text)
if len(text) <= max_chars:
return text
return text[:max_chars] + suffix
def get_system_prompt():
"""生成包含当前日期的系统提示词(精简版)"""
from datetime import datetime
current_date = datetime.now().strftime("%Y-%m-%d")
return f"""Financial analyst. Today: {current_date}. Use tools for company data, stock prices, news. Be concise."""
# ============================================================
# MCP 服务调用核心代码区
# 支持 FastMCP (JSON-RPC) 和 Gradio (SSE) 两种协议
# ============================================================
def call_mcp_tool(tool_name, arguments):
"""调用 MCP 工具"""
service_config = TOOL_ROUTING.get(tool_name)
if not service_config:
return {"error": f"Unknown tool: {tool_name}"}
try:
if service_config["type"] == "fastmcp":
return _call_fastmcp(service_config["url"], tool_name, arguments)
elif service_config["type"] == "gradio":
return _call_gradio_api(service_config["url"], tool_name, arguments)
else:
return {"error": "Unknown service type"}
except Exception as e:
return {"error": str(e)}
def _call_fastmcp(service_url, tool_name, arguments):
"""FastMCP: 标准 MCP JSON-RPC"""
response = requests.post(
service_url,
json={"jsonrpc": "2.0", "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}, "id": 1},
headers={"Content-Type": "application/json"},
timeout=30
)
if response.status_code != 200:
return {"error": f"HTTP {response.status_code}"}
data = response.json()
# 解包 MCP 协议: jsonrpc -> result -> content[0].text -> JSON
if isinstance(data, dict) and "result" in data:
result = data["result"]
if isinstance(result, dict) and "content" in result:
content = result["content"]
if isinstance(content, list) and len(content) > 0:
first_item = content[0]
if isinstance(first_item, dict) and "text" in first_item:
try:
return json.loads(first_item["text"])
except (json.JSONDecodeError, TypeError):
return {"text": first_item["text"]}
return result
return data
def _call_gradio_api(service_url, tool_name, arguments):
"""Gradio: SSE 流式协议"""
tool_map = {"get_quote": "test_quote_tool", "get_market_news": "test_market_news_tool", "get_company_news": "test_company_news_tool"}
gradio_fn = tool_map.get(tool_name)
if not gradio_fn:
return {"error": "No mapping"}
# 构造参数
if tool_name == "get_quote":
params = [arguments.get("symbol", "")]
elif tool_name == "get_market_news":
params = [arguments.get("category", "general")]
elif tool_name == "get_company_news":
params = [arguments.get("symbol", ""), arguments.get("from_date", ""), arguments.get("to_date", "")]
else:
params = []
# 提交请求
call_url = f"{service_url}/call/{gradio_fn}"
resp = requests.post(call_url, json={"data": params}, timeout=10)
if resp.status_code != 200:
return {"error": f"HTTP {resp.status_code}"}
event_id = resp.json().get("event_id")
if not event_id:
return {"error": "No event_id"}
# 获取结果 (SSE)
result_resp = requests.get(f"{call_url}/{event_id}", stream=True, timeout=20)
if result_resp.status_code != 200:
return {"error": f"HTTP {result_resp.status_code}"}
# 解析 SSE
for line in result_resp.iter_lines():
if line and line.decode('utf-8').startswith('data: '):
try:
result_data = json.loads(line.decode('utf-8')[6:])
if isinstance(result_data, list) and len(result_data) > 0:
return {"text": result_data[0]}
except json.JSONDecodeError:
continue
return {"error": "No result"}
# ============================================================
# End of MCP 服务调用代码区
# ============================================================
def chatbot_response(message, history):
"""AI 助手主函数(流式输出,性能优化)"""
try:
messages = [{"role": "system", "content": get_system_prompt()}]
# 添加历史(最近2轮) - 严格限制上下文长度
if history:
for item in history[-MAX_HISTORY_TURNS:]:
if isinstance(item, (list, tuple)) and len(item) == 2:
# 用户消息(不截断)
messages.append({"role": "user", "content": item[0]})
# 助手回复(严格截断)
assistant_msg = str(item[1])
if len(assistant_msg) > MAX_HISTORY_CHARS:
assistant_msg = truncate_text(assistant_msg, MAX_HISTORY_CHARS)
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
tool_calls_log = []
# LLM 调用循环(支持多轮工具调用)
final_response_content = None
for iteration in range(MAX_TOOL_ITERATIONS):
response = client.chat.completions.create(
model="Qwen/Qwen3-32B:groq",
messages=messages,
tools=MCP_TOOLS,
max_tokens=MAX_OUTPUT_TOKENS,
temperature=0.5,
tool_choice="auto",
stream=False
)
choice = response.choices[0]
if choice.message.tool_calls:
messages.append(choice.message)
for tool_call in choice.message.tool_calls:
tool_name = tool_call.function.name
try:
tool_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
tool_args = {}
# 调用 MCP 工具
tool_result = call_mcp_tool(tool_name, tool_args)
# 检查错误
if isinstance(tool_result, dict) and "error" in tool_result:
# 工具调用失败,记录错误
tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result, "error": True})
result_for_llm = json.dumps({"error": tool_result.get("error", "Unknown error")}, ensure_ascii=False)
else:
# 限制返回结果大小
result_str = json.dumps(tool_result, ensure_ascii=False)
if len(result_str) > MAX_TOOL_RESULT_CHARS:
if isinstance(tool_result, dict) and "text" in tool_result:
truncated_text = truncate_text(tool_result["text"], MAX_TOOL_RESULT_CHARS - 50)
tool_result_truncated = {"text": truncated_text, "_truncated": True}
elif isinstance(tool_result, dict):
truncated = {}
char_count = 0
for k, v in list(tool_result.items())[:8]: # 保留前8个字段
v_str = str(v)[:300] # 每个值最多300字符
truncated[k] = v_str
char_count += len(k) + len(v_str)
if char_count > MAX_TOOL_RESULT_CHARS:
break
tool_result_truncated = {**truncated, "_truncated": True}
else:
tool_result_truncated = {"preview": truncate_text(result_str, MAX_TOOL_RESULT_CHARS), "_truncated": True}
result_for_llm = json.dumps(tool_result_truncated, ensure_ascii=False)
else:
result_for_llm = result_str
# 记录成功的工具调用
tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result})
messages.append({
"role": "tool",
"name": tool_name,
"content": result_for_llm,
"tool_call_id": tool_call.id
})
continue
else:
# 没有更多工具调用,保存最终答案
final_response_content = choice.message.content
break
# 构建响应前缀(简化版)
response_prefix = ""
# 显示工具调用(使用原生HTML details标签)
if tool_calls_log:
response_prefix += """<div style='margin-bottom: 15px;'>
<div style='background: #f0f0f0; padding: 8px 12px; border-radius: 6px; font-weight: 600; color: #333;'>
🛠️ Tools Used ({} calls)
</div>
""".format(len(tool_calls_log))
for idx, tool_call in enumerate(tool_calls_log):
# 预先计算 JSON 字符串,避免重复调用
args_json = json.dumps(tool_call['arguments'], ensure_ascii=False)
result_json = json.dumps(tool_call.get('result', {}), ensure_ascii=False, indent=2)
result_preview = result_json[:1500] + ('...' if len(result_json) > 1500 else '')
# 显示错误状态
error_indicator = " ❌ Error" if tool_call.get('error') else ""
# 使用原生 HTML5 details/summary 标签(不需要 JavaScript)
response_prefix += f"""<details style='margin: 8px 0; border: 1px solid #ddd; border-radius: 6px; overflow: hidden;'>
<summary style='background: #fff; padding: 10px; cursor: pointer; user-select: none; list-style: none;'>
<div style='display: flex; justify-content: space-between; align-items: center;'>
<div style='flex: 1;'>
<strong style='color: #2c5aa0;'>📌 {idx+1}. {tool_call['name']}{error_indicator}</strong>
<div style='font-size: 0.85em; color: #666; margin-top: 4px;'>📥 Input: <code style='background: #f5f5f5; padding: 2px 6px; border-radius: 3px;'>{args_json}</code></div>
</div>
<span style='font-size: 1.2em; color: #999; margin-left: 10px;'>▶</span>
</div>
</summary>
<div style='background: #f9f9f9; padding: 12px; border-top: 1px solid #eee;'>
<div style='font-size: 0.9em; color: #333;'>
<strong>📤 Output:</strong>
<pre style='background: #fff; padding: 10px; border-radius: 4px; overflow-x: auto; margin-top: 6px; font-size: 0.85em; border: 1px solid #e0e0e0; max-height: 400px; white-space: pre-wrap;'>{result_preview}</pre>
</div>
</div>
</details>
"""
response_prefix += """</div>
---
"""
response_prefix += "\n"
# 流式输出最终答案
yield response_prefix
# 如果已经有最终答案,直接输出
if final_response_content:
# 已经从循环中获得了最终答案,直接输出
yield response_prefix + final_response_content
else:
# 如果循环结束但没有最终答案(达到最大迭代次数),需要再调用一次让模型总结
try:
stream = client.chat.completions.create(
model="Qwen/Qwen3-32B:groq",
messages=messages,
tools=None, # 不再允许调用工具
max_tokens=MAX_OUTPUT_TOKENS,
temperature=0.5,
stream=True
)
accumulated_text = ""
for chunk in stream:
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content:
accumulated_text += chunk.choices[0].delta.content
yield response_prefix + accumulated_text
except Exception as stream_error:
# 流式输出失败,尝试非流式
final_resp = client.chat.completions.create(
model="Qwen/Qwen3-32B:groq",
messages=messages,
tools=None,
max_tokens=MAX_OUTPUT_TOKENS,
temperature=0.5,
stream=False
)
yield response_prefix + final_resp.choices[0].message.content
except Exception as e:
import traceback
error_detail = str(e)
if "500" in error_detail:
yield f"❌ Error: 模型服务器错误。可能是数据太大或请求超时。\n\n详细信息: {error_detail[:200]}"
else:
yield f"❌ Error: {error_detail}\n\n{traceback.format_exc()[:500]}"
# ========== Gradio 界面(极简版)==========
with gr.Blocks(title="Financial AI Assistant") as demo:
gr.Markdown("# 💬 Financial AI Assistant")
chat = gr.ChatInterface(
fn=chatbot_response,
examples=[
"What's Apple's latest revenue and profit?",
"Show me NVIDIA's 3-year financial trends",
"How is Tesla's stock performing today?",
"Get the latest market news about crypto",
"Compare Microsoft's latest earnings with its current stock price",
],
chatbot=gr.Chatbot(height=700),
textbox=gr.Textbox(lines=4, placeholder="Ask me anything about finance, stocks, or company data...", show_label=False),
)
# 启动应用
if __name__ == "__main__":
import sys
# 修复 asyncio 事件循环问题
if sys.platform == 'linux':
try:
import asyncio
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
except:
pass
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
ssr_mode=False,
quiet=False
)