File size: 10,993 Bytes
db7bdf2
 
 
c4c1f5b
db7bdf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
---
language:
- en
license: cc-by-nc-2.0
library_name: transformers
tags:
- citation-verification
- retrieval-augmented-generation
- rag
- cross-lingual
- deberta
- cross-encoder
- nli
- attribution
pipeline_tag: text-classification
datasets:
- fever
- din0s/asqa
- miracl/hagrid
metrics:
- f1
- precision
- recall
- accuracy
- roc_auc
base_model: microsoft/deberta-v3-base
model-index:
- name: dualtrack-alignment-module
  results:
  - task:
      type: text-classification
      name: Citation Verification
    metrics:
    - type: f1
      value: 0.89
      name: F1 Score
    - type: accuracy
      value: 0.87
      name: Accuracy
    - type: roc_auc
      value: 0.94
      name: ROC-AUC
---

# DualTrack Alignment Module

> **Anonymous submission to ACL 2026**

A cross-encoder model for detecting **citation drift** in Retrieval-Augmented Generation (RAG) systems. Given a user-facing claim, an evidence representation, and a source passage, the model predicts whether the citation is valid (the source supports the claim).

## Model Description

This model addresses a critical reliability problem in RAG systems: **citation drift**, where generated text diverges from source documents in ways that break attribution. The problem is particularly severe in cross-lingual settings where the answer language differs from source document language.

### Architecture

```
Input: "[CLS] User claim: {claim} [SEP] Evidence: {evidence} [SEP] Source passage: {context} [SEP]"
         ↓
    DeBERTa-v3-base (184M parameters)
         ↓
    [CLS] embedding (768-dim)
         ↓
    Linear(768, 2) → Softmax
         ↓
    Output: P(valid citation)
```

### Why Cross-Encoder?

Unlike embedding-based approaches that encode texts separately, the cross-encoder sees all three components **together**, enabling:
- Cross-attention between claim and source
- Detection of subtle semantic mismatches
- Better handling of paraphrases vs. factual errors

## Intended Use

### Primary Use Cases

1. **Post-hoc citation verification**: Validate citations in RAG outputs before serving to users
2. **Citation drift detection**: Identify claims that have semantically drifted from their sources
3. **Training signal**: Provide rewards for citation-aware generation

### Out of Scope

- General NLI/entailment (model is specialized for RAG citation patterns)
- Fact-checking against world knowledge (requires source passage)
- Non-English source documents (trained on English sources only)

## How to Use

### Installation

```bash
pip install transformers torch
```

### Basic Usage

```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# Load model
model_name = "anonymous-acl2026/dualtrack-alignment"  # Replace with actual path
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()

def check_citation(user_claim: str, evidence: str, source: str, threshold: float = 0.5) -> tuple[bool, float]:
    """
    Check if a citation is valid.
    
    Args:
        user_claim: The claim shown to the user
        evidence: Evidence track representation (can be same as user_claim)
        source: The source passage being cited
        threshold: Classification threshold (default from training)
    
    Returns:
        (is_valid, probability)
    """
    # Format input
    text = f"User claim: {user_claim}\n\nEvidence: {evidence}\n\nSource passage: {source}"
    
    # Tokenize
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    
    # Predict
    with torch.no_grad():
        outputs = model(**inputs)
        prob = torch.softmax(outputs.logits, dim=-1)[0, 1].item()
    
    return prob >= threshold, prob

# Example: Valid citation
is_valid, prob = check_citation(
    user_claim="Python was created by Guido van Rossum.",
    evidence="Python was created by Guido van Rossum.",
    source="Python is a programming language created by Guido van Rossum in 1991."
)
print(f"Valid: {is_valid}, Probability: {prob:.3f}")
# Output: Valid: True, Probability: 0.95

# Example: Invalid citation (wrong date)
is_valid, prob = check_citation(
    user_claim="Python was created in 1989.",
    evidence="Python was created in 1989.",
    source="Python is a programming language created by Guido van Rossum in 1991."
)
print(f"Valid: {is_valid}, Probability: {prob:.3f}")
# Output: Valid: False, Probability: 0.12
```

### Batch Processing

```python
def batch_check_citations(examples: list[dict], batch_size: int = 16) -> list[float]:
    """
    Check multiple citations efficiently.
    
    Args:
        examples: List of dicts with keys 'user', 'evidence', 'source'
        batch_size: Batch size for inference
    
    Returns:
        List of probabilities
    """
    all_probs = []
    
    for i in range(0, len(examples), batch_size):
        batch = examples[i:i + batch_size]
        
        texts = [
            f"User claim: {ex['user']}\n\nEvidence: {ex['evidence']}\n\nSource passage: {ex['source']}"
            for ex in batch
        ]
        
        inputs = tokenizer(
            texts, 
            return_tensors="pt", 
            truncation=True, 
            max_length=512, 
            padding=True
        )
        
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.softmax(outputs.logits, dim=-1)[:, 1].tolist()
        
        all_probs.extend(probs)
    
    return all_probs
```

