""" Chart generation for forecast visualization """ import plotly.graph_objs as go from plotly.subplots import make_subplots import pandas as pd from typing import List from config.constants import COLORS, CHART_CONFIG def create_forecast_chart( historical_data: pd.DataFrame, forecast_data: pd.DataFrame, confidence_levels: List[int], title: str = "Time Series Forecast", y_axis_label: str = "Value", backtest_data: pd.DataFrame = None ) -> go.Figure: """ Create an interactive forecast chart with confidence intervals Args: historical_data: DataFrame with columns ['ds', 'y'] forecast_data: DataFrame with forecast and confidence intervals confidence_levels: List of confidence levels to plot title: Chart title y_axis_label: Label for y-axis (variable name being forecasted) backtest_data: Optional DataFrame with backtest results Returns: Plotly figure """ fig = go.Figure() # Add historical data fig.add_trace(go.Scatter( x=historical_data['ds'], y=historical_data['y'], mode='lines', name='Historical', line=dict(color=COLORS['historical'], width=2), hovertemplate=f'Date: %{{x}}
{y_axis_label}: %{{y:.2f}}' )) # Add backtest data if provided (shows model performance on historical data) if backtest_data is not None and len(backtest_data) > 0: # Add actual values from backtest period fig.add_trace(go.Scatter( x=backtest_data['timestamp'], y=backtest_data['actual'], mode='lines', name='Backtest Actual', line=dict(color='rgba(100, 100, 100, 0.6)', width=2, dash='dot'), hovertemplate=f'Date: %{{x}}
{y_axis_label} (Actual): %{{y:.2f}}' )) # Add predicted values from backtest period fig.add_trace(go.Scatter( x=backtest_data['timestamp'], y=backtest_data['predicted'], mode='lines', name='Backtest Predicted', line=dict(color='rgba(255, 100, 100, 0.8)', width=2), hovertemplate=f'Date: %{{x}}
{y_axis_label} (Predicted): %{{y:.2f}}' )) # Add confidence bands (from widest to narrowest) for cl in sorted(confidence_levels, reverse=True): lower_col = f'lower_{cl}' upper_col = f'upper_{cl}' if lower_col in forecast_data.columns and upper_col in forecast_data.columns: # Add filled area for confidence interval fig.add_trace(go.Scatter( x=forecast_data['ds'].tolist() + forecast_data['ds'].tolist()[::-1], y=forecast_data[upper_col].tolist() + forecast_data[lower_col].tolist()[::-1], fill='toself', fillcolor=COLORS['confidence'][cl], line=dict(width=0), name=f'{cl}% Confidence', showlegend=True, hoverinfo='skip' )) # Add forecast line fig.add_trace(go.Scatter( x=forecast_data['ds'], y=forecast_data['forecast'], mode='lines', name='Forecast', line=dict(color=COLORS['forecast'], width=2), hovertemplate=f'Date: %{{x}}
{y_axis_label} (Forecast): %{{y:.2f}}' )) # Add vertical separator line if len(historical_data) > 0: last_historical_date = historical_data['ds'].iloc[-1] # Use add_shape instead of add_vline to avoid Timestamp arithmetic issues fig.add_shape( type="line", x0=last_historical_date, x1=last_historical_date, y0=0, y1=1, yref="paper", line=dict(color=COLORS['separator'], dash="dash", width=1) ) # Add annotation fig.add_annotation( x=last_historical_date, y=1.0, yref="paper", text="Forecast Start", showarrow=False, yanchor="bottom" ) # Update layout fig.update_layout( title=dict(text=title, x=0.5, xanchor='center'), xaxis_title="Date", yaxis_title=y_axis_label, hovermode='x unified', template='plotly_white', height=700, # Increased height to accommodate rangeslider showlegend=True, legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 ), margin=dict(l=50, r=50, t=80, b=150), # Increased bottom margin for larger rangeslider xaxis=dict( rangeslider=dict( visible=True, thickness=0.12 # Wider slider (12% of chart height) ), type='date' ) ) # Update config fig.update_layout( modebar_add=['v1hovermode', 'toggleSpikelines'] ) return fig def create_empty_chart(message: str = "No data available") -> go.Figure: """ Create an empty placeholder chart Args: message: Message to display Returns: Plotly figure """ fig = go.Figure() fig.add_annotation( text=message, xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=20, color='gray') ) fig.update_layout( template='plotly_white', height=600, xaxis=dict(visible=False), yaxis=dict(visible=False) ) return fig def create_metrics_display(metrics: dict, inference_time: float = None) -> list: """ Create metrics display components Args: metrics: Dictionary of metric values inference_time: Time taken for inference in seconds Returns: List of Dash components """ import dash_bootstrap_components as dbc from dash import html metric_cards = [] # Add inference time if available if inference_time is not None: metric_cards.append( dbc.Col([ dbc.Card([ dbc.CardBody([ html.H6("Inference Time", className="text-muted mb-2"), html.H4(f"{inference_time:.2f}s") ]) ], className="text-center") ], md=2) ) # Add other metrics metric_names = { 'MAE': 'Mean Absolute Error', 'RMSE': 'Root Mean Squared Error', 'MAPE': 'Mean Absolute % Error', 'R2': 'R-Squared' } for key, name in metric_names.items(): if key in metrics and metrics[key] is not None: value = metrics[key] if key in ['MAPE']: formatted_value = f"{value:.2f}%" elif key == 'R2': formatted_value = f"{value:.4f}" else: formatted_value = f"{value:.2f}" metric_cards.append( dbc.Col([ dbc.Card([ dbc.CardBody([ html.H6(name, className="text-muted mb-2"), html.H4(formatted_value) ]) ], className="text-center") ], md=2) ) return metric_cards def create_backtest_metrics_display(metrics: dict) -> list: """ Create backtest metrics display components Args: metrics: Dictionary of backtest metric values (MAE, RMSE, MAPE, R2) Returns: Dash component card """ import dash_bootstrap_components as dbc from dash import html return dbc.Card([ dbc.CardHeader([ html.I(className="fas fa-chart-bar me-2"), html.Span("Backtest Performance Metrics", className="fw-bold") ]), dbc.CardBody([ html.P("Model performance on historical data validation:", className="text-muted small mb-3"), dbc.Row([ dbc.Col([ html.Small("MAE", className="text-muted"), html.H5(f"{metrics.get('MAE', 0):.2f}", className="mb-0") ], md=3), dbc.Col([ html.Small("RMSE", className="text-muted"), html.H5(f"{metrics.get('RMSE', 0):.2f}", className="mb-0") ], md=3), dbc.Col([ html.Small("MAPE", className="text-muted"), html.H5(f"{metrics.get('MAPE', 0):.2f}%", className="mb-0") ], md=3), dbc.Col([ html.Small("R²", className="text-muted"), html.H5(f"{metrics.get('R2', 0):.4f}", className="mb-0") ], md=3), ]), html.Hr(), html.Small([ html.I(className="fas fa-info-circle me-1"), "Lower MAE/RMSE/MAPE and higher R² (closer to 1.0) indicate better model performance" ], className="text-muted") ]) ], className="mt-3") def decimate_data(df: pd.DataFrame, max_points: int = 10000) -> pd.DataFrame: """ Reduce number of data points for visualization Args: df: Input DataFrame max_points: Maximum number of points to keep Returns: Decimated DataFrame """ if len(df) <= max_points: return df # Use systematic sampling step = len(df) // max_points return df.iloc[::step].reset_index(drop=True)