| | """ |
| | WebSocket Support Module |
| | Provides real-time updates via WebSocket connections with connection management |
| | """ |
| |
|
| | import asyncio |
| | import json |
| | from datetime import datetime |
| | from typing import Set, Dict, Any, Optional, List |
| | from fastapi import WebSocket, WebSocketDisconnect, APIRouter |
| | from starlette.websockets import WebSocketState |
| | from utils.logger import setup_logger |
| | from database.db_manager import db_manager |
| | from monitoring.rate_limiter import rate_limiter |
| | from config import config |
| |
|
| | |
| | logger = setup_logger("websocket", level="INFO") |
| |
|
| | |
| | router = APIRouter() |
| |
|
| |
|
| | class ConnectionManager: |
| | """ |
| | Manages WebSocket connections and broadcasts messages to all connected clients |
| | """ |
| |
|
| | def __init__(self): |
| | """Initialize connection manager""" |
| | self.active_connections: Set[WebSocket] = set() |
| | self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {} |
| | self._broadcast_task: Optional[asyncio.Task] = None |
| | self._heartbeat_task: Optional[asyncio.Task] = None |
| | self._is_running = False |
| |
|
| | async def connect(self, websocket: WebSocket, client_id: str = None): |
| | """ |
| | Accept and register a new WebSocket connection |
| | |
| | Args: |
| | websocket: WebSocket connection |
| | client_id: Optional client identifier |
| | """ |
| | await websocket.accept() |
| | self.active_connections.add(websocket) |
| |
|
| | |
| | self.connection_metadata[websocket] = { |
| | 'client_id': client_id or f"client_{id(websocket)}", |
| | 'connected_at': datetime.utcnow().isoformat(), |
| | 'last_ping': datetime.utcnow().isoformat() |
| | } |
| |
|
| | logger.info( |
| | f"WebSocket connected: {self.connection_metadata[websocket]['client_id']} " |
| | f"(Total connections: {len(self.active_connections)})" |
| | ) |
| |
|
| | |
| | await self.send_personal_message( |
| | { |
| | 'type': 'connection_established', |
| | 'client_id': self.connection_metadata[websocket]['client_id'], |
| | 'timestamp': datetime.utcnow().isoformat(), |
| | 'message': 'Connected to Crypto API Monitor WebSocket' |
| | }, |
| | websocket |
| | ) |
| |
|
| | def disconnect(self, websocket: WebSocket): |
| | """ |
| | Unregister and close a WebSocket connection |
| | |
| | Args: |
| | websocket: WebSocket connection to disconnect |
| | """ |
| | if websocket in self.active_connections: |
| | client_id = self.connection_metadata.get(websocket, {}).get('client_id', 'unknown') |
| | self.active_connections.remove(websocket) |
| |
|
| | if websocket in self.connection_metadata: |
| | del self.connection_metadata[websocket] |
| |
|
| | logger.info( |
| | f"WebSocket disconnected: {client_id} " |
| | f"(Remaining connections: {len(self.active_connections)})" |
| | ) |
| |
|
| | async def send_personal_message(self, message: Dict[str, Any], websocket: WebSocket): |
| | """ |
| | Send a message to a specific WebSocket connection |
| | |
| | Args: |
| | message: Message dictionary to send |
| | websocket: Target WebSocket connection |
| | """ |
| | try: |
| | if websocket.client_state == WebSocketState.CONNECTED: |
| | await websocket.send_json(message) |
| | except Exception as e: |
| | logger.error(f"Error sending personal message: {e}") |
| | self.disconnect(websocket) |
| |
|
| | async def broadcast(self, message: Dict[str, Any]): |
| | """ |
| | Broadcast a message to all connected clients |
| | |
| | Args: |
| | message: Message dictionary to broadcast |
| | """ |
| | disconnected = [] |
| |
|
| | for connection in self.active_connections.copy(): |
| | try: |
| | if connection.client_state == WebSocketState.CONNECTED: |
| | await connection.send_json(message) |
| | else: |
| | disconnected.append(connection) |
| | except Exception as e: |
| | logger.error(f"Error broadcasting to client: {e}") |
| | disconnected.append(connection) |
| |
|
| | |
| | for connection in disconnected: |
| | self.disconnect(connection) |
| |
|
| | async def broadcast_status_update(self): |
| | """ |
| | Broadcast system status update to all connected clients |
| | """ |
| | try: |
| | |
| | latest_metrics = db_manager.get_latest_system_metrics() |
| |
|
| | |
| | providers = config.get_all_providers() |
| |
|
| | |
| | rate_limit_statuses = rate_limiter.get_all_statuses() |
| |
|
| | |
| | alerts = db_manager.get_alerts(acknowledged=False, hours=1) |
| |
|
| | |
| | message = { |
| | 'type': 'status_update', |
| | 'timestamp': datetime.utcnow().isoformat(), |
| | 'system_metrics': { |
| | 'total_providers': latest_metrics.total_providers if latest_metrics else len(providers), |
| | 'online_count': latest_metrics.online_count if latest_metrics else 0, |
| | 'degraded_count': latest_metrics.degraded_count if latest_metrics else 0, |
| | 'offline_count': latest_metrics.offline_count if latest_metrics else 0, |
| | 'avg_response_time_ms': latest_metrics.avg_response_time_ms if latest_metrics else 0, |
| | 'total_requests_hour': latest_metrics.total_requests_hour if latest_metrics else 0, |
| | 'total_failures_hour': latest_metrics.total_failures_hour if latest_metrics else 0, |
| | 'system_health': latest_metrics.system_health if latest_metrics else 'unknown' |
| | }, |
| | 'alert_count': len(alerts), |
| | 'active_websocket_clients': len(self.active_connections) |
| | } |
| |
|
| | await self.broadcast(message) |
| | logger.debug(f"Broadcasted status update to {len(self.active_connections)} clients") |
| |
|
| | except Exception as e: |
| | logger.error(f"Error broadcasting status update: {e}", exc_info=True) |
| |
|
| | async def broadcast_new_log_entry(self, log_type: str, log_data: Dict[str, Any]): |
| | """ |
| | Broadcast a new log entry |
| | |
| | Args: |
| | log_type: Type of log (connection, failure, collection, rate_limit) |
| | log_data: Log data dictionary |
| | """ |
| | try: |
| | message = { |
| | 'type': 'new_log_entry', |
| | 'timestamp': datetime.utcnow().isoformat(), |
| | 'log_type': log_type, |
| | 'data': log_data |
| | } |
| |
|
| | await self.broadcast(message) |
| | logger.debug(f"Broadcasted new {log_type} log entry") |
| |
|
| | except Exception as e: |
| | logger.error(f"Error broadcasting log entry: {e}", exc_info=True) |
| |
|
| | async def broadcast_rate_limit_alert(self, provider_name: str, percentage: float): |
| | """ |
| | Broadcast rate limit alert |
| | |
| | Args: |
| | provider_name: Provider name |
| | percentage: Current usage percentage |
| | """ |
| | try: |
| | message = { |
| | 'type': 'rate_limit_alert', |
| | 'timestamp': datetime.utcnow().isoformat(), |
| | 'provider': provider_name, |
| | 'percentage': percentage, |
| | 'severity': 'critical' if percentage >= 95 else 'warning' |
| | } |
| |
|
| | await self.broadcast(message) |
| | logger.info(f"Broadcasted rate limit alert for {provider_name} ({percentage}%)") |
| |
|
| | except Exception as e: |
| | logger.error(f"Error broadcasting rate limit alert: {e}", exc_info=True) |
| |
|
| | async def broadcast_provider_status_change( |
| | self, |
| | provider_name: str, |
| | old_status: str, |
| | new_status: str, |
| | details: Optional[Dict] = None |
| | ): |
| | """ |
| | Broadcast provider status change |
| | |
| | Args: |
| | provider_name: Provider name |
| | old_status: Previous status |
| | new_status: New status |
| | details: Optional details about the change |
| | """ |
| | try: |
| | message = { |
| | 'type': 'provider_status_change', |
| | 'timestamp': datetime.utcnow().isoformat(), |
| | 'provider': provider_name, |
| | 'old_status': old_status, |
| | 'new_status': new_status, |
| | 'details': details or {} |
| | } |
| |
|
| | await self.broadcast(message) |
| | logger.info( |
| | f"Broadcasted provider status change: {provider_name} " |
| | f"{old_status} -> {new_status}" |
| | ) |
| |
|
| | except Exception as e: |
| | logger.error(f"Error broadcasting provider status change: {e}", exc_info=True) |
| |
|
| | async def _periodic_broadcast_loop(self): |
| | """ |
| | Background task that broadcasts updates every 10 seconds |
| | """ |
| | logger.info("Starting periodic broadcast loop") |
| |
|
| | while self._is_running: |
| | try: |
| | |
| | await self.broadcast_status_update() |
| |
|
| | |
| | rate_limit_statuses = rate_limiter.get_all_statuses() |
| | for provider, status_data in rate_limit_statuses.items(): |
| | if status_data and status_data.get('percentage', 0) >= 80: |
| | await self.broadcast_rate_limit_alert( |
| | provider, |
| | status_data['percentage'] |
| | ) |
| |
|
| | |
| | await asyncio.sleep(10) |
| |
|
| | except Exception as e: |
| | logger.error(f"Error in periodic broadcast loop: {e}", exc_info=True) |
| | await asyncio.sleep(10) |
| |
|
| | logger.info("Periodic broadcast loop stopped") |
| |
|
| | async def _heartbeat_loop(self): |
| | """ |
| | Background task that sends heartbeat pings to all clients |
| | """ |
| | logger.info("Starting heartbeat loop") |
| |
|
| | while self._is_running: |
| | try: |
| | |
| | ping_message = { |
| | 'type': 'ping', |
| | 'timestamp': datetime.utcnow().isoformat() |
| | } |
| |
|
| | await self.broadcast(ping_message) |
| |
|
| | |
| | await asyncio.sleep(30) |
| |
|
| | except Exception as e: |
| | logger.error(f"Error in heartbeat loop: {e}", exc_info=True) |
| | await asyncio.sleep(30) |
| |
|
| | logger.info("Heartbeat loop stopped") |
| |
|
| | async def start_background_tasks(self): |
| | """ |
| | Start background broadcast and heartbeat tasks |
| | """ |
| | if self._is_running: |
| | logger.warning("Background tasks already running") |
| | return |
| |
|
| | self._is_running = True |
| |
|
| | |
| | self._broadcast_task = asyncio.create_task(self._periodic_broadcast_loop()) |
| | logger.info("Started periodic broadcast task") |
| |
|
| | |
| | self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) |
| | logger.info("Started heartbeat task") |
| |
|
| | async def stop_background_tasks(self): |
| | """ |
| | Stop background broadcast and heartbeat tasks |
| | """ |
| | if not self._is_running: |
| | logger.warning("Background tasks not running") |
| | return |
| |
|
| | self._is_running = False |
| |
|
| | |
| | if self._broadcast_task: |
| | self._broadcast_task.cancel() |
| | try: |
| | await self._broadcast_task |
| | except asyncio.CancelledError: |
| | pass |
| | logger.info("Stopped periodic broadcast task") |
| |
|
| | |
| | if self._heartbeat_task: |
| | self._heartbeat_task.cancel() |
| | try: |
| | await self._heartbeat_task |
| | except asyncio.CancelledError: |
| | pass |
| | logger.info("Stopped heartbeat task") |
| |
|
| | async def close_all_connections(self): |
| | """ |
| | Close all active WebSocket connections |
| | """ |
| | logger.info(f"Closing {len(self.active_connections)} active connections") |
| |
|
| | for connection in self.active_connections.copy(): |
| | try: |
| | if connection.client_state == WebSocketState.CONNECTED: |
| | await connection.close(code=1000, reason="Server shutdown") |
| | except Exception as e: |
| | logger.error(f"Error closing connection: {e}") |
| |
|
| | self.active_connections.clear() |
| | self.connection_metadata.clear() |
| | logger.info("All WebSocket connections closed") |
| |
|
| | def get_connection_count(self) -> int: |
| | """ |
| | Get the number of active connections |
| | |
| | Returns: |
| | Number of active connections |
| | """ |
| | return len(self.active_connections) |
| |
|
| | def get_connection_info(self) -> List[Dict[str, Any]]: |
| | """ |
| | Get information about all active connections |
| | |
| | Returns: |
| | List of connection metadata dictionaries |
| | """ |
| | return [ |
| | { |
| | 'client_id': metadata['client_id'], |
| | 'connected_at': metadata['connected_at'], |
| | 'last_ping': metadata['last_ping'] |
| | } |
| | for metadata in self.connection_metadata.values() |
| | ] |
| |
|
| |
|
| | |
| | manager = ConnectionManager() |
| |
|
| |
|
| | @router.websocket("/ws/live") |
| | async def websocket_live_endpoint(websocket: WebSocket): |
| | """ |
| | WebSocket endpoint for real-time updates |
| | |
| | Provides: |
| | - System status updates every 10 seconds |
| | - Real-time log entries |
| | - Rate limit alerts |
| | - Provider status changes |
| | - Heartbeat pings every 30 seconds |
| | |
| | Message Types: |
| | - connection_established: Sent when client connects |
| | - status_update: Periodic system status (every 10s) |
| | - new_log_entry: New log entry notification |
| | - rate_limit_alert: Rate limit warning |
| | - provider_status_change: Provider status change |
| | - ping: Heartbeat ping (every 30s) |
| | """ |
| | client_id = None |
| |
|
| | try: |
| | |
| | await manager.connect(websocket) |
| | client_id = manager.connection_metadata.get(websocket, {}).get('client_id', 'unknown') |
| |
|
| | |
| | if not manager._is_running: |
| | await manager.start_background_tasks() |
| |
|
| | |
| | while True: |
| | try: |
| | |
| | data = await websocket.receive_text() |
| |
|
| | |
| | try: |
| | message = json.loads(data) |
| |
|
| | |
| | if message.get('type') == 'pong': |
| | if websocket in manager.connection_metadata: |
| | manager.connection_metadata[websocket]['last_ping'] = datetime.utcnow().isoformat() |
| | logger.debug(f"Received pong from {client_id}") |
| |
|
| | |
| | elif message.get('type') == 'subscribe': |
| | |
| | logger.debug(f"Client {client_id} subscription request: {message}") |
| |
|
| | |
| | elif message.get('type') == 'unsubscribe': |
| | logger.debug(f"Client {client_id} unsubscribe request: {message}") |
| |
|
| | except json.JSONDecodeError: |
| | logger.warning(f"Received invalid JSON from {client_id}: {data}") |
| |
|
| | except WebSocketDisconnect: |
| | logger.info(f"Client {client_id} disconnected") |
| | break |
| |
|
| | except Exception as e: |
| | logger.error(f"Error handling message from {client_id}: {e}", exc_info=True) |
| | break |
| |
|
| | except Exception as e: |
| | logger.error(f"WebSocket error for {client_id}: {e}", exc_info=True) |
| |
|
| | finally: |
| | |
| | manager.disconnect(websocket) |
| |
|
| |
|
| | @router.get("/ws/stats") |
| | async def websocket_stats(): |
| | """ |
| | Get WebSocket connection statistics |
| | |
| | Returns: |
| | Dictionary with connection stats |
| | """ |
| | return { |
| | 'active_connections': manager.get_connection_count(), |
| | 'connections': manager.get_connection_info(), |
| | 'background_tasks_running': manager._is_running, |
| | 'timestamp': datetime.utcnow().isoformat() |
| | } |
| |
|
| |
|
| | |
| | __all__ = ['router', 'manager', 'ConnectionManager'] |
| |
|