InnSight-Backend / api /encryption.py
jackonthemike's picture
Initial commit: InnSight scraper backend with Playwright
d77abf8
"""
Field-level encryption utilities.
This module provides encryption for sensitive data fields
using AES-256 encryption with key derivation.
Usage:
from encryption import field_encryption
# Encrypt sensitive data
encrypted = field_encryption.encrypt("sensitive value")
# Decrypt
decrypted = field_encryption.decrypt(encrypted)
# Encrypt specific fields in dict
secure_data = field_encryption.encrypt_dict(
user_data,
['phone_number', 'address']
)
"""
import base64
import hashlib
import hmac
import os
import secrets
from dataclasses import dataclass
from typing import Any, Optional
# Check for cryptography library
try:
from cryptography.fernet import Fernet, InvalidToken
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
CRYPTO_AVAILABLE = True
except ImportError:
CRYPTO_AVAILABLE = False
InvalidToken = Exception
class EncryptionError(Exception):
"""Base exception for encryption errors"""
pass
class DecryptionError(EncryptionError):
"""Raised when decryption fails"""
pass
class EncryptionNotConfigured(EncryptionError):
"""Raised when encryption is not properly configured"""
pass
# Configuration for which fields to encrypt per table
ENCRYPTED_FIELDS = {
'users': ['phone_number', 'address', 'personal_id'],
'payment_info': ['card_last_four', 'billing_address'],
'api_credentials': ['api_key', 'api_secret', 'webhook_secret'],
'hotels': [], # No sensitive fields
}
@dataclass
class EncryptedValue:
"""Container for encrypted value with metadata"""
ciphertext: str
version: int = 1
algorithm: str = "aes-256-fernet"
def to_string(self) -> str:
"""Serialize to storable string"""
return f"v{self.version}:{self.algorithm}:{self.ciphertext}"
@classmethod
def from_string(cls, value: str) -> "EncryptedValue":
"""Deserialize from stored string"""
if not value.startswith("v"):
# Legacy format - just ciphertext
return cls(ciphertext=value)
parts = value.split(":", 2)
if len(parts) != 3:
raise ValueError("Invalid encrypted value format")
version = int(parts[0][1:])
algorithm = parts[1]
ciphertext = parts[2]
return cls(
ciphertext=ciphertext,
version=version,
algorithm=algorithm
)
class FieldEncryption:
"""
Field-level encryption using Fernet (AES-128-CBC).
Provides encryption for individual fields with:
- Key derivation from master key
- Automatic encoding/decoding
- Support for key rotation
"""
def __init__(
self,
master_key: Optional[str] = None,
salt: bytes = b'innsight_field_encryption_v1'
):
"""
Initialize encryption with master key.
Args:
master_key: Base64 encoded master key (32 bytes recommended)
salt: Salt for key derivation
"""
self._master_key = master_key or os.environ.get('ENCRYPTION_MASTER_KEY')
self._salt = salt
self._fernet = None
self._initialized = False
if self._master_key:
self._initialize()
def _initialize(self) -> None:
"""Initialize Fernet cipher"""
if not CRYPTO_AVAILABLE:
# Use simple obfuscation as fallback
self._initialized = True
return
if not self._master_key:
return
try:
# Derive encryption key from master key
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=self._salt,
iterations=100000,
backend=default_backend()
)
key = base64.urlsafe_b64encode(
kdf.derive(self._master_key.encode())
)
self._fernet = Fernet(key)
self._initialized = True
except Exception as e:
raise EncryptionError(f"Failed to initialize encryption: {e}")
@property
def is_configured(self) -> bool:
"""Check if encryption is properly configured"""
return self._initialized and self._master_key is not None
def encrypt(self, plaintext: str) -> str:
"""
Encrypt a string value.
Args:
plaintext: Value to encrypt
Returns:
Encrypted value as string (with version prefix)
"""
if not plaintext:
return plaintext
if not self.is_configured:
# Return obfuscated value if not configured
return self._obfuscate(plaintext)
try:
if CRYPTO_AVAILABLE and self._fernet:
ciphertext = self._fernet.encrypt(plaintext.encode()).decode()
else:
ciphertext = self._obfuscate(plaintext)
encrypted = EncryptedValue(ciphertext=ciphertext)
return encrypted.to_string()
except Exception as e:
raise EncryptionError(f"Encryption failed: {e}")
def decrypt(self, ciphertext: str) -> str:
"""
Decrypt an encrypted value.
Args:
ciphertext: Encrypted value (may have version prefix)
Returns:
Decrypted plaintext
"""
if not ciphertext:
return ciphertext
try:
# Parse encrypted value
encrypted = EncryptedValue.from_string(ciphertext)
if not self.is_configured:
return self._deobfuscate(encrypted.ciphertext)
if CRYPTO_AVAILABLE and self._fernet:
return self._fernet.decrypt(
encrypted.ciphertext.encode()
).decode()
else:
return self._deobfuscate(encrypted.ciphertext)
except InvalidToken:
raise DecryptionError("Invalid or corrupted ciphertext")
except Exception as e:
raise DecryptionError(f"Decryption failed: {e}")
def _obfuscate(self, plaintext: str) -> str:
"""Simple obfuscation when cryptography is not available"""
# Base64 encode with simple XOR
key = (self._master_key or "default_key")[:32].ljust(32, '0')
obfuscated = bytes(
ord(c) ^ ord(key[i % len(key)])
for i, c in enumerate(plaintext)
)
return base64.b64encode(obfuscated).decode()
def _deobfuscate(self, obfuscated: str) -> str:
"""Reverse simple obfuscation"""
key = (self._master_key or "default_key")[:32].ljust(32, '0')
try:
decoded = base64.b64decode(obfuscated)
plaintext = ''.join(
chr(b ^ ord(key[i % len(key)]))
for i, b in enumerate(decoded)
)
return plaintext
except Exception:
return obfuscated # Return as-is if can't decode
def encrypt_dict(
self,
data: dict,
fields: list[str]
) -> dict:
"""
Encrypt specified fields in a dictionary.
Args:
data: Dictionary containing data
fields: List of field names to encrypt
Returns:
New dictionary with encrypted fields
"""
if not data:
return data
result = data.copy()
for field in fields:
if field in result and result[field]:
result[field] = self.encrypt(str(result[field]))
return result
def decrypt_dict(
self,
data: dict,
fields: list[str]
) -> dict:
"""
Decrypt specified fields in a dictionary.
Args:
data: Dictionary containing encrypted data
fields: List of field names to decrypt
Returns:
New dictionary with decrypted fields
"""
if not data:
return data
result = data.copy()
for field in fields:
if field in result and result[field]:
try:
result[field] = self.decrypt(str(result[field]))
except DecryptionError:
# Leave field as-is if decryption fails
pass
return result
def hash_for_search(self, plaintext: str) -> str:
"""
Create a searchable hash of plaintext.
This allows searching encrypted fields without decryption
by storing both encrypted value and hash.
Args:
plaintext: Value to hash
Returns:
Deterministic hash suitable for searching
"""
if not plaintext:
return ""
key = (self._master_key or "search_key").encode()
h = hmac.new(key, plaintext.lower().encode(), hashlib.sha256)
return h.hexdigest()[:32]
# Global encryption instance
field_encryption = FieldEncryption()
def encrypt_sensitive_fields(
data: dict,
table_name: str
) -> dict:
"""
Encrypt fields based on table configuration.
Args:
data: Dictionary of field values
table_name: Name of the table/entity
Returns:
Dictionary with sensitive fields encrypted
"""
fields = ENCRYPTED_FIELDS.get(table_name, [])
return field_encryption.encrypt_dict(data, fields)
def decrypt_sensitive_fields(
data: dict,
table_name: str
) -> dict:
"""
Decrypt fields based on table configuration.
Args:
data: Dictionary with encrypted fields
table_name: Name of the table/entity
Returns:
Dictionary with sensitive fields decrypted
"""
fields = ENCRYPTED_FIELDS.get(table_name, [])
return field_encryption.decrypt_dict(data, fields)
def generate_encryption_key() -> str:
"""Generate a new random encryption key"""
return base64.b64encode(secrets.token_bytes(32)).decode()
def mask_sensitive_value(value: str, visible_chars: int = 4) -> str:
"""
Mask a sensitive value for display.
Args:
value: Value to mask
visible_chars: Number of characters to keep visible at end
Returns:
Masked string like "****1234"
"""
if not value:
return value
if len(value) <= visible_chars:
return '*' * len(value)
masked_len = len(value) - visible_chars
return '*' * masked_len + value[-visible_chars:]