JC321's picture
Upload chat_direct.py
5eb0763 verified
"""
Financial AI Assistant - Direct Method Library (不依赖 HTTP)
直接导入并调用 easy_financial_mcp.py 中的函数
支持本地和 HF Space 部署
"""
import sys
from pathlib import Path
import os
import json
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
import requests
import warnings
# 抑削 asyncio 警告
warnings.filterwarnings('ignore', category=DeprecationWarning)
os.environ['PYTHONWARNINGS'] = 'ignore'
# 先加载 .env 文件
load_dotenv()
# 添加服务模块路径
PROJECT_ROOT = Path(__file__).parent.parent.absolute()
sys.path.insert(0, str(PROJECT_ROOT))
# 直接导入 MCP 中定义的函数
try:
from EasyFinancialAgent.easy_financial_mcp import (
search_company as _search_company,
get_company_info as _get_company_info,
get_company_filings as _get_company_filings,
get_financial_data as _get_financial_data,
extract_financial_metrics as _extract_financial_metrics,
get_latest_financial_data as _get_latest_financial_data,
advanced_search_company as _advanced_search_company
)
MCP_DIRECT_AVAILABLE = True
print("[FinancialAI] ✓ Direct MCP functions imported successfully")
except ImportError as e:
MCP_DIRECT_AVAILABLE = False
print(f"[FinancialAI] ✗ Failed to import MCP functions: {e}")
# 定义占位符函数
def _advanced_search_company(x):
return {"error": "MCP not available"}
def _get_company_info(x):
return {"error": "MCP not available"}
def _get_company_filings(x, y=None):
return {"error": "MCP not available"}
def _get_financial_data(x, y):
return {"error": "MCP not available"}
def _get_latest_financial_data(x):
return {"error": "MCP not available"}
def _extract_financial_metrics(x, y=3):
return {"error": "MCP not available"}
# ============================================================
# 便捷方法 - 公司搜索相关
# ============================================================
def search_company_direct(company_input):
"""
批量搜索公司信息(直接调用)
使用 advanced_search_company 工具,支持公司名称、Ticker 或 CIK 代码
Args:
company_input: 公司名称、Ticker 代码或 CIK 代码
Returns:
批量搜索结果
Example:
result = search_company_direct("Apple")
result = search_company_direct("AAPL")
result = search_company_direct("0000320193")
"""
if not MCP_DIRECT_AVAILABLE:
return {"error": "MCP functions not available"}
try:
result = _advanced_search_company(company_input)
return [result]
except Exception as e:
return {"error": str(e)}
def get_company_info_direct(cik):
"""
获取公司详细信息(直接调用)
Args:
cik: 公司 CIK 代码
Returns:
公司信息
Example:
result = get_company_info_direct("0000320193")
"""
if not MCP_DIRECT_AVAILABLE:
return {"error": "MCP functions not available"}
try:
return _get_company_info(cik)
except Exception as e:
return {"error": str(e)}
def get_company_filings_direct(cik):
"""
获取公司 SEC 文件列表(直接调用)
Args:
cik: 公司 CIK 代码
Returns:
文件列表
Example:
result = get_company_filings_direct("0000320193")
"""
if not MCP_DIRECT_AVAILABLE:
return {"error": "MCP functions not available"}
try:
return _get_company_filings(cik)
except Exception as e:
return {"error": str(e)}
def advanced_search_company_detailed(company_input):
"""
高级公司搜索 - 支持公司名称、Ticker 或 CIK 的强大搜索方法
不同于 search_company_direct,该方法来自 EasyReportDataMCP 中的 mcp_server_fastmcp
更具有灵活性,可以自动检测输入的类型
Args:
company_input: 公司名称 ("Tesla", "Apple Inc")
Ticker 代码 ("TSLA", "AAPL", "MSFT")
CIK 代码 ("0001318605", "0000320193")
Returns:
dict: 包含以下信息:
- cik: 公司的 Central Index Key
- name: 办公室注册名称
- tickers: 股票代码
- sic: Standard Industrial Classification 代码
- sic_description: 行业/行业描述
Example:
# 按公司名称搜索
result = advanced_search_company_detailed("Tesla")
# 按 Ticker 搜索
result = advanced_search_company_detailed("TSLA")
# 按 CIK 搜索
result = advanced_search_company_detailed("0001318605")
"""
if not MCP_DIRECT_AVAILABLE:
return {"error": "MCP functions not available"}
try:
# 直接调用 advanced_search_company 工具
result = _advanced_search_company(company_input)
return result
except Exception as e:
import traceback
return {
"error": str(e),
"traceback": traceback.format_exc()
}
def format_search_result(search_result):
"""
提取并格式化搜索结果
将 advanced_search_company 的结果转换为标准格式:
[{company_name: str, cik: str, ticker: str}]
Args:
search_result: advanced_search_company 的返回结果
格式: {'cik': '...', 'name': '...', 'tickers': [...], ...}
Returns:
list[dict]: 格式化的结果
[
{
'company_name': str, # 公司名称
'cik': str, # CIK 代码
'ticker': str # 第一个股票代码
}
]
Example:
search_result = {'cik': '0001577552', 'name': 'Alibaba Group Holding Ltd', 'tickers': ['BABA'], '_source': 'company_tickers_cache'}
formatted = format_search_result(search_result)
# 输出: [{'company_name': 'Alibaba Group Holding Ltd', 'cik': '0001577552', 'ticker': 'BABA'}]
"""
# 处理错误情况
if isinstance(search_result, dict) and 'error' in search_result:
return []
# 处理列表情况
if isinstance(search_result, list):
formatted_list = []
for item in search_result:
formatted_item = format_search_result(item)
formatted_list.extend(formatted_item)
return formatted_list
# 处理单个字典
if not isinstance(search_result, dict):
return []
try:
company_name = search_result.get('name', '')
cik = search_result.get('cik', '')
tickers = search_result.get('tickers', [])
# 取数组的第一个元素,或使用空字符串
ticker = tickers[0] if isinstance(tickers, list) and len(tickers) > 0 else ''
return [{
'company_name': company_name,
'cik': cik,
'ticker': ticker
}]
except Exception as e:
return []
def format_search_result_for_display(search_result):
"""
格式化搜索结果为显示用的字符串列表
Args:
search_result: advanced_search_company 的返回结果
Returns:
list[str]: 格式化的字符串列表 ["公司名 (Ticker)"]
Example:
result = format_search_result_for_display({'cik': '0001577552', 'name': 'Alibaba Group', 'tickers': ['BABA']})
# 输出: ['Alibaba Group (BABA)']
"""
formatted_data = format_search_result(search_result)
# ✅ 更稳健的美股主要代码判断逻辑
def is_main_us_ticker(ticker):
if not ticker:
return False
ticker = ticker.upper().strip()
# 处理包含点号的情况(如 BRK.B)
ticker_clean = ticker.replace('.', '')
# 判断规则:
# 1. 6+字母基本是OTC或基金 - 拒绝
if len(ticker_clean) > 5:
return False
# 2. 5个字母且以特定后缀结尾 - 拒绝常见OTC/权证/单位后缀
if len(ticker_clean) == 5 and ticker_clean.endswith(('F', 'Y', 'Q', 'D', 'W', 'U', 'P')):
return False
# 3. 其他情况接受(包括 GOOGL, BABA, BRK.B 等)
return True
display_list = []
for item in formatted_data:
company_name = item.get('company_name', 'Unknown')
ticker = item.get('ticker', '')
# ✅ 只显示主要美股代码
if ticker and is_main_us_ticker(ticker):
display_text = f"{company_name} ({ticker})"
display_list.append(display_text)
elif not ticker:
# 如果没有ticker,也显示公司名
display_list.append(company_name)
return display_list
def search_and_format(company_input):
"""
搎合搜索并立即格式化结果
一个一步到位的便法方法,执行搜索并格式化结果
Args:
company_input: 公司名称、Ticker 或 CIK
Returns:
list[dict]: 格式化的结果
Example:
result = search_and_format('BABA')
# 输出: [{'company_name': 'Alibaba Group Holding Ltd', 'cik': '0001577552', 'ticker': 'BABA'}]
"""
# 执行搜索
search_result = advanced_search_company_detailed(company_input)
# 检查是否有错误
if isinstance(search_result, dict) and 'error' in search_result:
return []
# 格式化结果
return format_search_result(search_result)
# ============================================================
# 便捷方法 - 财务数据相关
# ============================================================
def get_latest_financial_data_direct(cik):
"""
获取公司最新财务数据(直接调用)
Args:
cik: 公司 CIK 代码
Returns:
最新财务数据
Example:
result = get_latest_financial_data_direct("0000320193")
"""
if not MCP_DIRECT_AVAILABLE:
return {"error": "MCP functions not available"}
try:
return _get_latest_financial_data(cik)
except Exception as e:
return {"error": str(e)}
def extract_financial_metrics_direct(cik, years=5):
"""
提取多年财务指标趋势(直接调用)
Args:
cik: 公司 CIK 代码
years: 年数(默认 3 年)
Returns:
财务指标数据
Example:
result = extract_financial_metrics_direct("0000320193", years=5)
"""
if not MCP_DIRECT_AVAILABLE:
return {"error": "MCP functions not available"}
try:
return _extract_financial_metrics(cik, years)
except Exception as e:
return {"error": str(e)}
# ============================================================
# 高级方法 - 综合查询
# ============================================================
def query_company_direct(company_input, get_filings=True, get_metrics=True):
"""
综合查询公司信息(直接调用)
包括搜索、基本信息、文件列表和财务指标
Args:
company_input: 公司名称或代码
get_filings: 是否获取文件列表
get_metrics: 是否获取财务指标
Returns:
综合结果字典,包含 search, company_info, filings, metrics
Example:
result = query_company_direct("Apple", get_filings=True, get_metrics=True)
"""
from datetime import datetime
result = {
"timestamp": datetime.now().isoformat(),
"query_input": company_input,
"status": "success",
"data": {
"company_search": None,
"company_info": None,
"filings": None,
"metrics": None
},
"errors": []
}
if not MCP_DIRECT_AVAILABLE:
result["status"] = "error"
result["errors"].append("MCP functions not available")
return result
try:
# 1. 搜索公司
search_result = search_company_direct(company_input)
if "error" in search_result:
result["errors"].append(f"Search error: {search_result['error']}")
result["status"] = "error"
return result
result["data"]["company_search"] = search_result
# 从搜索结果提取 CIK
cik = None
if isinstance(search_result, dict):
cik = search_result.get("cik")
elif isinstance(search_result, (list, tuple)) and len(search_result) > 0:
# 从列表中获取第一个元素
try:
first_item = search_result[0] if isinstance(search_result, (list, tuple)) else None
if isinstance(first_item, dict):
cik = first_item.get("cik")
except (IndexError, TypeError):
pass
if not cik:
result["errors"].append("Could not extract CIK from search result")
result["status"] = "error"
return result
# 2. 获取公司信息
company_info = get_company_info_direct(cik)
if "error" not in company_info:
result["data"]["company_info"] = company_info
else:
result["errors"].append(f"Failed to get company info: {company_info.get('error')}")
# 3. 获取文件列表
if get_filings:
filings = get_company_filings_direct(cik)
if "error" not in filings:
result["data"]["filings"] = filings
else:
result["errors"].append(f"Failed to get filings: {filings.get('error')}")
# 4. 获取财务指标
if get_metrics:
metrics = extract_financial_metrics_direct(cik, years=3)
if "error" not in metrics:
result["data"]["metrics"] = metrics
else:
result["errors"].append(f"Failed to get metrics: {metrics.get('error')}")
except Exception as e:
result["status"] = "error"
result["errors"].append(f"Exception: {str(e)}")
import traceback
result["errors"].append(traceback.format_exc())
return result
# ============================================================
# LLM 模型配置与初始化
# ============================================================
# 初始化 LLM 客户端
def _init_llm_client():
"""初始化 LLM 客户端"""
global llm_client
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
llm_client = None
try:
if hf_token:
llm_client = InferenceClient(api_key=hf_token)
print("[FinancialAI] ✓ LLM client initialized with HF_TOKEN")
return True
else:
print("[FinancialAI] ⚠ Warning: HF_TOKEN not found, LLM features disabled")
return False
except Exception as e:
print(f"[FinancialAI] ✗ Failed to initialize LLM client: {e}")
return False
# 全局 llm_client 变量
llm_client = None
_init_llm_client()
def get_system_prompt():
"""生成系统提示词"""
from datetime import datetime
current_date = datetime.now().strftime("%Y-%m-%d")
return f"""You are a financial analysis expert. Today is {current_date}.
Your role:
- Analyze company financial data, reports, and market news
- Provide investment insights based on factual data
- Be concise, objective, and data-driven
- Always include disclaimers about market risks
⚠️ IMPORTANT: You have a maximum of 5 tool calls. Choose the MOST RELEVANT tools carefully:
- Use 'advanced_search_company' ONLY if you need to find a company's CIK
- Use 'extract_financial_metrics' for comprehensive multi-year financial analysis (RECOMMENDED for most queries)
- Use 'get_latest_financial_data' for quick recent snapshot
- Use 'get_quote' for real-time stock price
- Use 'get_company_news' for company-specific news
- Use 'get_market_news' for general market trends
Prioritize the most important tools for the user's question. Avoid redundant calls.
Output should be in English."""
def analyze_company_with_llm(company_input, analysis_type="summary"):
"""
使用 LLM 分析公司信息
Args:
company_input: 公司名称或代码
analysis_type: 分析类型 ("summary", "investment", "risks")
Returns:
LLM 分析结果
Example:
result = analyze_company_with_llm("Apple", "investment")
"""
if not llm_client:
return {"error": "LLM client not available"}
if not MCP_DIRECT_AVAILABLE:
return {"error": "MCP functions not available"}
try:
# 先获取公司财务数据
company_data = get_company_summary_direct(company_input)
if company_data["status"] == "error":
return {"error": f"Failed to fetch company data: {company_data['errors']}"}
# 构建提示
data_str = json.dumps(company_data["data"], ensure_ascii=False, indent=2)
if analysis_type == "investment":
prompt = f"""
Based on the following company financial data, provide an investment recommendation:
{data_str}
Provide:
1. Investment Recommendation (Buy/Hold/Sell)
2. Key Strengths and Weaknesses
3. Price Target Range
4. Risk Assessment
5. Risk Disclaimer
"""
elif analysis_type == "risks":
prompt = f"""
Based on the following company data, analyze the key risks:
{data_str}
Identify:
1. Financial Risks
2. Market Risks
3. Operational Risks
4. Mitigation Strategies
5. Risk Disclaimer
"""
else: # summary
prompt = f"""
Provide a financial summary of the following company:
{data_str}
Include:
1. Company Overview
2. Financial Health
3. Recent Performance
4. Investment Outlook
"""
# 调用 LLM
response = llm_client.chat.completions.create(
model="Qwen/Qwen2.5-72B-Instruct",
messages=[
{"role": "system", "content": get_system_prompt()},
{"role": "user", "content": prompt}
],
max_tokens=1500,
temperature=0.7,
top_p=0.95,
stream=False
)
return {
"company": company_input,
"analysis_type": analysis_type,
"analysis": response.choices[0].message.content,
"data_used": company_data["data"]
}
except Exception as e:
return {"error": f"LLM analysis failed: {str(e)}"}
# ============================================================
# 便捷方法 - 获取单一时期财务数据
# ============================================================
def get_financial_data_direct(cik, period):
"""
获取指定时期的财务数据(直接调用)
Args:
cik: 公司 CIK 代码
period: 时期 (e.g., "2024", "2024Q3")
Returns:
财务数据
Example:
result = get_financial_data_direct("0000320193", "2024")
"""
if not MCP_DIRECT_AVAILABLE:
return {"error": "MCP functions not available"}
try:
return _get_financial_data(cik, period)
except Exception as e:
return {"error": str(e)}
# ============================================================
# 便捷方法 - 获取文件列表
# ============================================================
def get_company_filings_with_form_direct(cik, form_types=None):
"""
获取指定类型的公司 SEC 文件列表(直接调用)
Args:
cik: 公司 CIK 代码
form_types: 表单类型列表 (e.g., ["10-K", "10-Q"])
Returns:
文件列表
Example:
result = get_company_filings_with_form_direct("0000320193", ["10-K"])
"""
if not MCP_DIRECT_AVAILABLE:
return {"error": "MCP functions not available"}
try:
return _get_company_filings(cik, form_types)
except Exception as e:
return {"error": str(e)}
# ============================================================
# 便捷方法 - 轻量级查询
# ============================================================
def get_company_summary_direct(company_input):
"""
获取公司简要摘要信息(轻量级查询,仅搜索和基本信息)
Args:
company_input: 公司名称或代码
Returns:
公司摘要数据
Example:
result = get_company_summary_direct("Apple")
"""
from datetime import datetime
result = {
"timestamp": datetime.now().isoformat(),
"query_input": company_input,
"status": "success",
"data": {
"company_search": None,
"company_info": None
},
"errors": []
}
if not MCP_DIRECT_AVAILABLE:
result["status"] = "error"
result["errors"].append("MCP functions not available")
return result
try:
# 1. 搜索公司
search_result = search_company_direct(company_input)
if "error" in search_result:
result["errors"].append(f"Search error: {search_result['error']}")
result["status"] = "error"
return result
result["data"]["company_search"] = search_result
# 从搜索结果提取 CIK
cik = None
if isinstance(search_result, dict):
cik = search_result.get("cik")
elif isinstance(search_result, (list, tuple)) and len(search_result) > 0:
try:
first_item = search_result[0]
if isinstance(first_item, dict):
cik = first_item.get("cik")
except (IndexError, TypeError):
pass
if not cik:
result["errors"].append("Could not extract CIK from search result")
result["status"] = "error"
return result
# 2. 获取公司信息
company_info = get_company_info_direct(cik)
if "error" not in company_info:
result["data"]["company_info"] = company_info
else:
result["errors"].append(f"Failed to get company info: {company_info.get('error')}")
except Exception as e:
result["status"] = "error"
result["errors"].append(f"Exception: {str(e)}")
import traceback
result["errors"].append(traceback.format_exc())
return result
def get_financial_metrics_only_direct(company_input, years=5):
"""
获取公司财务指标趋势(仅财务指标,不获取文件列表)
Args:
company_input: 公司名称或代码
years: 年数(默认 5 年)
Returns:
财务指标数据
Example:
result = get_financial_metrics_only_direct("Apple", years=5)
"""
from datetime import datetime
result = {
"timestamp": datetime.now().isoformat(),
"query_input": company_input,
"years": years,
"status": "success",
"data": None,
"errors": []
}
if not MCP_DIRECT_AVAILABLE:
result["status"] = "error"
result["errors"].append("MCP functions not available")
return result
try:
# 1. 搜索公司
search_result = search_company_direct(company_input)
if "error" in search_result:
result["errors"].append(f"Search error: {search_result['error']}")
result["status"] = "error"
return result
# 从搜索结果提取 CIK
cik = None
if isinstance(search_result, dict):
cik = search_result.get("cik")
elif isinstance(search_result, (list, tuple)) and len(search_result) > 0:
try:
first_item = search_result[0]
if isinstance(first_item, dict):
cik = first_item.get("cik")
except (IndexError, TypeError):
pass
if not cik:
result["errors"].append("Could not extract CIK from search result")
result["status"] = "error"
return result
# 2. 获取财务指标
metrics = extract_financial_metrics_direct(cik, years=years)
if "error" in metrics:
result["errors"].append(f"Failed to get metrics: {metrics['error']}")
result["status"] = "error"
else:
result["data"] = metrics
except Exception as e:
result["status"] = "error"
result["errors"].append(f"Exception: {str(e)}")
import traceback
result["errors"].append(traceback.format_exc())
return result
# ============================================================
# 测试函数
# ============================================================
if __name__ == "__main__":
print("\n" + "="*60)
print("Financial AI Assistant - Direct Method Test")
print("="*60)
# 测试 1: 公司搜索
print("\n1. 搜索公司 (Apple)...")
result = search_company_direct("Apple")
print(f" 结果: {result}")
# 测试 2: 公司摘要
print("\n2. 获取公司摘要信息 (Tesla)...")
summary = get_company_summary_direct("Tesla")
print(f" 状态: {summary['status']}")
print(f" 数据: {summary['data']}")
print(f" 错误: {summary['errors']}")
# 测试 3: 财务指标
print("\n3. 获取财务指标 (Microsoft)...")
metrics = get_financial_metrics_only_direct("Microsoft", years=3)
print(f" 状态: {metrics['status']}")
if metrics['status'] == 'success':
print(f" 指标数据: {metrics['data']}")
else:
print(f" 错误: {metrics['errors']}")
# 测试 4: 完整查询
print("\n4. 获取 Amazon 完整信息...")
full_query = query_company_direct("Amazon", get_filings=True, get_metrics=True)
print(f" 状态: {full_query['status']}")
print(f" 错误: {full_query['errors']}")
# 测试 5: LLM 分析 - 摘要
print("\n5. LLM 分析 - 公司摘要(Google)...")
if llm_client:
llm_result = analyze_company_with_llm("Google", "summary")
if "error" in llm_result:
print(f" 错误: {llm_result['error']}")
else:
print(f" 分析结果: {llm_result['analysis'][:200]}...")
else:
print(" LLM 客户端不可用")
# 测试 6: LLM 分析 - 投资建议
print("\n6. LLM 分析 - 投资建议(NVIDIA)...")
if llm_client:
llm_result = analyze_company_with_llm("NVIDIA", "investment")
if "error" in llm_result:
print(f" 错误: {llm_result['error']}")
else:
print(f" 分析结果: {llm_result['analysis'][:200]}...")
else:
print(" LLM 客户端不可用")
# 测试 7: LLM 分析 - 风险评估
print("\n7. LLM 分析 - 风险评估(Meta)...")
if llm_client:
llm_result = analyze_company_with_llm("Meta", "risks")
if "error" in llm_result:
print(f" 错误: {llm_result['error']}")
else:
print(f" 分析结果: {llm_result['analysis'][:200]}...")
else:
print(" LLM 客户端不可用")
print("\n" + "="*60)
# ============================================================
# 完整对话引擎 - chatbot_response
# ============================================================
# Token 限制配置
MAX_TOTAL_TOKENS = 6000
MAX_TOOL_RESULT_CHARS = 1500
MAX_HISTORY_CHARS = 500
MAX_HISTORY_TURNS = 2
MAX_TOOL_ITERATIONS = 5 # ✅ 限制最多调用5个工具,确保选择最合适的工具
MAX_OUTPUT_TOKENS = 2000
# MCP 工具配置 - 包含财务数据和市场新闻工具
MCP_TOOLS = [
# 财务数据工具 (EasyReportDataMCP)
{"type": "function", "function": {"name": "advanced_search_company", "description": "Search US companies by name, ticker, or CIK. Returns company information including CIK, name, tickers, and industry classification.", "parameters": {"type": "object", "properties": {"company_input": {"type": "string", "description": "Company name (e.g., 'Tesla'), ticker symbol (e.g., 'TSLA'), or CIK code (e.g., '0001318605')"}}, "required": ["company_input"]}}},
{"type": "function", "function": {"name": "get_latest_financial_data", "description": "Get the most recent financial data for a company including revenue, net income, EPS, operating expenses, and cash flow.", "parameters": {"type": "object", "properties": {"cik": {"type": "string", "description": "Company CIK code (10-digit format, e.g., '0001318605')"}}, "required": ["cik"]}}},
{"type": "function", "function": {"name": "extract_financial_metrics", "description": "Extract multi-year financial metrics trends showing historical performance over specified years.", "parameters": {"type": "object", "properties": {"cik": {"type": "string", "description": "Company CIK code (10-digit format)"}, "years": {"type": "integer", "description": "Number of years of data to retrieve (e.g., 3 or 5)", "default": 3}}, "required": ["cik", "years"]}}},
# 市场和新闻工具 (MarketandStockMCP)
{"type": "function", "function": {"name": "get_quote", "description": "Get real-time stock quote data including current price, daily change, high/low, and previous close. Use when users ask about current stock prices or market performance.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string", "description": "Stock ticker symbol (e.g., 'AAPL', 'TSLA', 'MSFT')"}}, "required": ["symbol"]}}},
{"type": "function", "function": {"name": "get_market_news", "description": "Get latest market news by category. Use when users ask about general market trends, forex, crypto, or M&A news.", "parameters": {"type": "object", "properties": {"category": {"type": "string", "enum": ["general", "forex", "crypto", "merger"], "description": "News category: general (stocks/economy), forex (currency), crypto (cryptocurrency), merger (M&A)", "default": "general"}, "min_id": {"type": "integer", "description": "Minimum news ID for pagination (default: 0)", "default": 0}}, "required": ["category"]}}},
{"type": "function", "function": {"name": "get_company_news", "description": "Get company-specific news and announcements. Only available for North American companies. Use when users ask about specific company news.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string", "description": "Company stock ticker symbol (e.g., 'AAPL', 'TSLA')"}, "from_date": {"type": "string", "description": "Start date in YYYY-MM-DD format (optional, defaults to 7 days ago)"}, "to_date": {"type": "string", "description": "End date in YYYY-MM-DD format (optional, defaults to today)"}}, "required": ["symbol"]}}}
]
def truncate_text(text, max_chars, suffix="...[truncated]"):
"""截断文本到指定长度"""
text = str(text)
if len(text) <= max_chars:
return text
return text[:max_chars] + suffix
def call_mcp_tool(tool_name, arguments):
"""直接调用 MCP 工具函数(不通过HTTP)"""
try:
# ✅ 财务数据工具 - 直接调用 Python 函数
if tool_name == "advanced_search_company":
company_input = arguments.get("company_input", "")
return _advanced_search_company(company_input)
elif tool_name == "get_latest_financial_data":
cik = arguments.get("cik", "")
return _get_latest_financial_data(cik)
elif tool_name == "extract_financial_metrics":
cik = arguments.get("cik", "")
years = arguments.get("years", 3)
return _extract_financial_metrics(cik, years)
# ✅ 市场和新闻工具 - 直接调用 Python 函数
elif tool_name == "get_quote":
from MarketandStockMCP.news_quote_mcp import get_quote
symbol = arguments.get("symbol", "")
return get_quote(symbol)
elif tool_name == "get_market_news":
from MarketandStockMCP.news_quote_mcp import get_market_news
category = arguments.get("category", "general")
min_id = arguments.get("min_id", 0)
return get_market_news(category, min_id)
elif tool_name == "get_company_news":
from MarketandStockMCP.news_quote_mcp import get_company_news
symbol = arguments.get("symbol", "")
from_date = arguments.get("from_date")
to_date = arguments.get("to_date")
return get_company_news(symbol, from_date, to_date)
else:
return {"error": f"Unknown tool: {tool_name}"}
except Exception as e:
import traceback
return {
"error": f"{str(e)}",
"traceback": traceback.format_exc()[:500]
}
def chatbot_response(message, history=None):
"""
AI 助手主函数(完整对话引擎)
支持多轮对话、动态工具调用、流式输出
Args:
message: 用户消息
history: 对话历史,格式: [(user_msg, assistant_msg), ...]
Returns:
生成器,不断 yield 响应文本
Example:
for response in chatbot_response("What's Apple's revenue?", []):
print(response)
"""
if not llm_client:
yield "❌ Error: LLM client not available"
return
if not MCP_DIRECT_AVAILABLE:
yield "❌ Error: MCP functions not available"
return
try:
messages = [{"role": "system", "content": get_system_prompt()}]
# 添加历史(最近2轮) - 严格限制上下文长度
if history:
for item in history[-MAX_HISTORY_TURNS:]:
if isinstance(item, (list, tuple)) and len(item) == 2:
messages.append({"role": "user", "content": item[0]})
assistant_msg = str(item[1])
if len(assistant_msg) > MAX_HISTORY_CHARS:
assistant_msg = truncate_text(assistant_msg, MAX_HISTORY_CHARS)
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
tool_calls_log = []
final_response_content = None
# LLM 调用循环(支持多轮工具调用)
for iteration in range(MAX_TOOL_ITERATIONS):
response = llm_client.chat.completions.create(
model="Qwen/Qwen2.5-72B-Instruct",
messages=messages,
tools=MCP_TOOLS, # type: ignore
max_tokens=MAX_OUTPUT_TOKENS,
temperature=0.7,
tool_choice="auto",
stream=False
)
choice = response.choices[0]
if choice.message.tool_calls:
messages.append(choice.message)
for tool_call in choice.message.tool_calls:
tool_name = tool_call.function.name
try:
tool_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
tool_args = {}
tool_result = call_mcp_tool(tool_name, tool_args)
if isinstance(tool_result, dict) and "error" in tool_result:
tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result, "error": True})
result_for_llm = json.dumps({"error": tool_result.get("error", "Unknown error")}, ensure_ascii=False)
else:
result_str = json.dumps(tool_result, ensure_ascii=False)
if len(result_str) > MAX_TOOL_RESULT_CHARS:
if isinstance(tool_result, dict) and "text" in tool_result:
truncated_text = truncate_text(tool_result["text"], MAX_TOOL_RESULT_CHARS - 50)
tool_result_truncated = {"text": truncated_text, "_truncated": True}
elif isinstance(tool_result, dict):
truncated = {}
char_count = 0
for k, v in list(tool_result.items())[:8]:
v_str = str(v)[:300]
truncated[k] = v_str
char_count += len(k) + len(v_str)
if char_count > MAX_TOOL_RESULT_CHARS:
break
tool_result_truncated = {**truncated, "_truncated": True}
else:
tool_result_truncated = {"preview": truncate_text(result_str, MAX_TOOL_RESULT_CHARS), "_truncated": True}
result_for_llm = json.dumps(tool_result_truncated, ensure_ascii=False)
else:
result_for_llm = result_str
tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result})
messages.append({
"role": "tool",
"name": tool_name,
"content": result_for_llm,
"tool_call_id": tool_call.id
})
continue
else:
final_response_content = choice.message.content
break
response_prefix = ""
if tool_calls_log:
# ✅ 可折叠的工具调用显示,点击三角形展开/收起
tool_count = len(tool_calls_log)
# 添加CSS样式,实现三角形旋转动画
response_prefix += """<style>
details.tools-container > summary::before {
content: '▶';
display: inline-block;
margin-right: 8px;
transition: transform 0.2s;
}
details.tools-container[open] > summary::before {
transform: rotate(90deg);
}
details.tools-container > summary {
list-style: none;
}
details.tools-container > summary::-webkit-details-marker {
display: none;
}
</style>
"""
response_prefix += f"""<div style='margin-bottom: 15px;'>
<details class='tools-container'>
<summary style='background: #f0f0f0; padding: 8px 12px; border-radius: 6px; font-weight: 600; color: #333; cursor: pointer; user-select: none;'>
<span>🛠️ Tools Used ({tool_count}/{MAX_TOOL_ITERATIONS} calls)</span>
</summary>
<div style='margin-top: 8px;'>
"""
for idx, tool_call in enumerate(tool_calls_log):
args_json = json.dumps(tool_call['arguments'], ensure_ascii=False)
result_json = json.dumps(tool_call.get('result', {}), ensure_ascii=False, indent=2)
result_preview = result_json[:1500] + ('...' if len(result_json) > 1500 else '')
error_indicator = " ❌ Error" if tool_call.get('error') else ""
response_prefix += f"""<details style='margin: 8px 0; border: 1px solid #ddd; border-radius: 6px; overflow: hidden;'>
<summary style='background: #fff; padding: 10px; cursor: pointer; user-select: none; list-style: none;'>
<strong style='color: #2c5aa0;'>📋 {idx+1}. {tool_call['name']}{error_indicator}</strong>
</summary>
<div style='background: #f9f9f9; padding: 12px;'>
<pre style='background: #fff; padding: 10px; overflow-x: auto; font-size: 0.85em;'>{result_preview}</pre>
</div>
</details>
"""
# ✅ 关闭外层details和div标签
response_prefix += """ </div>
</details>
</div>
---
"""
yield response_prefix
if final_response_content:
yield response_prefix + final_response_content
else:
try:
stream = llm_client.chat.completions.create(
model="Qwen/Qwen2.5-72B-Instruct",
messages=messages,
tools=None,
max_tokens=MAX_OUTPUT_TOKENS,
temperature=0.7,
stream=True
)
accumulated_text = ""
for chunk in stream:
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content:
accumulated_text += chunk.choices[0].delta.content
yield response_prefix + accumulated_text
except Exception:
final_resp = llm_client.chat.completions.create(
model="Qwen/Qwen2.5-72B-Instruct",
messages=messages,
tools=None,
max_tokens=MAX_OUTPUT_TOKENS,
temperature=0.7,
stream=False
)
yield response_prefix + (final_resp.choices[0].message.content or "")
except Exception as e:
import traceback
error_detail = str(e)
if "500" in error_detail:
yield f"❌ Error: 模型服务器错误\n\n{error_detail[:200]}"
else:
yield f"❌ Error: {error_detail}\n\n{traceback.format_exc()[:500]}"