InnSight-Backend / api /database.py
jackonthemike's picture
feat: Sync backend updates including AI Revenue Analyst
cef0de3
"""
Database module for PostgreSQL integration (Neon/Vercel Postgres).
This module provides a clean abstraction layer for database operations,
supporting both in-memory storage (for development/testing) and PostgreSQL
via Neon (for production).
Environment Variables:
DATABASE_URL or POSTGRES_URL: PostgreSQL connection string
TESTING: Set to "true" to force in-memory database
Usage:
from database import db
# Create user
user = await db.create_user(email, password_hash, full_name)
# Get user
user = await db.get_user_by_email(email)
"""
import os
import sys
from datetime import datetime, timezone
from typing import Optional
from abc import ABC, abstractmethod
from urllib.parse import urlparse, parse_qs
import asyncio
import threading
import asyncio
import threading
import time
import logging
# Configure logger
logger = logging.getLogger(__name__)
# pg8000 for PostgreSQL - pure Python, lightweight
# Wrap in try/except for safer imports
PG8000_AVAILABLE = False
pg8000 = None
try:
import pg8000
import pg8000.native
PG8000_AVAILABLE = True
PG8000_AVAILABLE = True
except ImportError as e:
logger.warning(f"pg8000 not available: {e}")
except Exception as e:
logger.error(f"Error loading pg8000: {e}")
# ============= Abstract Database Interface =============
class DatabaseInterface(ABC):
"""Abstract interface for database operations"""
@abstractmethod
async def init_tables(self) -> None:
"""Initialize database tables"""
pass
# User operations
@abstractmethod
async def create_user(self, email: str, password_hash: str, full_name: str,
is_admin: bool = False, totp_secret: Optional[str] = None) -> dict:
pass
@abstractmethod
async def get_user_by_email(self, email: str) -> Optional[dict]:
pass
@abstractmethod
async def get_user_by_id(self, user_id: int) -> Optional[dict]:
pass
@abstractmethod
async def update_user_last_login(self, email: str) -> None:
pass
@abstractmethod
async def delete_user(self, user_id: int) -> bool:
pass
@abstractmethod
async def list_users(self) -> list[dict]:
pass
# Hotel operations
@abstractmethod
@abstractmethod
async def create_hotel(self, name: str, owner_id: int, booking_url: Optional[str] = None,
website_url: Optional[str] = None) -> dict:
pass
@abstractmethod
async def get_hotels(self, owner_id: Optional[int] = None) -> list[dict]:
pass
@abstractmethod
async def get_hotel(self, hotel_id: int) -> Optional[dict]:
pass
@abstractmethod
async def delete_hotel(self, hotel_id: int, owner_id: Optional[int] = None) -> bool:
pass
# ============= In-Memory Database (Development/Testing) =============
class InMemoryDatabase(DatabaseInterface):
"""In-memory database for development and testing"""
def __init__(self):
self._users: dict[str, dict] = {}
self._hotels: list[dict] = []
self._user_id_counter = 0
self._hotel_id_counter = 0
async def init_tables(self) -> None:
"""No-op for in-memory database"""
pass
async def create_user(self, email: str, password_hash: str, full_name: str,
is_admin: bool = False, totp_secret: Optional[str] = None) -> dict:
self._user_id_counter += 1
user = {
"id": self._user_id_counter,
"email": email,
"password_hash": password_hash,
"full_name": full_name,
"is_admin": is_admin,
"totp_secret": totp_secret,
"created_at": datetime.now(timezone.utc).isoformat(),
"last_login": None
}
self._users[email] = user
return user
async def get_user_by_email(self, email: str) -> Optional[dict]:
return self._users.get(email)
async def get_user_by_id(self, user_id: int) -> Optional[dict]:
for user in self._users.values():
if user["id"] == user_id:
return user
return None
async def update_user_last_login(self, email: str) -> None:
if email in self._users:
self._users[email]["last_login"] = datetime.now(timezone.utc).isoformat()
async def delete_user(self, user_id: int) -> bool:
for email, user in list(self._users.items()):
if user["id"] == user_id:
del self._users[email]
return True
return False
async def list_users(self) -> list[dict]:
return list(self._users.values())
async def create_hotel(self, name: str, owner_id: int, booking_url: Optional[str] = None,
website_url: Optional[str] = None) -> dict:
self._hotel_id_counter += 1
hotel = {
"id": self._hotel_id_counter,
"name": name,
"owner_id": owner_id,
"booking_url": booking_url,
"website_url": website_url,
"created_at": datetime.now(timezone.utc).isoformat()
}
self._hotels.append(hotel)
return hotel
async def get_hotels(self, owner_id: Optional[int] = None) -> list[dict]:
if owner_id:
return [h for h in self._hotels if h.get("owner_id") == owner_id or h.get("owner_id") is None]
return self._hotels
async def get_hotel(self, hotel_id: int) -> Optional[dict]:
for hotel in self._hotels:
if hotel["id"] == hotel_id:
return hotel
return None
async def delete_hotel(self, hotel_id: int, owner_id: Optional[int] = None) -> bool:
for i, hotel in enumerate(self._hotels):
if hotel["id"] == hotel_id:
if owner_id and hotel.get("owner_id") != owner_id:
return False
self._hotels.pop(i)
return True
return False
def seed_admin_users(self, admin_configs: list[dict]) -> None:
"""Seed admin users for development"""
for config in admin_configs:
self._user_id_counter += 1
self._users[config["email"]] = {
"id": self._user_id_counter,
"email": config["email"],
"password_hash": config["password_hash"],
"full_name": config["full_name"],
"is_admin": True,
"totp_secret": config.get("totp_secret"),
"created_at": datetime.now(timezone.utc).isoformat(),
"last_login": None
}
def seed_regular_users(self, user_configs: list[dict]) -> None:
"""Seed regular (non-admin) users for development"""
for config in user_configs:
self._user_id_counter += 1
self._users[config["email"]] = {
"id": self._user_id_counter,
"email": config["email"],
"password_hash": config["password_hash"],
"full_name": config["full_name"],
"is_admin": False, # Regular users are NOT admins
"totp_secret": config.get("totp_secret"),
"created_at": datetime.now(timezone.utc).isoformat(),
"last_login": None
}
def seed_demo_hotels(self, hotels: list[dict]) -> None:
"""Seed demo hotels for development"""
for hotel in hotels:
self._hotel_id_counter += 1
self._hotels.append({
"id": self._hotel_id_counter,
"name": hotel["name"],
"owner_id": None, # Publicly visible demo hotel
"booking_url": hotel.get("booking_url"),
"website_url": hotel.get("website_url"),
"created_at": datetime.now(timezone.utc).isoformat()
})
# ============= PostgreSQL Database (Production - Neon/Vercel) =============
class PostgresDatabase(DatabaseInterface):
"""PostgreSQL database using pg8000 (pure Python, works with Neon, Vercel Postgres, etc.)"""
def __init__(self, connection_url: str):
"""
Initialize with a PostgreSQL connection URL.
Format: postgresql://user:password@host:port/database?sslmode=require
"""
self.connection_url = connection_url
self._conn = None
self._parse_connection_url()
def _parse_connection_url(self):
"""Parse the connection URL into components"""
from urllib.parse import unquote
parsed = urlparse(self.connection_url)
self._host = parsed.hostname or 'localhost'
self._port = parsed.port or 5432
self._user = unquote(parsed.username or 'postgres')
self._password = unquote(parsed.password or '')
self._database = parsed.path.lstrip('/').split('?')[0] or 'postgres'
# Parse query params for SSL
query_params = parse_qs(parsed.query)
self._ssl = 'sslmode' in query_params and query_params['sslmode'][0] != 'disable'
# Neon and Vercel always require SSL
if 'neon' in (self._host or '') or 'vercel' in (self._host or ''):
self._ssl = True
def _get_connection(self):
"""Get or create database connection"""
# Ensure we have a lock
if not hasattr(self, '_lock'):
self._lock = threading.Lock()
if self._conn is None:
try:
ssl_context = None
if self._ssl:
import ssl
ssl_context = ssl.create_default_context()
# pg8000.connect is the correct API
self._conn = pg8000.connect(
host=self._host,
port=self._port,
user=self._user,
password=self._password,
database=self._database,
ssl_context=ssl_context
)
self._conn.autocommit = True
except Exception as e:
logger.error(f"Failed to connect to database: {e}")
raise
return self._conn
def _run_query_in_thread(self, query: str, params: Optional[tuple] = None) -> list[dict]:
"""Run query in a separate thread with locking and retry logic"""
if not hasattr(self, '_lock'):
self._lock = threading.Lock()
with self._lock:
# Simple retry mechanism for lost connections
max_retries = 2
for attempt in range(max_retries):
try:
conn = self._get_connection()
cursor = conn.cursor()
try:
# pg8000 uses %s style parameters just like psycopg2
if params is not None:
cursor.execute(query, params)
else:
cursor.execute(query)
# Check if this is a SELECT-like query that returns rows
try:
if cursor.description and len(cursor.description) > 0:
columns = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
return [dict(zip(columns, row)) for row in rows]
except (TypeError, AttributeError):
# DDL statements don't return rows
pass
return []
finally:
cursor.close()
except (pg8000.InterfaceError, pg8000.DatabaseError, AttributeError) as e:
# Connection might be dead, close and retry
logger.warning(f"Database error (attempt {attempt+1}/{max_retries}): {e}")
if self._conn:
try:
self._conn.close()
except:
pass
self._conn = None
if attempt == max_retries - 1:
raise e
# Wait a bit before retry
time.sleep(0.5)
# Should not be reached due to raise above
return []
async def _execute_query(self, query: str, params: Optional[tuple] = None) -> list[dict]:
"""Execute a SQL query asynchronously using a thread pool"""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, self._run_query_in_thread, query, params)
async def init_tables(self) -> None:
"""Initialize database tables"""
# Users table
await self._execute_query("""
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
email VARCHAR(255) UNIQUE NOT NULL,
password_hash VARCHAR(512) NOT NULL,
full_name VARCHAR(255) NOT NULL,
is_admin BOOLEAN DEFAULT FALSE,
totp_secret VARCHAR(64),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_login TIMESTAMP
)
""")
# Hotels table
await self._execute_query("""
CREATE TABLE IF NOT EXISTS hotels (
id SERIAL PRIMARY KEY,
name VARCHAR(255) NOT NULL,
booking_url TEXT,
website_url TEXT,
owner_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Add owner_id column if it doesn't exist (for migration)
await self._execute_query("""
ALTER TABLE hotels ADD COLUMN IF NOT EXISTS owner_id INTEGER REFERENCES users(id) ON DELETE CASCADE
""")
# Price comparisons table (for history)
await self._execute_query("""
CREATE TABLE IF NOT EXISTS price_comparisons (
id SERIAL PRIMARY KEY,
user_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
hotel_ids INTEGER[] NOT NULL,
check_in DATE NOT NULL,
check_out DATE NOT NULL,
occupancy VARCHAR(50),
results JSONB,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create index for faster email lookups
await self._execute_query("""
CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)
""")
async def create_user(self, email: str, password_hash: str, full_name: str,
is_admin: bool = False, totp_secret: Optional[str] = None) -> dict:
rows = await self._execute_query(
"""
INSERT INTO users (email, password_hash, full_name, is_admin, totp_secret)
VALUES (%s, %s, %s, %s, %s)
RETURNING id, email, password_hash, full_name, is_admin, totp_secret,
created_at::text, last_login::text
""",
(email, password_hash, full_name, is_admin, totp_secret)
)
return rows[0] if rows else {}
async def get_user_by_email(self, email: str) -> Optional[dict]:
rows = await self._execute_query(
"""SELECT id, email, password_hash, full_name, is_admin, totp_secret,
created_at::text, last_login::text
FROM users WHERE email = %s""",
(email,)
)
return rows[0] if rows else None
async def get_user_by_id(self, user_id: int) -> Optional[dict]:
rows = await self._execute_query(
"""SELECT id, email, password_hash, full_name, is_admin, totp_secret,
created_at::text, last_login::text
FROM users WHERE id = %s""",
(user_id,)
)
return rows[0] if rows else None
async def update_user_last_login(self, email: str) -> None:
await self._execute_query(
"UPDATE users SET last_login = CURRENT_TIMESTAMP WHERE email = %s",
(email,)
)
async def delete_user(self, user_id: int) -> bool:
rows = await self._execute_query(
"DELETE FROM users WHERE id = %s RETURNING id",
(user_id,)
)
return len(rows) > 0
async def list_users(self) -> list[dict]:
return await self._execute_query(
"""SELECT id, email, password_hash, full_name, is_admin, totp_secret,
created_at::text, last_login::text
FROM users ORDER BY created_at DESC"""
)
async def create_hotel(self, name: str, owner_id: int, booking_url: Optional[str] = None,
website_url: Optional[str] = None) -> dict:
rows = await self._execute_query(
"""
INSERT INTO hotels (name, owner_id, booking_url, website_url)
VALUES (%s, %s, %s, %s)
RETURNING id, name, owner_id, booking_url, website_url, created_at::text
""",
(name, owner_id, booking_url, website_url)
)
return rows[0] if rows else {}
async def get_hotels(self, owner_id: Optional[int] = None) -> list[dict]:
if owner_id:
return await self._execute_query(
"""SELECT id, name, owner_id, booking_url, website_url, created_at::text
FROM hotels WHERE owner_id = %s OR owner_id IS NULL ORDER BY name""",
(owner_id,)
)
return await self._execute_query(
"SELECT id, name, owner_id, booking_url, website_url, created_at::text FROM hotels ORDER BY name"
)
async def get_hotel(self, hotel_id: int) -> Optional[dict]:
"""Get a single hotel by ID"""
rows = await self._execute_query(
"""SELECT id, name, owner_id, booking_url, website_url, created_at::text
FROM hotels WHERE id = %s""",
(hotel_id,)
)
return rows[0] if rows else None
async def delete_hotel(self, hotel_id: int, owner_id: Optional[int] = None) -> bool:
"""Delete hotel, optionally verifying ownership"""
if owner_id:
rows = await self._execute_query(
"DELETE FROM hotels WHERE id = %s AND owner_id = %s RETURNING id",
(hotel_id, owner_id)
)
else:
rows = await self._execute_query(
"DELETE FROM hotels WHERE id = %s RETURNING id",
(hotel_id,)
)
return len(rows) > 0
rows = await self._execute_query(
"DELETE FROM hotels WHERE id = %s RETURNING id",
(hotel_id,)
)
return len(rows) > 0
def close(self):
"""Close database connection"""
if self._conn:
self._conn.close()
self._conn = None
# ============= Database Factory =============
def create_database() -> DatabaseInterface:
"""
Create the appropriate database instance based on environment.
Environment variables checked (in order):
- DATABASE_URL: Standard PostgreSQL connection string
- POSTGRES_URL: Vercel/Neon Postgres connection string
Returns InMemoryDatabase for development/testing.
Returns PostgresDatabase when a connection URL is configured.
"""
# Check for database URL (multiple common env var names)
postgres_url = (
os.environ.get("DATABASE_URL") or
os.environ.get("POSTGRES_URL") or
os.environ.get("POSTGRES_URL_NON_POOLING") or
""
)
testing = os.environ.get("TESTING", "false").lower() == "true"
# Use in-memory database for testing or when no database URL is configured
if testing or not postgres_url:
return InMemoryDatabase()
# Check if pg8000 is available
if not PG8000_AVAILABLE:
logger.warning("pg8000 not available, falling back to in-memory database")
return InMemoryDatabase()
return PostgresDatabase(postgres_url)
# Singleton database instance
db: DatabaseInterface = create_database()