--- language: - en license: cc-by-nc-2.0 library_name: transformers tags: - citation-verification - retrieval-augmented-generation - rag - cross-lingual - deberta - cross-encoder - nli - attribution pipeline_tag: text-classification datasets: - fever - din0s/asqa - miracl/hagrid metrics: - f1 - precision - recall - accuracy - roc_auc base_model: microsoft/deberta-v3-base model-index: - name: dualtrack-alignment-module results: - task: type: text-classification name: Citation Verification metrics: - type: f1 value: 0.89 name: F1 Score - type: accuracy value: 0.87 name: Accuracy - type: roc_auc value: 0.94 name: ROC-AUC --- # DualTrack Alignment Module > **Anonymous submission to ACL 2026** A cross-encoder model for detecting **citation drift** in Retrieval-Augmented Generation (RAG) systems. Given a user-facing claim, an evidence representation, and a source passage, the model predicts whether the citation is valid (the source supports the claim). ## Model Description This model addresses a critical reliability problem in RAG systems: **citation drift**, where generated text diverges from source documents in ways that break attribution. The problem is particularly severe in cross-lingual settings where the answer language differs from source document language. ### Architecture ``` Input: "[CLS] User claim: {claim} [SEP] Evidence: {evidence} [SEP] Source passage: {context} [SEP]" ↓ DeBERTa-v3-base (184M parameters) ↓ [CLS] embedding (768-dim) ↓ Linear(768, 2) → Softmax ↓ Output: P(valid citation) ``` ### Why Cross-Encoder? Unlike embedding-based approaches that encode texts separately, the cross-encoder sees all three components **together**, enabling: - Cross-attention between claim and source - Detection of subtle semantic mismatches - Better handling of paraphrases vs. factual errors ## Intended Use ### Primary Use Cases 1. **Post-hoc citation verification**: Validate citations in RAG outputs before serving to users 2. **Citation drift detection**: Identify claims that have semantically drifted from their sources 3. **Training signal**: Provide rewards for citation-aware generation ### Out of Scope - General NLI/entailment (model is specialized for RAG citation patterns) - Fact-checking against world knowledge (requires source passage) - Non-English source documents (trained on English sources only) ## How to Use ### Installation ```bash pip install transformers torch ``` ### Basic Usage ```python from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch # Load model model_name = "anonymous-acl2026/dualtrack-alignment" # Replace with actual path tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) model.eval() def check_citation(user_claim: str, evidence: str, source: str, threshold: float = 0.5) -> tuple[bool, float]: """ Check if a citation is valid. Args: user_claim: The claim shown to the user evidence: Evidence track representation (can be same as user_claim) source: The source passage being cited threshold: Classification threshold (default from training) Returns: (is_valid, probability) """ # Format input text = f"User claim: {user_claim}\n\nEvidence: {evidence}\n\nSource passage: {source}" # Tokenize inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) # Predict with torch.no_grad(): outputs = model(**inputs) prob = torch.softmax(outputs.logits, dim=-1)[0, 1].item() return prob >= threshold, prob # Example: Valid citation is_valid, prob = check_citation( user_claim="Python was created by Guido van Rossum.", evidence="Python was created by Guido van Rossum.", source="Python is a programming language created by Guido van Rossum in 1991." ) print(f"Valid: {is_valid}, Probability: {prob:.3f}") # Output: Valid: True, Probability: 0.95 # Example: Invalid citation (wrong date) is_valid, prob = check_citation( user_claim="Python was created in 1989.", evidence="Python was created in 1989.", source="Python is a programming language created by Guido van Rossum in 1991." ) print(f"Valid: {is_valid}, Probability: {prob:.3f}") # Output: Valid: False, Probability: 0.12 ``` ### Batch Processing ```python def batch_check_citations(examples: list[dict], batch_size: int = 16) -> list[float]: """ Check multiple citations efficiently. Args: examples: List of dicts with keys 'user', 'evidence', 'source' batch_size: Batch size for inference Returns: List of probabilities """ all_probs = [] for i in range(0, len(examples), batch_size): batch = examples[i:i + batch_size] texts = [ f"User claim: {ex['user']}\n\nEvidence: {ex['evidence']}\n\nSource passage: {ex['source']}" for ex in batch ] inputs = tokenizer( texts, return_tensors="pt", truncation=True, max_length=512, padding=True ) with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=-1)[:, 1].tolist() all_probs.extend(probs) return all_probs ``` ### Integration with DualTrack ```python class DualTrackAlignmentModule: """ Alignment module for the DualTrack RAG system. Detects citation drift between user track and source documents. """ def __init__(self, model_path: str, threshold: float = None, device: str = None): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModelForSequenceClassification.from_pretrained(model_path) self.model.to(self.device) self.model.eval() # Load optimal threshold from metadata import json import os metadata_path = os.path.join(model_path, "metadata.json") if os.path.exists(metadata_path): with open(metadata_path) as f: metadata = json.load(f) self.threshold = threshold or metadata.get("optimal_threshold", 0.5) else: self.threshold = threshold or 0.5 def detect_drift( self, user_claims: list[str], evidence_claims: list[str], sources: list[str] ) -> list[dict]: """ Detect citation drift for multiple claim-source pairs. Returns list of {is_valid, probability, drift_detected}. """ results = [] for user, evidence, source in zip(user_claims, evidence_claims, sources): text = f"User claim: {user}\n\nEvidence: {evidence}\n\nSource passage: {source}" inputs = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=512 ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) prob = torch.softmax(outputs.logits, dim=-1)[0, 1].item() results.append({ "is_valid": prob >= self.threshold, "probability": prob, "drift_detected": prob < self.threshold }) return results ``` ## Training Details ### Training Data The model was trained on a curated dataset combining multiple sources: | Source | Examples | Description | |--------|----------|-------------| | FEVER | ~8,000 | Fact verification with SUPPORTS/REFUTES labels | | HAGRID | ~2,000 | Attributed QA with quote-based evidence | | ASQA | ~3,000 | Ambiguous questions with long-form answers | **Label Generation (V3 - LLM-Supervised)**: - Training labels verified by GPT-4o-mini ("Does context support claim?") - Evaluation uses independent NLI model (DeBERTa-MNLI) - This breaks circularity: model learns LLM judgment, evaluated by NLI **Data Augmentation**: - **Negative perturbations**: date_change, number_change, entity_swap, false_detail, negation, topic_drift - **Positive perturbations**: paraphrase, synonym_swap, formal_informal register changes ### Training Procedure | Hyperparameter | Value | |----------------|-------| | Base model | `microsoft/deberta-v3-base` | | Max sequence length | 512 | | Batch size | 8 | | Gradient accumulation | 2 | | Effective batch size | 16 | | Learning rate | 2e-5 | | Warmup ratio | 0.1 | | Weight decay | 0.01 | | Epochs | 5 | | Early stopping patience | 3 | | FP16 training | Yes | | Optimizer | AdamW | **Training Infrastructure**: - Single GPU (NVIDIA T4/V100) - Training time: ~2-3 hours - Framework: HuggingFace Transformers + PyTorch ### Evaluation **Validation Set Performance** (15% held-out, stratified): | Metric | Score | |--------|-------| | Accuracy | 0.87 | | Precision | 0.88 | | Recall | 0.90 | | F1 | 0.89 | | ROC-AUC | 0.94 | **Optimal Threshold**: 0.50 (determined via F1 maximization on validation set) **Performance by Perturbation Type**: | Type | Accuracy | Notes | |------|----------|-------| | original | 0.91 | Clean examples | | paraphrase | 0.88 | Meaning-preserving rewrites | | entity_swap | 0.94 | Wrong person/place/org | | date_change | 0.92 | Incorrect dates | | negation | 0.89 | Reversed claims | | topic_drift | 0.85 | Subtle semantic shifts | ## Limitations 1. **English only**: Trained on English source passages. Cross-lingual application requires translation or multilingual encoder. 2. **RAG-specific**: Optimized for RAG citation patterns; may not generalize to arbitrary NLI tasks. 3. **Passage length**: Max 512 tokens. Long documents require chunking or summarization. 4. **Threshold sensitivity**: Default threshold (0.5) may need tuning for specific applications. High-precision applications should use higher thresholds. 5. **Training data bias**: Performance may vary on domains not represented in FEVER/HAGRID/ASQA (e.g., legal, medical, code). ## Ethical Considerations ### Intended Benefits - Improved reliability of AI-generated citations - Reduced misinformation from RAG hallucinations - Better transparency in AI-assisted research ### Potential Risks - Over-reliance on automated verification (human review still recommended for high-stakes applications) - False negatives may incorrectly flag valid citations - False positives may miss genuine attribution errors ### Recommendations - Use as one signal among many, not sole arbiter - Monitor performance on domain-specific data - Combine with human review for critical applications *This model is part of an anonymous submission to ACL 2026. Author information will be added upon acceptance.*