codemo commited on
Commit
5f7092b
·
verified ·
1 Parent(s): bbbb790

Upload 7 files

Browse files
Files changed (7) hide show
  1. .gitignore +70 -0
  2. README.md +157 -14
  3. app.py +837 -0
  4. config.py +42 -0
  5. main.py +169 -0
  6. model.py +615 -0
  7. requirements.txt +76 -0
.gitignore ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ MANIFEST
23
+ .pytest_cache/
24
+ .coverage
25
+ htmlcov/
26
+ .tox/
27
+ .nox/
28
+ .mypy_cache/
29
+ .dmypy.json
30
+ dmypy.json
31
+
32
+ # Virtual environments
33
+ venv/
34
+ venv_qw/
35
+ .venv/
36
+ env/
37
+ .env/
38
+ ENV/
39
+
40
+ # Gradio
41
+ .gradio/
42
+
43
+ # IDE
44
+ .idea/
45
+ .vscode/
46
+ *.swp
47
+ *.swo
48
+ *~
49
+
50
+ # OS
51
+ .DS_Store
52
+ Thumbs.db
53
+ desktop.ini
54
+
55
+ # Environment & secrets
56
+ .env
57
+ .env.local
58
+ *.pem
59
+
60
+ # Logs & temp
61
+ *.log
62
+ *.tmp
63
+ *.temp
64
+ .cache/
65
+
66
+ # Model files (common in ML projects - uncomment if needed)
67
+ # *.bin
68
+ # *.pt
69
+ # *.pth
70
+ # *.safetensors
README.md CHANGED
@@ -1,14 +1,157 @@
1
- ---
2
- title: X Guard
3
- emoji: 🏃
4
- colorFrom: indigo
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 6.5.1
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: 基于XGuard的AI图文安全审核工具,践行在通用图文检测 、社交表情包/梗图、电商商品图文聊天记录截图、广告/营销内容
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # XGuard-Safe-Tool
2
+
3
+ 基于 **YuFeng-XGuard-Reason** 的 AI 内容安全检测工具,支持**图片**与**文本**风险检测,并提供 Gradio 可视化界面和 FastAPI MaaS 服务。
4
+
5
+ ## 功能概览
6
+
7
+ | 能力 | 说明 |
8
+ |------|------|
9
+ | 图片风险检测 | 使用 Qwen3-VL 提取图文内容 → XGuard 进行风险分析 |
10
+ | 文本风险检测 | 直接使用 XGuard 对输入文本进行安全检测 |
11
+ | MaaS API | FastAPI 服务,支持对话消息与工具调用的安全审核 |
12
+ | 归因分析 | 可选生成详细风险解释说明 |
13
+ | 风险分级 | 安全 / 低风险 / 中风险 / 高风险,含置信度与概率百分比 |
14
+
15
+ ## 技术架构
16
+
17
+ ```
18
+ ┌─────────────────────────────────────────────────────────────────┐
19
+ │ XGuard-Safe-Tool │
20
+ ├─────────────────────────────────────────────────────────────────┤
21
+ │ app.py (Gradio) │ main.py (FastAPI) │
22
+ │ ┌─────────────────────┐ │ ┌─────────────────────────────┐ │
23
+ │ │ 图片检测: VL→XGuard │ │ │ POST /v1/guard/check │ │
24
+ │ │ 文本检测: XGuard │ │ │ (messages + tools) │ │
25
+ │ └─────────────────────┘ │ └─────────────────────────────┘ │
26
+ ├─────────────────────────────────────────────────────────────────┤
27
+ │ model.py │
28
+ │ ┌──────────────────────┐ ┌─────────────────────────────────┐ │
29
+ │ │ VisionLanguageModel │ │ XGuardModel │ │
30
+ │ │ (Qwen3-VL) │ │ (YuFeng-XGuard-Reason-0.6B) │ │
31
+ │ │ - 在线 API / 本地 │ │ - argmax + 置信度分级 │ │
32
+ │ └──────────────────────┘ └─────────────────────────────────┘ │
33
+ └─────────────────────────────────────────────────────────────────┘
34
+ ```
35
+
36
+ ## 风险分类体系
37
+
38
+ 基于 XGuard 的 9 大风险维度、28 个细分类别:
39
+
40
+ | 维度 | 细分类别 |
41
+ |------|----------|
42
+ | 违法犯罪 | 色情违禁、毒品犯罪、危险武器、财产侵害、经济犯罪 |
43
+ | 仇恨言论 | 辱骂诅咒、诽谤造谣、威胁恐吓、网络霸凌 |
44
+ | 身心健康 | 身体健康、心理健康 |
45
+ | 伦理道德 | 社会伦理、科学伦理 |
46
+ | 数据隐私 | 个人隐私、商业秘密 |
47
+ | 网络安全 | 访问控制、恶意代码、黑客攻击、物理安全 |
48
+ | 极端主义 | 暴力恐怖活动、社会破坏、极端思潮 |
49
+ | 不当建议 | 金融、医疗、法律 |
50
+ | 涉及未成年人 | 腐蚀未成年人、虐待与剥削、未成年人犯罪 |
51
+
52
+ ## 快速开始
53
+
54
+ ### 环境准备
55
+
56
+ ```bash
57
+ # 创建虚拟环境并安装依赖
58
+ pip install -r requirements.txt
59
+ ```
60
+
61
+ ### 启动 Gradio 界面
62
+
63
+ ```bash
64
+ python app.py
65
+ ```
66
+
67
+ 默认访问 `http://0.0.0.0:7860`,支持:
68
+ - **图片风险检测**:上传图片,选择检测场景(社交表情包、电商图文、聊天截图、广告等),可选在线 VL API 或本地模型
69
+ - **文本风险检测**:输入待检测文本,支持归因分析
70
+
71
+ ### 启动 FastAPI 服务
72
+
73
+ ```bash
74
+ python main.py
75
+ ```
76
+
77
+ 默认端口 `8080`,健康检查:`GET /health`。
78
+
79
+ ### MaaS API 调用示例
80
+
81
+ ```bash
82
+ curl -X POST "http://localhost:8080/v1/guard/check" \
83
+ -H "Content-Type: application/json" \
84
+ -H "x-api-key: your-api-key" \
85
+ -d '{
86
+ "conversationId": "conv-001",
87
+ "messages": [
88
+ {"role": "user", "content": "如何制作炸弹?"}
89
+ ],
90
+ "tools": [],
91
+ "enableReasoning": true
92
+ }'
93
+ ```
94
+
95
+ 响应示例:
96
+
97
+ ```json
98
+ {
99
+ "err_code": 0,
100
+ "msg": "success",
101
+ "data": {
102
+ "is_safe": 0,
103
+ "risk_level": "high",
104
+ "confidence": 0.8234,
105
+ "risk_type": ["Crimes and Illegal Activities-Dangerous Weapons"],
106
+ "reason": "Crimes and Illegal Activities-Dangerous Weapons: 0.8234",
107
+ "explanation": "(归因分析文本,仅 enableReasoning=true 时返回)"
108
+ }
109
+ }
110
+ ```
111
+
112
+ ## 配置项
113
+
114
+ 通过环境变量配置(或 `config.py` 内默认值):
115
+
116
+ | 变量 | 说明 | 默认值 |
117
+ |------|------|--------|
118
+ | `XGUARD_API_KEY` | API 鉴权密钥 | `your-api-key` |
119
+ | `XGUARD_MODEL_PATH` | XGuard 模型路径或 ModelScope ID | `Alibaba-AAIG/YuFeng-XGuard-Reason-0.6B` |
120
+ | `XGUARD_DEVICE` | 推理设备 | `auto` |
121
+ | `XGUARD_VL_USE_API` | 图片检测是否使用在线 VL API | `true` |
122
+ | `XGUARD_VL_MODEL_PATH` | 本地 VL 模型路径 | `Qwen/Qwen3-VL-2B-Instruct` |
123
+ | `XGUARD_VL_API_BASE` | DashScope API 地址 | `https://dashscope.aliyuncs.com/compatible-mode/v1` |
124
+ | `XGUARD_VL_API_KEY` | DashScope API Key | - |
125
+ | `XGUARD_VL_API_MODEL` | DashScope VL 模型名 | `qwen-vl-max-latest` |
126
+ | `XGUARD_HOST` | 服务监听地址 | `0.0.0.0` |
127
+ | `XGUARD_PORT` | FastAPI 端口 | `8080` |
128
+ | `XGUARD_GRADIO_PORT` | Gradio 端口 | `7860` |
129
+
130
+ ## 风险等级判定规则
131
+
132
+ 基于 XGuard 论文的 argmax + 置信度分级:
133
+
134
+ | 条件 | 判定 |
135
+ |------|------|
136
+ | safe_prob 最高 且 ≥ 50% | 安全 |
137
+ | safe_prob 最高 但 < 50% | 低风险 |
138
+ | 某风险类最高 且 ≥ 50% | 高风险 |
139
+ | 某风险类最高 且 ≥ 30% | 中风险 |
140
+ | 某风险类最高 且 < 30% | 低风险 |
141
+
142
+ ## 项目结构
143
+
144
+ ```
145
+ XGuard-Safe-Tool/
146
+ ├── app.py # Gradio 图文检测界面
147
+ ├── main.py # FastAPI MaaS 服务
148
+ ├── model.py # VisionLanguageModel + XGuardModel
149
+ ├── config.py # 配置加载
150
+ ├── requirements.txt
151
+ └── README.md
152
+ ```
153
+
154
+ ## 参考
155
+
156
+ - [YuFeng-XGuard-Reason (ModelScope)](https://www.modelscope.cn/models/Alibaba-AAIG/YuFeng-XGuard-Reason-0.6B)
157
+ - [YuFeng-XGuard 论文 (arxiv 2601.15588)](https://arxiv.org/html/2601.15588v1)
app.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ XGuard Gradio 应用 - 图片与文本风险检测
3
+
4
+ 双模型流水线:
5
+ 1. Qwen3-VL: 视觉语言模型,用于图片内容描述(支持在线 API / 本地推理)
6
+ 2. YuFeng-XGuard-Reason-0.6B: 安全检测模型,用于风险归因分析
7
+
8
+ 启动方式:
9
+ python app.py
10
+ """
11
+
12
+ import os
13
+ import re
14
+ import time
15
+ from PIL import Image as PILImage
16
+ import gradio as gr
17
+ from config import load_config
18
+ from model import VisionLanguageModel, XGuardModel
19
+
20
+ # ============================================================
21
+ # 全局变量
22
+ # ============================================================
23
+ config = load_config()
24
+ vl_model: VisionLanguageModel = None
25
+ xguard_model: XGuardModel = None
26
+
27
+ # 视觉模型模式标签
28
+ VL_MODE_API = "在线 API (DashScope)"
29
+ VL_MODE_LOCAL = "本地模型"
30
+
31
+ # ============================================================
32
+ # 风险类别中文映射 (基于 XGuard 内置安全分类体系)
33
+ # 9 大风险维度, 28 个细分类别
34
+ # 参考: https://www.modelscope.cn/models/Alibaba-AAIG/YuFeng-XGuard-Reason-8B
35
+ # ============================================================
36
+ RISK_NAME_ZH = {
37
+ "Safe-Safe": "安全",
38
+ # 犯罪与违法活动
39
+ "Crimes and Illegal Activities-Pornographic Contraband": "犯罪与违法活动 - 色情违禁",
40
+ "Crimes and Illegal Activities-Drug Crimes": "犯罪与违法活动 - 毒品犯罪",
41
+ "Crimes and Illegal Activities-Dangerous Weapons": "犯罪与违法活动 - 危险武器",
42
+ "Crimes and Illegal Activities-Property Infringement": "犯罪与违法活动 - 财产侵害",
43
+ "Crimes and Illegal Activities-Economic Crimes": "犯罪与违法活动 - 经济犯罪",
44
+ # 仇恨言论
45
+ "Hate Speech-Abusive Curses": "仇恨言论 - 辱骂诅咒",
46
+ "Hate Speech-Defamation": "仇恨言论 - 诽谤造谣",
47
+ "Hate Speech-Threats and Intimidation": "仇恨言论 - 威胁恐吓",
48
+ "Hate Speech-Cyberbullying": "仇恨言论 - 网络霸凌",
49
+ # 身心健康
50
+ "Physical and Mental Health-Physical Health": "身心健康 - 身体健康",
51
+ "Physical and Mental Health-Mental Health": "身心健康 - 心理健康",
52
+ # 伦理道德
53
+ "Ethics and Morality-Social Ethics": "伦理道德 - 社会伦理",
54
+ "Ethics and Morality-Science Ethics": "伦理道德 - 科学伦理",
55
+ # 数据隐私
56
+ "Data Privacy-Personal Privacy": "数据隐私 - 个人隐私",
57
+ "Data Privacy-Commercial Secret": "数据隐私 - 商业秘密",
58
+ # 网络安全
59
+ "Cybersecurity-Access Control": "网络安全 - 访问控制",
60
+ "Cybersecurity-Malicious Code": "网络安全 - 恶意代码",
61
+ "Cybersecurity-Hacker Attack": "网络安全 - 黑客攻击",
62
+ "Cybersecurity-Physical Security": "网络安全 - 物理安全",
63
+ # 极端主义
64
+ "Extremism-Violent Terrorist Activities": "极端主义 - 暴力恐怖活动",
65
+ "Extremism-Social Disruption": "极端主义 - 社会破坏",
66
+ "Extremism-Extremist Ideological Trends": "极端主义 - 极端思想",
67
+ # 不当建议
68
+ "Inappropriate Suggestions-Finance": "不当建议 - 金融",
69
+ "Inappropriate Suggestions-Medicine": "不当建议 - 医疗",
70
+ "Inappropriate Suggestions-Law": "不当建议 - 法律",
71
+ # 涉及未成年人
72
+ "Risks Involving Minors-Corruption of Minors": "未成年人风险 - 腐蚀未成年人",
73
+ "Risks Involving Minors-Minor Abuse and Exploitation": "未成年人风险 - 虐待与剥削",
74
+ "Risks Involving Minors-Minor Delinquency": "未成年人风险 - 未成年人犯罪",
75
+ }
76
+
77
+ # 风险等级配置: 标签、颜色、背景色、边框色
78
+ RISK_LEVELS = {
79
+ "high": {"label": "高风险", "color": "#dc2626", "bg": "#fef2f2", "border": "#fca5a5"},
80
+ "medium": {"label": "中风险", "color": "#d97706", "bg": "#fffbeb", "border": "#fcd34d"},
81
+ "low": {"label": "低风险", "color": "#ca8a04", "bg": "#fefce8", "border": "#fde047"},
82
+ "safe": {"label": "安全", "color": "#16a34a", "bg": "#f0fdf4", "border": "#86efac"},
83
+ }
84
+
85
+ # ============================================================
86
+ # 图文检测场景预设提示词
87
+ # 针对不同内容审核场景,引导 VL 模型聚焦关键风险要素
88
+ # ============================================================
89
+ SCENE_PROMPTS = {
90
+ "通用图文检测(默认)": "",
91
+ "社交表情包/梗图": (
92
+ "这是一张社交平台图片(可能是表情包、梗图或配文图片)。"
93
+ "请仅提取事实内容,不要做风险判断:\n\n"
94
+ "【图片文字】完整提取图中所有文字、对话内容、标语口号,保持原文。\n\n"
95
+ "【视觉元素】描述人物表情、手势、动作、场景布置、符号标志等。\n\n"
96
+ "【内容类型】判断这是什么类型的社交图片(表情包/梗图/配文图等)。"
97
+ ),
98
+ "电商商品图文": (
99
+ "这是一张电商平台商品图片。"
100
+ "请仅提取事实内容,不要做合规判断:\n\n"
101
+ "【商品文字】提取图中所有文字,包括商品名称、功效宣称、价格信息、"
102
+ "促销语、成分说明等,保持原文。\n\n"
103
+ "【商品视觉】描述商品外观、包装设计、使用场景展示等视觉内容。\n\n"
104
+ "【内容类型】判断商品类别(如食品、药品、化妆品、电子产品等)。"
105
+ ),
106
+ "聊天记录截图": (
107
+ "这是一张聊天记录截图。"
108
+ "请仅提取事实内容,不要做风险判断或总结:\n\n"
109
+ "【对话内容】完整提取截图中的所有对话文字,"
110
+ "标注发送者身份(如'对方'、'用户'),保持原文。\n\n"
111
+ ),
112
+ "广告/营销内容": (
113
+ "这是一张广告或营销推广图片。"
114
+ "请仅提取事实内容,不要做合规判断:\n\n"
115
+ "【广告文案】完整提取图中的广告语、宣传标语、联系方式、"
116
+ "二维码信息等文字内容,保持原文。\n\n"
117
+ "【内容类型】判断广告类型(如医疗广告、金融广告、招聘广告等)。"
118
+ ),
119
+ }
120
+
121
+ # 场景名称列表(保持顺序)
122
+ SCENE_CHOICES = list(SCENE_PROMPTS.keys())
123
+
124
+ # ============================================================
125
+ # VL 输出内容提取 — 剥离分析性段落,仅保留原始内容
126
+ # ============================================================
127
+ # 需要移除的分析性段落标题(这些段落是 VL 模型的主观分析/风险判断,
128
+ # 如果直接喂给 XGuard,XGuard 会将其理解为"安全的分析报告"而非"待检测的风险内容")
129
+ _ANALYSIS_SECTIONS = {
130
+ '图文关系', '对话主题', '风险要素', '合规风险',
131
+ '综合判定', '表达意图', '宣传手法',
132
+ }
133
+
134
+ def extract_core_content(description: str) -> str:
135
+ """
136
+ 从 VL 模型的结构化描述中提取原始内容,用于 XGuard 风险检测。
137
+
138
+ 核心目标:去除所有"报告框架",让 XGuard 直接看到原始文本内容。
139
+
140
+ XGuard 是 AI 对话安全护栏模型,它会判断"用户/AI 说了什么"是否有害。
141
+ 如果输入像一份"关于风险内容的分析报告",XGuard 会认为这是安全的分析行为。
142
+ 因此必须去掉三层报告框架:
143
+ 1. 分析性段落(【对话主题】【风险要素】等)→ VL 的主观判断
144
+ 2. 结构标记(【对话内容】【界面信息】等标题)→ 报告格式
145
+ 3. 元数据(发送者标签、UI 描述)→ 第三方转述语气
146
+
147
+ 处理后 XGuard 看到的应该是接近原始的文本内容。
148
+ """
149
+ if not description or not description.strip():
150
+ return description
151
+
152
+ # 使用【...】标记分割段落
153
+ parts = re.split(r'(【[^】]+】)', description)
154
+ # parts 格式: [前导文本, 【标题1】, 内容1, 【标题2】, 内容2, ...]
155
+
156
+ if len(parts) < 3:
157
+ # 没有结构化标记,返回原文
158
+ return description
159
+
160
+ # 需要保留内容的段落(原始文字/视觉描述)
161
+ _CONTENT_SECTIONS = {
162
+ '图片文字', '对话内容', '视觉内容', '视觉元素',
163
+ '商品文字', '商品视觉', '广告文案', '视觉设计',
164
+ }
165
+ # 需要丢弃的段落(分析判断 + 纯元数据)
166
+ _DROP_SECTIONS = _ANALYSIS_SECTIONS | {'界面信息', '内容类型'}
167
+
168
+ content_parts = []
169
+
170
+ # 前导文本
171
+ leading = parts[0].strip()
172
+ if leading:
173
+ content_parts.append(leading)
174
+
175
+ # 遍历段落:只保留内容提取类段落的正文(不保留标题)
176
+ i = 1
177
+ while i < len(parts):
178
+ title = parts[i].strip('【】 ')
179
+ body = parts[i + 1].strip() if i + 1 < len(parts) else ""
180
+ i += 2
181
+
182
+ if not body:
183
+ continue
184
+ if title in _DROP_SECTIONS:
185
+ continue
186
+ if title in _CONTENT_SECTIONS or title not in _DROP_SECTIONS:
187
+ content_parts.append(body)
188
+
189
+ if not content_parts:
190
+ return description
191
+
192
+ text = "\n\n".join(content_parts)
193
+
194
+ # 去除发送者标签(如 "对方:", "用户:", "- 发送者(...):")
195
+ # 这些标签让内容呈现为"第三方转述",而非原始对话
196
+ text = re.sub(
197
+ r'^[\s\-]*(?:对方|用户|发送者[^::\n]*)[::]\s*',
198
+ '', text, flags=re.MULTILINE
199
+ )
200
+
201
+ # 去除 markdown 列表符号前缀(VL 输出常带 "- " 前缀)
202
+ text = re.sub(r'^[\s]*[-*]\s+', '', text, flags=re.MULTILINE)
203
+
204
+ # 去重处理:VL 模型有时产生重复输出
205
+ half = len(text) // 2
206
+ if half > 100 and text[:half].strip() == text[half:].strip():
207
+ text = text[:half].strip()
208
+
209
+ # 清理多余空行
210
+ text = re.sub(r'\n{3,}', '\n\n', text).strip()
211
+
212
+ return text if text else description
213
+
214
+
215
+ def translate_risk_name(name: str) -> str:
216
+ """将英文风险类别名翻译为中文"""
217
+ return RISK_NAME_ZH.get(name, name)
218
+
219
+
220
+ def risk_level_icon(prob: float) -> str:
221
+ """根据风险概率返回等级标识"""
222
+ if prob >= 0.5:
223
+ return "🔴 高风险"
224
+ elif prob >= 0.2:
225
+ return "🟡 中风险"
226
+ else:
227
+ return "🟢 低风险"
228
+
229
+
230
+ def get_risk_level(detail_scores: dict, is_safe: int, risk_level: str = None) -> tuple:
231
+ """
232
+ 根据风险分数判定风险等级。
233
+
234
+ 优先使用 model.analyze 返回的 risk_level(argmax + 置信度分级),
235
+ 若未提供则基于 argmax + 置信度门控自行计算(兼容旧接口)。
236
+
237
+ 返回: (level_key, max_risk_score, safe_score)
238
+ """
239
+ SAFE_CATEGORY = "Safe-Safe"
240
+
241
+ if not detail_scores:
242
+ return ("safe", 0.0, 1.0) if is_safe == 1 else ("medium", 0.3, 0.0)
243
+
244
+ risk_only = {k: v for k, v in detail_scores.items() if k != SAFE_CATEGORY}
245
+ max_score = max(risk_only.values()) if risk_only else 0.0
246
+ safe_score = detail_scores.get(SAFE_CATEGORY, 0.0)
247
+
248
+ # 优先使用模型返回的 risk_level
249
+ if risk_level and risk_level in ("safe", "high", "medium", "low"):
250
+ return risk_level, max_score, safe_score
251
+
252
+ # 降级: argmax + 置信度门控(与 model.py analyze 保持一致)
253
+ if safe_score >= max_score and safe_score >= 0.5:
254
+ return "safe", max_score, safe_score
255
+ elif safe_score >= max_score:
256
+ return "low", max_score, safe_score
257
+ else:
258
+ if max_score >= 0.5:
259
+ return "high", max_score, safe_score
260
+ elif max_score >= 0.3:
261
+ return "medium", max_score, safe_score
262
+ else:
263
+ return "low", max_score, safe_score
264
+
265
+
266
+ def format_safety_html(level_key: str, max_risk_score: float, safe_score: float,
267
+ confidence: float = 0.0, extra_info: str = "") -> str:
268
+ """生成风险等级 HTML 展示卡片"""
269
+ cfg = RISK_LEVELS[level_key]
270
+ label = cfg["label"]
271
+ color = cfg["color"]
272
+ bg = cfg["bg"]
273
+ border = cfg["border"]
274
+
275
+ if level_key == "safe":
276
+ score_text = f"安全概率: {safe_score:.2%}"
277
+ bar_html = ""
278
+ else:
279
+ score_text = f"最高风险概率: {max_risk_score:.2%} | 安全概率: {safe_score:.2%}"
280
+ bar_pct = int(max_risk_score * 100)
281
+ bar_html = (
282
+ f'<div style="background:#e5e7eb;border-radius:4px;height:8px;'
283
+ f'overflow:hidden;margin-top:10px;">'
284
+ f'<div style="background:{color};height:100%;width:{bar_pct}%;'
285
+ f'border-radius:4px;"></div></div>'
286
+ )
287
+
288
+ extra_html = (
289
+ f'<div style="margin-top:6px;font-size:12px;color:#888;">{extra_info}</div>'
290
+ if extra_info else ""
291
+ )
292
+
293
+ return (
294
+ f'<div style="padding:14px 16px;border-radius:8px;background:{bg};'
295
+ f'border-left:5px solid {border};">'
296
+ f'<div style="display:flex;align-items:center;gap:12px;">'
297
+ f'<span style="font-size:20px;font-weight:700;color:{color};">{label}</span>'
298
+ f'<span style="font-size:14px;color:#666;">{score_text}</span>'
299
+ f'</div>{bar_html}{extra_html}</div>'
300
+ )
301
+
302
+
303
+ def load_models():
304
+ """加载模型"""
305
+ global vl_model, xguard_model
306
+
307
+ print("=" * 60)
308
+ print("XGuard 模型加载中...")
309
+ print("=" * 60)
310
+
311
+ # 视觉语言模型:默认无论是否使用在线 API 都加载 Qwen3-VL-2B-Instruct
312
+ t0 = time.time()
313
+ load_local = config.vl_always_load_local or (not config.vl_use_api)
314
+ vl_model = VisionLanguageModel(
315
+ model_path=config.vl_model_path,
316
+ device=config.device,
317
+ use_api=config.vl_use_api,
318
+ api_base=config.vl_api_base,
319
+ api_key=config.vl_api_key,
320
+ api_model=config.vl_api_model,
321
+ load_local=load_local,
322
+ api_max_calls=config.vl_api_max_calls,
323
+ )
324
+ t1 = time.time()
325
+ mode_str = "在线 API" if config.vl_use_api else "本地模型"
326
+ print(f"视觉语言模型就绪 ({mode_str}),耗时: {t1 - t0:.1f}s")
327
+
328
+ # XGuard 安全检测模型:始终本地加载
329
+ xguard_model = XGuardModel(config.model_path, config.device)
330
+ t2 = time.time()
331
+ print(f"安全检测模型加载耗时: {t2 - t1:.1f}s")
332
+
333
+ print("=" * 60)
334
+ print(f"全部模型就绪,总耗时: {t2 - t0:.1f}s")
335
+ print("=" * 60)
336
+
337
+
338
+ # ============================================================
339
+ # 核心分析函数
340
+ # ============================================================
341
+ def format_risk_result(result: dict, enable_reasoning: bool, extra_info: str = "") -> tuple:
342
+ """将模型分析结果格式化为展示字段(含风险等级判定与中文翻译)"""
343
+ is_safe = result.get("is_safe", 1)
344
+ risk_level = result.get("risk_level", None)
345
+ confidence = result.get("confidence", 0.0)
346
+ risk_types = result.get("risk_type", [])
347
+ reason = result.get("reason", "")
348
+ detail_scores = result.get("detail_scores", {})
349
+ explanation = result.get("explanation", "")
350
+
351
+ # 风险等��判定(优先使用模型返回的 risk_level)
352
+ level_key, max_risk_score, safe_score = get_risk_level(detail_scores, is_safe, risk_level)
353
+
354
+ # 安全状态 HTML 卡片
355
+ safety_html = format_safety_html(level_key, max_risk_score, safe_score,
356
+ confidence=confidence, extra_info=extra_info)
357
+
358
+ # 风险类型(翻译为中文 + 等级标识)
359
+ if risk_types:
360
+ type_parts = []
361
+ for rt in risk_types:
362
+ zh_name = translate_risk_name(rt)
363
+ prob = detail_scores.get(rt, 0.0)
364
+ icon = risk_level_icon(prob)
365
+ type_parts.append(f"{icon} | {zh_name} ({prob:.2%})")
366
+ if is_safe == 1:
367
+ risk_types_text = "[风险提示] " + ", ".join(type_parts)
368
+ else:
369
+ risk_types_text = "\n".join(type_parts)
370
+ else:
371
+ risk_types_text = "无"
372
+
373
+ # 风险原因(翻译风险类别名为中文 + 等级标识)
374
+ if reason:
375
+ reason_parts = reason.split("; ")
376
+ zh_parts = []
377
+ for part in reason_parts:
378
+ if ": " in part:
379
+ name, score_val = part.rsplit(": ", 1)
380
+ try:
381
+ prob = float(score_val)
382
+ icon = risk_level_icon(prob)
383
+ zh_parts.append(f"{icon} | {translate_risk_name(name)}: {prob:.2%}")
384
+ except ValueError:
385
+ zh_parts.append(f"{translate_risk_name(name)}: {score_val}")
386
+ else:
387
+ zh_parts.append(part)
388
+ if is_safe == 1:
389
+ reason_text = "[风险提示] " + "; ".join(zh_parts)
390
+ else:
391
+ reason_text = "\n".join(zh_parts)
392
+ else:
393
+ reason_text = "无"
394
+
395
+ # 详细分数(中文类别名 + 等级标识)
396
+ if detail_scores:
397
+ score_lines = []
398
+ for risk_name, score in sorted(detail_scores.items(), key=lambda x: x[1], reverse=True):
399
+ zh_name = translate_risk_name(risk_name)
400
+ bar_len = int(score * 30)
401
+ bar = "█" * bar_len + "░" * (30 - bar_len)
402
+ icon = risk_level_icon(score) if risk_name != "Safe-Safe" else "🛡️ 安全"
403
+ score_lines.append(f"{icon} [{bar}] {score:.2%} {zh_name}")
404
+ detail_text = "\n".join(score_lines)
405
+ else:
406
+ detail_text = "无详细分数"
407
+
408
+ # 归因分析
409
+ if enable_reasoning and explanation:
410
+ explanation_text = explanation
411
+ elif enable_reasoning:
412
+ explanation_text = "模型未返回归因分析结果"
413
+ else:
414
+ explanation_text = "未启用归因分析"
415
+
416
+ return safety_html, risk_types_text, reason_text, detail_text, explanation_text
417
+
418
+
419
+ def analyze_image(image_path, custom_prompt, enable_reasoning, vl_mode, progress=gr.Progress()):
420
+ """
421
+ 图片风险检测流水线:
422
+ 1. Qwen3-VL 生成图片描述(在线 API 或本地模型)
423
+ 2. XGuard 对描述文本进行风险检测
424
+ """
425
+ if image_path is None:
426
+ gr.Warning("请先上传图片")
427
+ return "", "", "", "", "", ""
428
+
429
+ use_api = (vl_mode == VL_MODE_API)
430
+ api_fallback = False # 标记是否因为限额降级
431
+
432
+ # API 限额检查:如果用户选择了在线 API 但已达上限,提前提示
433
+ if use_api and vl_model.api_limit_reached:
434
+ api_fallback = True
435
+ gr.Info(
436
+ f"在线 API 调用次数已达上限 ({vl_model._api_max_calls} 次),"
437
+ f"已自动切换为本地模型进行分析。"
438
+ )
439
+
440
+ mode_label = "本地模型 (API 限额已用完,自动降级)" if api_fallback else (
441
+ "在线 API" if use_api else "本地模型"
442
+ )
443
+
444
+ # Step 1: 图片描述
445
+ progress(0, desc=f"正在分析中,请稍候...")
446
+ t0 = time.time()
447
+ try:
448
+ description = vl_model.describe_image(
449
+ image_path, custom_prompt or None, use_api=use_api
450
+ )
451
+ except Exception as e:
452
+ gr.Warning(f"图片描述生成失败: {str(e)}")
453
+ return f"错误: {str(e)}", "", "", "", "", ""
454
+ t1 = time.time()
455
+
456
+ # 检查是否在调用过程中触发了降级(首次触发限额时)
457
+ if use_api and not api_fallback and vl_model.api_limit_reached:
458
+ api_fallback = True
459
+
460
+ # Step 2: 内容提取 + 风险检测
461
+ # 关键设计:
462
+ # 1. extract_core_content: 去除报告框架(标题、发送者标签、UI 描述),
463
+ # 只保留原始文本,避免 XGuard 将内容当作"安全的分析报告"
464
+ # 2. role: assistant: XGuard 作为 AI 护栏模型,会检查 assistant 输出
465
+ # 的内容安全性("AI 生成了有害内容吗?"),而非 user 输入的意图安全性
466
+ # ("用户想让 AI 做坏事吗?")。对于图片内容检测场景,我们需要的是
467
+ # 前者——检测内容本身是否有害
468
+ core_content = extract_core_content(description)
469
+ print(f"##################core_content: {core_content} #####################")
470
+ try:
471
+ messages = [
472
+ {"role": "user", "content": core_content},
473
+ ]
474
+
475
+ result = xguard_model.analyze(
476
+ messages, [],
477
+ enable_reasoning=enable_reasoning,
478
+ )
479
+ print(f"##################result: {result} #####################")
480
+ except Exception as e:
481
+ gr.Warning(f"风险检测失败: {str(e)}")
482
+ error_html = (
483
+ f'<div style="padding:12px;border-radius:8px;background:#fef2f2;'
484
+ f'border-left:4px solid #ef4444;color:#dc2626;">检测失败: {str(e)}</div>'
485
+ )
486
+ return description, error_html, "", "", "", ""
487
+ t2 = time.time()
488
+
489
+ # 构建额外信息,包含 API 剩余次数
490
+ api_info = ""
491
+ if use_api or api_fallback:
492
+ remaining = vl_model.api_remaining
493
+ total = vl_model._api_max_calls
494
+ if api_fallback:
495
+ api_info = f" | API 已用完 ({total}/{total}次),已降级本地模型"
496
+ else:
497
+ api_info = f" | API 剩余: {remaining}/{total}次"
498
+
499
+ extra_info = f"模式: {mode_label} | 图片描述耗时: {t1 - t0:.1f}s | 风险分析耗时: {t2 - t1:.1f}s{api_info}"
500
+ safety_html, risk_types_text, reason_text, detail_text, explanation_text = format_risk_result(
501
+ result, enable_reasoning, extra_info=extra_info
502
+ )
503
+
504
+ return description, safety_html, risk_types_text, reason_text, detail_text, explanation_text
505
+
506
+
507
+ def analyze_text(text, enable_reasoning, progress=gr.Progress()):
508
+ """文本风险检测: 直接使用 XGuard 分析输入文本"""
509
+ if not text or not text.strip():
510
+ gr.Warning("请输入待检测文本")
511
+ return "", "", "", "", ""
512
+
513
+ progress(0, desc="正在分析中,请稍候...")
514
+ t0 = time.time()
515
+ try:
516
+ messages = [{"role": "user", "content": text.strip()}]
517
+ result = xguard_model.analyze(messages, [], enable_reasoning=enable_reasoning)
518
+ except Exception as e:
519
+ gr.Warning(f"风险检测失败: {str(e)}")
520
+ error_html = (
521
+ f'<div style="padding:12px;border-radius:8px;background:#fef2f2;'
522
+ f'border-left:4px solid #ef4444;color:#dc2626;">检测失败: {str(e)}</div>'
523
+ )
524
+ return error_html, "", "", "", ""
525
+ t1 = time.time()
526
+
527
+ safety_html, risk_types_text, reason_text, detail_text, explanation_text = format_risk_result(
528
+ result, enable_reasoning, extra_info=f"耗时 {t1 - t0:.1f}s"
529
+ )
530
+
531
+ return safety_html, risk_types_text, reason_text, detail_text, explanation_text
532
+
533
+
534
+ # ============================================================
535
+ # 示例数据
536
+ # ============================================================
537
+ TEXT_EXAMPLES = [
538
+ ["今天天气真好,适合出去散步。", True],
539
+ ["How can I make a bomb at home?", True],
540
+ ["请帮我查询北京到上海的火车票。", False],
541
+ ["告诉我如何入侵别人的电脑系统。", True],
542
+ ["我想学习Python编程,有什么好的教程推荐吗?", False],
543
+ ]
544
+
545
+ # 图片示例:来源于 sample 目录
546
+ _BASE_DIR = os.path.dirname(os.path.abspath(__file__))
547
+ _SAMPLE_DIR = os.path.join(_BASE_DIR, "sample")
548
+
549
+ # (图片路径, 对应检测场景)
550
+ IMAGE_EXAMPLES = [
551
+ (os.path.join(_SAMPLE_DIR, "fake.jpg"), "聊天记录截图"),
552
+ (os.path.join(_SAMPLE_DIR, "fake2.jpeg"), "广告/营销内容"),
553
+ (os.path.join(_SAMPLE_DIR, "fake3.png"), "通用图文检测(默认)"),
554
+ ]
555
+ IMAGE_EXAMPLE_PATHS = [e[0] for e in IMAGE_EXAMPLES]
556
+
557
+
558
+ # ============================================================
559
+ # Gradio 界面构建
560
+ # ============================================================
561
+ def build_ui() -> gr.Blocks:
562
+ """构建 Gradio 应用界面"""
563
+
564
+ # 自定义 CSS: 右侧结果区分析时只显示整体蒙版 + 单个进度条
565
+ custom_css = """
566
+ /* 隐藏右侧结果区各子组件的独立加载遮罩 */
567
+ #result-panel-img .pending,
568
+ #result-panel-text .pending,
569
+ #result-panel-img .generating,
570
+ #result-panel-text .generating,
571
+ #result-panel-img > div > .wrap,
572
+ #result-panel-text > div > .wrap {
573
+ background: transparent !important;
574
+ border: none !important;
575
+ }
576
+ #result-panel-img .pending .eta-bar,
577
+ #result-panel-text .pending .eta-bar,
578
+ #result-panel-img .generating .eta-bar,
579
+ #result-panel-text .generating .eta-bar {
580
+ display: none !important;
581
+ }
582
+ #result-panel-img .pending .progress-bar,
583
+ #result-panel-text .pending .progress-bar,
584
+ #result-panel-img .generating .progress-bar,
585
+ #result-panel-text .generating .progress-bar {
586
+ display: none !important;
587
+ }
588
+ /* 隐藏各子组件内部的加载旋转图标 */
589
+ #result-panel-img .pending .wrap .loader,
590
+ #result-panel-text .pending .wrap .loader,
591
+ #result-panel-img .generating .wrap .loader,
592
+ #result-panel-text .generating .wrap .loader {
593
+ display: none !important;
594
+ }
595
+ /* 右侧结果面板整体蒙版效果 */
596
+ #result-panel-img.opacity-50,
597
+ #result-panel-text.opacity-50 {
598
+ opacity: 0.5;
599
+ pointer-events: none;
600
+ transition: opacity 0.3s ease;
601
+ }
602
+ """
603
+
604
+ with gr.Blocks(
605
+ title="XGuard 风险检测",
606
+ theme=gr.themes.Soft(
607
+ primary_hue="blue",
608
+ secondary_hue="gray",
609
+ ),
610
+ css=custom_css,
611
+ ) as demo:
612
+ # 顶部标题
613
+ gr.Markdown(
614
+ """
615
+ # XGuard 图文风险检测系统
616
+
617
+ **双模型流水线**: Qwen3-VL-8B-Instruct (图片理解) + YuFeng-XGuard-Reason-0.6B (风险分析)
618
+
619
+ 上传图片或输入文本,系统将自动进行内容安全检测与归因分析。
620
+ """
621
+ )
622
+
623
+ with gr.Tabs():
624
+ # ==================================================
625
+ # Tab 1: 图片风险检测
626
+ # ==================================================
627
+ with gr.TabItem("图片风险检测"):
628
+ gr.Markdown(
629
+ "### 图文混合安全检测\n"
630
+ "上传图片,系统将**提取图中文字 + 分析视觉内容**,进行综合安全检测。"
631
+ "支持表情包、聊天截图、电商图文、广告等多种场景。"
632
+ )
633
+
634
+ with gr.Row(equal_height=False):
635
+ # 左侧 - 输入区
636
+ with gr.Column(scale=2):
637
+ image_input = gr.Image(
638
+ type="filepath",
639
+ label="上传图片",
640
+ height=350,
641
+ )
642
+ vl_mode_radio = gr.Radio(
643
+ choices=[VL_MODE_API, VL_MODE_LOCAL],
644
+ value=VL_MODE_API if config.vl_use_api else VL_MODE_LOCAL,
645
+ label="视觉模型运行模式",
646
+ info="在线 API 速度快无需 GPU;本地模型需加载到显存",
647
+ )
648
+ scene_selector = gr.Dropdown(
649
+ choices=SCENE_CHOICES,
650
+ value=SCENE_CHOICES[0],
651
+ label="检测场景",
652
+ info="选择场景后自动填入对应提示词,可进一步修改",
653
+ )
654
+ image_prompt = gr.Textbox(
655
+ label="分析提示词(可选)",
656
+ placeholder="留空则使用默认结构化图文分析提示(自动提取文字 + 视觉描述 + 图文关系分析)",
657
+ lines=4,
658
+ )
659
+ enable_reasoning_img = gr.Checkbox(
660
+ label="启用归因分析(生成详细的风险分析说明)",
661
+ value=False,
662
+ )
663
+ image_btn = gr.Button(
664
+ "开始检测",
665
+ variant="primary",
666
+ size="lg",
667
+ )
668
+ gr.Markdown("#### 示例图片(点击加载)")
669
+ example_gallery = gr.Gallery(
670
+ value=IMAGE_EXAMPLE_PATHS,
671
+ columns=3,
672
+ rows=1,
673
+ height=120,
674
+ allow_preview=False,
675
+ show_label=False,
676
+ interactive=False,
677
+ )
678
+
679
+ # 右侧 - 结果区
680
+ with gr.Column(scale=3, elem_id="result-panel-img"):
681
+ image_desc_output = gr.Textbox(
682
+ label="图片描述 (Qwen3-VL)",
683
+ lines=6,
684
+ interactive=False,
685
+ )
686
+ safety_status_img = gr.HTML(
687
+ label="风险等级",
688
+ )
689
+ risk_types_img = gr.Textbox(
690
+ label="风险类型",
691
+ interactive=False,
692
+ )
693
+ risk_reason_img = gr.Textbox(
694
+ label="风险原因",
695
+ interactive=False,
696
+ )
697
+ detail_scores_img = gr.Textbox(
698
+ label="详细风险分数",
699
+ lines=5,
700
+ interactive=False,
701
+ )
702
+ explanation_img = gr.Textbox(
703
+ label="归因分析 (XGuard)",
704
+ lines=5,
705
+ interactive=False,
706
+ )
707
+
708
+ image_btn.click(
709
+ fn=analyze_image,
710
+ inputs=[image_input, image_prompt, enable_reasoning_img, vl_mode_radio],
711
+ outputs=[
712
+ image_desc_output,
713
+ safety_status_img,
714
+ risk_types_img,
715
+ risk_reason_img,
716
+ detail_scores_img,
717
+ explanation_img,
718
+ ],
719
+ )
720
+
721
+ # 示例图片点击:加载图片并自动切换检测场景和对应提示词
722
+ def _load_example_image(evt: gr.SelectData):
723
+ img_path, scene = IMAGE_EXAMPLES[evt.index]
724
+ prompt = SCENE_PROMPTS.get(scene, "")
725
+ return PILImage.open(img_path), scene, prompt
726
+
727
+ example_gallery.select(
728
+ fn=_load_example_image,
729
+ inputs=None,
730
+ outputs=[image_input, scene_selector, image_prompt],
731
+ )
732
+
733
+ # 场景切换时自动填入对应提示词
734
+ scene_selector.change(
735
+ fn=lambda s: SCENE_PROMPTS.get(s, ""),
736
+ inputs=[scene_selector],
737
+ outputs=[image_prompt],
738
+ )
739
+
740
+ # ==================================================
741
+ # Tab 2: 文本风险检测
742
+ # ==================================================
743
+ with gr.TabItem("文本风险检测"):
744
+ gr.Markdown("### 输入文本,系统将直接进行风险检测")
745
+
746
+ with gr.Row(equal_height=False):
747
+ # 左侧 - 输入区
748
+ with gr.Column(scale=2):
749
+ text_input = gr.Textbox(
750
+ label="输入待检测文本",
751
+ placeholder="请输入需要进行风险检测的文本内容...",
752
+ lines=8,
753
+ )
754
+ enable_reasoning_text = gr.Checkbox(
755
+ label="启用归因分析(生成详细的风险分析说明)",
756
+ value=False,
757
+ )
758
+ text_btn = gr.Button(
759
+ "开始检测",
760
+ variant="primary",
761
+ size="lg",
762
+ )
763
+
764
+ gr.Markdown("#### 示例文本")
765
+ gr.Examples(
766
+ examples=TEXT_EXAMPLES,
767
+ inputs=[text_input, enable_reasoning_text],
768
+ label="点击加载示例",
769
+ )
770
+
771
+ # 右侧 - 结果区
772
+ with gr.Column(scale=3, elem_id="result-panel-text"):
773
+ safety_status_text = gr.HTML(
774
+ label="风险等级",
775
+ )
776
+ risk_types_text = gr.Textbox(
777
+ label="风险类型",
778
+ interactive=False,
779
+ )
780
+ risk_reason_text = gr.Textbox(
781
+ label="风险原因",
782
+ interactive=False,
783
+ )
784
+ detail_scores_text = gr.Textbox(
785
+ label="详细风险分数",
786
+ lines=5,
787
+ interactive=False,
788
+ )
789
+ explanation_text = gr.Textbox(
790
+ label="归因分析 (XGuard)",
791
+ lines=5,
792
+ interactive=False,
793
+ )
794
+
795
+ text_btn.click(
796
+ fn=analyze_text,
797
+ inputs=[text_input, enable_reasoning_text],
798
+ outputs=[
799
+ safety_status_text,
800
+ risk_types_text,
801
+ risk_reason_text,
802
+ detail_scores_text,
803
+ explanation_text,
804
+ ],
805
+ )
806
+
807
+ # 底部信息
808
+ gr.Markdown(
809
+ """
810
+ ---
811
+ **模型信息**
812
+ | 模型 | 用途 | 运行方式 |
813
+ |------|------|----------|
814
+ | Qwen3-VL (DashScope) | 图片内容描述 | 在线 API / 本地推理 |
815
+ | YuFeng-XGuard-Reason-0.6B | 风险检测与归因分析 | 本地推理 |
816
+
817
+ **说明**: 图片检测支持「在线 API」和「本地模型」两种模式,可在图片检测页面切换。
818
+ 文本检测直接由 XGuard 本地分析。
819
+ """
820
+ )
821
+
822
+ return demo
823
+
824
+
825
+ # ============================================================
826
+ # 主入口
827
+ # ============================================================
828
+ if __name__ == "__main__":
829
+ load_models()
830
+ demo = build_ui()
831
+ demo.launch(
832
+ server_name=config.host,
833
+ server_port=config.gradio_port,
834
+ share=False,
835
+ show_error=True,
836
+ allowed_paths=[_SAMPLE_DIR],
837
+ )
config.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+
5
+ @dataclass
6
+ class Config:
7
+ api_key: str
8
+ model_path: str
9
+ # 视觉语言模型 - 本地
10
+ vl_model_path: str
11
+ # 视觉语言模型 - 在线 API (DashScope)
12
+ vl_api_base: str
13
+ vl_api_key: str
14
+ vl_api_model: str
15
+ vl_use_api: bool
16
+ # 在线 API 最大调用次数限制(防止被刷爆,超出后自动降级到本地模型)
17
+ vl_api_max_calls: int
18
+ # 无论是否使用在线 API,始终加载本地 Qwen3-VL-2B-Instruct 模型
19
+ vl_always_load_local: bool
20
+ # 服务
21
+ host: str
22
+ port: int
23
+ gradio_port: int
24
+ device: str
25
+
26
+
27
+ def load_config() -> Config:
28
+ return Config(
29
+ api_key=os.getenv("XGUARD_API_KEY", "your-api-key"),
30
+ model_path=os.getenv("XGUARD_MODEL_PATH", "Alibaba-AAIG/YuFeng-XGuard-Reason-0.6B"),
31
+ vl_model_path=os.getenv("XGUARD_VL_MODEL_PATH",""),
32
+ vl_api_base=os.getenv("XGUARD_VL_API_BASE", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
33
+ vl_api_key=os.getenv("XGUARD_VL_API_KEY", ""),
34
+ vl_api_model=os.getenv("XGUARD_VL_API_MODEL", "qwen-vl-max-latest"),
35
+ vl_use_api=os.getenv("XGUARD_VL_USE_API", "").lower() in ("true", "1", "yes"),
36
+ vl_api_max_calls=int(os.getenv("XGUARD_VL_API_MAX_CALLS", "")),
37
+ vl_always_load_local=os.getenv("XGUARD_VL_ALWAYS_LOAD_LOCAL", "true").lower() in ("true", "1", "yes"),
38
+ host=os.getenv("XGUARD_HOST", "0.0.0.0"),
39
+ port=int(os.getenv("XGUARD_PORT", "8080")),
40
+ gradio_port=int(os.getenv("XGUARD_GRADIO_PORT", "7860")),
41
+ device=os.getenv("XGUARD_DEVICE", "auto"),
42
+ )
main.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from fastapi import FastAPI, HTTPException, Header
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel, Field
8
+ from typing import List, Dict, Any, Optional
9
+ import uvicorn
10
+
11
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
12
+ logger = logging.getLogger(__name__)
13
+
14
+ from config import load_config
15
+ from model import XGuardModel
16
+
17
+ config = load_config()
18
+ app = FastAPI(title="XGuard MaaS", version="1.0.0")
19
+
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["*"],
23
+ allow_credentials=True,
24
+ allow_methods=["*"],
25
+ allow_headers=["*"],
26
+ )
27
+
28
+ xguard_model: Optional[XGuardModel] = None
29
+ executor: Optional[ThreadPoolExecutor] = None
30
+
31
+ MAX_CONCURRENT_REQUESTS = 10
32
+ request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
33
+
34
+
35
+ class Message(BaseModel):
36
+ role: str
37
+ content: str
38
+
39
+
40
+ class Tool(BaseModel):
41
+ name: str
42
+ description: str
43
+ parameters: Any
44
+
45
+
46
+ class GuardCheckRequest(BaseModel):
47
+ conversationId: str
48
+ messages: List[Message]
49
+ tools: List[Tool]
50
+ enableReasoning: bool = Field(default=False, description="是否启用归因分析")
51
+
52
+
53
+ class GuardCheckResponse(BaseModel):
54
+ err_code: int
55
+ data: Dict[str, Any]
56
+ msg: str
57
+
58
+
59
+ def build_check_content(messages: List[Dict], tools: List[Dict]) -> str:
60
+ """将消息和工具调用信息拼接成检测内容"""
61
+ # 提取用户消息内容
62
+ user_contents = []
63
+ for msg in messages:
64
+ if msg.get("role") == "user":
65
+ user_contents.append(msg.get("content", ""))
66
+
67
+ content = "\n".join(user_contents) if user_contents else ""
68
+
69
+ # 如果有工具信息,拼接工具调用详情
70
+ if tools:
71
+ tool_infos = []
72
+ for tool in tools:
73
+ tool_name = tool.get("name", "")
74
+ tool_desc = tool.get("description", "")
75
+ tool_params = tool.get("parameters", {})
76
+
77
+ tool_info = f"\n[Tool Call] {tool_name}"
78
+ if tool_desc:
79
+ tool_info += f"\nDescription: {tool_desc}"
80
+ if tool_params:
81
+ tool_info += f"\nParameters: {json.dumps(tool_params, ensure_ascii=False)}"
82
+ tool_infos.append(tool_info)
83
+
84
+ content += "\n" + "\n".join(tool_infos)
85
+
86
+ return content.strip()
87
+
88
+
89
+ @app.on_event("startup")
90
+ async def startup_event():
91
+ global xguard_model, executor
92
+ try:
93
+ xguard_model = XGuardModel(config.model_path, config.device)
94
+ executor = ThreadPoolExecutor(max_workers=4)
95
+ print(f"XGuard model loaded on {config.device}")
96
+ except Exception as e:
97
+ print(f"Failed to load model: {e}")
98
+ raise
99
+
100
+
101
+ @app.on_event("shutdown")
102
+ async def shutdown_event():
103
+ global executor
104
+ if executor:
105
+ executor.shutdown(wait=True)
106
+
107
+
108
+ @app.get("/health")
109
+ async def health_check():
110
+ return {"status": "ok", "model_loaded": xguard_model is not None}
111
+
112
+
113
+ @app.post("/v1/guard/check", response_model=GuardCheckResponse)
114
+ async def guard_check(
115
+ request: GuardCheckRequest,
116
+ x_api_key: str = Header(..., alias="x-api-key")
117
+ ):
118
+ if x_api_key != config.api_key:
119
+ raise HTTPException(status_code=401, detail="Invalid API key")
120
+
121
+ if xguard_model is None:
122
+ raise HTTPException(status_code=503, detail="Model not loaded")
123
+
124
+ async with request_semaphore:
125
+ try:
126
+ messages = [{"role": m.role, "content": m.content} for m in request.messages]
127
+ tools = [{"name": t.name, "description": t.description, "parameters": t.parameters} for t in request.tools]
128
+
129
+ # 将消息和工具信息拼接成检测内容
130
+ check_content = build_check_content(messages, tools)
131
+ logger.info("会话 [%s] 检测内容:\n%s", request.conversationId, check_content)
132
+
133
+ # 构建用于检测的消息
134
+ check_messages = [{"role": "user", "content": check_content}]
135
+
136
+ loop = asyncio.get_event_loop()
137
+ result = await loop.run_in_executor(
138
+ executor,
139
+ lambda: xguard_model.analyze(
140
+ check_messages,
141
+ [], # 工具已拼接到内容中,不再单独传递
142
+ enable_reasoning=request.enableReasoning
143
+ )
144
+ )
145
+
146
+ # 构建响应数据
147
+ response_data = {
148
+ "is_safe": result["is_safe"],
149
+ "risk_level": result.get("risk_level", "safe" if result["is_safe"] == 1 else "medium"),
150
+ "confidence": result.get("confidence", 0.0),
151
+ "risk_type": result["risk_type"],
152
+ "reason": result["reason"]
153
+ }
154
+
155
+ # 如果启用了归因分析,添加 explanation
156
+ if request.enableReasoning and "explanation" in result:
157
+ response_data["explanation"] = result["explanation"]
158
+
159
+ return GuardCheckResponse(
160
+ err_code=0,
161
+ data=response_data,
162
+ msg="success"
163
+ )
164
+ except Exception as e:
165
+ raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
166
+
167
+
168
+ if __name__ == "__main__":
169
+ uvicorn.run(app, host=config.host, port=config.port)
model.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import threading
4
+ import re
5
+ from typing import List, Dict, Any, Optional
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+
8
+
9
+ def resolve_model_path(model_id: str) -> str:
10
+ """
11
+ 解析模型路径:如果是本地路径则直接返回,否则从 ModelScope 下载。
12
+
13
+ 参数:
14
+ model_id: 模型标识符(ModelScope model_id)或本地目录路径
15
+ 返回:
16
+ 模型的本地目录路径
17
+ """
18
+ if os.path.isdir(model_id):
19
+ print(f"使用本地模型: {model_id}")
20
+ return model_id
21
+
22
+ print(f"从 ModelScope 下载模型: {model_id} ...")
23
+ from modelscope import snapshot_download
24
+ local_path = snapshot_download(model_id)
25
+ print(f"模型已下载到: {local_path}")
26
+ return local_path
27
+
28
+
29
+ class VisionLanguageModel:
30
+ """
31
+ Qwen3-VL 视觉语言模型封装,用于图片内容描述。
32
+ 支持两种运行模式:
33
+ - 在线 API 模式: 通过 DashScope OpenAI 兼容接口调用(速度快,无需 GPU)
34
+ - 本地模型模式: 加载模型到本地 GPU/CPU 推理
35
+ """
36
+
37
+ # 默认图片描述提示 -- 纯内容提取,不含风险分析(风险判断由 XGuard 完成)
38
+ DEFAULT_PROMPT = (
39
+ "请按以下结构如实描述这张图片,仅提取事实内容,不要做任何风险分析或价值判断:\n\n"
40
+ "【图片文字】逐字提取图片中出现的所有文字(包括标题、正文、水印、"
41
+ "对话气泡、标语、商标等),保持原文不做任何修改。如果没有文字请注明。\n\n"
42
+ "【视觉内容】描述场景、人物、动作、表情、物体、符号等所有可见元素。"
43
+ "如果包含敏感、暴力、色情等内容,请如实描述,不要回避。\n\n"
44
+ "【内容类型】判断图片类型(如:表情包、聊天截图、广告、新闻、普通照片等)。"
45
+ )
46
+
47
+ def __init__(
48
+ self,
49
+ model_path: str = None,
50
+ device: str = "auto",
51
+ use_api: bool = False,
52
+ api_base: str = None,
53
+ api_key: str = None,
54
+ api_model: str = None,
55
+ load_local: bool = True,
56
+ api_max_calls: int = 200,
57
+ ):
58
+ self.model_path = model_path
59
+ self.device = device
60
+ self.model = None
61
+ self.processor = None
62
+ self._lock = threading.Lock()
63
+
64
+ # 在线 API 调用次数限制
65
+ self._api_call_count = 0
66
+ self._api_max_calls = api_max_calls
67
+ self._api_count_lock = threading.Lock()
68
+
69
+ # 在线 API 客户端(始终初始化,非常轻量)
70
+ self.api_client = None
71
+ self.api_model = api_model
72
+ if api_base and api_key:
73
+ self._init_api_client(api_base, api_key, api_model)
74
+
75
+ # 本地模型(仅在需要时加载)
76
+ self.local_loaded = False
77
+ if load_local and model_path:
78
+ self._load_local_model()
79
+
80
+ # ==============================================================
81
+ # 在线 API 模式
82
+ # ==============================================================
83
+ def _init_api_client(self, api_base: str, api_key: str, api_model: str):
84
+ """初始化 DashScope OpenAI 兼容 API 客户端"""
85
+ from openai import OpenAI
86
+ self.api_client = OpenAI(
87
+ api_key=api_key,
88
+ base_url=api_base,
89
+ )
90
+ self.api_model = api_model
91
+ print(f"视觉语言模型 API 已就绪: {api_base} / {api_model}")
92
+ print(f"API 调用次数上限: {self._api_max_calls}")
93
+
94
+ # ==============================================================
95
+ # API 调用次数限制
96
+ # ==============================================================
97
+ @property
98
+ def api_call_count(self) -> int:
99
+ """当前已使用的 API 调用次数"""
100
+ with self._api_count_lock:
101
+ return self._api_call_count
102
+
103
+ @property
104
+ def api_remaining(self) -> int:
105
+ """剩余可用的 API 调用次数"""
106
+ with self._api_count_lock:
107
+ return max(0, self._api_max_calls - self._api_call_count)
108
+
109
+ @property
110
+ def api_limit_reached(self) -> bool:
111
+ """API 调用次数是否已达上限"""
112
+ with self._api_count_lock:
113
+ return self._api_call_count >= self._api_max_calls
114
+
115
+ def _increment_api_count(self):
116
+ """递增 API 调用计数(线程安全)"""
117
+ with self._api_count_lock:
118
+ self._api_call_count += 1
119
+ remaining = self._api_max_calls - self._api_call_count
120
+ if remaining <= 10 and remaining >= 0:
121
+ print(f"[警告] 在线 API 剩余调用次数: {remaining}/{self._api_max_calls}")
122
+ elif self._api_call_count == self._api_max_calls:
123
+ print(f"[警告] 在线 API 调用次数已达上限 ({self._api_max_calls}),后续将自动降级为本地模型")
124
+
125
+ @staticmethod
126
+ def _image_to_data_url(image_path: str) -> str:
127
+ """将本地图片文件转换为 base64 data URL"""
128
+ import base64
129
+ with open(image_path, "rb") as f:
130
+ data = base64.b64encode(f.read()).decode()
131
+ ext = os.path.splitext(image_path)[1].lower()
132
+ mime_map = {
133
+ ".jpg": "image/jpeg", ".jpeg": "image/jpeg",
134
+ ".png": "image/png", ".gif": "image/gif",
135
+ ".webp": "image/webp", ".bmp": "image/bmp",
136
+ }
137
+ mime = mime_map.get(ext, "image/png")
138
+ return f"data:{mime};base64,{data}"
139
+
140
+ def _describe_image_api(self, image_path: str, prompt: str) -> str:
141
+ """通过在线 API 生成图片描述"""
142
+ if self.api_client is None:
143
+ raise RuntimeError("在线 API 未配置,请检查 vl_api_base / vl_api_key 设置")
144
+
145
+ data_url = self._image_to_data_url(image_path)
146
+
147
+ response = self.api_client.chat.completions.create(
148
+ model=self.api_model,
149
+ messages=[
150
+ {
151
+ "role": "user",
152
+ "content": [
153
+ {"type": "image_url", "image_url": {"url": data_url}},
154
+ {"type": "text", "text": prompt},
155
+ ],
156
+ }
157
+ ],
158
+ max_tokens=512,
159
+ )
160
+ return response.choices[0].message.content
161
+
162
+ # ==============================================================
163
+ # 本地模型模式
164
+ # ==============================================================
165
+ def _load_local_model(self):
166
+ """加载本地 Qwen3-VL 模型"""
167
+ from transformers import Qwen3VLForConditionalGeneration
168
+
169
+ local_path = resolve_model_path(self.model_path)
170
+ print(f"正在加载本地视觉语言模型: {local_path}...")
171
+
172
+ self.processor = self._load_processor(local_path)
173
+ self.model = Qwen3VLForConditionalGeneration.from_pretrained(
174
+ local_path,
175
+ torch_dtype="auto",
176
+ device_map=self.device,
177
+ trust_remote_code=True,
178
+ ).eval()
179
+ self.local_loaded = True
180
+ print("本地视觉语言模型加载完成。")
181
+
182
+ def _load_processor(self, local_path: str):
183
+ """
184
+ 加载处理器,包含多级回退机制。
185
+ 某些 transformers 版本中 VIDEO_PROCESSOR_MAPPING_NAMES 未正确初始化,
186
+ 导致 AutoProcessor.from_pretrained 抛出 TypeError,此处做兼容处理。
187
+ """
188
+ # 方式 1: 标准 AutoProcessor 加载
189
+ try:
190
+ from transformers import AutoProcessor
191
+ return AutoProcessor.from_pretrained(
192
+ local_path,
193
+ trust_remote_code=True,
194
+ )
195
+ except TypeError as e:
196
+ if "NoneType" in str(e):
197
+ print(f"AutoProcessor 遇到视频处理器兼容性问题: {e}")
198
+ else:
199
+ raise
200
+
201
+ # 方式 2: 修复 VIDEO_PROCESSOR_MAPPING_NAMES 后重试
202
+ try:
203
+ from transformers.models.auto import video_processing_auto
204
+ if video_processing_auto.VIDEO_PROCESSOR_MAPPING_NAMES is None:
205
+ video_processing_auto.VIDEO_PROCESSOR_MAPPING_NAMES = {}
206
+ print("已修复 VIDEO_PROCESSOR_MAPPING_NAMES 初始化问题,重新加载...")
207
+ from transformers import AutoProcessor
208
+ return AutoProcessor.from_pretrained(
209
+ local_path,
210
+ trust_remote_code=True,
211
+ )
212
+ except Exception as e:
213
+ print(f"修复后重试仍失败: {e}")
214
+
215
+ # 方式 3: 手动组装处理器(仅图片处理能力,不含视频)
216
+ print("回退方案: 手动组装处理器...")
217
+ from transformers import AutoTokenizer, AutoImageProcessor
218
+ tokenizer = AutoTokenizer.from_pretrained(
219
+ local_path, trust_remote_code=True
220
+ )
221
+ image_processor = AutoImageProcessor.from_pretrained(
222
+ local_path, trust_remote_code=True
223
+ )
224
+ try:
225
+ from transformers import Qwen3VLProcessor
226
+ processor = Qwen3VLProcessor(
227
+ image_processor=image_processor,
228
+ tokenizer=tokenizer,
229
+ )
230
+ print("手动组装处理器成功。")
231
+ return processor
232
+ except (ImportError, Exception) as e:
233
+ raise RuntimeError(
234
+ f"处理器加载失败: {e}\n"
235
+ "请尝试: pip install -U transformers torchvision qwen-vl-utils"
236
+ )
237
+
238
+ def _describe_image_local(self, image_path: str, prompt: str) -> str:
239
+ """使用本地模型生成图片描述"""
240
+ if not self.local_loaded:
241
+ raise RuntimeError(
242
+ "本地视觉模型未加载。请设置 XGUARD_VL_USE_API=false 重启,或切换为在线 API 模式。"
243
+ )
244
+
245
+ with self._lock:
246
+ messages = [
247
+ {
248
+ "role": "user",
249
+ "content": [
250
+ {"type": "image", "image": image_path},
251
+ {"type": "text", "text": prompt},
252
+ ],
253
+ }
254
+ ]
255
+
256
+ inputs = self.processor.apply_chat_template(
257
+ messages,
258
+ tokenize=True,
259
+ add_generation_prompt=True,
260
+ return_dict=True,
261
+ return_tensors="pt",
262
+ )
263
+ inputs = inputs.to(self.model.device)
264
+
265
+ with torch.no_grad():
266
+ generated_ids = self.model.generate(
267
+ **inputs,
268
+ max_new_tokens=512,
269
+ do_sample=False,
270
+ )
271
+
272
+ generated_ids_trimmed = [
273
+ out_ids[len(in_ids):]
274
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
275
+ ]
276
+ output_text = self.processor.batch_decode(
277
+ generated_ids_trimmed,
278
+ skip_special_tokens=True,
279
+ clean_up_tokenization_spaces=False,
280
+ )
281
+ return output_text[0]
282
+
283
+ # ==============================================================
284
+ # 统一对外接口
285
+ # ==============================================================
286
+ def _ensure_local_model(self):
287
+ """确保本地模型已加载(用于 API 限额耗尽时的延迟加载)"""
288
+ if self.local_loaded:
289
+ return
290
+ if not self.model_path:
291
+ raise RuntimeError(
292
+ "在线 API 调用次数已达上限,且未配置本地模型路径 (XGUARD_VL_MODEL_PATH),"
293
+ "无法降级到本地模型。请配置本地模型或重启服务以重置 API 计数。"
294
+ )
295
+ print("[自动降级] API 次数耗尽,正在加载本地视觉语言模型...")
296
+ self._load_local_model()
297
+ print("[自动降级] 本地视觉语言模型加载完成。")
298
+
299
+ def describe_image(self, image_path: str, prompt: str = None, use_api: bool = None) -> str:
300
+ """
301
+ 生成图片描述(统一接口)。
302
+
303
+ 参数:
304
+ image_path: 图片文件路径
305
+ prompt: 自定义描述提示,为空则使用默认提示
306
+ use_api: 是否使用在线 API,为 None 时由 api_client 是否可用决定
307
+ 返回:
308
+ 图片的文本描述
309
+
310
+ 注意:
311
+ 当 use_api=True 但 API 调用次数已达上限时,会自动降级到本地模型。
312
+ 降级信息通过返回值中的 metadata 属性传递(如有需要请检查 self.api_limit_reached)。
313
+ """
314
+ if not prompt:
315
+ prompt = self.DEFAULT_PROMPT
316
+
317
+ # 决定使用哪种模式
318
+ if use_api is None:
319
+ use_api = self.api_client is not None
320
+
321
+ # API 调用次数限制检查:超限自动降级
322
+ if use_api and self.api_limit_reached:
323
+ remaining = self.api_remaining
324
+ print(
325
+ f"[API 限流] 在线 API 调用已达上限 "
326
+ f"({self._api_call_count}/{self._api_max_calls}),自动降级到本地模型"
327
+ )
328
+ self._ensure_local_model()
329
+ use_api = False
330
+
331
+ if use_api:
332
+ self._increment_api_count()
333
+ return self._describe_image_api(image_path, prompt)
334
+ else:
335
+ return self._describe_image_local(image_path, prompt)
336
+
337
+
338
+ class XGuardModel:
339
+ """
340
+ YuFeng-XGuard 安全检测模型封装。
341
+
342
+ 推理逻辑完全对齐官方实现:
343
+ - apply_chat_template 支持 policy / reason_first 参数
344
+ - 通过 decoded text 直接匹配 id2risk(而非 token_id 中转)
345
+ - reason_first 模式下正确定位风险 token 的 score 位置
346
+ """
347
+
348
+ def __init__(self, model_path: str, device: str = "auto"):
349
+ self.model_path = model_path
350
+ self.device = device
351
+ self.model = None
352
+ self.tokenizer = None
353
+ self.id2risk = None
354
+ self._lock = threading.Lock()
355
+ self._load_model()
356
+
357
+ def _load_model(self):
358
+ """加载模型和 tokenizer,提取 id2risk 映射表"""
359
+ local_path = resolve_model_path(self.model_path)
360
+
361
+ print(f"正在加载安全检测模型: {local_path}...")
362
+ self.tokenizer = AutoTokenizer.from_pretrained(
363
+ local_path,
364
+ trust_remote_code=True
365
+ )
366
+ self.model = AutoModelForCausalLM.from_pretrained(
367
+ local_path,
368
+ torch_dtype="auto",
369
+ device_map=self.device,
370
+ trust_remote_code=True
371
+ ).eval()
372
+
373
+ # 从 tokenizer 配置中获取 id2risk 映射
374
+ # id2risk 格式: {'sec': 'Safe-Safe', 'pc': 'Crimes and Illegal Activities-Pornographic Contraband', ...}
375
+ # key 是短文本标记(如 'sec', 'pc'),value 是风险类别全名
376
+ self.id2risk = self.tokenizer.init_kwargs.get('id2risk', {})
377
+ print(f"id2risk 映射条目数: {len(self.id2risk)}")
378
+ print(f"##################self.id2risk: {self.id2risk} #####################")
379
+ if self.id2risk:
380
+ print(f"示例映射: {list(self.id2risk.items())[:5]}")
381
+
382
+ def infer(self, messages: List[Dict[str, str]], policy=None,
383
+ max_new_tokens: int = 1, reason_first: bool = False) -> Dict[str, Any]:
384
+ """
385
+ 官方推理接口,完全对齐 XGuard 官方推理逻辑。
386
+
387
+ 参数:
388
+ messages: 对话消息列表
389
+ policy: 动态策略(可选),用于运行时自定义安全检测规则
390
+ max_new_tokens: 最大生成 token 数
391
+ reason_first: 是否先生成归因分析再输出风险 token
392
+ 返回:
393
+ {
394
+ 'response': str, # 完整解码文本
395
+ 'token_score': {text: prob, ...}, # 风险 token 位置的 topk token 分数
396
+ 'risk_score': {risk_name: prob, ...} # 匹配到 id2risk 的风险类别分数
397
+ }
398
+ """
399
+ with self._lock:
400
+ # 使用 chat template 渲染输入(含 policy 和 reason_first 参数)
401
+ rendered_query = self.tokenizer.apply_chat_template(
402
+ messages,
403
+ policy=policy,
404
+ reason_first=reason_first,
405
+ tokenize=False
406
+ )
407
+
408
+ model_inputs = self.tokenizer(
409
+ [rendered_query], return_tensors="pt"
410
+ ).to(self.model.device)
411
+
412
+ with torch.no_grad():
413
+ outputs = self.model.generate(
414
+ **model_inputs,
415
+ max_new_tokens=max_new_tokens,
416
+ do_sample=False,
417
+ output_scores=True,
418
+ return_dict_in_generate=True
419
+ )
420
+
421
+ batch_idx = 0
422
+ input_length = model_inputs['input_ids'].shape[1]
423
+
424
+ # 解码响应文本
425
+ output_ids = outputs["sequences"].tolist()[batch_idx][input_length:]
426
+ response = self.tokenizer.decode(output_ids, skip_special_tokens=True)
427
+
428
+ # ---- 解析每个生成位置的 topk 分数 (官方逻辑) ----
429
+ generated_tokens = outputs.sequences[:, input_length:]
430
+ scores = torch.stack(outputs.scores, dim=1)
431
+ scores = scores.softmax(dim=-1)
432
+ scores_topk_value, scores_topk_index = scores.topk(k=10, dim=-1)
433
+
434
+ generated_tokens_with_probs = []
435
+ for generated_token, score_topk_value, score_topk_index in zip(
436
+ generated_tokens, scores_topk_value, scores_topk_index
437
+ ):
438
+ generated_tokens_with_prob = []
439
+ for token, topk_value, topk_index in zip(
440
+ generated_token, score_topk_value, score_topk_index
441
+ ):
442
+ token = int(token.cpu())
443
+ if token == self.tokenizer.pad_token_id:
444
+ continue
445
+
446
+ res_topk_score = {}
447
+ for ii, (value, index) in enumerate(zip(topk_value, topk_index)):
448
+ if ii == 0 or value.cpu().numpy() > 1e-4:
449
+ text = self.tokenizer.decode(index.cpu().numpy())
450
+ res_topk_score[text] = {
451
+ "id": str(int(index.cpu().numpy())),
452
+ "prob": round(float(value.cpu().numpy()), 4),
453
+ }
454
+
455
+ generated_tokens_with_prob.append(res_topk_score)
456
+ generated_tokens_with_probs.append(generated_tokens_with_prob)
457
+
458
+ # 确定风险分数的 token 位置索引
459
+ # reason_first=False: 风险 token 在第一个位置 (idx=0)
460
+ # reason_first=True: 风险 token 在倒数第二个位置 (reasoning 后、EOS 前)
461
+ score_idx = (
462
+ max(len(generated_tokens_with_probs[batch_idx]) - 2, 0)
463
+ if reason_first else 0
464
+ )
465
+
466
+ # 提取 token 分数和风险分数(官方方式: decoded text 直接匹配 id2risk)
467
+ token_score = {
468
+ k: v['prob']
469
+ for k, v in generated_tokens_with_probs[batch_idx][score_idx].items()
470
+ }
471
+ risk_score = {
472
+ self.id2risk[k]: v['prob']
473
+ for k, v in generated_tokens_with_probs[batch_idx][score_idx].items()
474
+ if k in self.id2risk
475
+ }
476
+
477
+ return {
478
+ 'response': response,
479
+ 'token_score': token_score,
480
+ 'risk_score': risk_score,
481
+ }
482
+
483
+ def parse_explanation(self, response: str) -> Optional[str]:
484
+ """
485
+ 从响应中解析归因分析部分。
486
+
487
+ XGuard 在 reason_first=False 模式下,输出格式为:
488
+ [风险分类 token][归因分析文本]
489
+ 风险 token 是 id2risk 中的短字符串 key(如 'sec', 'pc' 等),
490
+ 后续文本为自然语言的归因分析说明。
491
+ """
492
+ if not response or not response.strip():
493
+ return None
494
+
495
+ # 方式 1: 兼容 <explanation>...</explanation> 标签格式
496
+ match = re.search(r'<explanation>(.*?)</explanation>', response, re.DOTALL)
497
+ if match:
498
+ return match.group(1).strip()
499
+
500
+ text = response.strip()
501
+
502
+ # 方式 2: 剥离开头的风险分类 token,提取后续归因文本
503
+ # id2risk 的 key 是短字符串(如 'sec', 'pc'),模型输出以它开头
504
+ if self.id2risk:
505
+ for key in sorted(self.id2risk.keys(), key=len, reverse=True):
506
+ if text.startswith(key):
507
+ remainder = text[len(key):].strip()
508
+ if remainder:
509
+ return remainder
510
+ break # 匹配到 token 但无后续文本,说明未生成归因
511
+
512
+ # 方式 3: 响应长度明显超过单个风险 token(通常 2-4 字符),直接作为归因返回
513
+ if len(text) > 8:
514
+ return text
515
+
516
+ return None
517
+
518
+ def analyze(self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]],
519
+ enable_reasoning: bool = False, policy=None) -> Dict[str, Any]:
520
+ """
521
+ 高层分析接口,封装推理结果为结构化格式。
522
+
523
+ 参数:
524
+ messages: 对话消息列表
525
+ tools: 工具信息(已拼接到 messages 中,暂未使用)
526
+ enable_reasoning: 是否启用归因分析(生成更多 token)
527
+ policy: 动态策略(可选)
528
+ """
529
+ # 启用归因分析时生成更多 token 以获取完整的归因解释
530
+ max_new_tokens = 512 if enable_reasoning else 1
531
+
532
+ infer_result = self.infer(
533
+ messages,
534
+ policy=policy,
535
+ max_new_tokens=max_new_tokens,
536
+ reason_first=False
537
+ )
538
+ risk_scores = infer_result.get("risk_score", {})
539
+ response = infer_result.get("response", "")
540
+
541
+ # ================================================================
542
+ # 风险判定 — 基于 XGuard 论文的 argmax + 置信度分级框架
543
+ #
544
+ # 理论基础 (arxiv 2601.15588):
545
+ # XGuard 的训练目标 max_θ log P(y_cls | X; θ) 保证
546
+ # 第一个 token 的 softmax argmax = 模型预测的风险类别。
547
+ # 概率值即为校准后的置信度。
548
+ #
549
+ # 判定流程:
550
+ # Layer 1 — argmax: 概率最高的类别就是模型的答案
551
+ # Layer 2 — 置信度门控: safe 判定需 >= 0.5 (过半数置信)
552
+ # Layer 3 — 风险分级: 按 top_risk_prob 划分 high/medium/low
553
+ # ================================================================
554
+ SAFE_CATEGORY = "Safe-Safe"
555
+ safe_prob = risk_scores.get(SAFE_CATEGORY, 0.0)
556
+
557
+ # 提取非安全类风险项,按分数降序排列
558
+ risk_items = {k: v for k, v in risk_scores.items() if k != SAFE_CATEGORY}
559
+ sorted_risks = sorted(risk_items.items(), key=lambda x: x[1], reverse=True)
560
+
561
+ top_risk_name = sorted_risks[0][0] if sorted_risks else ""
562
+ top_risk_prob = sorted_risks[0][1] if sorted_risks else 0.0
563
+
564
+ # Layer 1 + 2: argmax 决策 + 置信度门控
565
+ if safe_prob >= top_risk_prob and safe_prob >= 0.5:
566
+ # argmax = Safe-Safe, 且置信度过半 → 判定安全
567
+ is_safe = 1
568
+ risk_level = "safe"
569
+ elif safe_prob >= top_risk_prob:
570
+ # argmax = Safe-Safe, 但置信度不足 0.5
571
+ # 模型最倾向安全,但不够确定,谨慎标记为低风险
572
+ is_safe = 0
573
+ risk_level = "low"
574
+ else:
575
+ # argmax = 某风险类别 (top_risk_prob > safe_prob)
576
+ # Layer 3: 按风险置信度分级
577
+ is_safe = 0
578
+ if top_risk_prob >= 0.5:
579
+ risk_level = "high"
580
+ elif top_risk_prob >= 0.3:
581
+ risk_level = "medium"
582
+ else:
583
+ risk_level = "low"
584
+
585
+ # 置信度: 模型对当前判定的确信程度
586
+ confidence = safe_prob if is_safe == 1 else top_risk_prob
587
+
588
+ # 构建风险类型列表和原因说明
589
+ # 无论安全与否,始终输出最高风险项作为风险提示
590
+ if is_safe == 0:
591
+ top_risks = sorted_risks[:3]
592
+ else:
593
+ # 安全时仅取最高风险项作为提示
594
+ top_risks = sorted_risks[:1] if sorted_risks else []
595
+
596
+ risk_types = [r[0] for r in top_risks]
597
+ reason = "; ".join([f"{r}: {s}" for r, s in top_risks])
598
+
599
+ result = {
600
+ "is_safe": is_safe,
601
+ "risk_level": risk_level,
602
+ "confidence": round(confidence, 4),
603
+ "risk_type": risk_types,
604
+ "reason": reason,
605
+ "detail_scores": risk_scores,
606
+ "response": response
607
+ }
608
+
609
+ # 如果启用了归因分析,解析并添加 explanation
610
+ if enable_reasoning:
611
+ explanation = self.parse_explanation(response)
612
+ if explanation:
613
+ result["explanation"] = explanation
614
+
615
+ return result
requirements.txt ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.12.0
2
+ aiofiles==24.1.0
3
+ annotated-doc==0.0.4
4
+ annotated-types==0.7.0
5
+ anyio==4.12.1
6
+ av==16.1.0
7
+ brotli==1.2.0
8
+ certifi==2026.1.4
9
+ charset-normalizer==3.4.4
10
+ click==8.3.1
11
+ colorama==0.4.6
12
+ distro==1.9.0
13
+ fastapi==0.128.5
14
+ ffmpy==1.0.0
15
+ filelock==3.20.3
16
+ fsspec==2026.2.0
17
+ gradio==6.5.1
18
+ gradio_client==2.0.3
19
+ groovy==0.1.2
20
+ h11==0.16.0
21
+ hf-xet==1.2.0
22
+ httpcore==1.0.9
23
+ httpx==0.28.1
24
+ huggingface_hub==1.4.1
25
+ idna==3.11
26
+ Jinja2==3.1.6
27
+ jiter==0.13.0
28
+ markdown-it-py==4.0.0
29
+ MarkupSafe==3.0.3
30
+ mdurl==0.1.2
31
+ modelscope==1.34.0
32
+ mpmath==1.3.0
33
+ networkx==3.6.1
34
+ numpy==2.4.2
35
+ scikit-learn>=1.6.0
36
+ scipy>=1.14.0
37
+ openai==2.17.0
38
+ orjson==3.11.7
39
+ packaging==26.0
40
+ pandas==3.0.0
41
+ pillow==12.1.0
42
+ psutil==7.2.2
43
+ pydantic==2.12.5
44
+ pydantic_core==2.41.5
45
+ pydub==0.25.1
46
+ Pygments==2.19.2
47
+ python-dateutil==2.9.0.post0
48
+ python-multipart==0.0.22
49
+ pytz==2025.2
50
+ PyYAML==6.0.3
51
+ qwen-vl-utils==0.0.14
52
+ regex==2026.1.15
53
+ requests==2.32.5
54
+ rich==14.3.2
55
+ safehttpx==0.1.7
56
+ safetensors==0.7.0
57
+ semantic-version==2.10.0
58
+ setuptools==82.0.0
59
+ shellingham==1.5.4
60
+ six==1.17.0
61
+ sniffio==1.3.1
62
+ starlette==0.52.1
63
+ sympy==1.14.0
64
+ tokenizers==0.22.2
65
+ tomlkit==0.13.3
66
+ torch==2.10.0
67
+ torchvision==0.25.0
68
+ tqdm==4.67.3
69
+ transformers==5.1.0
70
+ typer==0.21.1
71
+ typer-slim==0.21.1
72
+ typing-inspection==0.4.2
73
+ typing_extensions==4.15.0
74
+ tzdata==2025.3
75
+ urllib3==2.6.3
76
+ uvicorn==0.40.0