| | """ |
| | Database Initialization and Session Management |
| | """ |
| |
|
| | from sqlalchemy import create_engine |
| | from sqlalchemy.orm import sessionmaker, Session |
| | from contextlib import contextmanager |
| | from config import config |
| | from database.models import Base, Provider, ProviderStatusEnum |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | engine = create_engine( |
| | config.DATABASE_URL, |
| | connect_args={"check_same_thread": False} if "sqlite" in config.DATABASE_URL else {} |
| | ) |
| |
|
| | |
| | SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) |
| |
|
| |
|
| | def init_database(): |
| | """Initialize database and populate providers""" |
| | try: |
| | |
| | Base.metadata.create_all(bind=engine) |
| | logger.info("Database tables created successfully") |
| |
|
| | |
| | db = SessionLocal() |
| | try: |
| | for provider_config in config.PROVIDERS: |
| | existing = db.query(Provider).filter(Provider.name == provider_config.name).first() |
| | if not existing: |
| | provider = Provider( |
| | name=provider_config.name, |
| | category=provider_config.category, |
| | endpoint_url=provider_config.endpoint_url, |
| | requires_key=provider_config.requires_key, |
| | api_key_masked=mask_api_key(provider_config.api_key) if provider_config.api_key else None, |
| | rate_limit_type=provider_config.rate_limit_type, |
| | rate_limit_value=provider_config.rate_limit_value, |
| | timeout_ms=provider_config.timeout_ms, |
| | priority_tier=provider_config.priority_tier, |
| | status=ProviderStatusEnum.UNKNOWN |
| | ) |
| | db.add(provider) |
| |
|
| | db.commit() |
| | logger.info(f"Initialized {len(config.PROVIDERS)} providers") |
| | finally: |
| | db.close() |
| |
|
| | except Exception as e: |
| | logger.error(f"Database initialization failed: {e}") |
| | raise |
| |
|
| |
|
| | @contextmanager |
| | def get_db() -> Session: |
| | """Get database session""" |
| | db = SessionLocal() |
| | try: |
| | yield db |
| | finally: |
| | db.close() |
| |
|
| |
|
| | def mask_api_key(key: str) -> str: |
| | """Mask API key showing only first 4 and last 4 characters""" |
| | if not key or len(key) < 8: |
| | return "****" |
| | return f"{key[:4]}...{key[-4:]}" |
| |
|