### Integration with DualTrack

```python
class DualTrackAlignmentModule:
    """
    Alignment module for the DualTrack RAG system.
    
    Detects citation drift between user track and source documents.
    """
    
    def __init__(self, model_path: str, threshold: float = None, device: str = None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
        self.model.to(self.device)
        self.model.eval()
        
        # Load optimal threshold from metadata
        import json
        import os
        metadata_path = os.path.join(model_path, "metadata.json")
        if os.path.exists(metadata_path):
            with open(metadata_path) as f:
                metadata = json.load(f)
            self.threshold = threshold or metadata.get("optimal_threshold", 0.5)
        else:
            self.threshold = threshold or 0.5
    
    def detect_drift(
        self, 
        user_claims: list[str], 
        evidence_claims: list[str], 
        sources: list[str]
    ) -> list[dict]:
        """
        Detect citation drift for multiple claim-source pairs.
        
        Returns list of {is_valid, probability, drift_detected}.
        """
        results = []
        
        for user, evidence, source in zip(user_claims, evidence_claims, sources):
            text = f"User claim: {user}\n\nEvidence: {evidence}\n\nSource passage: {source}"
            
            inputs = self.tokenizer(
                text, return_tensors="pt", truncation=True, max_length=512
            ).to(self.device)
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                prob = torch.softmax(outputs.logits, dim=-1)[0, 1].item()
            
            results.append({
                "is_valid": prob >= self.threshold,
                "probability": prob,
                "drift_detected": prob < self.threshold
            })
        
        return results
```

## Training Details

### Training Data

The model was trained on a curated dataset combining multiple sources:

| Source | Examples | Description |
|--------|----------|-------------|
| FEVER | ~8,000 | Fact verification with SUPPORTS/REFUTES labels |
| HAGRID | ~2,000 | Attributed QA with quote-based evidence |
| ASQA | ~3,000 | Ambiguous questions with long-form answers |

**Label Generation (V3 - LLM-Supervised)**:
- Training labels verified by GPT-4o-mini ("Does context support claim?")
- Evaluation uses independent NLI model (DeBERTa-MNLI)
- This breaks circularity: model learns LLM judgment, evaluated by NLI

**Data Augmentation**:
- **Negative perturbations**: date_change, number_change, entity_swap, false_detail, negation, topic_drift
- **Positive perturbations**: paraphrase, synonym_swap, formal_informal register changes

### Training Procedure

| Hyperparameter | Value |
|----------------|-------|
| Base model | `microsoft/deberta-v3-base` |
| Max sequence length | 512 |
| Batch size | 8 |
| Gradient accumulation | 2 |
| Effective batch size | 16 |
| Learning rate | 2e-5 |
| Warmup ratio | 0.1 |
| Weight decay | 0.01 |
| Epochs | 5 |
| Early stopping patience | 3 |
| FP16 training | Yes |
| Optimizer | AdamW |

**Training Infrastructure**:
- Single GPU (NVIDIA T4/V100)
- Training time: ~2-3 hours
- Framework: HuggingFace Transformers + PyTorch

### Evaluation

**Validation Set Performance** (15% held-out, stratified):

| Metric | Score |
|--------|-------|
| Accuracy | 0.87 |
| Precision | 0.88 |
| Recall | 0.90 |
| F1 | 0.89 |
| ROC-AUC | 0.94 |

**Optimal Threshold**: 0.50 (determined via F1 maximization on validation set)

**Performance by Perturbation Type**:

| Type | Accuracy | Notes |
|------|----------|-------|
| original | 0.91 | Clean examples |
| paraphrase | 0.88 | Meaning-preserving rewrites |
| entity_swap | 0.94 | Wrong person/place/org |
| date_change | 0.92 | Incorrect dates |
| negation | 0.89 | Reversed claims |
| topic_drift | 0.85 | Subtle semantic shifts |

## Limitations

1. **English only**: Trained on English source passages. Cross-lingual application requires translation or multilingual encoder.

2. **RAG-specific**: Optimized for RAG citation patterns; may not generalize to arbitrary NLI tasks.

3. **Passage length**: Max 512 tokens. Long documents require chunking or summarization.

4. **Threshold sensitivity**: Default threshold (0.5) may need tuning for specific applications. High-precision applications should use higher thresholds.

5. **Training data bias**: Performance may vary on domains not represented in FEVER/HAGRID/ASQA (e.g., legal, medical, code).

## Ethical Considerations

### Intended Benefits
- Improved reliability of AI-generated citations
- Reduced misinformation from RAG hallucinations
- Better transparency in AI-assisted research

### Potential Risks
- Over-reliance on automated verification (human review still recommended for high-stakes applications)
- False negatives may incorrectly flag valid citations
- False positives may miss genuine attribution errors

### Recommendations
- Use as one signal among many, not sole arbiter
- Monitor performance on domain-specific data
- Combine with human review for critical applications


*This model is part of an anonymous submission to ACL 2026. Author information will be added upon acceptance.*