jackonthemike's picture
Initial commit: InnSight scraper backend with Playwright
d77abf8
"""
Caching utilities for API performance optimization.
This module provides a simple in-memory cache with TTL support,
suitable for serverless environments. For production with multiple
instances, consider using Redis or a similar distributed cache.
Usage:
from cache import cache_manager
# Cache a value
await cache_manager.set("key", value, ttl=300)
# Get cached value
value = await cache_manager.get("key")
# Use as decorator
@cached(ttl=300)
async def get_dashboard(hotel_id: int):
...
"""
import asyncio
import hashlib
import json
import time
from datetime import timedelta
from functools import wraps
from typing import Any, Callable, Optional, TypeVar, ParamSpec
import os
P = ParamSpec('P')
T = TypeVar('T')
class CacheEntry:
"""Single cache entry with expiration"""
def __init__(self, value: Any, ttl_seconds: int):
self.value = value
self.expires_at = time.time() + ttl_seconds
self.created_at = time.time()
@property
def is_expired(self) -> bool:
return time.time() > self.expires_at
@property
def ttl_remaining(self) -> float:
return max(0, self.expires_at - time.time())
class InMemoryCache:
"""
Thread-safe in-memory cache with TTL support.
Features:
- Automatic expiration
- LRU-style cleanup when max size reached
- Stats tracking
"""
def __init__(self, max_size: int = 1000, cleanup_interval: int = 60):
self._cache: dict[str, CacheEntry] = {}
self._max_size = max_size
self._lock = asyncio.Lock()
self._hits = 0
self._misses = 0
self._cleanup_interval = cleanup_interval
self._last_cleanup = time.time()
async def get(self, key: str) -> Optional[Any]:
"""Get value from cache if exists and not expired"""
async with self._lock:
entry = self._cache.get(key)
if entry is None:
self._misses += 1
return None
if entry.is_expired:
del self._cache[key]
self._misses += 1
return None
self._hits += 1
return entry.value
async def set(
self,
key: str,
value: Any,
ttl: int | timedelta = 300
) -> None:
"""Set value in cache with TTL"""
if isinstance(ttl, timedelta):
ttl_seconds = int(ttl.total_seconds())
else:
ttl_seconds = ttl
async with self._lock:
# Cleanup if needed
await self._maybe_cleanup()
# Evict oldest if at capacity
if len(self._cache) >= self._max_size and key not in self._cache:
await self._evict_oldest()
self._cache[key] = CacheEntry(value, ttl_seconds)
async def delete(self, key: str) -> bool:
"""Delete a specific key from cache"""
async with self._lock:
if key in self._cache:
del self._cache[key]
return True
return False
async def delete_pattern(self, pattern: str) -> int:
"""Delete all keys matching pattern (simple prefix match)"""
async with self._lock:
keys_to_delete = [
k for k in self._cache.keys()
if k.startswith(pattern)
]
for key in keys_to_delete:
del self._cache[key]
return len(keys_to_delete)
async def clear(self) -> None:
"""Clear all cache entries"""
async with self._lock:
self._cache.clear()
self._hits = 0
self._misses = 0
async def _maybe_cleanup(self) -> None:
"""Remove expired entries periodically"""
now = time.time()
if now - self._last_cleanup < self._cleanup_interval:
return
self._last_cleanup = now
expired_keys = [
k for k, v in self._cache.items()
if v.is_expired
]
for key in expired_keys:
del self._cache[key]
async def _evict_oldest(self) -> None:
"""Evict oldest entry when at capacity"""
if not self._cache:
return
oldest_key = min(
self._cache.keys(),
key=lambda k: self._cache[k].created_at
)
del self._cache[oldest_key]
@property
def stats(self) -> dict:
"""Get cache statistics"""
total = self._hits + self._misses
hit_rate = (self._hits / total * 100) if total > 0 else 0
return {
"size": len(self._cache),
"max_size": self._max_size,
"hits": self._hits,
"misses": self._misses,
"hit_rate": f"{hit_rate:.1f}%"
}
# Global cache instance
cache_manager = InMemoryCache(
max_size=int(os.environ.get('CACHE_MAX_SIZE', '1000')),
cleanup_interval=int(os.environ.get('CACHE_CLEANUP_INTERVAL', '60'))
)
def generate_cache_key(*args, **kwargs) -> str:
"""Generate a unique cache key from arguments"""
key_parts = [str(arg) for arg in args]
key_parts.extend(f"{k}={v}" for k, v in sorted(kwargs.items()))
key_string = ":".join(key_parts)
return hashlib.md5(key_string.encode()).hexdigest()
def cached(
ttl: int | timedelta = 300,
key_prefix: str = "",
key_builder: Optional[Callable[..., str]] = None
):
"""
Decorator to cache async function results.
Args:
ttl: Time to live in seconds or timedelta
key_prefix: Prefix for cache key
key_builder: Optional custom function to build cache key
Example:
@cached(ttl=300, key_prefix="dashboard")
async def get_dashboard(hotel_id: int):
...
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# Build cache key
if key_builder:
cache_key = key_builder(*args, **kwargs)
else:
cache_key = generate_cache_key(func.__name__, *args, **kwargs)
if key_prefix:
cache_key = f"{key_prefix}:{cache_key}"
# Try cache first
cached_value = await cache_manager.get(cache_key)
if cached_value is not None:
return cached_value
# Call function and cache result
result = await func(*args, **kwargs)
await cache_manager.set(cache_key, result, ttl)
return result
# Add method to clear this function's cache
wrapper.clear_cache = lambda: cache_manager.delete_pattern(
f"{key_prefix}:" if key_prefix else func.__name__
)
return wrapper
return decorator
# Predefined cache TTLs for different data types
class CacheTTL:
"""Standard cache TTLs for different data types"""
REALTIME = 30 # 30 seconds - near real-time data
SHORT = 60 # 1 minute - frequently changing
MEDIUM = 300 # 5 minutes - standard API responses
LONG = 900 # 15 minutes - slow-changing data
VERY_LONG = 3600 # 1 hour - static reference data
# Specific use cases
DASHBOARD = MEDIUM # Dashboard data
PRICES = MEDIUM # Price data
COMPARISON = LONG # Comparison data
HOTELS = VERY_LONG # Hotel list (rarely changes)
USER_SESSION = LONG # User session data