|
|
import gradio as gr |
|
|
import requests |
|
|
import json |
|
|
import os |
|
|
import warnings |
|
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
|
|
|
warnings.filterwarnings('ignore', category=DeprecationWarning) |
|
|
os.environ['PYTHONWARNINGS'] = 'ignore' |
|
|
|
|
|
|
|
|
if 'CUDA_VISIBLE_DEVICES' not in os.environ: |
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '' |
|
|
|
|
|
|
|
|
MCP_TOOLS = [ |
|
|
{"type": "function", "function": {"name": "advanced_search_company", "description": "Search US companies", "parameters": {"type": "object", "properties": {"company_input": {"type": "string"}}, "required": ["company_input"]}}}, |
|
|
{"type": "function", "function": {"name": "get_latest_financial_data", "description": "Get latest financial data", "parameters": {"type": "object", "properties": {"cik": {"type": "string"}}, "required": ["cik"]}}}, |
|
|
{"type": "function", "function": {"name": "extract_financial_metrics", "description": "Get multi-year trends", "parameters": {"type": "object", "properties": {"cik": {"type": "string"}, "years": {"type": "integer"}}, "required": ["cik", "years"]}}}, |
|
|
{"type": "function", "function": {"name": "get_quote", "description": "Get stock quote", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}}, |
|
|
{"type": "function", "function": {"name": "get_market_news", "description": "Get market news", "parameters": {"type": "object", "properties": {"category": {"type": "string"}}, "required": ["category"]}}}, |
|
|
{"type": "function", "function": {"name": "get_company_news", "description": "Get company news", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}, "from_date": {"type": "string"}, "to_date": {"type": "string"}}, "required": ["symbol"]}}} |
|
|
] |
|
|
|
|
|
|
|
|
MCP_SERVICES = { |
|
|
"financial": {"url": "http://localhost:7861/mcp", "type": "fastmcp"}, |
|
|
"market": {"url": "https://jc321-marketandstockmcp.hf.space", "type": "gradio"} |
|
|
} |
|
|
|
|
|
TOOL_ROUTING = { |
|
|
"advanced_search_company": MCP_SERVICES["financial"], |
|
|
"get_latest_financial_data": MCP_SERVICES["financial"], |
|
|
"extract_financial_metrics": MCP_SERVICES["financial"], |
|
|
"get_quote": MCP_SERVICES["market"], |
|
|
"get_market_news": MCP_SERVICES["market"], |
|
|
"get_company_news": MCP_SERVICES["market"] |
|
|
} |
|
|
|
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") |
|
|
client = InferenceClient(api_key=hf_token) if hf_token else InferenceClient() |
|
|
print(f"✅ LLM initialized: Qwen/Qwen3-32B:groq") |
|
|
print(f"📊 MCP Services: {len(MCP_SERVICES)} services, {len(MCP_TOOLS)} tools") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_TOTAL_TOKENS = 6000 |
|
|
MAX_TOOL_RESULT_CHARS = 1500 |
|
|
MAX_HISTORY_CHARS = 500 |
|
|
MAX_HISTORY_TURNS = 2 |
|
|
MAX_TOOL_ITERATIONS = 6 |
|
|
MAX_OUTPUT_TOKENS = 2000 |
|
|
|
|
|
def estimate_tokens(text): |
|
|
"""估算文本 token 数量(粗略:1 token ≈ 2 字符)""" |
|
|
return len(str(text)) // 2 |
|
|
|
|
|
def truncate_text(text, max_chars, suffix="...[truncated]"): |
|
|
"""截断文本到指定长度""" |
|
|
text = str(text) |
|
|
if len(text) <= max_chars: |
|
|
return text |
|
|
return text[:max_chars] + suffix |
|
|
|
|
|
def get_system_prompt(): |
|
|
"""生成包含当前日期的系统提示词(精简版)""" |
|
|
from datetime import datetime |
|
|
current_date = datetime.now().strftime("%Y-%m-%d") |
|
|
return f"""Financial analyst. Today: {current_date}. Use tools for company data, stock prices, news. Be concise.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def call_mcp_tool(tool_name, arguments): |
|
|
"""调用 MCP 工具""" |
|
|
service_config = TOOL_ROUTING.get(tool_name) |
|
|
if not service_config: |
|
|
return {"error": f"Unknown tool: {tool_name}"} |
|
|
|
|
|
try: |
|
|
if service_config["type"] == "fastmcp": |
|
|
return _call_fastmcp(service_config["url"], tool_name, arguments) |
|
|
elif service_config["type"] == "gradio": |
|
|
return _call_gradio_api(service_config["url"], tool_name, arguments) |
|
|
else: |
|
|
return {"error": "Unknown service type"} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
def _call_fastmcp(service_url, tool_name, arguments): |
|
|
"""FastMCP: 标准 MCP JSON-RPC""" |
|
|
response = requests.post( |
|
|
service_url, |
|
|
json={"jsonrpc": "2.0", "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}, "id": 1}, |
|
|
headers={"Content-Type": "application/json"}, |
|
|
timeout=30 |
|
|
) |
|
|
|
|
|
if response.status_code != 200: |
|
|
return {"error": f"HTTP {response.status_code}"} |
|
|
|
|
|
data = response.json() |
|
|
|
|
|
|
|
|
if isinstance(data, dict) and "result" in data: |
|
|
result = data["result"] |
|
|
if isinstance(result, dict) and "content" in result: |
|
|
content = result["content"] |
|
|
if isinstance(content, list) and len(content) > 0: |
|
|
first_item = content[0] |
|
|
if isinstance(first_item, dict) and "text" in first_item: |
|
|
try: |
|
|
return json.loads(first_item["text"]) |
|
|
except (json.JSONDecodeError, TypeError): |
|
|
return {"text": first_item["text"]} |
|
|
return result |
|
|
return data |
|
|
|
|
|
|
|
|
def _call_gradio_api(service_url, tool_name, arguments): |
|
|
"""Gradio: SSE 流式协议""" |
|
|
tool_map = {"get_quote": "test_quote_tool", "get_market_news": "test_market_news_tool", "get_company_news": "test_company_news_tool"} |
|
|
gradio_fn = tool_map.get(tool_name) |
|
|
if not gradio_fn: |
|
|
return {"error": "No mapping"} |
|
|
|
|
|
|
|
|
if tool_name == "get_quote": |
|
|
params = [arguments.get("symbol", "")] |
|
|
elif tool_name == "get_market_news": |
|
|
params = [arguments.get("category", "general")] |
|
|
elif tool_name == "get_company_news": |
|
|
params = [arguments.get("symbol", ""), arguments.get("from_date", ""), arguments.get("to_date", "")] |
|
|
else: |
|
|
params = [] |
|
|
|
|
|
|
|
|
call_url = f"{service_url}/call/{gradio_fn}" |
|
|
resp = requests.post(call_url, json={"data": params}, timeout=10) |
|
|
if resp.status_code != 200: |
|
|
return {"error": f"HTTP {resp.status_code}"} |
|
|
|
|
|
event_id = resp.json().get("event_id") |
|
|
if not event_id: |
|
|
return {"error": "No event_id"} |
|
|
|
|
|
|
|
|
result_resp = requests.get(f"{call_url}/{event_id}", stream=True, timeout=20) |
|
|
if result_resp.status_code != 200: |
|
|
return {"error": f"HTTP {result_resp.status_code}"} |
|
|
|
|
|
|
|
|
for line in result_resp.iter_lines(): |
|
|
if line and line.decode('utf-8').startswith('data: '): |
|
|
try: |
|
|
result_data = json.loads(line.decode('utf-8')[6:]) |
|
|
if isinstance(result_data, list) and len(result_data) > 0: |
|
|
return {"text": result_data[0]} |
|
|
except json.JSONDecodeError: |
|
|
continue |
|
|
|
|
|
return {"error": "No result"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chatbot_response(message, history): |
|
|
"""AI 助手主函数(流式输出,性能优化)""" |
|
|
try: |
|
|
messages = [{"role": "system", "content": get_system_prompt()}] |
|
|
|
|
|
|
|
|
if history: |
|
|
for item in history[-MAX_HISTORY_TURNS:]: |
|
|
if isinstance(item, (list, tuple)) and len(item) == 2: |
|
|
|
|
|
messages.append({"role": "user", "content": item[0]}) |
|
|
|
|
|
|
|
|
assistant_msg = str(item[1]) |
|
|
if len(assistant_msg) > MAX_HISTORY_CHARS: |
|
|
assistant_msg = truncate_text(assistant_msg, MAX_HISTORY_CHARS) |
|
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
tool_calls_log = [] |
|
|
|
|
|
|
|
|
final_response_content = None |
|
|
for iteration in range(MAX_TOOL_ITERATIONS): |
|
|
response = client.chat.completions.create( |
|
|
model="Qwen/Qwen3-32B:groq", |
|
|
messages=messages, |
|
|
tools=MCP_TOOLS, |
|
|
max_tokens=MAX_OUTPUT_TOKENS, |
|
|
temperature=0.5, |
|
|
tool_choice="auto", |
|
|
stream=False |
|
|
) |
|
|
|
|
|
choice = response.choices[0] |
|
|
|
|
|
if choice.message.tool_calls: |
|
|
messages.append(choice.message) |
|
|
|
|
|
for tool_call in choice.message.tool_calls: |
|
|
tool_name = tool_call.function.name |
|
|
try: |
|
|
tool_args = json.loads(tool_call.function.arguments) |
|
|
except json.JSONDecodeError: |
|
|
tool_args = {} |
|
|
|
|
|
|
|
|
tool_result = call_mcp_tool(tool_name, tool_args) |
|
|
|
|
|
|
|
|
if isinstance(tool_result, dict) and "error" in tool_result: |
|
|
|
|
|
tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result, "error": True}) |
|
|
result_for_llm = json.dumps({"error": tool_result.get("error", "Unknown error")}, ensure_ascii=False) |
|
|
else: |
|
|
|
|
|
result_str = json.dumps(tool_result, ensure_ascii=False) |
|
|
|
|
|
if len(result_str) > MAX_TOOL_RESULT_CHARS: |
|
|
if isinstance(tool_result, dict) and "text" in tool_result: |
|
|
truncated_text = truncate_text(tool_result["text"], MAX_TOOL_RESULT_CHARS - 50) |
|
|
tool_result_truncated = {"text": truncated_text, "_truncated": True} |
|
|
elif isinstance(tool_result, dict): |
|
|
truncated = {} |
|
|
char_count = 0 |
|
|
for k, v in list(tool_result.items())[:8]: |
|
|
v_str = str(v)[:300] |
|
|
truncated[k] = v_str |
|
|
char_count += len(k) + len(v_str) |
|
|
if char_count > MAX_TOOL_RESULT_CHARS: |
|
|
break |
|
|
tool_result_truncated = {**truncated, "_truncated": True} |
|
|
else: |
|
|
tool_result_truncated = {"preview": truncate_text(result_str, MAX_TOOL_RESULT_CHARS), "_truncated": True} |
|
|
result_for_llm = json.dumps(tool_result_truncated, ensure_ascii=False) |
|
|
else: |
|
|
result_for_llm = result_str |
|
|
|
|
|
|
|
|
tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result}) |
|
|
|
|
|
messages.append({ |
|
|
"role": "tool", |
|
|
"name": tool_name, |
|
|
"content": result_for_llm, |
|
|
"tool_call_id": tool_call.id |
|
|
}) |
|
|
|
|
|
continue |
|
|
else: |
|
|
|
|
|
final_response_content = choice.message.content |
|
|
break |
|
|
|
|
|
|
|
|
response_prefix = "" |
|
|
|
|
|
|
|
|
if tool_calls_log: |
|
|
response_prefix += """<div style='margin-bottom: 15px;'> |
|
|
<div style='background: #f0f0f0; padding: 8px 12px; border-radius: 6px; font-weight: 600; color: #333;'> |
|
|
🛠️ Tools Used ({} calls) |
|
|
</div> |
|
|
""".format(len(tool_calls_log)) |
|
|
|
|
|
for idx, tool_call in enumerate(tool_calls_log): |
|
|
|
|
|
args_json = json.dumps(tool_call['arguments'], ensure_ascii=False) |
|
|
result_json = json.dumps(tool_call.get('result', {}), ensure_ascii=False, indent=2) |
|
|
result_preview = result_json[:1500] + ('...' if len(result_json) > 1500 else '') |
|
|
|
|
|
|
|
|
error_indicator = " ❌ Error" if tool_call.get('error') else "" |
|
|
|
|
|
|
|
|
response_prefix += f"""<details style='margin: 8px 0; border: 1px solid #ddd; border-radius: 6px; overflow: hidden;'> |
|
|
<summary style='background: #fff; padding: 10px; cursor: pointer; user-select: none; list-style: none;'> |
|
|
<div style='display: flex; justify-content: space-between; align-items: center;'> |
|
|
<div style='flex: 1;'> |
|
|
<strong style='color: #2c5aa0;'>📌 {idx+1}. {tool_call['name']}{error_indicator}</strong> |
|
|
<div style='font-size: 0.85em; color: #666; margin-top: 4px;'>📥 Input: <code style='background: #f5f5f5; padding: 2px 6px; border-radius: 3px;'>{args_json}</code></div> |
|
|
</div> |
|
|
<span style='font-size: 1.2em; color: #999; margin-left: 10px;'>▶</span> |
|
|
</div> |
|
|
</summary> |
|
|
<div style='background: #f9f9f9; padding: 12px; border-top: 1px solid #eee;'> |
|
|
<div style='font-size: 0.9em; color: #333;'> |
|
|
<strong>📤 Output:</strong> |
|
|
<pre style='background: #fff; padding: 10px; border-radius: 4px; overflow-x: auto; margin-top: 6px; font-size: 0.85em; border: 1px solid #e0e0e0; max-height: 400px; white-space: pre-wrap;'>{result_preview}</pre> |
|
|
</div> |
|
|
</div> |
|
|
</details> |
|
|
""" |
|
|
|
|
|
response_prefix += """</div> |
|
|
|
|
|
--- |
|
|
|
|
|
""" |
|
|
response_prefix += "\n" |
|
|
|
|
|
|
|
|
yield response_prefix |
|
|
|
|
|
|
|
|
if final_response_content: |
|
|
|
|
|
yield response_prefix + final_response_content |
|
|
else: |
|
|
|
|
|
try: |
|
|
stream = client.chat.completions.create( |
|
|
model="Qwen/Qwen3-32B:groq", |
|
|
messages=messages, |
|
|
tools=None, |
|
|
max_tokens=MAX_OUTPUT_TOKENS, |
|
|
temperature=0.5, |
|
|
stream=True |
|
|
) |
|
|
|
|
|
accumulated_text = "" |
|
|
for chunk in stream: |
|
|
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content: |
|
|
accumulated_text += chunk.choices[0].delta.content |
|
|
yield response_prefix + accumulated_text |
|
|
except Exception as stream_error: |
|
|
|
|
|
final_resp = client.chat.completions.create( |
|
|
model="Qwen/Qwen3-32B:groq", |
|
|
messages=messages, |
|
|
tools=None, |
|
|
max_tokens=MAX_OUTPUT_TOKENS, |
|
|
temperature=0.5, |
|
|
stream=False |
|
|
) |
|
|
yield response_prefix + final_resp.choices[0].message.content |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_detail = str(e) |
|
|
if "500" in error_detail: |
|
|
yield f"❌ Error: 模型服务器错误。可能是数据太大或请求超时。\n\n详细信息: {error_detail[:200]}" |
|
|
else: |
|
|
yield f"❌ Error: {error_detail}\n\n{traceback.format_exc()[:500]}" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Financial AI Assistant") as demo: |
|
|
gr.Markdown("# 💬 Financial AI Assistant") |
|
|
|
|
|
chat = gr.ChatInterface( |
|
|
fn=chatbot_response, |
|
|
examples=[ |
|
|
"What's Apple's latest revenue and profit?", |
|
|
"Show me NVIDIA's 3-year financial trends", |
|
|
"How is Tesla's stock performing today?", |
|
|
"Get the latest market news about crypto", |
|
|
"Compare Microsoft's latest earnings with its current stock price", |
|
|
], |
|
|
chatbot=gr.Chatbot(height=700), |
|
|
textbox=gr.Textbox(lines=4, placeholder="Ask me anything about finance, stocks, or company data...", show_label=False), |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import sys |
|
|
|
|
|
|
|
|
if sys.platform == 'linux': |
|
|
try: |
|
|
import asyncio |
|
|
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) |
|
|
except: |
|
|
pass |
|
|
|
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
show_error=True, |
|
|
ssr_mode=False, |
|
|
quiet=False |
|
|
) |
|
|
|