File size: 6,104 Bytes
d77abf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import base64
import hashlib
import hmac
import json
import os
import struct
import time
import secrets
import string
from datetime import datetime, timedelta, timezone
from typing import Optional

# Environment Configuration for Utils
TESTING = os.environ.get("TESTING", "false").lower() == "true"
ENV_MODE = os.environ.get("ENV_MODE", "development")
IS_PRODUCTION = ENV_MODE == "production"

def get_required_env(key: str, default: Optional[str] = None) -> str:
    """Get environment variable, raise error if required and not set"""
    value = os.environ.get(key, default)
    if value is None:
        raise RuntimeError(f"Required environment variable {key} is not set")
    return value

# JWT Secret setup
if IS_PRODUCTION:
    JWT_SECRET = get_required_env("JWT_SECRET")
else:
    JWT_SECRET = os.environ.get("JWT_SECRET", "dev-only-secret-change-in-production")

# ============= TOTP Functions =============

def decode_base32_safe(secret: str) -> bytes:
    """Safe Base32 decoding with automatic padding"""
    secret = secret.upper().strip()
    # Remove any existing padding first
    secret = secret.rstrip('=')
    # Calculate required padding
    padding_len = (8 - (len(secret) % 8)) % 8
    secret += '=' * padding_len
    return base64.b32decode(secret)

def get_totp_token(secret: str) -> str:
    """Generate current TOTP token"""
    try:
        key = decode_base32_safe(secret)
    except Exception:
        # Fallback for empty or invalid secrets
        return "000000"
        
    # Get current time step (30 second intervals)
    counter = int(time.time() // 30)
    # Pack counter as big-endian 8-byte integer
    counter_bytes = struct.pack('>Q', counter)
    # Generate HMAC-SHA1
    hmac_hash = hmac.new(key, counter_bytes, hashlib.sha1).digest()
    # Dynamic truncation
    offset = hmac_hash[-1] & 0x0F
    code = struct.unpack('>I', hmac_hash[offset:offset+4])[0] & 0x7FFFFFFF
    # Return 6-digit code
    return str(code % 1000000).zfill(6)

def verify_totp(secret: str, token: str) -> bool:
    """Verify TOTP token with 1 step tolerance"""
    try:
        key = decode_base32_safe(secret)
    except Exception:
        return False

    current_counter = int(time.time() // 30)
    # Check window: current, previous, next (30s tolerance)
    for offset in [0, -1, 1]:
        counter = current_counter + offset
        counter_bytes = struct.pack('>Q', counter)
        hmac_hash = hmac.new(key, counter_bytes, hashlib.sha1).digest()
        offset_val = hmac_hash[-1] & 0x0F
        code = struct.unpack('>I', hmac_hash[offset_val:offset_val+4])[0] & 0x7FFFFFFF
        expected = str(code % 1000000).zfill(6)
        
        # Constant time comparison to prevent timing attacks
        if hmac.compare_digest(token, expected):
            return True
            
    return False

def generate_totp_secret() -> str:
    """Generate a random Base32 TOTP secret"""
    # Generate 20 random bytes (160 bits as recommended for TOTP)
    random_bytes = secrets.token_bytes(20)
    # Encode as base32 and remove padding
    return base64.b32encode(random_bytes).decode('utf-8').rstrip('=')

# ============= Password Hashing =============

def hash_password(password: str) -> str:
    """Hash password using PBKDF2 with SHA256"""
    salt = os.urandom(32)
    key = hashlib.pbkdf2_hmac('sha256', password.encode('utf-8'), salt, 100000)
    return base64.b64encode(salt + key).decode('utf-8')

def verify_password(password: str, hashed: str) -> bool:
    """Verify password against hash"""
    try:
        decoded = base64.b64decode(hashed.encode('utf-8'))
        salt = decoded[:32]
        stored_key = decoded[32:]
        new_key = hashlib.pbkdf2_hmac('sha256', password.encode('utf-8'), salt, 100000)
        return hmac.compare_digest(stored_key, new_key)
    except Exception:
        return False

def generate_random_password(length: int = 12) -> str:
    """Generate a secure random password"""
    alphabet = string.ascii_letters + string.digits + "!@#$%"
    return ''.join(secrets.choice(alphabet) for _ in range(length))

def generate_username_from_email(email: str) -> str:
    """Generate username from email"""
    username_base = email.split('@')[0].lower()
    # Remove special characters
    username = ''.join(c for c in username_base if c.isalnum() or c in '_-.')
    return username

# ============= JWT Tokens =============

def create_jwt_token(payload: dict, expires_delta: timedelta = timedelta(days=7)) -> str:
    """Create a simple JWT token"""
    header = {"alg": "HS256", "typ": "JWT"}
    payload["exp"] = (datetime.now(timezone.utc) + expires_delta).timestamp()
    payload["iat"] = datetime.now(timezone.utc).timestamp()

    header_b64 = base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
    payload_b64 = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")

    message = f"{header_b64}.{payload_b64}"
    signature = hmac.new(JWT_SECRET.encode(), message.encode(), hashlib.sha256).digest()
    signature_b64 = base64.urlsafe_b64encode(signature).decode().rstrip("=")

    return f"{message}.{signature_b64}"

def verify_jwt_token(token: str) -> Optional[dict]:
    """Verify and decode JWT token"""
    try:
        parts = token.split(".")
        if len(parts) != 3:
            return None

        header_b64, payload_b64, signature_b64 = parts

        # Verify signature
        message = f"{header_b64}.{payload_b64}"
        expected_sig = hmac.new(JWT_SECRET.encode(), message.encode(), hashlib.sha256).digest()
        expected_sig_b64 = base64.urlsafe_b64encode(expected_sig).decode().rstrip("=")

        if not hmac.compare_digest(signature_b64, expected_sig_b64):
            return None

        # Decode payload with padding correction
        payload_b64 += "=" * ((4 - len(payload_b64) % 4) % 4)
        payload = json.loads(base64.urlsafe_b64decode(payload_b64))

        # Check expiration
        if payload.get("exp", 0) < datetime.now(timezone.utc).timestamp():
            return None

        return payload
    except Exception:
        return None