import base64 import hashlib import hmac import json import os import struct import time import secrets import string from datetime import datetime, timedelta, timezone from typing import Optional # Environment Configuration for Utils TESTING = os.environ.get("TESTING", "false").lower() == "true" ENV_MODE = os.environ.get("ENV_MODE", "development") IS_PRODUCTION = ENV_MODE == "production" def get_required_env(key: str, default: Optional[str] = None) -> str: """Get environment variable, raise error if required and not set""" value = os.environ.get(key, default) if value is None: raise RuntimeError(f"Required environment variable {key} is not set") return value # JWT Secret setup if IS_PRODUCTION: JWT_SECRET = get_required_env("JWT_SECRET") else: JWT_SECRET = os.environ.get("JWT_SECRET", "dev-only-secret-change-in-production") # ============= TOTP Functions ============= def decode_base32_safe(secret: str) -> bytes: """Safe Base32 decoding with automatic padding""" secret = secret.upper().strip() # Remove any existing padding first secret = secret.rstrip('=') # Calculate required padding padding_len = (8 - (len(secret) % 8)) % 8 secret += '=' * padding_len return base64.b32decode(secret) def get_totp_token(secret: str) -> str: """Generate current TOTP token""" try: key = decode_base32_safe(secret) except Exception: # Fallback for empty or invalid secrets return "000000" # Get current time step (30 second intervals) counter = int(time.time() // 30) # Pack counter as big-endian 8-byte integer counter_bytes = struct.pack('>Q', counter) # Generate HMAC-SHA1 hmac_hash = hmac.new(key, counter_bytes, hashlib.sha1).digest() # Dynamic truncation offset = hmac_hash[-1] & 0x0F code = struct.unpack('>I', hmac_hash[offset:offset+4])[0] & 0x7FFFFFFF # Return 6-digit code return str(code % 1000000).zfill(6) def verify_totp(secret: str, token: str) -> bool: """Verify TOTP token with 1 step tolerance""" try: key = decode_base32_safe(secret) except Exception: return False current_counter = int(time.time() // 30) # Check window: current, previous, next (30s tolerance) for offset in [0, -1, 1]: counter = current_counter + offset counter_bytes = struct.pack('>Q', counter) hmac_hash = hmac.new(key, counter_bytes, hashlib.sha1).digest() offset_val = hmac_hash[-1] & 0x0F code = struct.unpack('>I', hmac_hash[offset_val:offset_val+4])[0] & 0x7FFFFFFF expected = str(code % 1000000).zfill(6) # Constant time comparison to prevent timing attacks if hmac.compare_digest(token, expected): return True return False def generate_totp_secret() -> str: """Generate a random Base32 TOTP secret""" # Generate 20 random bytes (160 bits as recommended for TOTP) random_bytes = secrets.token_bytes(20) # Encode as base32 and remove padding return base64.b32encode(random_bytes).decode('utf-8').rstrip('=') # ============= Password Hashing ============= def hash_password(password: str) -> str: """Hash password using PBKDF2 with SHA256""" salt = os.urandom(32) key = hashlib.pbkdf2_hmac('sha256', password.encode('utf-8'), salt, 100000) return base64.b64encode(salt + key).decode('utf-8') def verify_password(password: str, hashed: str) -> bool: """Verify password against hash""" try: decoded = base64.b64decode(hashed.encode('utf-8')) salt = decoded[:32] stored_key = decoded[32:] new_key = hashlib.pbkdf2_hmac('sha256', password.encode('utf-8'), salt, 100000) return hmac.compare_digest(stored_key, new_key) except Exception: return False def generate_random_password(length: int = 12) -> str: """Generate a secure random password""" alphabet = string.ascii_letters + string.digits + "!@#$%" return ''.join(secrets.choice(alphabet) for _ in range(length)) def generate_username_from_email(email: str) -> str: """Generate username from email""" username_base = email.split('@')[0].lower() # Remove special characters username = ''.join(c for c in username_base if c.isalnum() or c in '_-.') return username # ============= JWT Tokens ============= def create_jwt_token(payload: dict, expires_delta: timedelta = timedelta(days=7)) -> str: """Create a simple JWT token""" header = {"alg": "HS256", "typ": "JWT"} payload["exp"] = (datetime.now(timezone.utc) + expires_delta).timestamp() payload["iat"] = datetime.now(timezone.utc).timestamp() header_b64 = base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=") payload_b64 = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=") message = f"{header_b64}.{payload_b64}" signature = hmac.new(JWT_SECRET.encode(), message.encode(), hashlib.sha256).digest() signature_b64 = base64.urlsafe_b64encode(signature).decode().rstrip("=") return f"{message}.{signature_b64}" def verify_jwt_token(token: str) -> Optional[dict]: """Verify and decode JWT token""" try: parts = token.split(".") if len(parts) != 3: return None header_b64, payload_b64, signature_b64 = parts # Verify signature message = f"{header_b64}.{payload_b64}" expected_sig = hmac.new(JWT_SECRET.encode(), message.encode(), hashlib.sha256).digest() expected_sig_b64 = base64.urlsafe_b64encode(expected_sig).decode().rstrip("=") if not hmac.compare_digest(signature_b64, expected_sig_b64): return None # Decode payload with padding correction payload_b64 += "=" * ((4 - len(payload_b64) % 4) % 4) payload = json.loads(base64.urlsafe_b64decode(payload_b64)) # Check expiration if payload.get("exp", 0) < datetime.now(timezone.utc).timestamp(): return None return payload except Exception: return None