""" Database module for PostgreSQL integration (Neon/Vercel Postgres). This module provides a clean abstraction layer for database operations, supporting both in-memory storage (for development/testing) and PostgreSQL via Neon (for production). Environment Variables: DATABASE_URL or POSTGRES_URL: PostgreSQL connection string TESTING: Set to "true" to force in-memory database Usage: from database import db # Create user user = await db.create_user(email, password_hash, full_name) # Get user user = await db.get_user_by_email(email) """ import os import sys from datetime import datetime, timezone from typing import Optional from abc import ABC, abstractmethod from urllib.parse import urlparse, parse_qs import asyncio import threading import asyncio import threading import time import logging # Configure logger logger = logging.getLogger(__name__) # pg8000 for PostgreSQL - pure Python, lightweight # Wrap in try/except for safer imports PG8000_AVAILABLE = False pg8000 = None try: import pg8000 import pg8000.native PG8000_AVAILABLE = True PG8000_AVAILABLE = True except ImportError as e: logger.warning(f"pg8000 not available: {e}") except Exception as e: logger.error(f"Error loading pg8000: {e}") # ============= Abstract Database Interface ============= class DatabaseInterface(ABC): """Abstract interface for database operations""" @abstractmethod async def init_tables(self) -> None: """Initialize database tables""" pass # User operations @abstractmethod async def create_user(self, email: str, password_hash: str, full_name: str, is_admin: bool = False, totp_secret: Optional[str] = None) -> dict: pass @abstractmethod async def get_user_by_email(self, email: str) -> Optional[dict]: pass @abstractmethod async def get_user_by_id(self, user_id: int) -> Optional[dict]: pass @abstractmethod async def update_user_last_login(self, email: str) -> None: pass @abstractmethod async def delete_user(self, user_id: int) -> bool: pass @abstractmethod async def list_users(self) -> list[dict]: pass # Hotel operations @abstractmethod @abstractmethod async def create_hotel(self, name: str, owner_id: int, booking_url: Optional[str] = None, website_url: Optional[str] = None) -> dict: pass @abstractmethod async def get_hotels(self, owner_id: Optional[int] = None) -> list[dict]: pass @abstractmethod async def get_hotel(self, hotel_id: int) -> Optional[dict]: pass @abstractmethod async def delete_hotel(self, hotel_id: int, owner_id: Optional[int] = None) -> bool: pass # ============= In-Memory Database (Development/Testing) ============= class InMemoryDatabase(DatabaseInterface): """In-memory database for development and testing""" def __init__(self): self._users: dict[str, dict] = {} self._hotels: list[dict] = [] self._user_id_counter = 0 self._hotel_id_counter = 0 async def init_tables(self) -> None: """No-op for in-memory database""" pass async def create_user(self, email: str, password_hash: str, full_name: str, is_admin: bool = False, totp_secret: Optional[str] = None) -> dict: self._user_id_counter += 1 user = { "id": self._user_id_counter, "email": email, "password_hash": password_hash, "full_name": full_name, "is_admin": is_admin, "totp_secret": totp_secret, "created_at": datetime.now(timezone.utc).isoformat(), "last_login": None } self._users[email] = user return user async def get_user_by_email(self, email: str) -> Optional[dict]: return self._users.get(email) async def get_user_by_id(self, user_id: int) -> Optional[dict]: for user in self._users.values(): if user["id"] == user_id: return user return None async def update_user_last_login(self, email: str) -> None: if email in self._users: self._users[email]["last_login"] = datetime.now(timezone.utc).isoformat() async def delete_user(self, user_id: int) -> bool: for email, user in list(self._users.items()): if user["id"] == user_id: del self._users[email] return True return False async def list_users(self) -> list[dict]: return list(self._users.values()) async def create_hotel(self, name: str, owner_id: int, booking_url: Optional[str] = None, website_url: Optional[str] = None) -> dict: self._hotel_id_counter += 1 hotel = { "id": self._hotel_id_counter, "name": name, "owner_id": owner_id, "booking_url": booking_url, "website_url": website_url, "created_at": datetime.now(timezone.utc).isoformat() } self._hotels.append(hotel) return hotel async def get_hotels(self, owner_id: Optional[int] = None) -> list[dict]: if owner_id: return [h for h in self._hotels if h.get("owner_id") == owner_id or h.get("owner_id") is None] return self._hotels async def get_hotel(self, hotel_id: int) -> Optional[dict]: for hotel in self._hotels: if hotel["id"] == hotel_id: return hotel return None async def delete_hotel(self, hotel_id: int, owner_id: Optional[int] = None) -> bool: for i, hotel in enumerate(self._hotels): if hotel["id"] == hotel_id: if owner_id and hotel.get("owner_id") != owner_id: return False self._hotels.pop(i) return True return False def seed_admin_users(self, admin_configs: list[dict]) -> None: """Seed admin users for development""" for config in admin_configs: self._user_id_counter += 1 self._users[config["email"]] = { "id": self._user_id_counter, "email": config["email"], "password_hash": config["password_hash"], "full_name": config["full_name"], "is_admin": True, "totp_secret": config.get("totp_secret"), "created_at": datetime.now(timezone.utc).isoformat(), "last_login": None } def seed_regular_users(self, user_configs: list[dict]) -> None: """Seed regular (non-admin) users for development""" for config in user_configs: self._user_id_counter += 1 self._users[config["email"]] = { "id": self._user_id_counter, "email": config["email"], "password_hash": config["password_hash"], "full_name": config["full_name"], "is_admin": False, # Regular users are NOT admins "totp_secret": config.get("totp_secret"), "created_at": datetime.now(timezone.utc).isoformat(), "last_login": None } def seed_demo_hotels(self, hotels: list[dict]) -> None: """Seed demo hotels for development""" for hotel in hotels: self._hotel_id_counter += 1 self._hotels.append({ "id": self._hotel_id_counter, "name": hotel["name"], "owner_id": None, # Publicly visible demo hotel "booking_url": hotel.get("booking_url"), "website_url": hotel.get("website_url"), "created_at": datetime.now(timezone.utc).isoformat() }) # ============= PostgreSQL Database (Production - Neon/Vercel) ============= class PostgresDatabase(DatabaseInterface): """PostgreSQL database using pg8000 (pure Python, works with Neon, Vercel Postgres, etc.)""" def __init__(self, connection_url: str): """ Initialize with a PostgreSQL connection URL. Format: postgresql://user:password@host:port/database?sslmode=require """ self.connection_url = connection_url self._conn = None self._parse_connection_url() def _parse_connection_url(self): """Parse the connection URL into components""" from urllib.parse import unquote parsed = urlparse(self.connection_url) self._host = parsed.hostname or 'localhost' self._port = parsed.port or 5432 self._user = unquote(parsed.username or 'postgres') self._password = unquote(parsed.password or '') self._database = parsed.path.lstrip('/').split('?')[0] or 'postgres' # Parse query params for SSL query_params = parse_qs(parsed.query) self._ssl = 'sslmode' in query_params and query_params['sslmode'][0] != 'disable' # Neon and Vercel always require SSL if 'neon' in (self._host or '') or 'vercel' in (self._host or ''): self._ssl = True def _get_connection(self): """Get or create database connection""" # Ensure we have a lock if not hasattr(self, '_lock'): self._lock = threading.Lock() if self._conn is None: try: ssl_context = None if self._ssl: import ssl ssl_context = ssl.create_default_context() # pg8000.connect is the correct API self._conn = pg8000.connect( host=self._host, port=self._port, user=self._user, password=self._password, database=self._database, ssl_context=ssl_context ) self._conn.autocommit = True except Exception as e: logger.error(f"Failed to connect to database: {e}") raise return self._conn def _run_query_in_thread(self, query: str, params: Optional[tuple] = None) -> list[dict]: """Run query in a separate thread with locking and retry logic""" if not hasattr(self, '_lock'): self._lock = threading.Lock() with self._lock: # Simple retry mechanism for lost connections max_retries = 2 for attempt in range(max_retries): try: conn = self._get_connection() cursor = conn.cursor() try: # pg8000 uses %s style parameters just like psycopg2 if params is not None: cursor.execute(query, params) else: cursor.execute(query) # Check if this is a SELECT-like query that returns rows try: if cursor.description and len(cursor.description) > 0: columns = [desc[0] for desc in cursor.description] rows = cursor.fetchall() return [dict(zip(columns, row)) for row in rows] except (TypeError, AttributeError): # DDL statements don't return rows pass return [] finally: cursor.close() except (pg8000.InterfaceError, pg8000.DatabaseError, AttributeError) as e: # Connection might be dead, close and retry logger.warning(f"Database error (attempt {attempt+1}/{max_retries}): {e}") if self._conn: try: self._conn.close() except: pass self._conn = None if attempt == max_retries - 1: raise e # Wait a bit before retry time.sleep(0.5) # Should not be reached due to raise above return [] async def _execute_query(self, query: str, params: Optional[tuple] = None) -> list[dict]: """Execute a SQL query asynchronously using a thread pool""" loop = asyncio.get_running_loop() return await loop.run_in_executor(None, self._run_query_in_thread, query, params) async def init_tables(self) -> None: """Initialize database tables""" # Users table await self._execute_query(""" CREATE TABLE IF NOT EXISTS users ( id SERIAL PRIMARY KEY, email VARCHAR(255) UNIQUE NOT NULL, password_hash VARCHAR(512) NOT NULL, full_name VARCHAR(255) NOT NULL, is_admin BOOLEAN DEFAULT FALSE, totp_secret VARCHAR(64), created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, last_login TIMESTAMP ) """) # Hotels table await self._execute_query(""" CREATE TABLE IF NOT EXISTS hotels ( id SERIAL PRIMARY KEY, name VARCHAR(255) NOT NULL, booking_url TEXT, website_url TEXT, owner_id INTEGER REFERENCES users(id) ON DELETE CASCADE, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) # Add owner_id column if it doesn't exist (for migration) await self._execute_query(""" ALTER TABLE hotels ADD COLUMN IF NOT EXISTS owner_id INTEGER REFERENCES users(id) ON DELETE CASCADE """) # Price comparisons table (for history) await self._execute_query(""" CREATE TABLE IF NOT EXISTS price_comparisons ( id SERIAL PRIMARY KEY, user_id INTEGER REFERENCES users(id) ON DELETE CASCADE, hotel_ids INTEGER[] NOT NULL, check_in DATE NOT NULL, check_out DATE NOT NULL, occupancy VARCHAR(50), results JSONB, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) # Create index for faster email lookups await self._execute_query(""" CREATE INDEX IF NOT EXISTS idx_users_email ON users(email) """) async def create_user(self, email: str, password_hash: str, full_name: str, is_admin: bool = False, totp_secret: Optional[str] = None) -> dict: rows = await self._execute_query( """ INSERT INTO users (email, password_hash, full_name, is_admin, totp_secret) VALUES (%s, %s, %s, %s, %s) RETURNING id, email, password_hash, full_name, is_admin, totp_secret, created_at::text, last_login::text """, (email, password_hash, full_name, is_admin, totp_secret) ) return rows[0] if rows else {} async def get_user_by_email(self, email: str) -> Optional[dict]: rows = await self._execute_query( """SELECT id, email, password_hash, full_name, is_admin, totp_secret, created_at::text, last_login::text FROM users WHERE email = %s""", (email,) ) return rows[0] if rows else None async def get_user_by_id(self, user_id: int) -> Optional[dict]: rows = await self._execute_query( """SELECT id, email, password_hash, full_name, is_admin, totp_secret, created_at::text, last_login::text FROM users WHERE id = %s""", (user_id,) ) return rows[0] if rows else None async def update_user_last_login(self, email: str) -> None: await self._execute_query( "UPDATE users SET last_login = CURRENT_TIMESTAMP WHERE email = %s", (email,) ) async def delete_user(self, user_id: int) -> bool: rows = await self._execute_query( "DELETE FROM users WHERE id = %s RETURNING id", (user_id,) ) return len(rows) > 0 async def list_users(self) -> list[dict]: return await self._execute_query( """SELECT id, email, password_hash, full_name, is_admin, totp_secret, created_at::text, last_login::text FROM users ORDER BY created_at DESC""" ) async def create_hotel(self, name: str, owner_id: int, booking_url: Optional[str] = None, website_url: Optional[str] = None) -> dict: rows = await self._execute_query( """ INSERT INTO hotels (name, owner_id, booking_url, website_url) VALUES (%s, %s, %s, %s) RETURNING id, name, owner_id, booking_url, website_url, created_at::text """, (name, owner_id, booking_url, website_url) ) return rows[0] if rows else {} async def get_hotels(self, owner_id: Optional[int] = None) -> list[dict]: if owner_id: return await self._execute_query( """SELECT id, name, owner_id, booking_url, website_url, created_at::text FROM hotels WHERE owner_id = %s OR owner_id IS NULL ORDER BY name""", (owner_id,) ) return await self._execute_query( "SELECT id, name, owner_id, booking_url, website_url, created_at::text FROM hotels ORDER BY name" ) async def get_hotel(self, hotel_id: int) -> Optional[dict]: """Get a single hotel by ID""" rows = await self._execute_query( """SELECT id, name, owner_id, booking_url, website_url, created_at::text FROM hotels WHERE id = %s""", (hotel_id,) ) return rows[0] if rows else None async def delete_hotel(self, hotel_id: int, owner_id: Optional[int] = None) -> bool: """Delete hotel, optionally verifying ownership""" if owner_id: rows = await self._execute_query( "DELETE FROM hotels WHERE id = %s AND owner_id = %s RETURNING id", (hotel_id, owner_id) ) else: rows = await self._execute_query( "DELETE FROM hotels WHERE id = %s RETURNING id", (hotel_id,) ) return len(rows) > 0 rows = await self._execute_query( "DELETE FROM hotels WHERE id = %s RETURNING id", (hotel_id,) ) return len(rows) > 0 def close(self): """Close database connection""" if self._conn: self._conn.close() self._conn = None # ============= Database Factory ============= def create_database() -> DatabaseInterface: """ Create the appropriate database instance based on environment. Environment variables checked (in order): - DATABASE_URL: Standard PostgreSQL connection string - POSTGRES_URL: Vercel/Neon Postgres connection string Returns InMemoryDatabase for development/testing. Returns PostgresDatabase when a connection URL is configured. """ # Check for database URL (multiple common env var names) postgres_url = ( os.environ.get("DATABASE_URL") or os.environ.get("POSTGRES_URL") or os.environ.get("POSTGRES_URL_NON_POOLING") or "" ) testing = os.environ.get("TESTING", "false").lower() == "true" # Use in-memory database for testing or when no database URL is configured if testing or not postgres_url: return InMemoryDatabase() # Check if pg8000 is available if not PG8000_AVAILABLE: logger.warning("pg8000 not available, falling back to in-memory database") return InMemoryDatabase() return PostgresDatabase(postgres_url) # Singleton database instance db: DatabaseInterface = create_database()