|
|
""" |
|
|
Financial AI Assistant - Direct Method Library (不依赖 HTTP) |
|
|
直接导入并调用 easy_financial_mcp.py 中的函数 |
|
|
支持本地和 HF Space 部署 |
|
|
""" |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
import os |
|
|
import json |
|
|
from dotenv import load_dotenv |
|
|
from huggingface_hub import InferenceClient |
|
|
import requests |
|
|
import warnings |
|
|
|
|
|
|
|
|
warnings.filterwarnings('ignore', category=DeprecationWarning) |
|
|
os.environ['PYTHONWARNINGS'] = 'ignore' |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent.absolute() |
|
|
sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
|
|
|
|
|
|
try: |
|
|
from EasyFinancialAgent.easy_financial_mcp import ( |
|
|
search_company as _search_company, |
|
|
get_company_info as _get_company_info, |
|
|
get_company_filings as _get_company_filings, |
|
|
get_financial_data as _get_financial_data, |
|
|
extract_financial_metrics as _extract_financial_metrics, |
|
|
get_latest_financial_data as _get_latest_financial_data, |
|
|
advanced_search_company as _advanced_search_company |
|
|
) |
|
|
MCP_DIRECT_AVAILABLE = True |
|
|
print("[FinancialAI] ✓ Direct MCP functions imported successfully") |
|
|
except ImportError as e: |
|
|
MCP_DIRECT_AVAILABLE = False |
|
|
print(f"[FinancialAI] ✗ Failed to import MCP functions: {e}") |
|
|
|
|
|
def _advanced_search_company(x): |
|
|
return {"error": "MCP not available"} |
|
|
def _get_company_info(x): |
|
|
return {"error": "MCP not available"} |
|
|
def _get_company_filings(x, y=None): |
|
|
return {"error": "MCP not available"} |
|
|
def _get_financial_data(x, y): |
|
|
return {"error": "MCP not available"} |
|
|
def _get_latest_financial_data(x): |
|
|
return {"error": "MCP not available"} |
|
|
def _extract_financial_metrics(x, y=3): |
|
|
return {"error": "MCP not available"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def search_company_direct(company_input): |
|
|
""" |
|
|
批量搜索公司信息(直接调用) |
|
|
|
|
|
使用 advanced_search_company 工具,支持公司名称、Ticker 或 CIK 代码 |
|
|
|
|
|
Args: |
|
|
company_input: 公司名称、Ticker 代码或 CIK 代码 |
|
|
|
|
|
Returns: |
|
|
批量搜索结果 |
|
|
|
|
|
Example: |
|
|
result = search_company_direct("Apple") |
|
|
result = search_company_direct("AAPL") |
|
|
result = search_company_direct("0000320193") |
|
|
""" |
|
|
if not MCP_DIRECT_AVAILABLE: |
|
|
return {"error": "MCP functions not available"} |
|
|
|
|
|
try: |
|
|
result = _advanced_search_company(company_input) |
|
|
return [result] |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
def get_company_info_direct(cik): |
|
|
""" |
|
|
获取公司详细信息(直接调用) |
|
|
|
|
|
Args: |
|
|
cik: 公司 CIK 代码 |
|
|
|
|
|
Returns: |
|
|
公司信息 |
|
|
|
|
|
Example: |
|
|
result = get_company_info_direct("0000320193") |
|
|
""" |
|
|
if not MCP_DIRECT_AVAILABLE: |
|
|
return {"error": "MCP functions not available"} |
|
|
|
|
|
try: |
|
|
return _get_company_info(cik) |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
def get_company_filings_direct(cik): |
|
|
""" |
|
|
获取公司 SEC 文件列表(直接调用) |
|
|
|
|
|
Args: |
|
|
cik: 公司 CIK 代码 |
|
|
|
|
|
Returns: |
|
|
文件列表 |
|
|
|
|
|
Example: |
|
|
result = get_company_filings_direct("0000320193") |
|
|
""" |
|
|
if not MCP_DIRECT_AVAILABLE: |
|
|
return {"error": "MCP functions not available"} |
|
|
|
|
|
try: |
|
|
return _get_company_filings(cik) |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
def advanced_search_company_detailed(company_input): |
|
|
""" |
|
|
高级公司搜索 - 支持公司名称、Ticker 或 CIK 的强大搜索方法 |
|
|
|
|
|
不同于 search_company_direct,该方法来自 EasyReportDataMCP 中的 mcp_server_fastmcp |
|
|
更具有灵活性,可以自动检测输入的类型 |
|
|
|
|
|
Args: |
|
|
company_input: 公司名称 ("Tesla", "Apple Inc") |
|
|
Ticker 代码 ("TSLA", "AAPL", "MSFT") |
|
|
CIK 代码 ("0001318605", "0000320193") |
|
|
|
|
|
Returns: |
|
|
dict: 包含以下信息: |
|
|
- cik: 公司的 Central Index Key |
|
|
- name: 办公室注册名称 |
|
|
- tickers: 股票代码 |
|
|
- sic: Standard Industrial Classification 代码 |
|
|
- sic_description: 行业/行业描述 |
|
|
|
|
|
Example: |
|
|
# 按公司名称搜索 |
|
|
result = advanced_search_company_detailed("Tesla") |
|
|
# 按 Ticker 搜索 |
|
|
result = advanced_search_company_detailed("TSLA") |
|
|
# 按 CIK 搜索 |
|
|
result = advanced_search_company_detailed("0001318605") |
|
|
""" |
|
|
if not MCP_DIRECT_AVAILABLE: |
|
|
return {"error": "MCP functions not available"} |
|
|
|
|
|
try: |
|
|
|
|
|
result = _advanced_search_company(company_input) |
|
|
return result |
|
|
except Exception as e: |
|
|
import traceback |
|
|
return { |
|
|
"error": str(e), |
|
|
"traceback": traceback.format_exc() |
|
|
} |
|
|
|
|
|
|
|
|
def format_search_result(search_result): |
|
|
""" |
|
|
提取并格式化搜索结果 |
|
|
|
|
|
将 advanced_search_company 的结果转换为标准格式: |
|
|
[{company_name: str, cik: str, ticker: str}] |
|
|
|
|
|
Args: |
|
|
search_result: advanced_search_company 的返回结果 |
|
|
格式: {'cik': '...', 'name': '...', 'tickers': [...], ...} |
|
|
|
|
|
Returns: |
|
|
list[dict]: 格式化的结果 |
|
|
[ |
|
|
{ |
|
|
'company_name': str, # 公司名称 |
|
|
'cik': str, # CIK 代码 |
|
|
'ticker': str # 第一个股票代码 |
|
|
} |
|
|
] |
|
|
|
|
|
Example: |
|
|
search_result = {'cik': '0001577552', 'name': 'Alibaba Group Holding Ltd', 'tickers': ['BABA'], '_source': 'company_tickers_cache'} |
|
|
formatted = format_search_result(search_result) |
|
|
# 输出: [{'company_name': 'Alibaba Group Holding Ltd', 'cik': '0001577552', 'ticker': 'BABA'}] |
|
|
""" |
|
|
|
|
|
if isinstance(search_result, dict) and 'error' in search_result: |
|
|
return [] |
|
|
|
|
|
|
|
|
if isinstance(search_result, list): |
|
|
formatted_list = [] |
|
|
for item in search_result: |
|
|
formatted_item = format_search_result(item) |
|
|
formatted_list.extend(formatted_item) |
|
|
return formatted_list |
|
|
|
|
|
|
|
|
if not isinstance(search_result, dict): |
|
|
return [] |
|
|
|
|
|
try: |
|
|
company_name = search_result.get('name', '') |
|
|
cik = search_result.get('cik', '') |
|
|
tickers = search_result.get('tickers', []) |
|
|
|
|
|
|
|
|
ticker = tickers[0] if isinstance(tickers, list) and len(tickers) > 0 else '' |
|
|
|
|
|
return [{ |
|
|
'company_name': company_name, |
|
|
'cik': cik, |
|
|
'ticker': ticker |
|
|
}] |
|
|
except Exception as e: |
|
|
return [] |
|
|
|
|
|
|
|
|
def format_search_result_for_display(search_result): |
|
|
""" |
|
|
格式化搜索结果为显示用的字符串列表 |
|
|
|
|
|
Args: |
|
|
search_result: advanced_search_company 的返回结果 |
|
|
|
|
|
Returns: |
|
|
list[str]: 格式化的字符串列表 ["公司名 (Ticker)"] |
|
|
|
|
|
Example: |
|
|
result = format_search_result_for_display({'cik': '0001577552', 'name': 'Alibaba Group', 'tickers': ['BABA']}) |
|
|
# 输出: ['Alibaba Group (BABA)'] |
|
|
""" |
|
|
formatted_data = format_search_result(search_result) |
|
|
|
|
|
|
|
|
def is_main_us_ticker(ticker): |
|
|
if not ticker: |
|
|
return False |
|
|
|
|
|
ticker = ticker.upper().strip() |
|
|
|
|
|
|
|
|
ticker_clean = ticker.replace('.', '') |
|
|
|
|
|
|
|
|
|
|
|
if len(ticker_clean) > 5: |
|
|
return False |
|
|
|
|
|
|
|
|
if len(ticker_clean) == 5 and ticker_clean.endswith(('F', 'Y', 'Q', 'D', 'W', 'U', 'P')): |
|
|
return False |
|
|
|
|
|
|
|
|
return True |
|
|
|
|
|
display_list = [] |
|
|
for item in formatted_data: |
|
|
company_name = item.get('company_name', 'Unknown') |
|
|
ticker = item.get('ticker', '') |
|
|
|
|
|
|
|
|
if ticker and is_main_us_ticker(ticker): |
|
|
display_text = f"{company_name} ({ticker})" |
|
|
display_list.append(display_text) |
|
|
elif not ticker: |
|
|
|
|
|
display_list.append(company_name) |
|
|
|
|
|
return display_list |
|
|
|
|
|
|
|
|
def search_and_format(company_input): |
|
|
""" |
|
|
搎合搜索并立即格式化结果 |
|
|
|
|
|
一个一步到位的便法方法,执行搜索并格式化结果 |
|
|
|
|
|
Args: |
|
|
company_input: 公司名称、Ticker 或 CIK |
|
|
|
|
|
Returns: |
|
|
list[dict]: 格式化的结果 |
|
|
|
|
|
Example: |
|
|
result = search_and_format('BABA') |
|
|
# 输出: [{'company_name': 'Alibaba Group Holding Ltd', 'cik': '0001577552', 'ticker': 'BABA'}] |
|
|
""" |
|
|
|
|
|
search_result = advanced_search_company_detailed(company_input) |
|
|
|
|
|
|
|
|
if isinstance(search_result, dict) and 'error' in search_result: |
|
|
return [] |
|
|
|
|
|
|
|
|
return format_search_result(search_result) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_latest_financial_data_direct(cik): |
|
|
""" |
|
|
获取公司最新财务数据(直接调用) |
|
|
|
|
|
Args: |
|
|
cik: 公司 CIK 代码 |
|
|
|
|
|
Returns: |
|
|
最新财务数据 |
|
|
|
|
|
Example: |
|
|
result = get_latest_financial_data_direct("0000320193") |
|
|
""" |
|
|
if not MCP_DIRECT_AVAILABLE: |
|
|
return {"error": "MCP functions not available"} |
|
|
|
|
|
try: |
|
|
return _get_latest_financial_data(cik) |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
def extract_financial_metrics_direct(cik, years=5): |
|
|
""" |
|
|
提取多年财务指标趋势(直接调用) |
|
|
|
|
|
Args: |
|
|
cik: 公司 CIK 代码 |
|
|
years: 年数(默认 3 年) |
|
|
|
|
|
Returns: |
|
|
财务指标数据 |
|
|
|
|
|
Example: |
|
|
result = extract_financial_metrics_direct("0000320193", years=5) |
|
|
""" |
|
|
if not MCP_DIRECT_AVAILABLE: |
|
|
return {"error": "MCP functions not available"} |
|
|
|
|
|
try: |
|
|
return _extract_financial_metrics(cik, years) |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def query_company_direct(company_input, get_filings=True, get_metrics=True): |
|
|
""" |
|
|
综合查询公司信息(直接调用) |
|
|
包括搜索、基本信息、文件列表和财务指标 |
|
|
|
|
|
Args: |
|
|
company_input: 公司名称或代码 |
|
|
get_filings: 是否获取文件列表 |
|
|
get_metrics: 是否获取财务指标 |
|
|
|
|
|
Returns: |
|
|
综合结果字典,包含 search, company_info, filings, metrics |
|
|
|
|
|
Example: |
|
|
result = query_company_direct("Apple", get_filings=True, get_metrics=True) |
|
|
""" |
|
|
from datetime import datetime |
|
|
|
|
|
result = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"query_input": company_input, |
|
|
"status": "success", |
|
|
"data": { |
|
|
"company_search": None, |
|
|
"company_info": None, |
|
|
"filings": None, |
|
|
"metrics": None |
|
|
}, |
|
|
"errors": [] |
|
|
} |
|
|
|
|
|
if not MCP_DIRECT_AVAILABLE: |
|
|
result["status"] = "error" |
|
|
result["errors"].append("MCP functions not available") |
|
|
return result |
|
|
|
|
|
try: |
|
|
|
|
|
search_result = search_company_direct(company_input) |
|
|
if "error" in search_result: |
|
|
result["errors"].append(f"Search error: {search_result['error']}") |
|
|
result["status"] = "error" |
|
|
return result |
|
|
|
|
|
result["data"]["company_search"] = search_result |
|
|
|
|
|
|
|
|
cik = None |
|
|
if isinstance(search_result, dict): |
|
|
cik = search_result.get("cik") |
|
|
elif isinstance(search_result, (list, tuple)) and len(search_result) > 0: |
|
|
|
|
|
try: |
|
|
first_item = search_result[0] if isinstance(search_result, (list, tuple)) else None |
|
|
if isinstance(first_item, dict): |
|
|
cik = first_item.get("cik") |
|
|
except (IndexError, TypeError): |
|
|
pass |
|
|
|
|
|
if not cik: |
|
|
result["errors"].append("Could not extract CIK from search result") |
|
|
result["status"] = "error" |
|
|
return result |
|
|
|
|
|
|
|
|
company_info = get_company_info_direct(cik) |
|
|
if "error" not in company_info: |
|
|
result["data"]["company_info"] = company_info |
|
|
else: |
|
|
result["errors"].append(f"Failed to get company info: {company_info.get('error')}") |
|
|
|
|
|
|
|
|
if get_filings: |
|
|
filings = get_company_filings_direct(cik) |
|
|
if "error" not in filings: |
|
|
result["data"]["filings"] = filings |
|
|
else: |
|
|
result["errors"].append(f"Failed to get filings: {filings.get('error')}") |
|
|
|
|
|
|
|
|
if get_metrics: |
|
|
metrics = extract_financial_metrics_direct(cik, years=3) |
|
|
if "error" not in metrics: |
|
|
result["data"]["metrics"] = metrics |
|
|
else: |
|
|
result["errors"].append(f"Failed to get metrics: {metrics.get('error')}") |
|
|
|
|
|
except Exception as e: |
|
|
result["status"] = "error" |
|
|
result["errors"].append(f"Exception: {str(e)}") |
|
|
import traceback |
|
|
result["errors"].append(traceback.format_exc()) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_llm_client(): |
|
|
"""初始化 LLM 客户端""" |
|
|
global llm_client |
|
|
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") |
|
|
llm_client = None |
|
|
try: |
|
|
if hf_token: |
|
|
llm_client = InferenceClient(api_key=hf_token) |
|
|
print("[FinancialAI] ✓ LLM client initialized with HF_TOKEN") |
|
|
return True |
|
|
else: |
|
|
print("[FinancialAI] ⚠ Warning: HF_TOKEN not found, LLM features disabled") |
|
|
return False |
|
|
except Exception as e: |
|
|
print(f"[FinancialAI] ✗ Failed to initialize LLM client: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
llm_client = None |
|
|
_init_llm_client() |
|
|
|
|
|
|
|
|
def get_system_prompt(): |
|
|
"""生成系统提示词""" |
|
|
from datetime import datetime |
|
|
current_date = datetime.now().strftime("%Y-%m-%d") |
|
|
return f"""You are a financial analysis expert. Today is {current_date}. |
|
|
Your role: |
|
|
- Analyze company financial data, reports, and market news |
|
|
- Provide investment insights based on factual data |
|
|
- Be concise, objective, and data-driven |
|
|
- Always include disclaimers about market risks |
|
|
|
|
|
⚠️ IMPORTANT: You have a maximum of 5 tool calls. Choose the MOST RELEVANT tools carefully: |
|
|
- Use 'advanced_search_company' ONLY if you need to find a company's CIK |
|
|
- Use 'extract_financial_metrics' for comprehensive multi-year financial analysis (RECOMMENDED for most queries) |
|
|
- Use 'get_latest_financial_data' for quick recent snapshot |
|
|
- Use 'get_quote' for real-time stock price |
|
|
- Use 'get_company_news' for company-specific news |
|
|
- Use 'get_market_news' for general market trends |
|
|
|
|
|
Prioritize the most important tools for the user's question. Avoid redundant calls. |
|
|
Output should be in English.""" |
|
|
|
|
|
|
|
|
def analyze_company_with_llm(company_input, analysis_type="summary"): |
|
|
""" |
|
|
使用 LLM 分析公司信息 |
|
|
|
|
|
Args: |
|
|
company_input: 公司名称或代码 |
|
|
analysis_type: 分析类型 ("summary", "investment", "risks") |
|
|
|
|
|
Returns: |
|
|
LLM 分析结果 |
|
|
|
|
|
Example: |
|
|
result = analyze_company_with_llm("Apple", "investment") |
|
|
""" |
|
|
if not llm_client: |
|
|
return {"error": "LLM client not available"} |
|
|
|
|
|
if not MCP_DIRECT_AVAILABLE: |
|
|
return {"error": "MCP functions not available"} |
|
|
|
|
|
try: |
|
|
|
|
|
company_data = get_company_summary_direct(company_input) |
|
|
if company_data["status"] == "error": |
|
|
return {"error": f"Failed to fetch company data: {company_data['errors']}"} |
|
|
|
|
|
|
|
|
data_str = json.dumps(company_data["data"], ensure_ascii=False, indent=2) |
|
|
|
|
|
if analysis_type == "investment": |
|
|
prompt = f""" |
|
|
Based on the following company financial data, provide an investment recommendation: |
|
|
|
|
|
{data_str} |
|
|
|
|
|
Provide: |
|
|
1. Investment Recommendation (Buy/Hold/Sell) |
|
|
2. Key Strengths and Weaknesses |
|
|
3. Price Target Range |
|
|
4. Risk Assessment |
|
|
5. Risk Disclaimer |
|
|
""" |
|
|
elif analysis_type == "risks": |
|
|
prompt = f""" |
|
|
Based on the following company data, analyze the key risks: |
|
|
|
|
|
{data_str} |
|
|
|
|
|
Identify: |
|
|
1. Financial Risks |
|
|
2. Market Risks |
|
|
3. Operational Risks |
|
|
4. Mitigation Strategies |
|
|
5. Risk Disclaimer |
|
|
""" |
|
|
else: |
|
|
prompt = f""" |
|
|
Provide a financial summary of the following company: |
|
|
|
|
|
{data_str} |
|
|
|
|
|
Include: |
|
|
1. Company Overview |
|
|
2. Financial Health |
|
|
3. Recent Performance |
|
|
4. Investment Outlook |
|
|
""" |
|
|
|
|
|
|
|
|
response = llm_client.chat.completions.create( |
|
|
model="Qwen/Qwen2.5-72B-Instruct", |
|
|
messages=[ |
|
|
{"role": "system", "content": get_system_prompt()}, |
|
|
{"role": "user", "content": prompt} |
|
|
], |
|
|
max_tokens=1500, |
|
|
temperature=0.7, |
|
|
top_p=0.95, |
|
|
stream=False |
|
|
) |
|
|
|
|
|
return { |
|
|
"company": company_input, |
|
|
"analysis_type": analysis_type, |
|
|
"analysis": response.choices[0].message.content, |
|
|
"data_used": company_data["data"] |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"LLM analysis failed: {str(e)}"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_financial_data_direct(cik, period): |
|
|
""" |
|
|
获取指定时期的财务数据(直接调用) |
|
|
|
|
|
Args: |
|
|
cik: 公司 CIK 代码 |
|
|
period: 时期 (e.g., "2024", "2024Q3") |
|
|
|
|
|
Returns: |
|
|
财务数据 |
|
|
|
|
|
Example: |
|
|
result = get_financial_data_direct("0000320193", "2024") |
|
|
""" |
|
|
if not MCP_DIRECT_AVAILABLE: |
|
|
return {"error": "MCP functions not available"} |
|
|
|
|
|
try: |
|
|
return _get_financial_data(cik, period) |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_company_filings_with_form_direct(cik, form_types=None): |
|
|
""" |
|
|
获取指定类型的公司 SEC 文件列表(直接调用) |
|
|
|
|
|
Args: |
|
|
cik: 公司 CIK 代码 |
|
|
form_types: 表单类型列表 (e.g., ["10-K", "10-Q"]) |
|
|
|
|
|
Returns: |
|
|
文件列表 |
|
|
|
|
|
Example: |
|
|
result = get_company_filings_with_form_direct("0000320193", ["10-K"]) |
|
|
""" |
|
|
if not MCP_DIRECT_AVAILABLE: |
|
|
return {"error": "MCP functions not available"} |
|
|
|
|
|
try: |
|
|
return _get_company_filings(cik, form_types) |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_company_summary_direct(company_input): |
|
|
""" |
|
|
获取公司简要摘要信息(轻量级查询,仅搜索和基本信息) |
|
|
|
|
|
Args: |
|
|
company_input: 公司名称或代码 |
|
|
|
|
|
Returns: |
|
|
公司摘要数据 |
|
|
|
|
|
Example: |
|
|
result = get_company_summary_direct("Apple") |
|
|
""" |
|
|
from datetime import datetime |
|
|
|
|
|
result = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"query_input": company_input, |
|
|
"status": "success", |
|
|
"data": { |
|
|
"company_search": None, |
|
|
"company_info": None |
|
|
}, |
|
|
"errors": [] |
|
|
} |
|
|
|
|
|
if not MCP_DIRECT_AVAILABLE: |
|
|
result["status"] = "error" |
|
|
result["errors"].append("MCP functions not available") |
|
|
return result |
|
|
|
|
|
try: |
|
|
|
|
|
search_result = search_company_direct(company_input) |
|
|
if "error" in search_result: |
|
|
result["errors"].append(f"Search error: {search_result['error']}") |
|
|
result["status"] = "error" |
|
|
return result |
|
|
|
|
|
result["data"]["company_search"] = search_result |
|
|
|
|
|
|
|
|
cik = None |
|
|
if isinstance(search_result, dict): |
|
|
cik = search_result.get("cik") |
|
|
elif isinstance(search_result, (list, tuple)) and len(search_result) > 0: |
|
|
try: |
|
|
first_item = search_result[0] |
|
|
if isinstance(first_item, dict): |
|
|
cik = first_item.get("cik") |
|
|
except (IndexError, TypeError): |
|
|
pass |
|
|
|
|
|
if not cik: |
|
|
result["errors"].append("Could not extract CIK from search result") |
|
|
result["status"] = "error" |
|
|
return result |
|
|
|
|
|
|
|
|
company_info = get_company_info_direct(cik) |
|
|
if "error" not in company_info: |
|
|
result["data"]["company_info"] = company_info |
|
|
else: |
|
|
result["errors"].append(f"Failed to get company info: {company_info.get('error')}") |
|
|
|
|
|
except Exception as e: |
|
|
result["status"] = "error" |
|
|
result["errors"].append(f"Exception: {str(e)}") |
|
|
import traceback |
|
|
result["errors"].append(traceback.format_exc()) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def get_financial_metrics_only_direct(company_input, years=5): |
|
|
""" |
|
|
获取公司财务指标趋势(仅财务指标,不获取文件列表) |
|
|
|
|
|
Args: |
|
|
company_input: 公司名称或代码 |
|
|
years: 年数(默认 5 年) |
|
|
|
|
|
Returns: |
|
|
财务指标数据 |
|
|
|
|
|
Example: |
|
|
result = get_financial_metrics_only_direct("Apple", years=5) |
|
|
""" |
|
|
from datetime import datetime |
|
|
|
|
|
result = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"query_input": company_input, |
|
|
"years": years, |
|
|
"status": "success", |
|
|
"data": None, |
|
|
"errors": [] |
|
|
} |
|
|
|
|
|
if not MCP_DIRECT_AVAILABLE: |
|
|
result["status"] = "error" |
|
|
result["errors"].append("MCP functions not available") |
|
|
return result |
|
|
|
|
|
try: |
|
|
|
|
|
search_result = search_company_direct(company_input) |
|
|
if "error" in search_result: |
|
|
result["errors"].append(f"Search error: {search_result['error']}") |
|
|
result["status"] = "error" |
|
|
return result |
|
|
|
|
|
|
|
|
cik = None |
|
|
if isinstance(search_result, dict): |
|
|
cik = search_result.get("cik") |
|
|
elif isinstance(search_result, (list, tuple)) and len(search_result) > 0: |
|
|
try: |
|
|
first_item = search_result[0] |
|
|
if isinstance(first_item, dict): |
|
|
cik = first_item.get("cik") |
|
|
except (IndexError, TypeError): |
|
|
pass |
|
|
|
|
|
if not cik: |
|
|
result["errors"].append("Could not extract CIK from search result") |
|
|
result["status"] = "error" |
|
|
return result |
|
|
|
|
|
|
|
|
metrics = extract_financial_metrics_direct(cik, years=years) |
|
|
if "error" in metrics: |
|
|
result["errors"].append(f"Failed to get metrics: {metrics['error']}") |
|
|
result["status"] = "error" |
|
|
else: |
|
|
result["data"] = metrics |
|
|
|
|
|
except Exception as e: |
|
|
result["status"] = "error" |
|
|
result["errors"].append(f"Exception: {str(e)}") |
|
|
import traceback |
|
|
result["errors"].append(traceback.format_exc()) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("\n" + "="*60) |
|
|
print("Financial AI Assistant - Direct Method Test") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
print("\n1. 搜索公司 (Apple)...") |
|
|
result = search_company_direct("Apple") |
|
|
print(f" 结果: {result}") |
|
|
|
|
|
|
|
|
print("\n2. 获取公司摘要信息 (Tesla)...") |
|
|
summary = get_company_summary_direct("Tesla") |
|
|
print(f" 状态: {summary['status']}") |
|
|
print(f" 数据: {summary['data']}") |
|
|
print(f" 错误: {summary['errors']}") |
|
|
|
|
|
|
|
|
print("\n3. 获取财务指标 (Microsoft)...") |
|
|
metrics = get_financial_metrics_only_direct("Microsoft", years=3) |
|
|
print(f" 状态: {metrics['status']}") |
|
|
if metrics['status'] == 'success': |
|
|
print(f" 指标数据: {metrics['data']}") |
|
|
else: |
|
|
print(f" 错误: {metrics['errors']}") |
|
|
|
|
|
|
|
|
print("\n4. 获取 Amazon 完整信息...") |
|
|
full_query = query_company_direct("Amazon", get_filings=True, get_metrics=True) |
|
|
print(f" 状态: {full_query['status']}") |
|
|
print(f" 错误: {full_query['errors']}") |
|
|
|
|
|
|
|
|
print("\n5. LLM 分析 - 公司摘要(Google)...") |
|
|
if llm_client: |
|
|
llm_result = analyze_company_with_llm("Google", "summary") |
|
|
if "error" in llm_result: |
|
|
print(f" 错误: {llm_result['error']}") |
|
|
else: |
|
|
print(f" 分析结果: {llm_result['analysis'][:200]}...") |
|
|
else: |
|
|
print(" LLM 客户端不可用") |
|
|
|
|
|
|
|
|
print("\n6. LLM 分析 - 投资建议(NVIDIA)...") |
|
|
if llm_client: |
|
|
llm_result = analyze_company_with_llm("NVIDIA", "investment") |
|
|
if "error" in llm_result: |
|
|
print(f" 错误: {llm_result['error']}") |
|
|
else: |
|
|
print(f" 分析结果: {llm_result['analysis'][:200]}...") |
|
|
else: |
|
|
print(" LLM 客户端不可用") |
|
|
|
|
|
|
|
|
print("\n7. LLM 分析 - 风险评估(Meta)...") |
|
|
if llm_client: |
|
|
llm_result = analyze_company_with_llm("Meta", "risks") |
|
|
if "error" in llm_result: |
|
|
print(f" 错误: {llm_result['error']}") |
|
|
else: |
|
|
print(f" 分析结果: {llm_result['analysis'][:200]}...") |
|
|
else: |
|
|
print(" LLM 客户端不可用") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_TOTAL_TOKENS = 6000 |
|
|
MAX_TOOL_RESULT_CHARS = 1500 |
|
|
MAX_HISTORY_CHARS = 500 |
|
|
MAX_HISTORY_TURNS = 2 |
|
|
MAX_TOOL_ITERATIONS = 5 |
|
|
MAX_OUTPUT_TOKENS = 2000 |
|
|
|
|
|
|
|
|
MCP_TOOLS = [ |
|
|
|
|
|
{"type": "function", "function": {"name": "advanced_search_company", "description": "Search US companies by name, ticker, or CIK. Returns company information including CIK, name, tickers, and industry classification.", "parameters": {"type": "object", "properties": {"company_input": {"type": "string", "description": "Company name (e.g., 'Tesla'), ticker symbol (e.g., 'TSLA'), or CIK code (e.g., '0001318605')"}}, "required": ["company_input"]}}}, |
|
|
{"type": "function", "function": {"name": "get_latest_financial_data", "description": "Get the most recent financial data for a company including revenue, net income, EPS, operating expenses, and cash flow.", "parameters": {"type": "object", "properties": {"cik": {"type": "string", "description": "Company CIK code (10-digit format, e.g., '0001318605')"}}, "required": ["cik"]}}}, |
|
|
{"type": "function", "function": {"name": "extract_financial_metrics", "description": "Extract multi-year financial metrics trends showing historical performance over specified years.", "parameters": {"type": "object", "properties": {"cik": {"type": "string", "description": "Company CIK code (10-digit format)"}, "years": {"type": "integer", "description": "Number of years of data to retrieve (e.g., 3 or 5)", "default": 3}}, "required": ["cik", "years"]}}}, |
|
|
|
|
|
|
|
|
{"type": "function", "function": {"name": "get_quote", "description": "Get real-time stock quote data including current price, daily change, high/low, and previous close. Use when users ask about current stock prices or market performance.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string", "description": "Stock ticker symbol (e.g., 'AAPL', 'TSLA', 'MSFT')"}}, "required": ["symbol"]}}}, |
|
|
{"type": "function", "function": {"name": "get_market_news", "description": "Get latest market news by category. Use when users ask about general market trends, forex, crypto, or M&A news.", "parameters": {"type": "object", "properties": {"category": {"type": "string", "enum": ["general", "forex", "crypto", "merger"], "description": "News category: general (stocks/economy), forex (currency), crypto (cryptocurrency), merger (M&A)", "default": "general"}, "min_id": {"type": "integer", "description": "Minimum news ID for pagination (default: 0)", "default": 0}}, "required": ["category"]}}}, |
|
|
{"type": "function", "function": {"name": "get_company_news", "description": "Get company-specific news and announcements. Only available for North American companies. Use when users ask about specific company news.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string", "description": "Company stock ticker symbol (e.g., 'AAPL', 'TSLA')"}, "from_date": {"type": "string", "description": "Start date in YYYY-MM-DD format (optional, defaults to 7 days ago)"}, "to_date": {"type": "string", "description": "End date in YYYY-MM-DD format (optional, defaults to today)"}}, "required": ["symbol"]}}} |
|
|
] |
|
|
|
|
|
|
|
|
def truncate_text(text, max_chars, suffix="...[truncated]"): |
|
|
"""截断文本到指定长度""" |
|
|
text = str(text) |
|
|
if len(text) <= max_chars: |
|
|
return text |
|
|
return text[:max_chars] + suffix |
|
|
|
|
|
|
|
|
def call_mcp_tool(tool_name, arguments): |
|
|
"""直接调用 MCP 工具函数(不通过HTTP)""" |
|
|
try: |
|
|
|
|
|
if tool_name == "advanced_search_company": |
|
|
company_input = arguments.get("company_input", "") |
|
|
return _advanced_search_company(company_input) |
|
|
|
|
|
elif tool_name == "get_latest_financial_data": |
|
|
cik = arguments.get("cik", "") |
|
|
return _get_latest_financial_data(cik) |
|
|
|
|
|
elif tool_name == "extract_financial_metrics": |
|
|
cik = arguments.get("cik", "") |
|
|
years = arguments.get("years", 3) |
|
|
return _extract_financial_metrics(cik, years) |
|
|
|
|
|
|
|
|
elif tool_name == "get_quote": |
|
|
from MarketandStockMCP.news_quote_mcp import get_quote |
|
|
symbol = arguments.get("symbol", "") |
|
|
return get_quote(symbol) |
|
|
|
|
|
elif tool_name == "get_market_news": |
|
|
from MarketandStockMCP.news_quote_mcp import get_market_news |
|
|
category = arguments.get("category", "general") |
|
|
min_id = arguments.get("min_id", 0) |
|
|
return get_market_news(category, min_id) |
|
|
|
|
|
elif tool_name == "get_company_news": |
|
|
from MarketandStockMCP.news_quote_mcp import get_company_news |
|
|
symbol = arguments.get("symbol", "") |
|
|
from_date = arguments.get("from_date") |
|
|
to_date = arguments.get("to_date") |
|
|
return get_company_news(symbol, from_date, to_date) |
|
|
|
|
|
else: |
|
|
return {"error": f"Unknown tool: {tool_name}"} |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
return { |
|
|
"error": f"{str(e)}", |
|
|
"traceback": traceback.format_exc()[:500] |
|
|
} |
|
|
|
|
|
|
|
|
def chatbot_response(message, history=None): |
|
|
""" |
|
|
AI 助手主函数(完整对话引擎) |
|
|
支持多轮对话、动态工具调用、流式输出 |
|
|
|
|
|
Args: |
|
|
message: 用户消息 |
|
|
history: 对话历史,格式: [(user_msg, assistant_msg), ...] |
|
|
|
|
|
Returns: |
|
|
生成器,不断 yield 响应文本 |
|
|
|
|
|
Example: |
|
|
for response in chatbot_response("What's Apple's revenue?", []): |
|
|
print(response) |
|
|
""" |
|
|
if not llm_client: |
|
|
yield "❌ Error: LLM client not available" |
|
|
return |
|
|
|
|
|
if not MCP_DIRECT_AVAILABLE: |
|
|
yield "❌ Error: MCP functions not available" |
|
|
return |
|
|
|
|
|
try: |
|
|
messages = [{"role": "system", "content": get_system_prompt()}] |
|
|
|
|
|
|
|
|
if history: |
|
|
for item in history[-MAX_HISTORY_TURNS:]: |
|
|
if isinstance(item, (list, tuple)) and len(item) == 2: |
|
|
messages.append({"role": "user", "content": item[0]}) |
|
|
assistant_msg = str(item[1]) |
|
|
if len(assistant_msg) > MAX_HISTORY_CHARS: |
|
|
assistant_msg = truncate_text(assistant_msg, MAX_HISTORY_CHARS) |
|
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
tool_calls_log = [] |
|
|
final_response_content = None |
|
|
|
|
|
|
|
|
for iteration in range(MAX_TOOL_ITERATIONS): |
|
|
response = llm_client.chat.completions.create( |
|
|
model="Qwen/Qwen2.5-72B-Instruct", |
|
|
messages=messages, |
|
|
tools=MCP_TOOLS, |
|
|
max_tokens=MAX_OUTPUT_TOKENS, |
|
|
temperature=0.7, |
|
|
tool_choice="auto", |
|
|
stream=False |
|
|
) |
|
|
|
|
|
choice = response.choices[0] |
|
|
|
|
|
if choice.message.tool_calls: |
|
|
messages.append(choice.message) |
|
|
|
|
|
for tool_call in choice.message.tool_calls: |
|
|
tool_name = tool_call.function.name |
|
|
try: |
|
|
tool_args = json.loads(tool_call.function.arguments) |
|
|
except json.JSONDecodeError: |
|
|
tool_args = {} |
|
|
|
|
|
tool_result = call_mcp_tool(tool_name, tool_args) |
|
|
|
|
|
if isinstance(tool_result, dict) and "error" in tool_result: |
|
|
tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result, "error": True}) |
|
|
result_for_llm = json.dumps({"error": tool_result.get("error", "Unknown error")}, ensure_ascii=False) |
|
|
else: |
|
|
result_str = json.dumps(tool_result, ensure_ascii=False) |
|
|
|
|
|
if len(result_str) > MAX_TOOL_RESULT_CHARS: |
|
|
if isinstance(tool_result, dict) and "text" in tool_result: |
|
|
truncated_text = truncate_text(tool_result["text"], MAX_TOOL_RESULT_CHARS - 50) |
|
|
tool_result_truncated = {"text": truncated_text, "_truncated": True} |
|
|
elif isinstance(tool_result, dict): |
|
|
truncated = {} |
|
|
char_count = 0 |
|
|
for k, v in list(tool_result.items())[:8]: |
|
|
v_str = str(v)[:300] |
|
|
truncated[k] = v_str |
|
|
char_count += len(k) + len(v_str) |
|
|
if char_count > MAX_TOOL_RESULT_CHARS: |
|
|
break |
|
|
tool_result_truncated = {**truncated, "_truncated": True} |
|
|
else: |
|
|
tool_result_truncated = {"preview": truncate_text(result_str, MAX_TOOL_RESULT_CHARS), "_truncated": True} |
|
|
result_for_llm = json.dumps(tool_result_truncated, ensure_ascii=False) |
|
|
else: |
|
|
result_for_llm = result_str |
|
|
|
|
|
tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result}) |
|
|
|
|
|
messages.append({ |
|
|
"role": "tool", |
|
|
"name": tool_name, |
|
|
"content": result_for_llm, |
|
|
"tool_call_id": tool_call.id |
|
|
}) |
|
|
|
|
|
continue |
|
|
else: |
|
|
final_response_content = choice.message.content |
|
|
break |
|
|
|
|
|
response_prefix = "" |
|
|
|
|
|
if tool_calls_log: |
|
|
|
|
|
tool_count = len(tool_calls_log) |
|
|
|
|
|
|
|
|
response_prefix += """<style> |
|
|
details.tools-container > summary::before { |
|
|
content: '▶'; |
|
|
display: inline-block; |
|
|
margin-right: 8px; |
|
|
transition: transform 0.2s; |
|
|
} |
|
|
details.tools-container[open] > summary::before { |
|
|
transform: rotate(90deg); |
|
|
} |
|
|
details.tools-container > summary { |
|
|
list-style: none; |
|
|
} |
|
|
details.tools-container > summary::-webkit-details-marker { |
|
|
display: none; |
|
|
} |
|
|
</style> |
|
|
""" |
|
|
|
|
|
response_prefix += f"""<div style='margin-bottom: 15px;'> |
|
|
<details class='tools-container'> |
|
|
<summary style='background: #f0f0f0; padding: 8px 12px; border-radius: 6px; font-weight: 600; color: #333; cursor: pointer; user-select: none;'> |
|
|
<span>🛠️ Tools Used ({tool_count}/{MAX_TOOL_ITERATIONS} calls)</span> |
|
|
</summary> |
|
|
<div style='margin-top: 8px;'> |
|
|
""" |
|
|
|
|
|
for idx, tool_call in enumerate(tool_calls_log): |
|
|
args_json = json.dumps(tool_call['arguments'], ensure_ascii=False) |
|
|
result_json = json.dumps(tool_call.get('result', {}), ensure_ascii=False, indent=2) |
|
|
result_preview = result_json[:1500] + ('...' if len(result_json) > 1500 else '') |
|
|
error_indicator = " ❌ Error" if tool_call.get('error') else "" |
|
|
|
|
|
response_prefix += f"""<details style='margin: 8px 0; border: 1px solid #ddd; border-radius: 6px; overflow: hidden;'> |
|
|
<summary style='background: #fff; padding: 10px; cursor: pointer; user-select: none; list-style: none;'> |
|
|
<strong style='color: #2c5aa0;'>📋 {idx+1}. {tool_call['name']}{error_indicator}</strong> |
|
|
</summary> |
|
|
<div style='background: #f9f9f9; padding: 12px;'> |
|
|
<pre style='background: #fff; padding: 10px; overflow-x: auto; font-size: 0.85em;'>{result_preview}</pre> |
|
|
</div> |
|
|
</details> |
|
|
""" |
|
|
|
|
|
|
|
|
response_prefix += """ </div> |
|
|
</details> |
|
|
</div> |
|
|
--- |
|
|
""" |
|
|
|
|
|
yield response_prefix |
|
|
|
|
|
if final_response_content: |
|
|
yield response_prefix + final_response_content |
|
|
else: |
|
|
try: |
|
|
stream = llm_client.chat.completions.create( |
|
|
model="Qwen/Qwen2.5-72B-Instruct", |
|
|
messages=messages, |
|
|
tools=None, |
|
|
max_tokens=MAX_OUTPUT_TOKENS, |
|
|
temperature=0.7, |
|
|
stream=True |
|
|
) |
|
|
|
|
|
accumulated_text = "" |
|
|
for chunk in stream: |
|
|
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content: |
|
|
accumulated_text += chunk.choices[0].delta.content |
|
|
yield response_prefix + accumulated_text |
|
|
except Exception: |
|
|
final_resp = llm_client.chat.completions.create( |
|
|
model="Qwen/Qwen2.5-72B-Instruct", |
|
|
messages=messages, |
|
|
tools=None, |
|
|
max_tokens=MAX_OUTPUT_TOKENS, |
|
|
temperature=0.7, |
|
|
stream=False |
|
|
) |
|
|
yield response_prefix + (final_resp.choices[0].message.content or "") |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_detail = str(e) |
|
|
if "500" in error_detail: |
|
|
yield f"❌ Error: 模型服务器错误\n\n{error_detail[:200]}" |
|
|
else: |
|
|
yield f"❌ Error: {error_detail}\n\n{traceback.format_exc()[:500]}" |
|
|
|