File size: 10,993 Bytes
db7bdf2 c4c1f5b db7bdf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 |
---
language:
- en
license: cc-by-nc-2.0
library_name: transformers
tags:
- citation-verification
- retrieval-augmented-generation
- rag
- cross-lingual
- deberta
- cross-encoder
- nli
- attribution
pipeline_tag: text-classification
datasets:
- fever
- din0s/asqa
- miracl/hagrid
metrics:
- f1
- precision
- recall
- accuracy
- roc_auc
base_model: microsoft/deberta-v3-base
model-index:
- name: dualtrack-alignment-module
results:
- task:
type: text-classification
name: Citation Verification
metrics:
- type: f1
value: 0.89
name: F1 Score
- type: accuracy
value: 0.87
name: Accuracy
- type: roc_auc
value: 0.94
name: ROC-AUC
---
# DualTrack Alignment Module
> **Anonymous submission to ACL 2026**
A cross-encoder model for detecting **citation drift** in Retrieval-Augmented Generation (RAG) systems. Given a user-facing claim, an evidence representation, and a source passage, the model predicts whether the citation is valid (the source supports the claim).
## Model Description
This model addresses a critical reliability problem in RAG systems: **citation drift**, where generated text diverges from source documents in ways that break attribution. The problem is particularly severe in cross-lingual settings where the answer language differs from source document language.
### Architecture
```
Input: "[CLS] User claim: {claim} [SEP] Evidence: {evidence} [SEP] Source passage: {context} [SEP]"
↓
DeBERTa-v3-base (184M parameters)
↓
[CLS] embedding (768-dim)
↓
Linear(768, 2) → Softmax
↓
Output: P(valid citation)
```
### Why Cross-Encoder?
Unlike embedding-based approaches that encode texts separately, the cross-encoder sees all three components **together**, enabling:
- Cross-attention between claim and source
- Detection of subtle semantic mismatches
- Better handling of paraphrases vs. factual errors
## Intended Use
### Primary Use Cases
1. **Post-hoc citation verification**: Validate citations in RAG outputs before serving to users
2. **Citation drift detection**: Identify claims that have semantically drifted from their sources
3. **Training signal**: Provide rewards for citation-aware generation
### Out of Scope
- General NLI/entailment (model is specialized for RAG citation patterns)
- Fact-checking against world knowledge (requires source passage)
- Non-English source documents (trained on English sources only)
## How to Use
### Installation
```bash
pip install transformers torch
```
### Basic Usage
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# Load model
model_name = "anonymous-acl2026/dualtrack-alignment" # Replace with actual path
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()
def check_citation(user_claim: str, evidence: str, source: str, threshold: float = 0.5) -> tuple[bool, float]:
"""
Check if a citation is valid.
Args:
user_claim: The claim shown to the user
evidence: Evidence track representation (can be same as user_claim)
source: The source passage being cited
threshold: Classification threshold (default from training)
Returns:
(is_valid, probability)
"""
# Format input
text = f"User claim: {user_claim}\n\nEvidence: {evidence}\n\nSource passage: {source}"
# Tokenize
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
# Predict
with torch.no_grad():
outputs = model(**inputs)
prob = torch.softmax(outputs.logits, dim=-1)[0, 1].item()
return prob >= threshold, prob
# Example: Valid citation
is_valid, prob = check_citation(
user_claim="Python was created by Guido van Rossum.",
evidence="Python was created by Guido van Rossum.",
source="Python is a programming language created by Guido van Rossum in 1991."
)
print(f"Valid: {is_valid}, Probability: {prob:.3f}")
# Output: Valid: True, Probability: 0.95
# Example: Invalid citation (wrong date)
is_valid, prob = check_citation(
user_claim="Python was created in 1989.",
evidence="Python was created in 1989.",
source="Python is a programming language created by Guido van Rossum in 1991."
)
print(f"Valid: {is_valid}, Probability: {prob:.3f}")
# Output: Valid: False, Probability: 0.12
```
### Batch Processing
```python
def batch_check_citations(examples: list[dict], batch_size: int = 16) -> list[float]:
"""
Check multiple citations efficiently.
Args:
examples: List of dicts with keys 'user', 'evidence', 'source'
batch_size: Batch size for inference
Returns:
List of probabilities
"""
all_probs = []
for i in range(0, len(examples), batch_size):
batch = examples[i:i + batch_size]
texts = [
f"User claim: {ex['user']}\n\nEvidence: {ex['evidence']}\n\nSource passage: {ex['source']}"
for ex in batch
]
inputs = tokenizer(
texts,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)[:, 1].tolist()
all_probs.extend(probs)
return all_probs
```
### Integration with DualTrack
```python
class DualTrackAlignmentModule:
"""
Alignment module for the DualTrack RAG system.
Detects citation drift between user track and source documents.
"""
def __init__(self, model_path: str, threshold: float = None, device: str = None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
# Load optimal threshold from metadata
import json
import os
metadata_path = os.path.join(model_path, "metadata.json")
if os.path.exists(metadata_path):
with open(metadata_path) as f:
metadata = json.load(f)
self.threshold = threshold or metadata.get("optimal_threshold", 0.5)
else:
self.threshold = threshold or 0.5
def detect_drift(
self,
user_claims: list[str],
evidence_claims: list[str],
sources: list[str]
) -> list[dict]:
"""
Detect citation drift for multiple claim-source pairs.
Returns list of {is_valid, probability, drift_detected}.
"""
results = []
for user, evidence, source in zip(user_claims, evidence_claims, sources):
text = f"User claim: {user}\n\nEvidence: {evidence}\n\nSource passage: {source}"
inputs = self.tokenizer(
text, return_tensors="pt", truncation=True, max_length=512
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
prob = torch.softmax(outputs.logits, dim=-1)[0, 1].item()
results.append({
"is_valid": prob >= self.threshold,
"probability": prob,
"drift_detected": prob < self.threshold
})
return results
```
## Training Details
### Training Data
The model was trained on a curated dataset combining multiple sources:
| Source | Examples | Description |
|--------|----------|-------------|
| FEVER | ~8,000 | Fact verification with SUPPORTS/REFUTES labels |
| HAGRID | ~2,000 | Attributed QA with quote-based evidence |
| ASQA | ~3,000 | Ambiguous questions with long-form answers |
**Label Generation (V3 - LLM-Supervised)**:
- Training labels verified by GPT-4o-mini ("Does context support claim?")
- Evaluation uses independent NLI model (DeBERTa-MNLI)
- This breaks circularity: model learns LLM judgment, evaluated by NLI
**Data Augmentation**:
- **Negative perturbations**: date_change, number_change, entity_swap, false_detail, negation, topic_drift
- **Positive perturbations**: paraphrase, synonym_swap, formal_informal register changes
### Training Procedure
| Hyperparameter | Value |
|----------------|-------|
| Base model | `microsoft/deberta-v3-base` |
| Max sequence length | 512 |
| Batch size | 8 |
| Gradient accumulation | 2 |
| Effective batch size | 16 |
| Learning rate | 2e-5 |
| Warmup ratio | 0.1 |
| Weight decay | 0.01 |
| Epochs | 5 |
| Early stopping patience | 3 |
| FP16 training | Yes |
| Optimizer | AdamW |
**Training Infrastructure**:
- Single GPU (NVIDIA T4/V100)
- Training time: ~2-3 hours
- Framework: HuggingFace Transformers + PyTorch
### Evaluation
**Validation Set Performance** (15% held-out, stratified):
| Metric | Score |
|--------|-------|
| Accuracy | 0.87 |
| Precision | 0.88 |
| Recall | 0.90 |
| F1 | 0.89 |
| ROC-AUC | 0.94 |
**Optimal Threshold**: 0.50 (determined via F1 maximization on validation set)
**Performance by Perturbation Type**:
| Type | Accuracy | Notes |
|------|----------|-------|
| original | 0.91 | Clean examples |
| paraphrase | 0.88 | Meaning-preserving rewrites |
| entity_swap | 0.94 | Wrong person/place/org |
| date_change | 0.92 | Incorrect dates |
| negation | 0.89 | Reversed claims |
| topic_drift | 0.85 | Subtle semantic shifts |
## Limitations
1. **English only**: Trained on English source passages. Cross-lingual application requires translation or multilingual encoder.
2. **RAG-specific**: Optimized for RAG citation patterns; may not generalize to arbitrary NLI tasks.
3. **Passage length**: Max 512 tokens. Long documents require chunking or summarization.
4. **Threshold sensitivity**: Default threshold (0.5) may need tuning for specific applications. High-precision applications should use higher thresholds.
5. **Training data bias**: Performance may vary on domains not represented in FEVER/HAGRID/ASQA (e.g., legal, medical, code).
## Ethical Considerations
### Intended Benefits
- Improved reliability of AI-generated citations
- Reduced misinformation from RAG hallucinations
- Better transparency in AI-assisted research
### Potential Risks
- Over-reliance on automated verification (human review still recommended for high-stakes applications)
- False negatives may incorrectly flag valid citations
- False positives may miss genuine attribution errors
### Recommendations
- Use as one signal among many, not sole arbiter
- Monitor performance on domain-specific data
- Combine with human review for critical applications
*This model is part of an anonymous submission to ACL 2026. Author information will be added upon acceptance.* |