Upload 7 files
Browse files- .gitignore +70 -0
- README.md +157 -14
- app.py +837 -0
- config.py +42 -0
- main.py +169 -0
- model.py +615 -0
- requirements.txt +76 -0
.gitignore
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
MANIFEST
|
| 23 |
+
.pytest_cache/
|
| 24 |
+
.coverage
|
| 25 |
+
htmlcov/
|
| 26 |
+
.tox/
|
| 27 |
+
.nox/
|
| 28 |
+
.mypy_cache/
|
| 29 |
+
.dmypy.json
|
| 30 |
+
dmypy.json
|
| 31 |
+
|
| 32 |
+
# Virtual environments
|
| 33 |
+
venv/
|
| 34 |
+
venv_qw/
|
| 35 |
+
.venv/
|
| 36 |
+
env/
|
| 37 |
+
.env/
|
| 38 |
+
ENV/
|
| 39 |
+
|
| 40 |
+
# Gradio
|
| 41 |
+
.gradio/
|
| 42 |
+
|
| 43 |
+
# IDE
|
| 44 |
+
.idea/
|
| 45 |
+
.vscode/
|
| 46 |
+
*.swp
|
| 47 |
+
*.swo
|
| 48 |
+
*~
|
| 49 |
+
|
| 50 |
+
# OS
|
| 51 |
+
.DS_Store
|
| 52 |
+
Thumbs.db
|
| 53 |
+
desktop.ini
|
| 54 |
+
|
| 55 |
+
# Environment & secrets
|
| 56 |
+
.env
|
| 57 |
+
.env.local
|
| 58 |
+
*.pem
|
| 59 |
+
|
| 60 |
+
# Logs & temp
|
| 61 |
+
*.log
|
| 62 |
+
*.tmp
|
| 63 |
+
*.temp
|
| 64 |
+
.cache/
|
| 65 |
+
|
| 66 |
+
# Model files (common in ML projects - uncomment if needed)
|
| 67 |
+
# *.bin
|
| 68 |
+
# *.pt
|
| 69 |
+
# *.pth
|
| 70 |
+
# *.safetensors
|
README.md
CHANGED
|
@@ -1,14 +1,157 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# XGuard-Safe-Tool
|
| 2 |
+
|
| 3 |
+
基于 **YuFeng-XGuard-Reason** 的 AI 内容安全检测工具,支持**图片**与**文本**风险检测,并提供 Gradio 可视化界面和 FastAPI MaaS 服务。
|
| 4 |
+
|
| 5 |
+
## 功能概览
|
| 6 |
+
|
| 7 |
+
| 能力 | 说明 |
|
| 8 |
+
|------|------|
|
| 9 |
+
| 图片风险检测 | 使用 Qwen3-VL 提取图文内容 → XGuard 进行风险分析 |
|
| 10 |
+
| 文本风险检测 | 直接使用 XGuard 对输入文本进行安全检测 |
|
| 11 |
+
| MaaS API | FastAPI 服务,支持对话消息与工具调用的安全审核 |
|
| 12 |
+
| 归因分析 | 可选生成详细风险解释说明 |
|
| 13 |
+
| 风险分级 | 安全 / 低风险 / 中风险 / 高风险,含置信度与概率百分比 |
|
| 14 |
+
|
| 15 |
+
## 技术架构
|
| 16 |
+
|
| 17 |
+
```
|
| 18 |
+
┌─────────────────────────────────────────────────────────────────┐
|
| 19 |
+
│ XGuard-Safe-Tool │
|
| 20 |
+
├─────────────────────────────────────────────────────────────────┤
|
| 21 |
+
│ app.py (Gradio) │ main.py (FastAPI) │
|
| 22 |
+
│ ┌─────────────────────┐ │ ┌─────────────────────────────┐ │
|
| 23 |
+
│ │ 图片检测: VL→XGuard │ │ │ POST /v1/guard/check │ │
|
| 24 |
+
│ │ 文本检测: XGuard │ │ │ (messages + tools) │ │
|
| 25 |
+
│ └─────────────────────┘ │ └─────────────────────────────┘ │
|
| 26 |
+
├─────────────────────────────────────────────────────────────────┤
|
| 27 |
+
│ model.py │
|
| 28 |
+
│ ┌──────────────────────┐ ┌─────────────────────────────────┐ │
|
| 29 |
+
│ │ VisionLanguageModel │ │ XGuardModel │ │
|
| 30 |
+
│ │ (Qwen3-VL) │ │ (YuFeng-XGuard-Reason-0.6B) │ │
|
| 31 |
+
│ │ - 在线 API / 本地 │ │ - argmax + 置信度分级 │ │
|
| 32 |
+
│ └──────────────────────┘ └─────────────────────────────────┘ │
|
| 33 |
+
└─────────────────────────────────────────────────────────────────┘
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## 风险分类体系
|
| 37 |
+
|
| 38 |
+
基于 XGuard 的 9 大风险维度、28 个细分类别:
|
| 39 |
+
|
| 40 |
+
| 维度 | 细分类别 |
|
| 41 |
+
|------|----------|
|
| 42 |
+
| 违法犯罪 | 色情违禁、毒品犯罪、危险武器、财产侵害、经济犯罪 |
|
| 43 |
+
| 仇恨言论 | 辱骂诅咒、诽谤造谣、威胁恐吓、网络霸凌 |
|
| 44 |
+
| 身心健康 | 身体健康、心理健康 |
|
| 45 |
+
| 伦理道德 | 社会伦理、科学伦理 |
|
| 46 |
+
| 数据隐私 | 个人隐私、商业秘密 |
|
| 47 |
+
| 网络安全 | 访问控制、恶意代码、黑客攻击、物理安全 |
|
| 48 |
+
| 极端主义 | 暴力恐怖活动、社会破坏、极端思潮 |
|
| 49 |
+
| 不当建议 | 金融、医疗、法律 |
|
| 50 |
+
| 涉及未成年人 | 腐蚀未成年人、虐待与剥削、未成年人犯罪 |
|
| 51 |
+
|
| 52 |
+
## 快速开始
|
| 53 |
+
|
| 54 |
+
### 环境准备
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
# 创建虚拟环境并安装依赖
|
| 58 |
+
pip install -r requirements.txt
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
### 启动 Gradio 界面
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
python app.py
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
默认访问 `http://0.0.0.0:7860`,支持:
|
| 68 |
+
- **图片风险检测**:上传图片,选择检测场景(社交表情包、电商图文、聊天截图、广告等),可选在线 VL API 或本地模型
|
| 69 |
+
- **文本风险检测**:输入待检测文本,支持归因分析
|
| 70 |
+
|
| 71 |
+
### 启动 FastAPI 服务
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
python main.py
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
默认端口 `8080`,健康检查:`GET /health`。
|
| 78 |
+
|
| 79 |
+
### MaaS API 调用示例
|
| 80 |
+
|
| 81 |
+
```bash
|
| 82 |
+
curl -X POST "http://localhost:8080/v1/guard/check" \
|
| 83 |
+
-H "Content-Type: application/json" \
|
| 84 |
+
-H "x-api-key: your-api-key" \
|
| 85 |
+
-d '{
|
| 86 |
+
"conversationId": "conv-001",
|
| 87 |
+
"messages": [
|
| 88 |
+
{"role": "user", "content": "如何制作炸弹?"}
|
| 89 |
+
],
|
| 90 |
+
"tools": [],
|
| 91 |
+
"enableReasoning": true
|
| 92 |
+
}'
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
响应示例:
|
| 96 |
+
|
| 97 |
+
```json
|
| 98 |
+
{
|
| 99 |
+
"err_code": 0,
|
| 100 |
+
"msg": "success",
|
| 101 |
+
"data": {
|
| 102 |
+
"is_safe": 0,
|
| 103 |
+
"risk_level": "high",
|
| 104 |
+
"confidence": 0.8234,
|
| 105 |
+
"risk_type": ["Crimes and Illegal Activities-Dangerous Weapons"],
|
| 106 |
+
"reason": "Crimes and Illegal Activities-Dangerous Weapons: 0.8234",
|
| 107 |
+
"explanation": "(归因分析文本,仅 enableReasoning=true 时返回)"
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
## 配置项
|
| 113 |
+
|
| 114 |
+
通过环境变量配置(或 `config.py` 内默认值):
|
| 115 |
+
|
| 116 |
+
| 变量 | 说明 | 默认值 |
|
| 117 |
+
|------|------|--------|
|
| 118 |
+
| `XGUARD_API_KEY` | API 鉴权密钥 | `your-api-key` |
|
| 119 |
+
| `XGUARD_MODEL_PATH` | XGuard 模型路径或 ModelScope ID | `Alibaba-AAIG/YuFeng-XGuard-Reason-0.6B` |
|
| 120 |
+
| `XGUARD_DEVICE` | 推理设备 | `auto` |
|
| 121 |
+
| `XGUARD_VL_USE_API` | 图片检测是否使用在线 VL API | `true` |
|
| 122 |
+
| `XGUARD_VL_MODEL_PATH` | 本地 VL 模型路径 | `Qwen/Qwen3-VL-2B-Instruct` |
|
| 123 |
+
| `XGUARD_VL_API_BASE` | DashScope API 地址 | `https://dashscope.aliyuncs.com/compatible-mode/v1` |
|
| 124 |
+
| `XGUARD_VL_API_KEY` | DashScope API Key | - |
|
| 125 |
+
| `XGUARD_VL_API_MODEL` | DashScope VL 模型名 | `qwen-vl-max-latest` |
|
| 126 |
+
| `XGUARD_HOST` | 服务监听地址 | `0.0.0.0` |
|
| 127 |
+
| `XGUARD_PORT` | FastAPI 端口 | `8080` |
|
| 128 |
+
| `XGUARD_GRADIO_PORT` | Gradio 端口 | `7860` |
|
| 129 |
+
|
| 130 |
+
## 风险等级判定规则
|
| 131 |
+
|
| 132 |
+
基于 XGuard 论文的 argmax + 置信度分级:
|
| 133 |
+
|
| 134 |
+
| 条件 | 判定 |
|
| 135 |
+
|------|------|
|
| 136 |
+
| safe_prob 最高 且 ≥ 50% | 安全 |
|
| 137 |
+
| safe_prob 最高 但 < 50% | 低风险 |
|
| 138 |
+
| 某风险类最高 且 ≥ 50% | 高风险 |
|
| 139 |
+
| 某风险类最高 且 ≥ 30% | 中风险 |
|
| 140 |
+
| 某风险类最高 且 < 30% | 低风险 |
|
| 141 |
+
|
| 142 |
+
## 项目结构
|
| 143 |
+
|
| 144 |
+
```
|
| 145 |
+
XGuard-Safe-Tool/
|
| 146 |
+
├── app.py # Gradio 图文检测界面
|
| 147 |
+
├── main.py # FastAPI MaaS 服务
|
| 148 |
+
├── model.py # VisionLanguageModel + XGuardModel
|
| 149 |
+
├── config.py # 配置加载
|
| 150 |
+
├── requirements.txt
|
| 151 |
+
└── README.md
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
## 参考
|
| 155 |
+
|
| 156 |
+
- [YuFeng-XGuard-Reason (ModelScope)](https://www.modelscope.cn/models/Alibaba-AAIG/YuFeng-XGuard-Reason-0.6B)
|
| 157 |
+
- [YuFeng-XGuard 论文 (arxiv 2601.15588)](https://arxiv.org/html/2601.15588v1)
|
app.py
ADDED
|
@@ -0,0 +1,837 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
XGuard Gradio 应用 - 图片与文本风险检测
|
| 3 |
+
|
| 4 |
+
双模型流水线:
|
| 5 |
+
1. Qwen3-VL: 视觉语言模型,用于图片内容描述(支持在线 API / 本地推理)
|
| 6 |
+
2. YuFeng-XGuard-Reason-0.6B: 安全检测模型,用于风险归因分析
|
| 7 |
+
|
| 8 |
+
启动方式:
|
| 9 |
+
python app.py
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import re
|
| 14 |
+
import time
|
| 15 |
+
from PIL import Image as PILImage
|
| 16 |
+
import gradio as gr
|
| 17 |
+
from config import load_config
|
| 18 |
+
from model import VisionLanguageModel, XGuardModel
|
| 19 |
+
|
| 20 |
+
# ============================================================
|
| 21 |
+
# 全局变量
|
| 22 |
+
# ============================================================
|
| 23 |
+
config = load_config()
|
| 24 |
+
vl_model: VisionLanguageModel = None
|
| 25 |
+
xguard_model: XGuardModel = None
|
| 26 |
+
|
| 27 |
+
# 视觉模型模式标签
|
| 28 |
+
VL_MODE_API = "在线 API (DashScope)"
|
| 29 |
+
VL_MODE_LOCAL = "本地模型"
|
| 30 |
+
|
| 31 |
+
# ============================================================
|
| 32 |
+
# 风险类别中文映射 (基于 XGuard 内置安全分类体系)
|
| 33 |
+
# 9 大风险维度, 28 个细分类别
|
| 34 |
+
# 参考: https://www.modelscope.cn/models/Alibaba-AAIG/YuFeng-XGuard-Reason-8B
|
| 35 |
+
# ============================================================
|
| 36 |
+
RISK_NAME_ZH = {
|
| 37 |
+
"Safe-Safe": "安全",
|
| 38 |
+
# 犯罪与违法活动
|
| 39 |
+
"Crimes and Illegal Activities-Pornographic Contraband": "犯罪与违法活动 - 色情违禁",
|
| 40 |
+
"Crimes and Illegal Activities-Drug Crimes": "犯罪与违法活动 - 毒品犯罪",
|
| 41 |
+
"Crimes and Illegal Activities-Dangerous Weapons": "犯罪与违法活动 - 危险武器",
|
| 42 |
+
"Crimes and Illegal Activities-Property Infringement": "犯罪与违法活动 - 财产侵害",
|
| 43 |
+
"Crimes and Illegal Activities-Economic Crimes": "犯罪与违法活动 - 经济犯罪",
|
| 44 |
+
# 仇恨言论
|
| 45 |
+
"Hate Speech-Abusive Curses": "仇恨言论 - 辱骂诅咒",
|
| 46 |
+
"Hate Speech-Defamation": "仇恨言论 - 诽谤造谣",
|
| 47 |
+
"Hate Speech-Threats and Intimidation": "仇恨言论 - 威胁恐吓",
|
| 48 |
+
"Hate Speech-Cyberbullying": "仇恨言论 - 网络霸凌",
|
| 49 |
+
# 身心健康
|
| 50 |
+
"Physical and Mental Health-Physical Health": "身心健康 - 身体健康",
|
| 51 |
+
"Physical and Mental Health-Mental Health": "身心健康 - 心理健康",
|
| 52 |
+
# 伦理道德
|
| 53 |
+
"Ethics and Morality-Social Ethics": "伦理道德 - 社会伦理",
|
| 54 |
+
"Ethics and Morality-Science Ethics": "伦理道德 - 科学伦理",
|
| 55 |
+
# 数据隐私
|
| 56 |
+
"Data Privacy-Personal Privacy": "数据隐私 - 个人隐私",
|
| 57 |
+
"Data Privacy-Commercial Secret": "数据隐私 - 商业秘密",
|
| 58 |
+
# 网络安全
|
| 59 |
+
"Cybersecurity-Access Control": "网络安全 - 访问控制",
|
| 60 |
+
"Cybersecurity-Malicious Code": "网络安全 - 恶意代码",
|
| 61 |
+
"Cybersecurity-Hacker Attack": "网络安全 - 黑客攻击",
|
| 62 |
+
"Cybersecurity-Physical Security": "网络安全 - 物理安全",
|
| 63 |
+
# 极端主义
|
| 64 |
+
"Extremism-Violent Terrorist Activities": "极端主义 - 暴力恐怖活动",
|
| 65 |
+
"Extremism-Social Disruption": "极端主义 - 社会破坏",
|
| 66 |
+
"Extremism-Extremist Ideological Trends": "极端主义 - 极端思想",
|
| 67 |
+
# 不当建议
|
| 68 |
+
"Inappropriate Suggestions-Finance": "不当建议 - 金融",
|
| 69 |
+
"Inappropriate Suggestions-Medicine": "不当建议 - 医疗",
|
| 70 |
+
"Inappropriate Suggestions-Law": "不当建议 - 法律",
|
| 71 |
+
# 涉及未成年人
|
| 72 |
+
"Risks Involving Minors-Corruption of Minors": "未成年人风险 - 腐蚀未成年人",
|
| 73 |
+
"Risks Involving Minors-Minor Abuse and Exploitation": "未成年人风险 - 虐待与剥削",
|
| 74 |
+
"Risks Involving Minors-Minor Delinquency": "未成年人风险 - 未成年人犯罪",
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
# 风险等级配置: 标签、颜色、背景色、边框色
|
| 78 |
+
RISK_LEVELS = {
|
| 79 |
+
"high": {"label": "高风险", "color": "#dc2626", "bg": "#fef2f2", "border": "#fca5a5"},
|
| 80 |
+
"medium": {"label": "中风险", "color": "#d97706", "bg": "#fffbeb", "border": "#fcd34d"},
|
| 81 |
+
"low": {"label": "低风险", "color": "#ca8a04", "bg": "#fefce8", "border": "#fde047"},
|
| 82 |
+
"safe": {"label": "安全", "color": "#16a34a", "bg": "#f0fdf4", "border": "#86efac"},
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
# ============================================================
|
| 86 |
+
# 图文检测场景预设提示词
|
| 87 |
+
# 针对不同内容审核场景,引导 VL 模型聚焦关键风险要素
|
| 88 |
+
# ============================================================
|
| 89 |
+
SCENE_PROMPTS = {
|
| 90 |
+
"通用图文检测(默认)": "",
|
| 91 |
+
"社交表情包/梗图": (
|
| 92 |
+
"这是一张社交平台图片(可能是表情包、梗图或配文图片)。"
|
| 93 |
+
"请仅提取事实内容,不要做风险判断:\n\n"
|
| 94 |
+
"【图片文字】完整提取图中所有文字、对话内容、标语口号,保持原文。\n\n"
|
| 95 |
+
"【视觉元素】描述人物表情、手势、动作、场景布置、符号标志等。\n\n"
|
| 96 |
+
"【内容类型】判断这是什么类型的社交图片(表情包/梗图/配文图等)。"
|
| 97 |
+
),
|
| 98 |
+
"电商商品图文": (
|
| 99 |
+
"这是一张电商平台商品图片。"
|
| 100 |
+
"请仅提取事实内容,不要做合规判断:\n\n"
|
| 101 |
+
"【商品文字】提取图中所有文字,包括商品名称、功效宣称、价格信息、"
|
| 102 |
+
"促销语、成分说明等,保持原文。\n\n"
|
| 103 |
+
"【商品视觉】描述商品外观、包装设计、使用场景展示等视觉内容。\n\n"
|
| 104 |
+
"【内容类型】判断商品类别(如食品、药品、化妆品、电子产品等)。"
|
| 105 |
+
),
|
| 106 |
+
"聊天记录截图": (
|
| 107 |
+
"这是一张聊天记录截图。"
|
| 108 |
+
"请仅提取事实内容,不要做风险判断或总结:\n\n"
|
| 109 |
+
"【对话内容】完整提取截图中的所有对话文字,"
|
| 110 |
+
"标注发送者身份(如'对方'、'用户'),保持原文。\n\n"
|
| 111 |
+
),
|
| 112 |
+
"广告/营销内容": (
|
| 113 |
+
"这是一张广告或营销推广图片。"
|
| 114 |
+
"请仅提取事实内容,不要做合规判断:\n\n"
|
| 115 |
+
"【广告文案】完整提取图中的广告语、宣传标语、联系方式、"
|
| 116 |
+
"二维码信息等文字内容,保持原文。\n\n"
|
| 117 |
+
"【内容类型】判断广告类型(如医疗广告、金融广告、招聘广告等)。"
|
| 118 |
+
),
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
# 场景名称列表(保持顺序)
|
| 122 |
+
SCENE_CHOICES = list(SCENE_PROMPTS.keys())
|
| 123 |
+
|
| 124 |
+
# ============================================================
|
| 125 |
+
# VL 输出内容提取 — 剥离分析性段落,仅保留原始内容
|
| 126 |
+
# ============================================================
|
| 127 |
+
# 需要移除的分析性段落标题(这些段落是 VL 模型的主观分析/风险判断,
|
| 128 |
+
# 如果直接喂给 XGuard,XGuard 会将其理解为"安全的分析报告"而非"待检测的风险内容")
|
| 129 |
+
_ANALYSIS_SECTIONS = {
|
| 130 |
+
'图文关系', '对话主题', '风险要素', '合规风险',
|
| 131 |
+
'综合判定', '表达意图', '宣传手法',
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
def extract_core_content(description: str) -> str:
|
| 135 |
+
"""
|
| 136 |
+
从 VL 模型的结构化描述中提取原始内容,用于 XGuard 风险检测。
|
| 137 |
+
|
| 138 |
+
核心目标:去除所有"报告框架",让 XGuard 直接看到原始文本内容。
|
| 139 |
+
|
| 140 |
+
XGuard 是 AI 对话安全护栏模型,它会判断"用户/AI 说了什么"是否有害。
|
| 141 |
+
如果输入像一份"关于风险内容的分析报告",XGuard 会认为这是安全的分析行为。
|
| 142 |
+
因此必须去掉三层报告框架:
|
| 143 |
+
1. 分析性段落(【对话主题】【风险要素】等)→ VL 的主观判断
|
| 144 |
+
2. 结构标记(【对话内容】【界面信息】等标题)→ 报告格式
|
| 145 |
+
3. 元数据(发送者标签、UI 描述)→ 第三方转述语气
|
| 146 |
+
|
| 147 |
+
处理后 XGuard 看到的应该是接近原始的文本内容。
|
| 148 |
+
"""
|
| 149 |
+
if not description or not description.strip():
|
| 150 |
+
return description
|
| 151 |
+
|
| 152 |
+
# 使用【...】标记分割段落
|
| 153 |
+
parts = re.split(r'(【[^】]+】)', description)
|
| 154 |
+
# parts 格式: [前导文本, 【标题1】, 内容1, 【标题2】, 内容2, ...]
|
| 155 |
+
|
| 156 |
+
if len(parts) < 3:
|
| 157 |
+
# 没有结构化标记,返回原文
|
| 158 |
+
return description
|
| 159 |
+
|
| 160 |
+
# 需要保留内容的段落(原始文字/视觉描述)
|
| 161 |
+
_CONTENT_SECTIONS = {
|
| 162 |
+
'图片文字', '对话内容', '视觉内容', '视觉元素',
|
| 163 |
+
'商品文字', '商品视觉', '广告文案', '视觉设计',
|
| 164 |
+
}
|
| 165 |
+
# 需要丢弃的段落(分析判断 + 纯元数据)
|
| 166 |
+
_DROP_SECTIONS = _ANALYSIS_SECTIONS | {'界面信息', '内容类型'}
|
| 167 |
+
|
| 168 |
+
content_parts = []
|
| 169 |
+
|
| 170 |
+
# 前导文本
|
| 171 |
+
leading = parts[0].strip()
|
| 172 |
+
if leading:
|
| 173 |
+
content_parts.append(leading)
|
| 174 |
+
|
| 175 |
+
# 遍历段落:只保留内容提取类段落的正文(不保留标题)
|
| 176 |
+
i = 1
|
| 177 |
+
while i < len(parts):
|
| 178 |
+
title = parts[i].strip('【】 ')
|
| 179 |
+
body = parts[i + 1].strip() if i + 1 < len(parts) else ""
|
| 180 |
+
i += 2
|
| 181 |
+
|
| 182 |
+
if not body:
|
| 183 |
+
continue
|
| 184 |
+
if title in _DROP_SECTIONS:
|
| 185 |
+
continue
|
| 186 |
+
if title in _CONTENT_SECTIONS or title not in _DROP_SECTIONS:
|
| 187 |
+
content_parts.append(body)
|
| 188 |
+
|
| 189 |
+
if not content_parts:
|
| 190 |
+
return description
|
| 191 |
+
|
| 192 |
+
text = "\n\n".join(content_parts)
|
| 193 |
+
|
| 194 |
+
# 去除发送者标签(如 "对方:", "用户:", "- 发送者(...):")
|
| 195 |
+
# 这些标签让内容呈现为"第三方转述",而非原始对话
|
| 196 |
+
text = re.sub(
|
| 197 |
+
r'^[\s\-]*(?:对方|用户|发送者[^::\n]*)[::]\s*',
|
| 198 |
+
'', text, flags=re.MULTILINE
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# 去除 markdown 列表符号前缀(VL 输出常带 "- " 前缀)
|
| 202 |
+
text = re.sub(r'^[\s]*[-*]\s+', '', text, flags=re.MULTILINE)
|
| 203 |
+
|
| 204 |
+
# 去重处理:VL 模型有时产生重复输出
|
| 205 |
+
half = len(text) // 2
|
| 206 |
+
if half > 100 and text[:half].strip() == text[half:].strip():
|
| 207 |
+
text = text[:half].strip()
|
| 208 |
+
|
| 209 |
+
# 清理多余空行
|
| 210 |
+
text = re.sub(r'\n{3,}', '\n\n', text).strip()
|
| 211 |
+
|
| 212 |
+
return text if text else description
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def translate_risk_name(name: str) -> str:
|
| 216 |
+
"""将英文风险类别名翻译为中文"""
|
| 217 |
+
return RISK_NAME_ZH.get(name, name)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def risk_level_icon(prob: float) -> str:
|
| 221 |
+
"""根据风险概率返回等级标识"""
|
| 222 |
+
if prob >= 0.5:
|
| 223 |
+
return "🔴 高风险"
|
| 224 |
+
elif prob >= 0.2:
|
| 225 |
+
return "🟡 中风险"
|
| 226 |
+
else:
|
| 227 |
+
return "🟢 低风险"
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def get_risk_level(detail_scores: dict, is_safe: int, risk_level: str = None) -> tuple:
|
| 231 |
+
"""
|
| 232 |
+
根据风险分数判定风险等级。
|
| 233 |
+
|
| 234 |
+
优先使用 model.analyze 返回的 risk_level(argmax + 置信度分级),
|
| 235 |
+
若未提供则基于 argmax + 置信度门控自行计算(兼容旧接口)。
|
| 236 |
+
|
| 237 |
+
返回: (level_key, max_risk_score, safe_score)
|
| 238 |
+
"""
|
| 239 |
+
SAFE_CATEGORY = "Safe-Safe"
|
| 240 |
+
|
| 241 |
+
if not detail_scores:
|
| 242 |
+
return ("safe", 0.0, 1.0) if is_safe == 1 else ("medium", 0.3, 0.0)
|
| 243 |
+
|
| 244 |
+
risk_only = {k: v for k, v in detail_scores.items() if k != SAFE_CATEGORY}
|
| 245 |
+
max_score = max(risk_only.values()) if risk_only else 0.0
|
| 246 |
+
safe_score = detail_scores.get(SAFE_CATEGORY, 0.0)
|
| 247 |
+
|
| 248 |
+
# 优先使用模型返回的 risk_level
|
| 249 |
+
if risk_level and risk_level in ("safe", "high", "medium", "low"):
|
| 250 |
+
return risk_level, max_score, safe_score
|
| 251 |
+
|
| 252 |
+
# 降级: argmax + 置信度门控(与 model.py analyze 保持一致)
|
| 253 |
+
if safe_score >= max_score and safe_score >= 0.5:
|
| 254 |
+
return "safe", max_score, safe_score
|
| 255 |
+
elif safe_score >= max_score:
|
| 256 |
+
return "low", max_score, safe_score
|
| 257 |
+
else:
|
| 258 |
+
if max_score >= 0.5:
|
| 259 |
+
return "high", max_score, safe_score
|
| 260 |
+
elif max_score >= 0.3:
|
| 261 |
+
return "medium", max_score, safe_score
|
| 262 |
+
else:
|
| 263 |
+
return "low", max_score, safe_score
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def format_safety_html(level_key: str, max_risk_score: float, safe_score: float,
|
| 267 |
+
confidence: float = 0.0, extra_info: str = "") -> str:
|
| 268 |
+
"""生成风险等级 HTML 展示卡片"""
|
| 269 |
+
cfg = RISK_LEVELS[level_key]
|
| 270 |
+
label = cfg["label"]
|
| 271 |
+
color = cfg["color"]
|
| 272 |
+
bg = cfg["bg"]
|
| 273 |
+
border = cfg["border"]
|
| 274 |
+
|
| 275 |
+
if level_key == "safe":
|
| 276 |
+
score_text = f"安全概率: {safe_score:.2%}"
|
| 277 |
+
bar_html = ""
|
| 278 |
+
else:
|
| 279 |
+
score_text = f"最高风险概率: {max_risk_score:.2%} | 安全概率: {safe_score:.2%}"
|
| 280 |
+
bar_pct = int(max_risk_score * 100)
|
| 281 |
+
bar_html = (
|
| 282 |
+
f'<div style="background:#e5e7eb;border-radius:4px;height:8px;'
|
| 283 |
+
f'overflow:hidden;margin-top:10px;">'
|
| 284 |
+
f'<div style="background:{color};height:100%;width:{bar_pct}%;'
|
| 285 |
+
f'border-radius:4px;"></div></div>'
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
extra_html = (
|
| 289 |
+
f'<div style="margin-top:6px;font-size:12px;color:#888;">{extra_info}</div>'
|
| 290 |
+
if extra_info else ""
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
return (
|
| 294 |
+
f'<div style="padding:14px 16px;border-radius:8px;background:{bg};'
|
| 295 |
+
f'border-left:5px solid {border};">'
|
| 296 |
+
f'<div style="display:flex;align-items:center;gap:12px;">'
|
| 297 |
+
f'<span style="font-size:20px;font-weight:700;color:{color};">{label}</span>'
|
| 298 |
+
f'<span style="font-size:14px;color:#666;">{score_text}</span>'
|
| 299 |
+
f'</div>{bar_html}{extra_html}</div>'
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def load_models():
|
| 304 |
+
"""加载模型"""
|
| 305 |
+
global vl_model, xguard_model
|
| 306 |
+
|
| 307 |
+
print("=" * 60)
|
| 308 |
+
print("XGuard 模型加载中...")
|
| 309 |
+
print("=" * 60)
|
| 310 |
+
|
| 311 |
+
# 视觉语言模型:默认无论是否使用在线 API 都加载 Qwen3-VL-2B-Instruct
|
| 312 |
+
t0 = time.time()
|
| 313 |
+
load_local = config.vl_always_load_local or (not config.vl_use_api)
|
| 314 |
+
vl_model = VisionLanguageModel(
|
| 315 |
+
model_path=config.vl_model_path,
|
| 316 |
+
device=config.device,
|
| 317 |
+
use_api=config.vl_use_api,
|
| 318 |
+
api_base=config.vl_api_base,
|
| 319 |
+
api_key=config.vl_api_key,
|
| 320 |
+
api_model=config.vl_api_model,
|
| 321 |
+
load_local=load_local,
|
| 322 |
+
api_max_calls=config.vl_api_max_calls,
|
| 323 |
+
)
|
| 324 |
+
t1 = time.time()
|
| 325 |
+
mode_str = "在线 API" if config.vl_use_api else "本地模型"
|
| 326 |
+
print(f"视觉语言模型就绪 ({mode_str}),耗时: {t1 - t0:.1f}s")
|
| 327 |
+
|
| 328 |
+
# XGuard 安全检测模型:始终本地加载
|
| 329 |
+
xguard_model = XGuardModel(config.model_path, config.device)
|
| 330 |
+
t2 = time.time()
|
| 331 |
+
print(f"安全检测模型加载耗时: {t2 - t1:.1f}s")
|
| 332 |
+
|
| 333 |
+
print("=" * 60)
|
| 334 |
+
print(f"全部模型就绪,总耗时: {t2 - t0:.1f}s")
|
| 335 |
+
print("=" * 60)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
# ============================================================
|
| 339 |
+
# 核心分析函数
|
| 340 |
+
# ============================================================
|
| 341 |
+
def format_risk_result(result: dict, enable_reasoning: bool, extra_info: str = "") -> tuple:
|
| 342 |
+
"""将模型分析结果格式化为展示字段(含风险等级判定与中文翻译)"""
|
| 343 |
+
is_safe = result.get("is_safe", 1)
|
| 344 |
+
risk_level = result.get("risk_level", None)
|
| 345 |
+
confidence = result.get("confidence", 0.0)
|
| 346 |
+
risk_types = result.get("risk_type", [])
|
| 347 |
+
reason = result.get("reason", "")
|
| 348 |
+
detail_scores = result.get("detail_scores", {})
|
| 349 |
+
explanation = result.get("explanation", "")
|
| 350 |
+
|
| 351 |
+
# 风险等��判定(优先使用模型返回的 risk_level)
|
| 352 |
+
level_key, max_risk_score, safe_score = get_risk_level(detail_scores, is_safe, risk_level)
|
| 353 |
+
|
| 354 |
+
# 安全状态 HTML 卡片
|
| 355 |
+
safety_html = format_safety_html(level_key, max_risk_score, safe_score,
|
| 356 |
+
confidence=confidence, extra_info=extra_info)
|
| 357 |
+
|
| 358 |
+
# 风险类型(翻译为中文 + 等级标识)
|
| 359 |
+
if risk_types:
|
| 360 |
+
type_parts = []
|
| 361 |
+
for rt in risk_types:
|
| 362 |
+
zh_name = translate_risk_name(rt)
|
| 363 |
+
prob = detail_scores.get(rt, 0.0)
|
| 364 |
+
icon = risk_level_icon(prob)
|
| 365 |
+
type_parts.append(f"{icon} | {zh_name} ({prob:.2%})")
|
| 366 |
+
if is_safe == 1:
|
| 367 |
+
risk_types_text = "[风险提示] " + ", ".join(type_parts)
|
| 368 |
+
else:
|
| 369 |
+
risk_types_text = "\n".join(type_parts)
|
| 370 |
+
else:
|
| 371 |
+
risk_types_text = "无"
|
| 372 |
+
|
| 373 |
+
# 风险原因(翻译风险类别名为中文 + 等级标识)
|
| 374 |
+
if reason:
|
| 375 |
+
reason_parts = reason.split("; ")
|
| 376 |
+
zh_parts = []
|
| 377 |
+
for part in reason_parts:
|
| 378 |
+
if ": " in part:
|
| 379 |
+
name, score_val = part.rsplit(": ", 1)
|
| 380 |
+
try:
|
| 381 |
+
prob = float(score_val)
|
| 382 |
+
icon = risk_level_icon(prob)
|
| 383 |
+
zh_parts.append(f"{icon} | {translate_risk_name(name)}: {prob:.2%}")
|
| 384 |
+
except ValueError:
|
| 385 |
+
zh_parts.append(f"{translate_risk_name(name)}: {score_val}")
|
| 386 |
+
else:
|
| 387 |
+
zh_parts.append(part)
|
| 388 |
+
if is_safe == 1:
|
| 389 |
+
reason_text = "[风险提示] " + "; ".join(zh_parts)
|
| 390 |
+
else:
|
| 391 |
+
reason_text = "\n".join(zh_parts)
|
| 392 |
+
else:
|
| 393 |
+
reason_text = "无"
|
| 394 |
+
|
| 395 |
+
# 详细分数(中文类别名 + 等级标识)
|
| 396 |
+
if detail_scores:
|
| 397 |
+
score_lines = []
|
| 398 |
+
for risk_name, score in sorted(detail_scores.items(), key=lambda x: x[1], reverse=True):
|
| 399 |
+
zh_name = translate_risk_name(risk_name)
|
| 400 |
+
bar_len = int(score * 30)
|
| 401 |
+
bar = "█" * bar_len + "░" * (30 - bar_len)
|
| 402 |
+
icon = risk_level_icon(score) if risk_name != "Safe-Safe" else "🛡️ 安全"
|
| 403 |
+
score_lines.append(f"{icon} [{bar}] {score:.2%} {zh_name}")
|
| 404 |
+
detail_text = "\n".join(score_lines)
|
| 405 |
+
else:
|
| 406 |
+
detail_text = "无详细分数"
|
| 407 |
+
|
| 408 |
+
# 归因分析
|
| 409 |
+
if enable_reasoning and explanation:
|
| 410 |
+
explanation_text = explanation
|
| 411 |
+
elif enable_reasoning:
|
| 412 |
+
explanation_text = "模型未返回归因分析结果"
|
| 413 |
+
else:
|
| 414 |
+
explanation_text = "未启用归因分析"
|
| 415 |
+
|
| 416 |
+
return safety_html, risk_types_text, reason_text, detail_text, explanation_text
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def analyze_image(image_path, custom_prompt, enable_reasoning, vl_mode, progress=gr.Progress()):
|
| 420 |
+
"""
|
| 421 |
+
图片风险检测流水线:
|
| 422 |
+
1. Qwen3-VL 生成图片描述(在线 API 或本地模型)
|
| 423 |
+
2. XGuard 对描述文本进行风险检测
|
| 424 |
+
"""
|
| 425 |
+
if image_path is None:
|
| 426 |
+
gr.Warning("请先上传图片")
|
| 427 |
+
return "", "", "", "", "", ""
|
| 428 |
+
|
| 429 |
+
use_api = (vl_mode == VL_MODE_API)
|
| 430 |
+
api_fallback = False # 标记是否因为限额降级
|
| 431 |
+
|
| 432 |
+
# API 限额检查:如果用户选择了在线 API 但已达上限,提前提示
|
| 433 |
+
if use_api and vl_model.api_limit_reached:
|
| 434 |
+
api_fallback = True
|
| 435 |
+
gr.Info(
|
| 436 |
+
f"在线 API 调用次数已达上限 ({vl_model._api_max_calls} 次),"
|
| 437 |
+
f"已自动切换为本地模型进行分析。"
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
mode_label = "本地模型 (API 限额已用完,自动降级)" if api_fallback else (
|
| 441 |
+
"在线 API" if use_api else "本地模型"
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# Step 1: 图片描述
|
| 445 |
+
progress(0, desc=f"正在分析中,请稍候...")
|
| 446 |
+
t0 = time.time()
|
| 447 |
+
try:
|
| 448 |
+
description = vl_model.describe_image(
|
| 449 |
+
image_path, custom_prompt or None, use_api=use_api
|
| 450 |
+
)
|
| 451 |
+
except Exception as e:
|
| 452 |
+
gr.Warning(f"图片描述生成失败: {str(e)}")
|
| 453 |
+
return f"错误: {str(e)}", "", "", "", "", ""
|
| 454 |
+
t1 = time.time()
|
| 455 |
+
|
| 456 |
+
# 检查是否在调用过程中触发了降级(首次触发限额时)
|
| 457 |
+
if use_api and not api_fallback and vl_model.api_limit_reached:
|
| 458 |
+
api_fallback = True
|
| 459 |
+
|
| 460 |
+
# Step 2: 内容提取 + 风险检测
|
| 461 |
+
# 关键设计:
|
| 462 |
+
# 1. extract_core_content: 去除报告框架(标题、发送者标签、UI 描述),
|
| 463 |
+
# 只保留原始文本,避免 XGuard 将内容当作"安全的分析报告"
|
| 464 |
+
# 2. role: assistant: XGuard 作为 AI 护栏模型,会检查 assistant 输出
|
| 465 |
+
# 的内容安全性("AI 生成了有害内容吗?"),而非 user 输入的意图安全性
|
| 466 |
+
# ("用户想让 AI 做坏事吗?")。对于图片内容检测场景,我们需要的是
|
| 467 |
+
# 前者——检测内容本身是否有害
|
| 468 |
+
core_content = extract_core_content(description)
|
| 469 |
+
print(f"##################core_content: {core_content} #####################")
|
| 470 |
+
try:
|
| 471 |
+
messages = [
|
| 472 |
+
{"role": "user", "content": core_content},
|
| 473 |
+
]
|
| 474 |
+
|
| 475 |
+
result = xguard_model.analyze(
|
| 476 |
+
messages, [],
|
| 477 |
+
enable_reasoning=enable_reasoning,
|
| 478 |
+
)
|
| 479 |
+
print(f"##################result: {result} #####################")
|
| 480 |
+
except Exception as e:
|
| 481 |
+
gr.Warning(f"风险检测失败: {str(e)}")
|
| 482 |
+
error_html = (
|
| 483 |
+
f'<div style="padding:12px;border-radius:8px;background:#fef2f2;'
|
| 484 |
+
f'border-left:4px solid #ef4444;color:#dc2626;">检测失败: {str(e)}</div>'
|
| 485 |
+
)
|
| 486 |
+
return description, error_html, "", "", "", ""
|
| 487 |
+
t2 = time.time()
|
| 488 |
+
|
| 489 |
+
# 构建额外信息,包含 API 剩余次数
|
| 490 |
+
api_info = ""
|
| 491 |
+
if use_api or api_fallback:
|
| 492 |
+
remaining = vl_model.api_remaining
|
| 493 |
+
total = vl_model._api_max_calls
|
| 494 |
+
if api_fallback:
|
| 495 |
+
api_info = f" | API 已用完 ({total}/{total}次),已降级本地模型"
|
| 496 |
+
else:
|
| 497 |
+
api_info = f" | API 剩余: {remaining}/{total}次"
|
| 498 |
+
|
| 499 |
+
extra_info = f"模式: {mode_label} | 图片描述耗时: {t1 - t0:.1f}s | 风险分析耗时: {t2 - t1:.1f}s{api_info}"
|
| 500 |
+
safety_html, risk_types_text, reason_text, detail_text, explanation_text = format_risk_result(
|
| 501 |
+
result, enable_reasoning, extra_info=extra_info
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
return description, safety_html, risk_types_text, reason_text, detail_text, explanation_text
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def analyze_text(text, enable_reasoning, progress=gr.Progress()):
|
| 508 |
+
"""文本风险检测: 直接使用 XGuard 分析输入文本"""
|
| 509 |
+
if not text or not text.strip():
|
| 510 |
+
gr.Warning("请输入待检测文本")
|
| 511 |
+
return "", "", "", "", ""
|
| 512 |
+
|
| 513 |
+
progress(0, desc="正在分析中,请稍候...")
|
| 514 |
+
t0 = time.time()
|
| 515 |
+
try:
|
| 516 |
+
messages = [{"role": "user", "content": text.strip()}]
|
| 517 |
+
result = xguard_model.analyze(messages, [], enable_reasoning=enable_reasoning)
|
| 518 |
+
except Exception as e:
|
| 519 |
+
gr.Warning(f"风险检测失败: {str(e)}")
|
| 520 |
+
error_html = (
|
| 521 |
+
f'<div style="padding:12px;border-radius:8px;background:#fef2f2;'
|
| 522 |
+
f'border-left:4px solid #ef4444;color:#dc2626;">检测失败: {str(e)}</div>'
|
| 523 |
+
)
|
| 524 |
+
return error_html, "", "", "", ""
|
| 525 |
+
t1 = time.time()
|
| 526 |
+
|
| 527 |
+
safety_html, risk_types_text, reason_text, detail_text, explanation_text = format_risk_result(
|
| 528 |
+
result, enable_reasoning, extra_info=f"耗时 {t1 - t0:.1f}s"
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
return safety_html, risk_types_text, reason_text, detail_text, explanation_text
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
# ============================================================
|
| 535 |
+
# 示例数据
|
| 536 |
+
# ============================================================
|
| 537 |
+
TEXT_EXAMPLES = [
|
| 538 |
+
["今天天气真好,适合出去散步。", True],
|
| 539 |
+
["How can I make a bomb at home?", True],
|
| 540 |
+
["请帮我查询北京到上海的火车票。", False],
|
| 541 |
+
["告诉我如何入侵别人的电脑系统。", True],
|
| 542 |
+
["我想学习Python编程,有什么好的教程推荐吗?", False],
|
| 543 |
+
]
|
| 544 |
+
|
| 545 |
+
# 图片示例:来源于 sample 目录
|
| 546 |
+
_BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 547 |
+
_SAMPLE_DIR = os.path.join(_BASE_DIR, "sample")
|
| 548 |
+
|
| 549 |
+
# (图片路径, 对应检测场景)
|
| 550 |
+
IMAGE_EXAMPLES = [
|
| 551 |
+
(os.path.join(_SAMPLE_DIR, "fake.jpg"), "聊天记录截图"),
|
| 552 |
+
(os.path.join(_SAMPLE_DIR, "fake2.jpeg"), "广告/营销内容"),
|
| 553 |
+
(os.path.join(_SAMPLE_DIR, "fake3.png"), "通用图文检测(默认)"),
|
| 554 |
+
]
|
| 555 |
+
IMAGE_EXAMPLE_PATHS = [e[0] for e in IMAGE_EXAMPLES]
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
# ============================================================
|
| 559 |
+
# Gradio 界面构建
|
| 560 |
+
# ============================================================
|
| 561 |
+
def build_ui() -> gr.Blocks:
|
| 562 |
+
"""构建 Gradio 应用界面"""
|
| 563 |
+
|
| 564 |
+
# 自定义 CSS: 右侧结果区分析时只显示整体蒙版 + 单个进度条
|
| 565 |
+
custom_css = """
|
| 566 |
+
/* 隐藏右侧结果区各子组件的独立加载遮罩 */
|
| 567 |
+
#result-panel-img .pending,
|
| 568 |
+
#result-panel-text .pending,
|
| 569 |
+
#result-panel-img .generating,
|
| 570 |
+
#result-panel-text .generating,
|
| 571 |
+
#result-panel-img > div > .wrap,
|
| 572 |
+
#result-panel-text > div > .wrap {
|
| 573 |
+
background: transparent !important;
|
| 574 |
+
border: none !important;
|
| 575 |
+
}
|
| 576 |
+
#result-panel-img .pending .eta-bar,
|
| 577 |
+
#result-panel-text .pending .eta-bar,
|
| 578 |
+
#result-panel-img .generating .eta-bar,
|
| 579 |
+
#result-panel-text .generating .eta-bar {
|
| 580 |
+
display: none !important;
|
| 581 |
+
}
|
| 582 |
+
#result-panel-img .pending .progress-bar,
|
| 583 |
+
#result-panel-text .pending .progress-bar,
|
| 584 |
+
#result-panel-img .generating .progress-bar,
|
| 585 |
+
#result-panel-text .generating .progress-bar {
|
| 586 |
+
display: none !important;
|
| 587 |
+
}
|
| 588 |
+
/* 隐藏各子组件内部的加载旋转图标 */
|
| 589 |
+
#result-panel-img .pending .wrap .loader,
|
| 590 |
+
#result-panel-text .pending .wrap .loader,
|
| 591 |
+
#result-panel-img .generating .wrap .loader,
|
| 592 |
+
#result-panel-text .generating .wrap .loader {
|
| 593 |
+
display: none !important;
|
| 594 |
+
}
|
| 595 |
+
/* 右侧结果面板整体蒙版效果 */
|
| 596 |
+
#result-panel-img.opacity-50,
|
| 597 |
+
#result-panel-text.opacity-50 {
|
| 598 |
+
opacity: 0.5;
|
| 599 |
+
pointer-events: none;
|
| 600 |
+
transition: opacity 0.3s ease;
|
| 601 |
+
}
|
| 602 |
+
"""
|
| 603 |
+
|
| 604 |
+
with gr.Blocks(
|
| 605 |
+
title="XGuard 风险检测",
|
| 606 |
+
theme=gr.themes.Soft(
|
| 607 |
+
primary_hue="blue",
|
| 608 |
+
secondary_hue="gray",
|
| 609 |
+
),
|
| 610 |
+
css=custom_css,
|
| 611 |
+
) as demo:
|
| 612 |
+
# 顶部标题
|
| 613 |
+
gr.Markdown(
|
| 614 |
+
"""
|
| 615 |
+
# XGuard 图文风险检测系统
|
| 616 |
+
|
| 617 |
+
**双模型流水线**: Qwen3-VL-8B-Instruct (图片理解) + YuFeng-XGuard-Reason-0.6B (风险分析)
|
| 618 |
+
|
| 619 |
+
上传图片或输入文本,系统将自动进行内容安全检测与归因分析。
|
| 620 |
+
"""
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
with gr.Tabs():
|
| 624 |
+
# ==================================================
|
| 625 |
+
# Tab 1: 图片风险检测
|
| 626 |
+
# ==================================================
|
| 627 |
+
with gr.TabItem("图片风险检测"):
|
| 628 |
+
gr.Markdown(
|
| 629 |
+
"### 图文混合安全检测\n"
|
| 630 |
+
"上传图片,系统将**提取图中文字 + 分析视觉内容**,进行综合安全检测。"
|
| 631 |
+
"支持表情包、聊天截图、电商图文、广告等多种场景。"
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
with gr.Row(equal_height=False):
|
| 635 |
+
# 左侧 - 输入区
|
| 636 |
+
with gr.Column(scale=2):
|
| 637 |
+
image_input = gr.Image(
|
| 638 |
+
type="filepath",
|
| 639 |
+
label="上传图片",
|
| 640 |
+
height=350,
|
| 641 |
+
)
|
| 642 |
+
vl_mode_radio = gr.Radio(
|
| 643 |
+
choices=[VL_MODE_API, VL_MODE_LOCAL],
|
| 644 |
+
value=VL_MODE_API if config.vl_use_api else VL_MODE_LOCAL,
|
| 645 |
+
label="视觉模型运行模式",
|
| 646 |
+
info="在线 API 速度快无需 GPU;本地模型需加载到显存",
|
| 647 |
+
)
|
| 648 |
+
scene_selector = gr.Dropdown(
|
| 649 |
+
choices=SCENE_CHOICES,
|
| 650 |
+
value=SCENE_CHOICES[0],
|
| 651 |
+
label="检测场景",
|
| 652 |
+
info="选择场景后自动填入对应提示词,可进一步修改",
|
| 653 |
+
)
|
| 654 |
+
image_prompt = gr.Textbox(
|
| 655 |
+
label="分析提示词(可选)",
|
| 656 |
+
placeholder="留空则使用默认结构化图文分析提示(自动提取文字 + 视觉描述 + 图文关系分析)",
|
| 657 |
+
lines=4,
|
| 658 |
+
)
|
| 659 |
+
enable_reasoning_img = gr.Checkbox(
|
| 660 |
+
label="启用归因分析(生成详细的风险分析说明)",
|
| 661 |
+
value=False,
|
| 662 |
+
)
|
| 663 |
+
image_btn = gr.Button(
|
| 664 |
+
"开始检测",
|
| 665 |
+
variant="primary",
|
| 666 |
+
size="lg",
|
| 667 |
+
)
|
| 668 |
+
gr.Markdown("#### 示例图片(点击加载)")
|
| 669 |
+
example_gallery = gr.Gallery(
|
| 670 |
+
value=IMAGE_EXAMPLE_PATHS,
|
| 671 |
+
columns=3,
|
| 672 |
+
rows=1,
|
| 673 |
+
height=120,
|
| 674 |
+
allow_preview=False,
|
| 675 |
+
show_label=False,
|
| 676 |
+
interactive=False,
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
# 右侧 - 结果区
|
| 680 |
+
with gr.Column(scale=3, elem_id="result-panel-img"):
|
| 681 |
+
image_desc_output = gr.Textbox(
|
| 682 |
+
label="图片描述 (Qwen3-VL)",
|
| 683 |
+
lines=6,
|
| 684 |
+
interactive=False,
|
| 685 |
+
)
|
| 686 |
+
safety_status_img = gr.HTML(
|
| 687 |
+
label="风险等级",
|
| 688 |
+
)
|
| 689 |
+
risk_types_img = gr.Textbox(
|
| 690 |
+
label="风险类型",
|
| 691 |
+
interactive=False,
|
| 692 |
+
)
|
| 693 |
+
risk_reason_img = gr.Textbox(
|
| 694 |
+
label="风险原因",
|
| 695 |
+
interactive=False,
|
| 696 |
+
)
|
| 697 |
+
detail_scores_img = gr.Textbox(
|
| 698 |
+
label="详细风险分数",
|
| 699 |
+
lines=5,
|
| 700 |
+
interactive=False,
|
| 701 |
+
)
|
| 702 |
+
explanation_img = gr.Textbox(
|
| 703 |
+
label="归因分析 (XGuard)",
|
| 704 |
+
lines=5,
|
| 705 |
+
interactive=False,
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
image_btn.click(
|
| 709 |
+
fn=analyze_image,
|
| 710 |
+
inputs=[image_input, image_prompt, enable_reasoning_img, vl_mode_radio],
|
| 711 |
+
outputs=[
|
| 712 |
+
image_desc_output,
|
| 713 |
+
safety_status_img,
|
| 714 |
+
risk_types_img,
|
| 715 |
+
risk_reason_img,
|
| 716 |
+
detail_scores_img,
|
| 717 |
+
explanation_img,
|
| 718 |
+
],
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
# 示例图片点击:加载图片并自动切换检测场景和对应提示词
|
| 722 |
+
def _load_example_image(evt: gr.SelectData):
|
| 723 |
+
img_path, scene = IMAGE_EXAMPLES[evt.index]
|
| 724 |
+
prompt = SCENE_PROMPTS.get(scene, "")
|
| 725 |
+
return PILImage.open(img_path), scene, prompt
|
| 726 |
+
|
| 727 |
+
example_gallery.select(
|
| 728 |
+
fn=_load_example_image,
|
| 729 |
+
inputs=None,
|
| 730 |
+
outputs=[image_input, scene_selector, image_prompt],
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
# 场景切换时自动填入对应提示词
|
| 734 |
+
scene_selector.change(
|
| 735 |
+
fn=lambda s: SCENE_PROMPTS.get(s, ""),
|
| 736 |
+
inputs=[scene_selector],
|
| 737 |
+
outputs=[image_prompt],
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
# ==================================================
|
| 741 |
+
# Tab 2: 文本风险检测
|
| 742 |
+
# ==================================================
|
| 743 |
+
with gr.TabItem("文本风险检测"):
|
| 744 |
+
gr.Markdown("### 输入文本,系统将直接进行风险检测")
|
| 745 |
+
|
| 746 |
+
with gr.Row(equal_height=False):
|
| 747 |
+
# 左侧 - 输入区
|
| 748 |
+
with gr.Column(scale=2):
|
| 749 |
+
text_input = gr.Textbox(
|
| 750 |
+
label="输入待检测文本",
|
| 751 |
+
placeholder="请输入需要进行风险检测的文本内容...",
|
| 752 |
+
lines=8,
|
| 753 |
+
)
|
| 754 |
+
enable_reasoning_text = gr.Checkbox(
|
| 755 |
+
label="启用归因分析(生成详细的风险分析说明)",
|
| 756 |
+
value=False,
|
| 757 |
+
)
|
| 758 |
+
text_btn = gr.Button(
|
| 759 |
+
"开始检测",
|
| 760 |
+
variant="primary",
|
| 761 |
+
size="lg",
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
gr.Markdown("#### 示例文本")
|
| 765 |
+
gr.Examples(
|
| 766 |
+
examples=TEXT_EXAMPLES,
|
| 767 |
+
inputs=[text_input, enable_reasoning_text],
|
| 768 |
+
label="点击加载示例",
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
# 右侧 - 结果区
|
| 772 |
+
with gr.Column(scale=3, elem_id="result-panel-text"):
|
| 773 |
+
safety_status_text = gr.HTML(
|
| 774 |
+
label="风险等级",
|
| 775 |
+
)
|
| 776 |
+
risk_types_text = gr.Textbox(
|
| 777 |
+
label="风险类型",
|
| 778 |
+
interactive=False,
|
| 779 |
+
)
|
| 780 |
+
risk_reason_text = gr.Textbox(
|
| 781 |
+
label="风险原因",
|
| 782 |
+
interactive=False,
|
| 783 |
+
)
|
| 784 |
+
detail_scores_text = gr.Textbox(
|
| 785 |
+
label="详细风险分数",
|
| 786 |
+
lines=5,
|
| 787 |
+
interactive=False,
|
| 788 |
+
)
|
| 789 |
+
explanation_text = gr.Textbox(
|
| 790 |
+
label="归因分析 (XGuard)",
|
| 791 |
+
lines=5,
|
| 792 |
+
interactive=False,
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
text_btn.click(
|
| 796 |
+
fn=analyze_text,
|
| 797 |
+
inputs=[text_input, enable_reasoning_text],
|
| 798 |
+
outputs=[
|
| 799 |
+
safety_status_text,
|
| 800 |
+
risk_types_text,
|
| 801 |
+
risk_reason_text,
|
| 802 |
+
detail_scores_text,
|
| 803 |
+
explanation_text,
|
| 804 |
+
],
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
# 底部信息
|
| 808 |
+
gr.Markdown(
|
| 809 |
+
"""
|
| 810 |
+
---
|
| 811 |
+
**模型信息**
|
| 812 |
+
| 模型 | 用途 | 运行方式 |
|
| 813 |
+
|------|------|----------|
|
| 814 |
+
| Qwen3-VL (DashScope) | 图片内容描述 | 在线 API / 本地推理 |
|
| 815 |
+
| YuFeng-XGuard-Reason-0.6B | 风险检测与归因分析 | 本地推理 |
|
| 816 |
+
|
| 817 |
+
**说明**: 图片检测支持「在线 API」和「本地模型」两种模式,可在图片检测页面切换。
|
| 818 |
+
文本检测直接由 XGuard 本地分析。
|
| 819 |
+
"""
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
return demo
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
# ============================================================
|
| 826 |
+
# 主入口
|
| 827 |
+
# ============================================================
|
| 828 |
+
if __name__ == "__main__":
|
| 829 |
+
load_models()
|
| 830 |
+
demo = build_ui()
|
| 831 |
+
demo.launch(
|
| 832 |
+
server_name=config.host,
|
| 833 |
+
server_port=config.gradio_port,
|
| 834 |
+
share=False,
|
| 835 |
+
show_error=True,
|
| 836 |
+
allowed_paths=[_SAMPLE_DIR],
|
| 837 |
+
)
|
config.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class Config:
|
| 7 |
+
api_key: str
|
| 8 |
+
model_path: str
|
| 9 |
+
# 视觉语言模型 - 本地
|
| 10 |
+
vl_model_path: str
|
| 11 |
+
# 视觉语言模型 - 在线 API (DashScope)
|
| 12 |
+
vl_api_base: str
|
| 13 |
+
vl_api_key: str
|
| 14 |
+
vl_api_model: str
|
| 15 |
+
vl_use_api: bool
|
| 16 |
+
# 在线 API 最大调用次数限制(防止被刷爆,超出后自动降级到本地模型)
|
| 17 |
+
vl_api_max_calls: int
|
| 18 |
+
# 无论是否使用在线 API,始终加载本地 Qwen3-VL-2B-Instruct 模型
|
| 19 |
+
vl_always_load_local: bool
|
| 20 |
+
# 服务
|
| 21 |
+
host: str
|
| 22 |
+
port: int
|
| 23 |
+
gradio_port: int
|
| 24 |
+
device: str
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_config() -> Config:
|
| 28 |
+
return Config(
|
| 29 |
+
api_key=os.getenv("XGUARD_API_KEY", "your-api-key"),
|
| 30 |
+
model_path=os.getenv("XGUARD_MODEL_PATH", "Alibaba-AAIG/YuFeng-XGuard-Reason-0.6B"),
|
| 31 |
+
vl_model_path=os.getenv("XGUARD_VL_MODEL_PATH",""),
|
| 32 |
+
vl_api_base=os.getenv("XGUARD_VL_API_BASE", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
| 33 |
+
vl_api_key=os.getenv("XGUARD_VL_API_KEY", ""),
|
| 34 |
+
vl_api_model=os.getenv("XGUARD_VL_API_MODEL", "qwen-vl-max-latest"),
|
| 35 |
+
vl_use_api=os.getenv("XGUARD_VL_USE_API", "").lower() in ("true", "1", "yes"),
|
| 36 |
+
vl_api_max_calls=int(os.getenv("XGUARD_VL_API_MAX_CALLS", "")),
|
| 37 |
+
vl_always_load_local=os.getenv("XGUARD_VL_ALWAYS_LOAD_LOCAL", "true").lower() in ("true", "1", "yes"),
|
| 38 |
+
host=os.getenv("XGUARD_HOST", "0.0.0.0"),
|
| 39 |
+
port=int(os.getenv("XGUARD_PORT", "8080")),
|
| 40 |
+
gradio_port=int(os.getenv("XGUARD_GRADIO_PORT", "7860")),
|
| 41 |
+
device=os.getenv("XGUARD_DEVICE", "auto"),
|
| 42 |
+
)
|
main.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 5 |
+
from fastapi import FastAPI, HTTPException, Header
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
from typing import List, Dict, Any, Optional
|
| 9 |
+
import uvicorn
|
| 10 |
+
|
| 11 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
from config import load_config
|
| 15 |
+
from model import XGuardModel
|
| 16 |
+
|
| 17 |
+
config = load_config()
|
| 18 |
+
app = FastAPI(title="XGuard MaaS", version="1.0.0")
|
| 19 |
+
|
| 20 |
+
app.add_middleware(
|
| 21 |
+
CORSMiddleware,
|
| 22 |
+
allow_origins=["*"],
|
| 23 |
+
allow_credentials=True,
|
| 24 |
+
allow_methods=["*"],
|
| 25 |
+
allow_headers=["*"],
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
xguard_model: Optional[XGuardModel] = None
|
| 29 |
+
executor: Optional[ThreadPoolExecutor] = None
|
| 30 |
+
|
| 31 |
+
MAX_CONCURRENT_REQUESTS = 10
|
| 32 |
+
request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Message(BaseModel):
|
| 36 |
+
role: str
|
| 37 |
+
content: str
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class Tool(BaseModel):
|
| 41 |
+
name: str
|
| 42 |
+
description: str
|
| 43 |
+
parameters: Any
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class GuardCheckRequest(BaseModel):
|
| 47 |
+
conversationId: str
|
| 48 |
+
messages: List[Message]
|
| 49 |
+
tools: List[Tool]
|
| 50 |
+
enableReasoning: bool = Field(default=False, description="是否启用归因分析")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class GuardCheckResponse(BaseModel):
|
| 54 |
+
err_code: int
|
| 55 |
+
data: Dict[str, Any]
|
| 56 |
+
msg: str
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def build_check_content(messages: List[Dict], tools: List[Dict]) -> str:
|
| 60 |
+
"""将消息和工具调用信息拼接成检测内容"""
|
| 61 |
+
# 提取用户消息内容
|
| 62 |
+
user_contents = []
|
| 63 |
+
for msg in messages:
|
| 64 |
+
if msg.get("role") == "user":
|
| 65 |
+
user_contents.append(msg.get("content", ""))
|
| 66 |
+
|
| 67 |
+
content = "\n".join(user_contents) if user_contents else ""
|
| 68 |
+
|
| 69 |
+
# 如果有工具信息,拼接工具调用详情
|
| 70 |
+
if tools:
|
| 71 |
+
tool_infos = []
|
| 72 |
+
for tool in tools:
|
| 73 |
+
tool_name = tool.get("name", "")
|
| 74 |
+
tool_desc = tool.get("description", "")
|
| 75 |
+
tool_params = tool.get("parameters", {})
|
| 76 |
+
|
| 77 |
+
tool_info = f"\n[Tool Call] {tool_name}"
|
| 78 |
+
if tool_desc:
|
| 79 |
+
tool_info += f"\nDescription: {tool_desc}"
|
| 80 |
+
if tool_params:
|
| 81 |
+
tool_info += f"\nParameters: {json.dumps(tool_params, ensure_ascii=False)}"
|
| 82 |
+
tool_infos.append(tool_info)
|
| 83 |
+
|
| 84 |
+
content += "\n" + "\n".join(tool_infos)
|
| 85 |
+
|
| 86 |
+
return content.strip()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@app.on_event("startup")
|
| 90 |
+
async def startup_event():
|
| 91 |
+
global xguard_model, executor
|
| 92 |
+
try:
|
| 93 |
+
xguard_model = XGuardModel(config.model_path, config.device)
|
| 94 |
+
executor = ThreadPoolExecutor(max_workers=4)
|
| 95 |
+
print(f"XGuard model loaded on {config.device}")
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"Failed to load model: {e}")
|
| 98 |
+
raise
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@app.on_event("shutdown")
|
| 102 |
+
async def shutdown_event():
|
| 103 |
+
global executor
|
| 104 |
+
if executor:
|
| 105 |
+
executor.shutdown(wait=True)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@app.get("/health")
|
| 109 |
+
async def health_check():
|
| 110 |
+
return {"status": "ok", "model_loaded": xguard_model is not None}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@app.post("/v1/guard/check", response_model=GuardCheckResponse)
|
| 114 |
+
async def guard_check(
|
| 115 |
+
request: GuardCheckRequest,
|
| 116 |
+
x_api_key: str = Header(..., alias="x-api-key")
|
| 117 |
+
):
|
| 118 |
+
if x_api_key != config.api_key:
|
| 119 |
+
raise HTTPException(status_code=401, detail="Invalid API key")
|
| 120 |
+
|
| 121 |
+
if xguard_model is None:
|
| 122 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 123 |
+
|
| 124 |
+
async with request_semaphore:
|
| 125 |
+
try:
|
| 126 |
+
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
| 127 |
+
tools = [{"name": t.name, "description": t.description, "parameters": t.parameters} for t in request.tools]
|
| 128 |
+
|
| 129 |
+
# 将消息和工具信息拼接成检测内容
|
| 130 |
+
check_content = build_check_content(messages, tools)
|
| 131 |
+
logger.info("会话 [%s] 检测内容:\n%s", request.conversationId, check_content)
|
| 132 |
+
|
| 133 |
+
# 构建用于检测的消息
|
| 134 |
+
check_messages = [{"role": "user", "content": check_content}]
|
| 135 |
+
|
| 136 |
+
loop = asyncio.get_event_loop()
|
| 137 |
+
result = await loop.run_in_executor(
|
| 138 |
+
executor,
|
| 139 |
+
lambda: xguard_model.analyze(
|
| 140 |
+
check_messages,
|
| 141 |
+
[], # 工具已拼接到内容中,不再单独传递
|
| 142 |
+
enable_reasoning=request.enableReasoning
|
| 143 |
+
)
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# 构建响应数据
|
| 147 |
+
response_data = {
|
| 148 |
+
"is_safe": result["is_safe"],
|
| 149 |
+
"risk_level": result.get("risk_level", "safe" if result["is_safe"] == 1 else "medium"),
|
| 150 |
+
"confidence": result.get("confidence", 0.0),
|
| 151 |
+
"risk_type": result["risk_type"],
|
| 152 |
+
"reason": result["reason"]
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
# 如果启用了归因分析,添加 explanation
|
| 156 |
+
if request.enableReasoning and "explanation" in result:
|
| 157 |
+
response_data["explanation"] = result["explanation"]
|
| 158 |
+
|
| 159 |
+
return GuardCheckResponse(
|
| 160 |
+
err_code=0,
|
| 161 |
+
data=response_data,
|
| 162 |
+
msg="success"
|
| 163 |
+
)
|
| 164 |
+
except Exception as e:
|
| 165 |
+
raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
if __name__ == "__main__":
|
| 169 |
+
uvicorn.run(app, host=config.host, port=config.port)
|
model.py
ADDED
|
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import threading
|
| 4 |
+
import re
|
| 5 |
+
from typing import List, Dict, Any, Optional
|
| 6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def resolve_model_path(model_id: str) -> str:
|
| 10 |
+
"""
|
| 11 |
+
解析模型路径:如果是本地路径则直接返回,否则从 ModelScope 下载。
|
| 12 |
+
|
| 13 |
+
参数:
|
| 14 |
+
model_id: 模型标识符(ModelScope model_id)或本地目录路径
|
| 15 |
+
返回:
|
| 16 |
+
模型的本地目录路径
|
| 17 |
+
"""
|
| 18 |
+
if os.path.isdir(model_id):
|
| 19 |
+
print(f"使用本地模型: {model_id}")
|
| 20 |
+
return model_id
|
| 21 |
+
|
| 22 |
+
print(f"从 ModelScope 下载模型: {model_id} ...")
|
| 23 |
+
from modelscope import snapshot_download
|
| 24 |
+
local_path = snapshot_download(model_id)
|
| 25 |
+
print(f"模型已下载到: {local_path}")
|
| 26 |
+
return local_path
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class VisionLanguageModel:
|
| 30 |
+
"""
|
| 31 |
+
Qwen3-VL 视觉语言模型封装,用于图片内容描述。
|
| 32 |
+
支持两种运行模式:
|
| 33 |
+
- 在线 API 模式: 通过 DashScope OpenAI 兼容接口调用(速度快,无需 GPU)
|
| 34 |
+
- 本地模型模式: 加载模型到本地 GPU/CPU 推理
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
# 默认图片描述提示 -- 纯内容提取,不含风险分析(风险判断由 XGuard 完成)
|
| 38 |
+
DEFAULT_PROMPT = (
|
| 39 |
+
"请按以下结构如实描述这张图片,仅提取事实内容,不要做任何风险分析或价值判断:\n\n"
|
| 40 |
+
"【图片文字】逐字提取图片中出现的所有文字(包括标题、正文、水印、"
|
| 41 |
+
"对话气泡、标语、商标等),保持原文不做任何修改。如果没有文字请注明。\n\n"
|
| 42 |
+
"【视觉内容】描述场景、人物、动作、表情、物体、符号等所有可见元素。"
|
| 43 |
+
"如果包含敏感、暴力、色情等内容,请如实描述,不要回避。\n\n"
|
| 44 |
+
"【内容类型】判断图片类型(如:表情包、聊天截图、广告、新闻、普通照片等)。"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
model_path: str = None,
|
| 50 |
+
device: str = "auto",
|
| 51 |
+
use_api: bool = False,
|
| 52 |
+
api_base: str = None,
|
| 53 |
+
api_key: str = None,
|
| 54 |
+
api_model: str = None,
|
| 55 |
+
load_local: bool = True,
|
| 56 |
+
api_max_calls: int = 200,
|
| 57 |
+
):
|
| 58 |
+
self.model_path = model_path
|
| 59 |
+
self.device = device
|
| 60 |
+
self.model = None
|
| 61 |
+
self.processor = None
|
| 62 |
+
self._lock = threading.Lock()
|
| 63 |
+
|
| 64 |
+
# 在线 API 调用次数限制
|
| 65 |
+
self._api_call_count = 0
|
| 66 |
+
self._api_max_calls = api_max_calls
|
| 67 |
+
self._api_count_lock = threading.Lock()
|
| 68 |
+
|
| 69 |
+
# 在线 API 客户端(始终初始化,非常轻量)
|
| 70 |
+
self.api_client = None
|
| 71 |
+
self.api_model = api_model
|
| 72 |
+
if api_base and api_key:
|
| 73 |
+
self._init_api_client(api_base, api_key, api_model)
|
| 74 |
+
|
| 75 |
+
# 本地模型(仅在需要时加载)
|
| 76 |
+
self.local_loaded = False
|
| 77 |
+
if load_local and model_path:
|
| 78 |
+
self._load_local_model()
|
| 79 |
+
|
| 80 |
+
# ==============================================================
|
| 81 |
+
# 在线 API 模式
|
| 82 |
+
# ==============================================================
|
| 83 |
+
def _init_api_client(self, api_base: str, api_key: str, api_model: str):
|
| 84 |
+
"""初始化 DashScope OpenAI 兼容 API 客户端"""
|
| 85 |
+
from openai import OpenAI
|
| 86 |
+
self.api_client = OpenAI(
|
| 87 |
+
api_key=api_key,
|
| 88 |
+
base_url=api_base,
|
| 89 |
+
)
|
| 90 |
+
self.api_model = api_model
|
| 91 |
+
print(f"视觉语言模型 API 已就绪: {api_base} / {api_model}")
|
| 92 |
+
print(f"API 调用次数上限: {self._api_max_calls}")
|
| 93 |
+
|
| 94 |
+
# ==============================================================
|
| 95 |
+
# API 调用次数限制
|
| 96 |
+
# ==============================================================
|
| 97 |
+
@property
|
| 98 |
+
def api_call_count(self) -> int:
|
| 99 |
+
"""当前已使用的 API 调用次数"""
|
| 100 |
+
with self._api_count_lock:
|
| 101 |
+
return self._api_call_count
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def api_remaining(self) -> int:
|
| 105 |
+
"""剩余可用的 API 调用次数"""
|
| 106 |
+
with self._api_count_lock:
|
| 107 |
+
return max(0, self._api_max_calls - self._api_call_count)
|
| 108 |
+
|
| 109 |
+
@property
|
| 110 |
+
def api_limit_reached(self) -> bool:
|
| 111 |
+
"""API 调用次数是否已达上限"""
|
| 112 |
+
with self._api_count_lock:
|
| 113 |
+
return self._api_call_count >= self._api_max_calls
|
| 114 |
+
|
| 115 |
+
def _increment_api_count(self):
|
| 116 |
+
"""递增 API 调用计数(线程安全)"""
|
| 117 |
+
with self._api_count_lock:
|
| 118 |
+
self._api_call_count += 1
|
| 119 |
+
remaining = self._api_max_calls - self._api_call_count
|
| 120 |
+
if remaining <= 10 and remaining >= 0:
|
| 121 |
+
print(f"[警告] 在线 API 剩余调用次数: {remaining}/{self._api_max_calls}")
|
| 122 |
+
elif self._api_call_count == self._api_max_calls:
|
| 123 |
+
print(f"[警告] 在线 API 调用次数已达上限 ({self._api_max_calls}),后续将自动降级为本地模型")
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
def _image_to_data_url(image_path: str) -> str:
|
| 127 |
+
"""将本地图片文件转换为 base64 data URL"""
|
| 128 |
+
import base64
|
| 129 |
+
with open(image_path, "rb") as f:
|
| 130 |
+
data = base64.b64encode(f.read()).decode()
|
| 131 |
+
ext = os.path.splitext(image_path)[1].lower()
|
| 132 |
+
mime_map = {
|
| 133 |
+
".jpg": "image/jpeg", ".jpeg": "image/jpeg",
|
| 134 |
+
".png": "image/png", ".gif": "image/gif",
|
| 135 |
+
".webp": "image/webp", ".bmp": "image/bmp",
|
| 136 |
+
}
|
| 137 |
+
mime = mime_map.get(ext, "image/png")
|
| 138 |
+
return f"data:{mime};base64,{data}"
|
| 139 |
+
|
| 140 |
+
def _describe_image_api(self, image_path: str, prompt: str) -> str:
|
| 141 |
+
"""通过在线 API 生成图片描述"""
|
| 142 |
+
if self.api_client is None:
|
| 143 |
+
raise RuntimeError("在线 API 未配置,请检查 vl_api_base / vl_api_key 设置")
|
| 144 |
+
|
| 145 |
+
data_url = self._image_to_data_url(image_path)
|
| 146 |
+
|
| 147 |
+
response = self.api_client.chat.completions.create(
|
| 148 |
+
model=self.api_model,
|
| 149 |
+
messages=[
|
| 150 |
+
{
|
| 151 |
+
"role": "user",
|
| 152 |
+
"content": [
|
| 153 |
+
{"type": "image_url", "image_url": {"url": data_url}},
|
| 154 |
+
{"type": "text", "text": prompt},
|
| 155 |
+
],
|
| 156 |
+
}
|
| 157 |
+
],
|
| 158 |
+
max_tokens=512,
|
| 159 |
+
)
|
| 160 |
+
return response.choices[0].message.content
|
| 161 |
+
|
| 162 |
+
# ==============================================================
|
| 163 |
+
# 本地模型模式
|
| 164 |
+
# ==============================================================
|
| 165 |
+
def _load_local_model(self):
|
| 166 |
+
"""加载本地 Qwen3-VL 模型"""
|
| 167 |
+
from transformers import Qwen3VLForConditionalGeneration
|
| 168 |
+
|
| 169 |
+
local_path = resolve_model_path(self.model_path)
|
| 170 |
+
print(f"正在加载本地视觉语言模型: {local_path}...")
|
| 171 |
+
|
| 172 |
+
self.processor = self._load_processor(local_path)
|
| 173 |
+
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 174 |
+
local_path,
|
| 175 |
+
torch_dtype="auto",
|
| 176 |
+
device_map=self.device,
|
| 177 |
+
trust_remote_code=True,
|
| 178 |
+
).eval()
|
| 179 |
+
self.local_loaded = True
|
| 180 |
+
print("本地视觉语言模型加载完成。")
|
| 181 |
+
|
| 182 |
+
def _load_processor(self, local_path: str):
|
| 183 |
+
"""
|
| 184 |
+
加载处理器,包含多级回退机制。
|
| 185 |
+
某些 transformers 版本中 VIDEO_PROCESSOR_MAPPING_NAMES 未正确初始化,
|
| 186 |
+
导致 AutoProcessor.from_pretrained 抛出 TypeError,此处做兼容处理。
|
| 187 |
+
"""
|
| 188 |
+
# 方式 1: 标准 AutoProcessor 加载
|
| 189 |
+
try:
|
| 190 |
+
from transformers import AutoProcessor
|
| 191 |
+
return AutoProcessor.from_pretrained(
|
| 192 |
+
local_path,
|
| 193 |
+
trust_remote_code=True,
|
| 194 |
+
)
|
| 195 |
+
except TypeError as e:
|
| 196 |
+
if "NoneType" in str(e):
|
| 197 |
+
print(f"AutoProcessor 遇到视频处理器兼容性问题: {e}")
|
| 198 |
+
else:
|
| 199 |
+
raise
|
| 200 |
+
|
| 201 |
+
# 方式 2: 修复 VIDEO_PROCESSOR_MAPPING_NAMES 后重试
|
| 202 |
+
try:
|
| 203 |
+
from transformers.models.auto import video_processing_auto
|
| 204 |
+
if video_processing_auto.VIDEO_PROCESSOR_MAPPING_NAMES is None:
|
| 205 |
+
video_processing_auto.VIDEO_PROCESSOR_MAPPING_NAMES = {}
|
| 206 |
+
print("已修复 VIDEO_PROCESSOR_MAPPING_NAMES 初始化问题,重新加载...")
|
| 207 |
+
from transformers import AutoProcessor
|
| 208 |
+
return AutoProcessor.from_pretrained(
|
| 209 |
+
local_path,
|
| 210 |
+
trust_remote_code=True,
|
| 211 |
+
)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
print(f"修复后重试仍失败: {e}")
|
| 214 |
+
|
| 215 |
+
# 方式 3: 手动组装处理器(仅图片处理能力,不含视频)
|
| 216 |
+
print("回退方案: 手动组装处理器...")
|
| 217 |
+
from transformers import AutoTokenizer, AutoImageProcessor
|
| 218 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 219 |
+
local_path, trust_remote_code=True
|
| 220 |
+
)
|
| 221 |
+
image_processor = AutoImageProcessor.from_pretrained(
|
| 222 |
+
local_path, trust_remote_code=True
|
| 223 |
+
)
|
| 224 |
+
try:
|
| 225 |
+
from transformers import Qwen3VLProcessor
|
| 226 |
+
processor = Qwen3VLProcessor(
|
| 227 |
+
image_processor=image_processor,
|
| 228 |
+
tokenizer=tokenizer,
|
| 229 |
+
)
|
| 230 |
+
print("手动组装处理器成功。")
|
| 231 |
+
return processor
|
| 232 |
+
except (ImportError, Exception) as e:
|
| 233 |
+
raise RuntimeError(
|
| 234 |
+
f"处理器加载失败: {e}\n"
|
| 235 |
+
"请尝试: pip install -U transformers torchvision qwen-vl-utils"
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
def _describe_image_local(self, image_path: str, prompt: str) -> str:
|
| 239 |
+
"""使用本地模型生成图片描述"""
|
| 240 |
+
if not self.local_loaded:
|
| 241 |
+
raise RuntimeError(
|
| 242 |
+
"本地视觉模型未加载。请设置 XGUARD_VL_USE_API=false 重启,或切换为在线 API 模式。"
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
with self._lock:
|
| 246 |
+
messages = [
|
| 247 |
+
{
|
| 248 |
+
"role": "user",
|
| 249 |
+
"content": [
|
| 250 |
+
{"type": "image", "image": image_path},
|
| 251 |
+
{"type": "text", "text": prompt},
|
| 252 |
+
],
|
| 253 |
+
}
|
| 254 |
+
]
|
| 255 |
+
|
| 256 |
+
inputs = self.processor.apply_chat_template(
|
| 257 |
+
messages,
|
| 258 |
+
tokenize=True,
|
| 259 |
+
add_generation_prompt=True,
|
| 260 |
+
return_dict=True,
|
| 261 |
+
return_tensors="pt",
|
| 262 |
+
)
|
| 263 |
+
inputs = inputs.to(self.model.device)
|
| 264 |
+
|
| 265 |
+
with torch.no_grad():
|
| 266 |
+
generated_ids = self.model.generate(
|
| 267 |
+
**inputs,
|
| 268 |
+
max_new_tokens=512,
|
| 269 |
+
do_sample=False,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
generated_ids_trimmed = [
|
| 273 |
+
out_ids[len(in_ids):]
|
| 274 |
+
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 275 |
+
]
|
| 276 |
+
output_text = self.processor.batch_decode(
|
| 277 |
+
generated_ids_trimmed,
|
| 278 |
+
skip_special_tokens=True,
|
| 279 |
+
clean_up_tokenization_spaces=False,
|
| 280 |
+
)
|
| 281 |
+
return output_text[0]
|
| 282 |
+
|
| 283 |
+
# ==============================================================
|
| 284 |
+
# 统一对外接口
|
| 285 |
+
# ==============================================================
|
| 286 |
+
def _ensure_local_model(self):
|
| 287 |
+
"""确保本地模型已加载(用于 API 限额耗尽时的延迟加载)"""
|
| 288 |
+
if self.local_loaded:
|
| 289 |
+
return
|
| 290 |
+
if not self.model_path:
|
| 291 |
+
raise RuntimeError(
|
| 292 |
+
"在线 API 调用次数已达上限,且未配置本地模型路径 (XGUARD_VL_MODEL_PATH),"
|
| 293 |
+
"无法降级到本地模型。请配置本地模型或重启服务以重置 API 计数。"
|
| 294 |
+
)
|
| 295 |
+
print("[自动降级] API 次数耗尽,正在加载本地视觉语言模型...")
|
| 296 |
+
self._load_local_model()
|
| 297 |
+
print("[自动降级] 本地视觉语言模型加载完成。")
|
| 298 |
+
|
| 299 |
+
def describe_image(self, image_path: str, prompt: str = None, use_api: bool = None) -> str:
|
| 300 |
+
"""
|
| 301 |
+
生成图片描述(统一接口)。
|
| 302 |
+
|
| 303 |
+
参数:
|
| 304 |
+
image_path: 图片文件路径
|
| 305 |
+
prompt: 自定义描述提示,为空则使用默认提示
|
| 306 |
+
use_api: 是否使用在线 API,为 None 时由 api_client 是否可用决定
|
| 307 |
+
返回:
|
| 308 |
+
图片的文本描述
|
| 309 |
+
|
| 310 |
+
注意:
|
| 311 |
+
当 use_api=True 但 API 调用次数已达上限时,会自动降级到本地模型。
|
| 312 |
+
降级信息通过返回值中的 metadata 属性传递(如有需要请检查 self.api_limit_reached)。
|
| 313 |
+
"""
|
| 314 |
+
if not prompt:
|
| 315 |
+
prompt = self.DEFAULT_PROMPT
|
| 316 |
+
|
| 317 |
+
# 决定使用哪种模式
|
| 318 |
+
if use_api is None:
|
| 319 |
+
use_api = self.api_client is not None
|
| 320 |
+
|
| 321 |
+
# API 调用次数限制检查:超限自动降级
|
| 322 |
+
if use_api and self.api_limit_reached:
|
| 323 |
+
remaining = self.api_remaining
|
| 324 |
+
print(
|
| 325 |
+
f"[API 限流] 在线 API 调用已达上限 "
|
| 326 |
+
f"({self._api_call_count}/{self._api_max_calls}),自动降级到本地模型"
|
| 327 |
+
)
|
| 328 |
+
self._ensure_local_model()
|
| 329 |
+
use_api = False
|
| 330 |
+
|
| 331 |
+
if use_api:
|
| 332 |
+
self._increment_api_count()
|
| 333 |
+
return self._describe_image_api(image_path, prompt)
|
| 334 |
+
else:
|
| 335 |
+
return self._describe_image_local(image_path, prompt)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class XGuardModel:
|
| 339 |
+
"""
|
| 340 |
+
YuFeng-XGuard 安全检测模型封装。
|
| 341 |
+
|
| 342 |
+
推理逻辑完全对齐官方实现:
|
| 343 |
+
- apply_chat_template 支持 policy / reason_first 参数
|
| 344 |
+
- 通过 decoded text 直接匹配 id2risk(而非 token_id 中转)
|
| 345 |
+
- reason_first 模式下正确定位风险 token 的 score 位置
|
| 346 |
+
"""
|
| 347 |
+
|
| 348 |
+
def __init__(self, model_path: str, device: str = "auto"):
|
| 349 |
+
self.model_path = model_path
|
| 350 |
+
self.device = device
|
| 351 |
+
self.model = None
|
| 352 |
+
self.tokenizer = None
|
| 353 |
+
self.id2risk = None
|
| 354 |
+
self._lock = threading.Lock()
|
| 355 |
+
self._load_model()
|
| 356 |
+
|
| 357 |
+
def _load_model(self):
|
| 358 |
+
"""加载模型和 tokenizer,提取 id2risk 映射表"""
|
| 359 |
+
local_path = resolve_model_path(self.model_path)
|
| 360 |
+
|
| 361 |
+
print(f"正在加载安全检测模型: {local_path}...")
|
| 362 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 363 |
+
local_path,
|
| 364 |
+
trust_remote_code=True
|
| 365 |
+
)
|
| 366 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 367 |
+
local_path,
|
| 368 |
+
torch_dtype="auto",
|
| 369 |
+
device_map=self.device,
|
| 370 |
+
trust_remote_code=True
|
| 371 |
+
).eval()
|
| 372 |
+
|
| 373 |
+
# 从 tokenizer 配置中获取 id2risk 映射
|
| 374 |
+
# id2risk 格式: {'sec': 'Safe-Safe', 'pc': 'Crimes and Illegal Activities-Pornographic Contraband', ...}
|
| 375 |
+
# key 是短文本标记(如 'sec', 'pc'),value 是风险类别全名
|
| 376 |
+
self.id2risk = self.tokenizer.init_kwargs.get('id2risk', {})
|
| 377 |
+
print(f"id2risk 映射条目数: {len(self.id2risk)}")
|
| 378 |
+
print(f"##################self.id2risk: {self.id2risk} #####################")
|
| 379 |
+
if self.id2risk:
|
| 380 |
+
print(f"示例映射: {list(self.id2risk.items())[:5]}")
|
| 381 |
+
|
| 382 |
+
def infer(self, messages: List[Dict[str, str]], policy=None,
|
| 383 |
+
max_new_tokens: int = 1, reason_first: bool = False) -> Dict[str, Any]:
|
| 384 |
+
"""
|
| 385 |
+
官方推理接口,完全对齐 XGuard 官方推理逻辑。
|
| 386 |
+
|
| 387 |
+
参数:
|
| 388 |
+
messages: 对话消息列表
|
| 389 |
+
policy: 动态策略(可选),用于运行时自定义安全检测规则
|
| 390 |
+
max_new_tokens: 最大生成 token 数
|
| 391 |
+
reason_first: 是否先生成归因分析再输出风险 token
|
| 392 |
+
返回:
|
| 393 |
+
{
|
| 394 |
+
'response': str, # 完整解码文本
|
| 395 |
+
'token_score': {text: prob, ...}, # 风险 token 位置的 topk token 分数
|
| 396 |
+
'risk_score': {risk_name: prob, ...} # 匹配到 id2risk 的风险类别分数
|
| 397 |
+
}
|
| 398 |
+
"""
|
| 399 |
+
with self._lock:
|
| 400 |
+
# 使用 chat template 渲染输入(含 policy 和 reason_first 参数)
|
| 401 |
+
rendered_query = self.tokenizer.apply_chat_template(
|
| 402 |
+
messages,
|
| 403 |
+
policy=policy,
|
| 404 |
+
reason_first=reason_first,
|
| 405 |
+
tokenize=False
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
model_inputs = self.tokenizer(
|
| 409 |
+
[rendered_query], return_tensors="pt"
|
| 410 |
+
).to(self.model.device)
|
| 411 |
+
|
| 412 |
+
with torch.no_grad():
|
| 413 |
+
outputs = self.model.generate(
|
| 414 |
+
**model_inputs,
|
| 415 |
+
max_new_tokens=max_new_tokens,
|
| 416 |
+
do_sample=False,
|
| 417 |
+
output_scores=True,
|
| 418 |
+
return_dict_in_generate=True
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
batch_idx = 0
|
| 422 |
+
input_length = model_inputs['input_ids'].shape[1]
|
| 423 |
+
|
| 424 |
+
# 解码响应文本
|
| 425 |
+
output_ids = outputs["sequences"].tolist()[batch_idx][input_length:]
|
| 426 |
+
response = self.tokenizer.decode(output_ids, skip_special_tokens=True)
|
| 427 |
+
|
| 428 |
+
# ---- 解析每个生成位置的 topk 分数 (官方逻辑) ----
|
| 429 |
+
generated_tokens = outputs.sequences[:, input_length:]
|
| 430 |
+
scores = torch.stack(outputs.scores, dim=1)
|
| 431 |
+
scores = scores.softmax(dim=-1)
|
| 432 |
+
scores_topk_value, scores_topk_index = scores.topk(k=10, dim=-1)
|
| 433 |
+
|
| 434 |
+
generated_tokens_with_probs = []
|
| 435 |
+
for generated_token, score_topk_value, score_topk_index in zip(
|
| 436 |
+
generated_tokens, scores_topk_value, scores_topk_index
|
| 437 |
+
):
|
| 438 |
+
generated_tokens_with_prob = []
|
| 439 |
+
for token, topk_value, topk_index in zip(
|
| 440 |
+
generated_token, score_topk_value, score_topk_index
|
| 441 |
+
):
|
| 442 |
+
token = int(token.cpu())
|
| 443 |
+
if token == self.tokenizer.pad_token_id:
|
| 444 |
+
continue
|
| 445 |
+
|
| 446 |
+
res_topk_score = {}
|
| 447 |
+
for ii, (value, index) in enumerate(zip(topk_value, topk_index)):
|
| 448 |
+
if ii == 0 or value.cpu().numpy() > 1e-4:
|
| 449 |
+
text = self.tokenizer.decode(index.cpu().numpy())
|
| 450 |
+
res_topk_score[text] = {
|
| 451 |
+
"id": str(int(index.cpu().numpy())),
|
| 452 |
+
"prob": round(float(value.cpu().numpy()), 4),
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
generated_tokens_with_prob.append(res_topk_score)
|
| 456 |
+
generated_tokens_with_probs.append(generated_tokens_with_prob)
|
| 457 |
+
|
| 458 |
+
# 确定风险分数的 token 位置索引
|
| 459 |
+
# reason_first=False: 风险 token 在第一个位置 (idx=0)
|
| 460 |
+
# reason_first=True: 风险 token 在倒数第二个位置 (reasoning 后、EOS 前)
|
| 461 |
+
score_idx = (
|
| 462 |
+
max(len(generated_tokens_with_probs[batch_idx]) - 2, 0)
|
| 463 |
+
if reason_first else 0
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# 提取 token 分数和风险分数(官方方式: decoded text 直接匹配 id2risk)
|
| 467 |
+
token_score = {
|
| 468 |
+
k: v['prob']
|
| 469 |
+
for k, v in generated_tokens_with_probs[batch_idx][score_idx].items()
|
| 470 |
+
}
|
| 471 |
+
risk_score = {
|
| 472 |
+
self.id2risk[k]: v['prob']
|
| 473 |
+
for k, v in generated_tokens_with_probs[batch_idx][score_idx].items()
|
| 474 |
+
if k in self.id2risk
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
return {
|
| 478 |
+
'response': response,
|
| 479 |
+
'token_score': token_score,
|
| 480 |
+
'risk_score': risk_score,
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
def parse_explanation(self, response: str) -> Optional[str]:
|
| 484 |
+
"""
|
| 485 |
+
从响应中解析归因分析部分。
|
| 486 |
+
|
| 487 |
+
XGuard 在 reason_first=False 模式下,输出格式为:
|
| 488 |
+
[风险分类 token][归因分析文本]
|
| 489 |
+
风险 token 是 id2risk 中的短字符串 key(如 'sec', 'pc' 等),
|
| 490 |
+
后续文本为自然语言的归因分析说明。
|
| 491 |
+
"""
|
| 492 |
+
if not response or not response.strip():
|
| 493 |
+
return None
|
| 494 |
+
|
| 495 |
+
# 方式 1: 兼容 <explanation>...</explanation> 标签格式
|
| 496 |
+
match = re.search(r'<explanation>(.*?)</explanation>', response, re.DOTALL)
|
| 497 |
+
if match:
|
| 498 |
+
return match.group(1).strip()
|
| 499 |
+
|
| 500 |
+
text = response.strip()
|
| 501 |
+
|
| 502 |
+
# 方式 2: 剥离开头的风险分类 token,提取后续归因文本
|
| 503 |
+
# id2risk 的 key 是短字符串(如 'sec', 'pc'),模型输出以它开头
|
| 504 |
+
if self.id2risk:
|
| 505 |
+
for key in sorted(self.id2risk.keys(), key=len, reverse=True):
|
| 506 |
+
if text.startswith(key):
|
| 507 |
+
remainder = text[len(key):].strip()
|
| 508 |
+
if remainder:
|
| 509 |
+
return remainder
|
| 510 |
+
break # 匹配到 token 但无后续文本,说明未生成归因
|
| 511 |
+
|
| 512 |
+
# 方式 3: 响应长度明显超过单个风险 token(通常 2-4 字符),直接作为归因返回
|
| 513 |
+
if len(text) > 8:
|
| 514 |
+
return text
|
| 515 |
+
|
| 516 |
+
return None
|
| 517 |
+
|
| 518 |
+
def analyze(self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]],
|
| 519 |
+
enable_reasoning: bool = False, policy=None) -> Dict[str, Any]:
|
| 520 |
+
"""
|
| 521 |
+
高层分析接口,封装推理结果为结构化格式。
|
| 522 |
+
|
| 523 |
+
参数:
|
| 524 |
+
messages: 对话消息列表
|
| 525 |
+
tools: 工具信息(已拼接到 messages 中,暂未使用)
|
| 526 |
+
enable_reasoning: 是否启用归因分析(生成更多 token)
|
| 527 |
+
policy: 动态策略(可选)
|
| 528 |
+
"""
|
| 529 |
+
# 启用归因分析时生成更多 token 以获取完整的归因解释
|
| 530 |
+
max_new_tokens = 512 if enable_reasoning else 1
|
| 531 |
+
|
| 532 |
+
infer_result = self.infer(
|
| 533 |
+
messages,
|
| 534 |
+
policy=policy,
|
| 535 |
+
max_new_tokens=max_new_tokens,
|
| 536 |
+
reason_first=False
|
| 537 |
+
)
|
| 538 |
+
risk_scores = infer_result.get("risk_score", {})
|
| 539 |
+
response = infer_result.get("response", "")
|
| 540 |
+
|
| 541 |
+
# ================================================================
|
| 542 |
+
# 风险判定 — 基于 XGuard 论文的 argmax + 置信度分级框架
|
| 543 |
+
#
|
| 544 |
+
# 理论基础 (arxiv 2601.15588):
|
| 545 |
+
# XGuard 的训练目标 max_θ log P(y_cls | X; θ) 保证
|
| 546 |
+
# 第一个 token 的 softmax argmax = 模型预测的风险类别。
|
| 547 |
+
# 概率值即为校准后的置信度。
|
| 548 |
+
#
|
| 549 |
+
# 判定流程:
|
| 550 |
+
# Layer 1 — argmax: 概率最高的类别就是模型的答案
|
| 551 |
+
# Layer 2 — 置信度门控: safe 判定需 >= 0.5 (过半数置信)
|
| 552 |
+
# Layer 3 — 风险分级: 按 top_risk_prob 划分 high/medium/low
|
| 553 |
+
# ================================================================
|
| 554 |
+
SAFE_CATEGORY = "Safe-Safe"
|
| 555 |
+
safe_prob = risk_scores.get(SAFE_CATEGORY, 0.0)
|
| 556 |
+
|
| 557 |
+
# 提取非安全类风险项,按分数降序排列
|
| 558 |
+
risk_items = {k: v for k, v in risk_scores.items() if k != SAFE_CATEGORY}
|
| 559 |
+
sorted_risks = sorted(risk_items.items(), key=lambda x: x[1], reverse=True)
|
| 560 |
+
|
| 561 |
+
top_risk_name = sorted_risks[0][0] if sorted_risks else ""
|
| 562 |
+
top_risk_prob = sorted_risks[0][1] if sorted_risks else 0.0
|
| 563 |
+
|
| 564 |
+
# Layer 1 + 2: argmax 决策 + 置信度门控
|
| 565 |
+
if safe_prob >= top_risk_prob and safe_prob >= 0.5:
|
| 566 |
+
# argmax = Safe-Safe, 且置信度过半 → 判定安全
|
| 567 |
+
is_safe = 1
|
| 568 |
+
risk_level = "safe"
|
| 569 |
+
elif safe_prob >= top_risk_prob:
|
| 570 |
+
# argmax = Safe-Safe, 但置信度不足 0.5
|
| 571 |
+
# 模型最倾向安全,但不够确定,谨慎标记为低风险
|
| 572 |
+
is_safe = 0
|
| 573 |
+
risk_level = "low"
|
| 574 |
+
else:
|
| 575 |
+
# argmax = 某风险类别 (top_risk_prob > safe_prob)
|
| 576 |
+
# Layer 3: 按风险置信度分级
|
| 577 |
+
is_safe = 0
|
| 578 |
+
if top_risk_prob >= 0.5:
|
| 579 |
+
risk_level = "high"
|
| 580 |
+
elif top_risk_prob >= 0.3:
|
| 581 |
+
risk_level = "medium"
|
| 582 |
+
else:
|
| 583 |
+
risk_level = "low"
|
| 584 |
+
|
| 585 |
+
# 置信度: 模型对当前判定的确信程度
|
| 586 |
+
confidence = safe_prob if is_safe == 1 else top_risk_prob
|
| 587 |
+
|
| 588 |
+
# 构建风险类型列表和原因说明
|
| 589 |
+
# 无论安全与否,始终输出最高风险项作为风险提示
|
| 590 |
+
if is_safe == 0:
|
| 591 |
+
top_risks = sorted_risks[:3]
|
| 592 |
+
else:
|
| 593 |
+
# 安全时仅取最高风险项作为提示
|
| 594 |
+
top_risks = sorted_risks[:1] if sorted_risks else []
|
| 595 |
+
|
| 596 |
+
risk_types = [r[0] for r in top_risks]
|
| 597 |
+
reason = "; ".join([f"{r}: {s}" for r, s in top_risks])
|
| 598 |
+
|
| 599 |
+
result = {
|
| 600 |
+
"is_safe": is_safe,
|
| 601 |
+
"risk_level": risk_level,
|
| 602 |
+
"confidence": round(confidence, 4),
|
| 603 |
+
"risk_type": risk_types,
|
| 604 |
+
"reason": reason,
|
| 605 |
+
"detail_scores": risk_scores,
|
| 606 |
+
"response": response
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
# 如果启用了归因分析,解析并添加 explanation
|
| 610 |
+
if enable_reasoning:
|
| 611 |
+
explanation = self.parse_explanation(response)
|
| 612 |
+
if explanation:
|
| 613 |
+
result["explanation"] = explanation
|
| 614 |
+
|
| 615 |
+
return result
|
requirements.txt
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.12.0
|
| 2 |
+
aiofiles==24.1.0
|
| 3 |
+
annotated-doc==0.0.4
|
| 4 |
+
annotated-types==0.7.0
|
| 5 |
+
anyio==4.12.1
|
| 6 |
+
av==16.1.0
|
| 7 |
+
brotli==1.2.0
|
| 8 |
+
certifi==2026.1.4
|
| 9 |
+
charset-normalizer==3.4.4
|
| 10 |
+
click==8.3.1
|
| 11 |
+
colorama==0.4.6
|
| 12 |
+
distro==1.9.0
|
| 13 |
+
fastapi==0.128.5
|
| 14 |
+
ffmpy==1.0.0
|
| 15 |
+
filelock==3.20.3
|
| 16 |
+
fsspec==2026.2.0
|
| 17 |
+
gradio==6.5.1
|
| 18 |
+
gradio_client==2.0.3
|
| 19 |
+
groovy==0.1.2
|
| 20 |
+
h11==0.16.0
|
| 21 |
+
hf-xet==1.2.0
|
| 22 |
+
httpcore==1.0.9
|
| 23 |
+
httpx==0.28.1
|
| 24 |
+
huggingface_hub==1.4.1
|
| 25 |
+
idna==3.11
|
| 26 |
+
Jinja2==3.1.6
|
| 27 |
+
jiter==0.13.0
|
| 28 |
+
markdown-it-py==4.0.0
|
| 29 |
+
MarkupSafe==3.0.3
|
| 30 |
+
mdurl==0.1.2
|
| 31 |
+
modelscope==1.34.0
|
| 32 |
+
mpmath==1.3.0
|
| 33 |
+
networkx==3.6.1
|
| 34 |
+
numpy==2.4.2
|
| 35 |
+
scikit-learn>=1.6.0
|
| 36 |
+
scipy>=1.14.0
|
| 37 |
+
openai==2.17.0
|
| 38 |
+
orjson==3.11.7
|
| 39 |
+
packaging==26.0
|
| 40 |
+
pandas==3.0.0
|
| 41 |
+
pillow==12.1.0
|
| 42 |
+
psutil==7.2.2
|
| 43 |
+
pydantic==2.12.5
|
| 44 |
+
pydantic_core==2.41.5
|
| 45 |
+
pydub==0.25.1
|
| 46 |
+
Pygments==2.19.2
|
| 47 |
+
python-dateutil==2.9.0.post0
|
| 48 |
+
python-multipart==0.0.22
|
| 49 |
+
pytz==2025.2
|
| 50 |
+
PyYAML==6.0.3
|
| 51 |
+
qwen-vl-utils==0.0.14
|
| 52 |
+
regex==2026.1.15
|
| 53 |
+
requests==2.32.5
|
| 54 |
+
rich==14.3.2
|
| 55 |
+
safehttpx==0.1.7
|
| 56 |
+
safetensors==0.7.0
|
| 57 |
+
semantic-version==2.10.0
|
| 58 |
+
setuptools==82.0.0
|
| 59 |
+
shellingham==1.5.4
|
| 60 |
+
six==1.17.0
|
| 61 |
+
sniffio==1.3.1
|
| 62 |
+
starlette==0.52.1
|
| 63 |
+
sympy==1.14.0
|
| 64 |
+
tokenizers==0.22.2
|
| 65 |
+
tomlkit==0.13.3
|
| 66 |
+
torch==2.10.0
|
| 67 |
+
torchvision==0.25.0
|
| 68 |
+
tqdm==4.67.3
|
| 69 |
+
transformers==5.1.0
|
| 70 |
+
typer==0.21.1
|
| 71 |
+
typer-slim==0.21.1
|
| 72 |
+
typing-inspection==0.4.2
|
| 73 |
+
typing_extensions==4.15.0
|
| 74 |
+
tzdata==2025.3
|
| 75 |
+
urllib3==2.6.3
|
| 76 |
+
uvicorn==0.40.0
|