smi_forecast / app.py
lz211's picture
Create app.py
a25d0f6 verified
"""
SMI Volatility Forecast - Hugging Face Gradio App
LΓ€uft direkt auf Hugging Face Spaces
"""
import gradio as gr
import yfinance as yf
import pandas as pd
import numpy as np
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')
# FΓΌr Plots
import matplotlib.pyplot as plt
import io
from PIL import Image
class VolatilityForecaster:
def __init__(self, ticker, interval='5m', period='60d'):
self.ticker = ticker
self.interval = interval
self.period = period
self.data = None
self.returns = None
self.volatility = None
def fetch_data(self):
"""Fetch data from Yahoo Finance"""
stock = yf.Ticker(self.ticker)
self.data = stock.history(period=self.period, interval=self.interval)
if self.data.empty:
raise ValueError(f"No data found for {self.ticker}")
return self.data
def calculate_volatility(self, window=20):
"""Calculate rolling volatility from returns"""
self.returns = np.log(self.data['Close'] / self.data['Close'].shift(1))
periods_per_day = 78
periods_per_year = periods_per_day * 252
self.volatility = self.returns.rolling(window=window).std() * np.sqrt(periods_per_year)
self.volatility = self.volatility.dropna()
return self.volatility
def prepare_forecast_data(self, forecast_horizon=12):
"""Prepare data for forecasting"""
train_size = int(len(self.volatility) * 0.8)
train_data = self.volatility.iloc[:train_size].values
test_data = self.volatility.iloc[train_size:train_size+forecast_horizon].values
test_dates = self.volatility.index[train_size:train_size+forecast_horizon]
return train_data, test_data, test_dates
class ModelComparison:
def __init__(self, train_data, test_data, test_dates, forecast_horizon=12):
self.train_data = train_data
self.test_data = test_data
self.test_dates = test_dates
self.forecast_horizon = forecast_horizon
self.results = {}
def forecast_chronos(self):
"""Chronos-Modell von Amazon"""
try:
from chronos import ChronosPipeline
import torch
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-small",
device_map="cpu",
torch_dtype=torch.bfloat16,
)
context = torch.tensor(self.train_data[-100:])
forecast = pipeline.predict(
context=context,
prediction_length=self.forecast_horizon,
num_samples=20
)
forecast_median = np.median(forecast[0].numpy(), axis=0)
self.results['Chronos'] = {
'forecast': forecast_median,
'actual': self.test_data,
'dates': self.test_dates
}
return True
except Exception as e:
print(f"Chronos failed: {str(e)}")
return False
def forecast_moirai(self):
"""Moirai-Modell"""
try:
from uni2ts.model.moirai import MoiraiForecast
model = MoiraiForecast.load_from_checkpoint(
checkpoint_path="Salesforce/moirai-1.0-R-small",
map_location="cpu"
)
forecast = model.forecast(
past_data=self.train_data[-512:],
prediction_length=self.forecast_horizon
)
self.results['Moirai'] = {
'forecast': forecast.mean().numpy(),
'actual': self.test_data,
'dates': self.test_dates
}
return True
except Exception as e:
print(f"Moirai failed: {str(e)}")
return False
def forecast_moment(self):
"""MOMENT-Modell"""
try:
from momentfm import MOMENTPipeline
model = MOMENTPipeline.from_pretrained(
"AutonLab/MOMENT-1-large",
model_kwargs={'task_name': 'forecasting'}
)
context = self.train_data[-512:].reshape(1, -1)
forecast = model(context, output_length=self.forecast_horizon)
self.results['MOMENT'] = {
'forecast': forecast[0],
'actual': self.test_data,
'dates': self.test_dates
}
return True
except Exception as e:
print(f"MOMENT failed: {str(e)}")
return False
def forecast_timesfm(self):
"""TimesFM-Modell"""
try:
import timesfm
tfm = timesfm.TimesFm(
context_len=512,
horizon_len=self.forecast_horizon,
input_patch_len=32,
output_patch_len=128,
)
tfm.load_from_checkpoint()
forecast = tfm.forecast(
inputs=[self.train_data[-512:]],
freq=[0]
)
self.results['TimesFM'] = {
'forecast': forecast[0],
'actual': self.test_data,
'dates': self.test_dates
}
return True
except Exception as e:
print(f"TimesFM failed: {str(e)}")
return False
def calculate_metrics(self):
"""Calculate comprehensive performance metrics"""
metrics_df = []
for model_name, result in self.results.items():
if result is None:
continue
forecast = result['forecast']
actual = result['actual']
mae = np.mean(np.abs(forecast - actual))
rmse = np.sqrt(np.mean((forecast - actual)**2))
mape = np.mean(np.abs((actual - forecast) / (actual + 1e-10))) * 100
if len(actual) > 1:
actual_direction = np.sign(np.diff(actual))
forecast_direction = np.sign(np.diff(forecast))
directional_accuracy = np.mean(actual_direction == forecast_direction) * 100
else:
directional_accuracy = 0
ss_res = np.sum((actual - forecast)**2)
ss_tot = np.sum((actual - np.mean(actual))**2)
r2 = 1 - (ss_res / (ss_tot + 1e-10))
metrics_df.append({
'Model': model_name,
'MAE': mae,
'RMSE': rmse,
'MAPE (%)': mape,
'RΒ²': r2,
'Dir. Acc. (%)': directional_accuracy
})
return pd.DataFrame(metrics_df)
def run_all_forecasts(self):
"""Run all model forecasts"""
success_count = 0
if self.forecast_chronos():
success_count += 1
if self.forecast_moirai():
success_count += 1
if self.forecast_moment():
success_count += 1
if self.forecast_timesfm():
success_count += 1
return self.calculate_metrics(), success_count
def create_plot(comparison, stock_name):
"""Create visualization"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
colors = {'Chronos': 'red', 'Moirai': 'blue', 'MOMENT': 'green', 'TimesFM': 'orange'}
# Plot 1: Forecasts
for model_name, result in comparison.results.items():
if result is not None:
ax1.plot(result['dates'], result['actual'], 'k-',
linewidth=2.5, label='Actual', marker='o')
break
for model_name, result in comparison.results.items():
if result is not None:
ax1.plot(result['dates'], result['forecast'],
color=colors.get(model_name, 'gray'),
linestyle='--', linewidth=2,
label=f'{model_name}', marker='x')
ax1.set_xlabel('Time')
ax1.set_ylabel('Volatility (annualized)')
ax1.set_title(f'{stock_name} - Volatility Forecast')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.tick_params(axis='x', rotation=45)
# Plot 2: Metrics
metrics_df = comparison.calculate_metrics()
if not metrics_df.empty:
models = metrics_df['Model'].tolist()
mae_values = metrics_df['MAE'].tolist()
rmse_values = metrics_df['RMSE'].tolist()
x = np.arange(len(models))
width = 0.35
ax2.bar(x - width/2, mae_values, width, label='MAE', alpha=0.8)
ax2.bar(x + width/2, rmse_values, width, label='RMSE', alpha=0.8)
ax2.set_xlabel('Model')
ax2.set_ylabel('Error')
ax2.set_title(f'{stock_name} - MAE & RMSE Comparison')
ax2.set_xticks(x)
ax2.set_xticklabels(models, rotation=45)
ax2.legend()
ax2.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
# Convert to image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
img = Image.open(buf)
plt.close()
return img
def run_forecast(stock_ticker, forecast_minutes):
"""Main function for Gradio interface"""
try:
forecast_horizon = forecast_minutes // 5 # Convert to 5-min periods
status = f"πŸš€ Starting forecast for {stock_ticker}...\n\n"
# Fetch data
status += "πŸ“₯ Fetching data from Yahoo Finance...\n"
forecaster = VolatilityForecaster(ticker=stock_ticker, interval='5m', period='60d')
forecaster.fetch_data()
status += f"βœ… Downloaded {len(forecaster.data)} data points\n"
status += f"πŸ“… Date range: {forecaster.data.index[0]} to {forecaster.data.index[-1]}\n\n"
# Calculate volatility
status += "πŸ“Š Calculating volatility...\n"
forecaster.calculate_volatility(window=20)
train_data, test_data, test_dates = forecaster.prepare_forecast_data(
forecast_horizon=forecast_horizon
)
status += f"πŸ“Š Training data points: {len(train_data)}\n"
status += f"πŸ“Š Test data points: {len(test_data)}\n"
status += f"πŸ“… Test period: {test_dates[0]} to {test_dates[-1]}\n\n"
# Run forecasts
status += "πŸ€– Running model forecasts...\n\n"
comparison = ModelComparison(train_data, test_data, test_dates, forecast_horizon)
metrics_df, success_count = comparison.run_all_forecasts()
status += f"βœ… Successfully ran {success_count}/4 models\n\n"
# Create plot
plot_img = create_plot(comparison, stock_ticker)
# Format results
if not metrics_df.empty:
metrics_str = metrics_df.to_string(index=False)
best_rmse = metrics_df.loc[metrics_df['RMSE'].idxmin(), 'Model']
best_r2 = metrics_df.loc[metrics_df['RΒ²'].idxmax(), 'Model']
status += "="*60 + "\n"
status += "πŸ“Š RESULTS\n"
status += "="*60 + "\n\n"
status += metrics_str + "\n\n"
status += f"πŸ† Best Model (RMSE): {best_rmse}\n"
status += f"πŸ† Best Model (RΒ²): {best_r2}\n"
else:
status += "❌ No models completed successfully\n"
plot_img = None
return status, plot_img, metrics_df
except Exception as e:
return f"❌ Error: {str(e)}", None, None
# Gradio Interface
with gr.Blocks(title="SMI Volatility Forecast") as demo:
gr.Markdown("""
# πŸ“Š SMI Volatility Forecast - Model Comparison
Compare **Chronos, Moirai, MOMENT, and TimesFM** foundation models for volatility forecasting.
This app uses 5-minute data from Yahoo Finance (max 60 days) to predict volatility.
""")
with gr.Row():
with gr.Column(scale=1):
stock_input = gr.Dropdown(
choices=['NESN.SW', 'NOVN.SW', 'ROG.SW', 'UBSG.SW', 'ABBN.SW'],
value='NESN.SW',
label="πŸ“ˆ Select SMI Stock"
)
forecast_input = gr.Slider(
minimum=30,
maximum=120,
value=60,
step=30,
label="⏱️ Forecast Horizon (minutes)"
)
run_button = gr.Button("πŸš€ Run Forecast", variant="primary")
with gr.Column(scale=2):
status_output = gr.Textbox(
label="πŸ“‹ Status & Results",
lines=20,
max_lines=30
)
with gr.Row():
plot_output = gr.Image(label="πŸ“Š Visualization")
with gr.Row():
metrics_output = gr.Dataframe(
label="πŸ“ˆ Detailed Metrics",
headers=["Model", "MAE", "RMSE", "MAPE (%)", "RΒ²", "Dir. Acc. (%)"]
)
run_button.click(
fn=run_forecast,
inputs=[stock_input, forecast_input],
outputs=[status_output, plot_output, metrics_output]
)
gr.Markdown("""
## πŸ“– How it works
1. **Data Collection**: Fetches 5-minute historical data (60 days max from Yahoo Finance)
2. **Volatility Calculation**: Computes rolling volatility from log returns
3. **Train/Test Split**: 80% training, 20% testing (out-of-sample validation)
4. **Model Forecasting**: Runs 4 foundation models in parallel
5. **Evaluation**: Compares models using MAE, RMSE, MAPE, RΒ², and Directional Accuracy
### πŸ† Metrics Explained
- **MAE/RMSE**: Error measures (lower is better)
- **MAPE**: Percentage error (lower is better)
- **RΒ²**: Explained variance 0-1 (higher is better, >0.5 is good)
- **Directional Accuracy**: Trend prediction accuracy (>50% beats random)
""")
if __name__ == "__main__":
demo.launch()