| |
|
| | """
|
| | Difficulty Gate for ContinuumAgent Project
|
| | Smart routing system to determine whether to use patches based on query complexity
|
| | """
|
| |
|
| | import os
|
| | import json
|
| | from typing import Dict, Any, List, Optional, Tuple
|
| | import numpy as np
|
| | from llama_cpp import Llama
|
| |
|
| | class DifficultyGate:
|
| | """
|
| | Smart routing system to determine whether to use patches based on query complexity
|
| | Uses a simple heuristic approach for initial implementation, can be replaced with a learned classifier
|
| | """
|
| |
|
| | def __init__(self,
|
| | model_path: str,
|
| | gate_threshold: float = 0.7,
|
| | cache_dir: str = "models/gates",
|
| | n_gpu_layers: int = 0):
|
| | """
|
| | Initialize the difficulty gate
|
| |
|
| | Args:
|
| | model_path: Path to GGUF model file
|
| | gate_threshold: Threshold for routing to patched model (0.0-1.0)
|
| | cache_dir: Directory for caching gate decisions
|
| | n_gpu_layers: Number of layers to offload to GPU
|
| | """
|
| | self.model_path = model_path
|
| | self.gate_threshold = gate_threshold
|
| | self.cache_dir = cache_dir
|
| | self.n_gpu_layers = n_gpu_layers
|
| |
|
| |
|
| | os.makedirs(cache_dir, exist_ok=True)
|
| |
|
| |
|
| | self.cache_file = os.path.join(cache_dir, "gate_cache.json")
|
| |
|
| |
|
| | self.decision_cache = self._load_cache()
|
| |
|
| |
|
| | self._init_gate_model()
|
| |
|
| | def _init_gate_model(self) -> None:
|
| | """Initialize small gate model"""
|
| | try:
|
| | print(f"Loading gate model from {self.model_path}...")
|
| | self.gate_model = Llama(
|
| | model_path=self.model_path,
|
| | n_gpu_layers=self.n_gpu_layers,
|
| | n_ctx=512
|
| | )
|
| | except Exception as e:
|
| | print(f"Error loading gate model: {e}")
|
| | self.gate_model = None
|
| |
|
| | def _load_cache(self) -> Dict[str, Any]:
|
| | """
|
| | Load decision cache from file
|
| |
|
| | Returns:
|
| | Cache dictionary
|
| | """
|
| | if os.path.exists(self.cache_file):
|
| | try:
|
| | with open(self.cache_file, "r") as f:
|
| | cache = json.load(f)
|
| | print(f"Loaded {len(cache.get('queries', []))} cached gate decisions")
|
| | return cache
|
| | except Exception as e:
|
| | print(f"Error loading cache: {e}")
|
| |
|
| |
|
| | return {"queries": {}}
|
| |
|
| | def _save_cache(self) -> None:
|
| | """Save decision cache to file"""
|
| | try:
|
| | with open(self.cache_file, "w") as f:
|
| | json.dump(self.decision_cache, f, indent=2)
|
| | except Exception as e:
|
| | print(f"Error saving cache: {e}")
|
| |
|
| | def _query_hash(self, query: str) -> str:
|
| | """
|
| | Create simple hash for query caching
|
| |
|
| | Args:
|
| | query: Query string
|
| |
|
| | Returns:
|
| | Query hash
|
| | """
|
| |
|
| | import hashlib
|
| | return hashlib.md5(query.strip().lower().encode()).hexdigest()
|
| |
|
| | def _heuristic_features(self, query: str) -> Dict[str, float]:
|
| | """
|
| | Extract heuristic features from query
|
| |
|
| | Args:
|
| | query: Query string
|
| |
|
| | Returns:
|
| | Dictionary of feature values
|
| | """
|
| |
|
| | query_lower = query.lower()
|
| |
|
| |
|
| | length = len(query)
|
| | norm_length = min(1.0, length / 200.0)
|
| |
|
| |
|
| | factual_indicators = [
|
| | "what is", "when did", "where is", "who is",
|
| | "which", "how many", "list the", "tell me about",
|
| | "explain", "define"
|
| | ]
|
| | has_factual = any(indicator in query_lower for indicator in factual_indicators)
|
| |
|
| |
|
| | time_indicators = [
|
| | "recent", "latest", "current", "today", "now",
|
| | "this week", "this month", "this year",
|
| | "2023", "2024", "2025"
|
| | ]
|
| | has_time = any(indicator in query_lower for indicator in time_indicators)
|
| |
|
| |
|
| |
|
| | words = query.split()
|
| | capitalized_words = [w for w in words if w[0:1].isupper()]
|
| | entity_ratio = len(capitalized_words) / max(1, len(words))
|
| |
|
| |
|
| | complex_indicators = [
|
| | "why", "how does", "explain", "compare", "contrast",
|
| | "what if", "analyze", "evaluate", "synthesize"
|
| | ]
|
| | complexity_score = sum(indicator in query_lower for indicator in complex_indicators) / 3.0
|
| | complexity_score = min(1.0, complexity_score)
|
| |
|
| |
|
| | return {
|
| | "length": norm_length,
|
| | "has_factual": float(has_factual),
|
| | "has_time": float(has_time),
|
| | "entity_ratio": entity_ratio,
|
| | "complexity": complexity_score
|
| | }
|
| |
|
| | def _heuristic_decision(self, features: Dict[str, float]) -> Tuple[bool, float]:
|
| | """
|
| | Make decision based on heuristic features
|
| |
|
| | Args:
|
| | features: Feature dictionary
|
| |
|
| | Returns:
|
| | Tuple of (needs_patches, confidence)
|
| | """
|
| |
|
| | weights = {
|
| | "length": 0.1,
|
| | "has_factual": 0.3,
|
| | "has_time": 0.4,
|
| | "entity_ratio": 0.1,
|
| | "complexity": -0.1
|
| | }
|
| |
|
| |
|
| | score = sum(features[f] * weights[f] for f in features)
|
| |
|
| |
|
| | score = max(0.0, min(1.0, score))
|
| |
|
| |
|
| | needs_patches = score >= self.gate_threshold
|
| |
|
| | return needs_patches, score
|
| |
|
| | def _model_decision(self, query: str) -> Tuple[bool, float]:
|
| | """
|
| | Ask the model to decide if the query needs up-to-date knowledge
|
| |
|
| | Args:
|
| | query: Query string
|
| |
|
| | Returns:
|
| | Tuple of (needs_patches, confidence)
|
| | """
|
| | if not self.gate_model:
|
| |
|
| | features = self._heuristic_features(query)
|
| | return self._heuristic_decision(features)
|
| |
|
| |
|
| | prompt = f"""<s>[INST] Please analyze this question and determine if it requires the most up-to-date knowledge to answer correctly.
|
| | Respond with only a single word: 'YES' if up-to-date knowledge is needed, or 'NO' if it can be answered with general knowledge.
|
| |
|
| | Question: "{query}"
|
| |
|
| | Requires up-to-date knowledge? [/INST]"""
|
| |
|
| |
|
| | completion = self.gate_model.create_completion(
|
| | prompt=prompt,
|
| | max_tokens=5,
|
| | temperature=0.1,
|
| | stop=["</s>", "\n"]
|
| | )
|
| |
|
| |
|
| | response_text = completion.get("choices", [{}])[0].get("text", "").strip().upper()
|
| |
|
| |
|
| | confidence = 0.7
|
| |
|
| |
|
| | needs_patches = "YES" in response_text
|
| |
|
| | return needs_patches, confidence
|
| |
|
| | def should_use_patches(self, query: str, use_model: bool = True) -> Dict[str, Any]:
|
| | """
|
| | Determine if the query requires up-to-date knowledge patches
|
| |
|
| | Args:
|
| | query: Query string
|
| | use_model: Whether to use model for decision (vs pure heuristics)
|
| |
|
| | Returns:
|
| | Decision dictionary with keys:
|
| | - needs_patches: Boolean decision
|
| | - confidence: Confidence score (0.0-1.0)
|
| | - method: Decision method used
|
| | - features: Feature values if heuristic method used
|
| | """
|
| |
|
| | query_hash = self._query_hash(query)
|
| | if query_hash in self.decision_cache.get("queries", {}):
|
| | cached = self.decision_cache["queries"][query_hash]
|
| | cached["from_cache"] = True
|
| | return cached
|
| |
|
| |
|
| | features = self._heuristic_features(query)
|
| |
|
| |
|
| | if use_model and self.gate_model:
|
| | needs_patches, confidence = self._model_decision(query)
|
| | method = "model"
|
| | else:
|
| | needs_patches, confidence = self._heuristic_decision(features)
|
| | method = "heuristic"
|
| |
|
| |
|
| | decision = {
|
| | "needs_patches": needs_patches,
|
| | "confidence": confidence,
|
| | "method": method,
|
| | "features": features,
|
| | "from_cache": False
|
| | }
|
| |
|
| |
|
| | self.decision_cache.setdefault("queries", {})[query_hash] = decision
|
| | self._save_cache()
|
| |
|
| | return decision
|
| |
|
| | def main():
|
| | """Test difficulty gate"""
|
| |
|
| | model_dir = "models/slow"
|
| | model_files = [f for f in os.listdir(model_dir) if f.endswith(".gguf")]
|
| |
|
| | if not model_files:
|
| | print(f"No GGUF models found in {model_dir}")
|
| | return
|
| |
|
| | model_path = os.path.join(model_dir, model_files[0])
|
| | print(f"Using model: {model_path}")
|
| |
|
| |
|
| | gate = DifficultyGate(model_path=model_path)
|
| |
|
| |
|
| | test_queries = [
|
| | "What is the capital of France?",
|
| | "Who is the current president of the United States?",
|
| | "Explain the theory of relativity",
|
| | "What are the latest developments in the conflict in Ukraine?",
|
| | "Who won the most recent Super Bowl?",
|
| | "How do I write a for loop in Python?"
|
| | ]
|
| |
|
| | for query in test_queries:
|
| |
|
| | decision = gate.should_use_patches(query, use_model=False)
|
| | print(f"\nQuery: {query}")
|
| | print(f"Heuristic Decision: {decision['needs_patches']} (Confidence: {decision['confidence']:.2f})")
|
| | print(f"Features: {decision['features']}")
|
| |
|
| |
|
| | if gate.gate_model:
|
| | decision = gate.should_use_patches(query, use_model=True)
|
| | print(f"Model Decision: {decision['needs_patches']} (Confidence: {decision['confidence']:.2f})")
|
| |
|
| | if __name__ == "__main__":
|
| | main() |