Spaces:
Sleeping
Sleeping
| import base64 | |
| import hashlib | |
| import hmac | |
| import json | |
| import os | |
| import struct | |
| import time | |
| import secrets | |
| import string | |
| from datetime import datetime, timedelta, timezone | |
| from typing import Optional | |
| # Environment Configuration for Utils | |
| TESTING = os.environ.get("TESTING", "false").lower() == "true" | |
| ENV_MODE = os.environ.get("ENV_MODE", "development") | |
| IS_PRODUCTION = ENV_MODE == "production" | |
| def get_required_env(key: str, default: Optional[str] = None) -> str: | |
| """Get environment variable, raise error if required and not set""" | |
| value = os.environ.get(key, default) | |
| if value is None: | |
| raise RuntimeError(f"Required environment variable {key} is not set") | |
| return value | |
| # JWT Secret setup | |
| if IS_PRODUCTION: | |
| JWT_SECRET = get_required_env("JWT_SECRET") | |
| else: | |
| JWT_SECRET = os.environ.get("JWT_SECRET", "dev-only-secret-change-in-production") | |
| # ============= TOTP Functions ============= | |
| def decode_base32_safe(secret: str) -> bytes: | |
| """Safe Base32 decoding with automatic padding""" | |
| secret = secret.upper().strip() | |
| # Remove any existing padding first | |
| secret = secret.rstrip('=') | |
| # Calculate required padding | |
| padding_len = (8 - (len(secret) % 8)) % 8 | |
| secret += '=' * padding_len | |
| return base64.b32decode(secret) | |
| def get_totp_token(secret: str) -> str: | |
| """Generate current TOTP token""" | |
| try: | |
| key = decode_base32_safe(secret) | |
| except Exception: | |
| # Fallback for empty or invalid secrets | |
| return "000000" | |
| # Get current time step (30 second intervals) | |
| counter = int(time.time() // 30) | |
| # Pack counter as big-endian 8-byte integer | |
| counter_bytes = struct.pack('>Q', counter) | |
| # Generate HMAC-SHA1 | |
| hmac_hash = hmac.new(key, counter_bytes, hashlib.sha1).digest() | |
| # Dynamic truncation | |
| offset = hmac_hash[-1] & 0x0F | |
| code = struct.unpack('>I', hmac_hash[offset:offset+4])[0] & 0x7FFFFFFF | |
| # Return 6-digit code | |
| return str(code % 1000000).zfill(6) | |
| def verify_totp(secret: str, token: str) -> bool: | |
| """Verify TOTP token with 1 step tolerance""" | |
| try: | |
| key = decode_base32_safe(secret) | |
| except Exception: | |
| return False | |
| current_counter = int(time.time() // 30) | |
| # Check window: current, previous, next (30s tolerance) | |
| for offset in [0, -1, 1]: | |
| counter = current_counter + offset | |
| counter_bytes = struct.pack('>Q', counter) | |
| hmac_hash = hmac.new(key, counter_bytes, hashlib.sha1).digest() | |
| offset_val = hmac_hash[-1] & 0x0F | |
| code = struct.unpack('>I', hmac_hash[offset_val:offset_val+4])[0] & 0x7FFFFFFF | |
| expected = str(code % 1000000).zfill(6) | |
| # Constant time comparison to prevent timing attacks | |
| if hmac.compare_digest(token, expected): | |
| return True | |
| return False | |
| def generate_totp_secret() -> str: | |
| """Generate a random Base32 TOTP secret""" | |
| # Generate 20 random bytes (160 bits as recommended for TOTP) | |
| random_bytes = secrets.token_bytes(20) | |
| # Encode as base32 and remove padding | |
| return base64.b32encode(random_bytes).decode('utf-8').rstrip('=') | |
| # ============= Password Hashing ============= | |
| def hash_password(password: str) -> str: | |
| """Hash password using PBKDF2 with SHA256""" | |
| salt = os.urandom(32) | |
| key = hashlib.pbkdf2_hmac('sha256', password.encode('utf-8'), salt, 100000) | |
| return base64.b64encode(salt + key).decode('utf-8') | |
| def verify_password(password: str, hashed: str) -> bool: | |
| """Verify password against hash""" | |
| try: | |
| decoded = base64.b64decode(hashed.encode('utf-8')) | |
| salt = decoded[:32] | |
| stored_key = decoded[32:] | |
| new_key = hashlib.pbkdf2_hmac('sha256', password.encode('utf-8'), salt, 100000) | |
| return hmac.compare_digest(stored_key, new_key) | |
| except Exception: | |
| return False | |
| def generate_random_password(length: int = 12) -> str: | |
| """Generate a secure random password""" | |
| alphabet = string.ascii_letters + string.digits + "!@#$%" | |
| return ''.join(secrets.choice(alphabet) for _ in range(length)) | |
| def generate_username_from_email(email: str) -> str: | |
| """Generate username from email""" | |
| username_base = email.split('@')[0].lower() | |
| # Remove special characters | |
| username = ''.join(c for c in username_base if c.isalnum() or c in '_-.') | |
| return username | |
| # ============= JWT Tokens ============= | |
| def create_jwt_token(payload: dict, expires_delta: timedelta = timedelta(days=7)) -> str: | |
| """Create a simple JWT token""" | |
| header = {"alg": "HS256", "typ": "JWT"} | |
| payload["exp"] = (datetime.now(timezone.utc) + expires_delta).timestamp() | |
| payload["iat"] = datetime.now(timezone.utc).timestamp() | |
| header_b64 = base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=") | |
| payload_b64 = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=") | |
| message = f"{header_b64}.{payload_b64}" | |
| signature = hmac.new(JWT_SECRET.encode(), message.encode(), hashlib.sha256).digest() | |
| signature_b64 = base64.urlsafe_b64encode(signature).decode().rstrip("=") | |
| return f"{message}.{signature_b64}" | |
| def verify_jwt_token(token: str) -> Optional[dict]: | |
| """Verify and decode JWT token""" | |
| try: | |
| parts = token.split(".") | |
| if len(parts) != 3: | |
| return None | |
| header_b64, payload_b64, signature_b64 = parts | |
| # Verify signature | |
| message = f"{header_b64}.{payload_b64}" | |
| expected_sig = hmac.new(JWT_SECRET.encode(), message.encode(), hashlib.sha256).digest() | |
| expected_sig_b64 = base64.urlsafe_b64encode(expected_sig).decode().rstrip("=") | |
| if not hmac.compare_digest(signature_b64, expected_sig_b64): | |
| return None | |
| # Decode payload with padding correction | |
| payload_b64 += "=" * ((4 - len(payload_b64) % 4) % 4) | |
| payload = json.loads(base64.urlsafe_b64decode(payload_b64)) | |
| # Check expiration | |
| if payload.get("exp", 0) < datetime.now(timezone.utc).timestamp(): | |
| return None | |
| return payload | |
| except Exception: | |
| return None | |