"""
Chart generation for forecast visualization
"""
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import pandas as pd
from typing import List
from config.constants import COLORS, CHART_CONFIG
def create_forecast_chart(
historical_data: pd.DataFrame,
forecast_data: pd.DataFrame,
confidence_levels: List[int],
title: str = "Time Series Forecast",
y_axis_label: str = "Value",
backtest_data: pd.DataFrame = None
) -> go.Figure:
"""
Create an interactive forecast chart with confidence intervals
Args:
historical_data: DataFrame with columns ['ds', 'y']
forecast_data: DataFrame with forecast and confidence intervals
confidence_levels: List of confidence levels to plot
title: Chart title
y_axis_label: Label for y-axis (variable name being forecasted)
backtest_data: Optional DataFrame with backtest results
Returns:
Plotly figure
"""
fig = go.Figure()
# Add historical data
fig.add_trace(go.Scatter(
x=historical_data['ds'],
y=historical_data['y'],
mode='lines',
name='Historical',
line=dict(color=COLORS['historical'], width=2),
hovertemplate=f'Date: %{{x}}
{y_axis_label}: %{{y:.2f}}'
))
# Add backtest data if provided (shows model performance on historical data)
if backtest_data is not None and len(backtest_data) > 0:
# Add actual values from backtest period
fig.add_trace(go.Scatter(
x=backtest_data['timestamp'],
y=backtest_data['actual'],
mode='lines',
name='Backtest Actual',
line=dict(color='rgba(100, 100, 100, 0.6)', width=2, dash='dot'),
hovertemplate=f'Date: %{{x}}
{y_axis_label} (Actual): %{{y:.2f}}'
))
# Add predicted values from backtest period
fig.add_trace(go.Scatter(
x=backtest_data['timestamp'],
y=backtest_data['predicted'],
mode='lines',
name='Backtest Predicted',
line=dict(color='rgba(255, 100, 100, 0.8)', width=2),
hovertemplate=f'Date: %{{x}}
{y_axis_label} (Predicted): %{{y:.2f}}'
))
# Add confidence bands (from widest to narrowest)
for cl in sorted(confidence_levels, reverse=True):
lower_col = f'lower_{cl}'
upper_col = f'upper_{cl}'
if lower_col in forecast_data.columns and upper_col in forecast_data.columns:
# Add filled area for confidence interval
fig.add_trace(go.Scatter(
x=forecast_data['ds'].tolist() + forecast_data['ds'].tolist()[::-1],
y=forecast_data[upper_col].tolist() + forecast_data[lower_col].tolist()[::-1],
fill='toself',
fillcolor=COLORS['confidence'][cl],
line=dict(width=0),
name=f'{cl}% Confidence',
showlegend=True,
hoverinfo='skip'
))
# Add forecast line
fig.add_trace(go.Scatter(
x=forecast_data['ds'],
y=forecast_data['forecast'],
mode='lines',
name='Forecast',
line=dict(color=COLORS['forecast'], width=2),
hovertemplate=f'Date: %{{x}}
{y_axis_label} (Forecast): %{{y:.2f}}'
))
# Add vertical separator line
if len(historical_data) > 0:
last_historical_date = historical_data['ds'].iloc[-1]
# Use add_shape instead of add_vline to avoid Timestamp arithmetic issues
fig.add_shape(
type="line",
x0=last_historical_date,
x1=last_historical_date,
y0=0,
y1=1,
yref="paper",
line=dict(color=COLORS['separator'], dash="dash", width=1)
)
# Add annotation
fig.add_annotation(
x=last_historical_date,
y=1.0,
yref="paper",
text="Forecast Start",
showarrow=False,
yanchor="bottom"
)
# Update layout
fig.update_layout(
title=dict(text=title, x=0.5, xanchor='center'),
xaxis_title="Date",
yaxis_title=y_axis_label,
hovermode='x unified',
template='plotly_white',
height=700, # Increased height to accommodate rangeslider
showlegend=True,
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
),
margin=dict(l=50, r=50, t=80, b=150), # Increased bottom margin for larger rangeslider
xaxis=dict(
rangeslider=dict(
visible=True,
thickness=0.12 # Wider slider (12% of chart height)
),
type='date'
)
)
# Update config
fig.update_layout(
modebar_add=['v1hovermode', 'toggleSpikelines']
)
return fig
def create_empty_chart(message: str = "No data available") -> go.Figure:
"""
Create an empty placeholder chart
Args:
message: Message to display
Returns:
Plotly figure
"""
fig = go.Figure()
fig.add_annotation(
text=message,
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=20, color='gray')
)
fig.update_layout(
template='plotly_white',
height=600,
xaxis=dict(visible=False),
yaxis=dict(visible=False)
)
return fig
def create_metrics_display(metrics: dict, inference_time: float = None) -> list:
"""
Create metrics display components
Args:
metrics: Dictionary of metric values
inference_time: Time taken for inference in seconds
Returns:
List of Dash components
"""
import dash_bootstrap_components as dbc
from dash import html
metric_cards = []
# Add inference time if available
if inference_time is not None:
metric_cards.append(
dbc.Col([
dbc.Card([
dbc.CardBody([
html.H6("Inference Time", className="text-muted mb-2"),
html.H4(f"{inference_time:.2f}s")
])
], className="text-center")
], md=2)
)
# Add other metrics
metric_names = {
'MAE': 'Mean Absolute Error',
'RMSE': 'Root Mean Squared Error',
'MAPE': 'Mean Absolute % Error',
'R2': 'R-Squared'
}
for key, name in metric_names.items():
if key in metrics and metrics[key] is not None:
value = metrics[key]
if key in ['MAPE']:
formatted_value = f"{value:.2f}%"
elif key == 'R2':
formatted_value = f"{value:.4f}"
else:
formatted_value = f"{value:.2f}"
metric_cards.append(
dbc.Col([
dbc.Card([
dbc.CardBody([
html.H6(name, className="text-muted mb-2"),
html.H4(formatted_value)
])
], className="text-center")
], md=2)
)
return metric_cards
def create_backtest_metrics_display(metrics: dict) -> list:
"""
Create backtest metrics display components
Args:
metrics: Dictionary of backtest metric values (MAE, RMSE, MAPE, R2)
Returns:
Dash component card
"""
import dash_bootstrap_components as dbc
from dash import html
return dbc.Card([
dbc.CardHeader([
html.I(className="fas fa-chart-bar me-2"),
html.Span("Backtest Performance Metrics", className="fw-bold")
]),
dbc.CardBody([
html.P("Model performance on historical data validation:", className="text-muted small mb-3"),
dbc.Row([
dbc.Col([
html.Small("MAE", className="text-muted"),
html.H5(f"{metrics.get('MAE', 0):.2f}", className="mb-0")
], md=3),
dbc.Col([
html.Small("RMSE", className="text-muted"),
html.H5(f"{metrics.get('RMSE', 0):.2f}", className="mb-0")
], md=3),
dbc.Col([
html.Small("MAPE", className="text-muted"),
html.H5(f"{metrics.get('MAPE', 0):.2f}%", className="mb-0")
], md=3),
dbc.Col([
html.Small("R²", className="text-muted"),
html.H5(f"{metrics.get('R2', 0):.4f}", className="mb-0")
], md=3),
]),
html.Hr(),
html.Small([
html.I(className="fas fa-info-circle me-1"),
"Lower MAE/RMSE/MAPE and higher R² (closer to 1.0) indicate better model performance"
], className="text-muted")
])
], className="mt-3")
def decimate_data(df: pd.DataFrame, max_points: int = 10000) -> pd.DataFrame:
"""
Reduce number of data points for visualization
Args:
df: Input DataFrame
max_points: Maximum number of points to keep
Returns:
Decimated DataFrame
"""
if len(df) <= max_points:
return df
# Use systematic sampling
step = len(df) // max_points
return df.iloc[::step].reset_index(drop=True)