Spaces:
Build error
Build error
| """ | |
| SMI Volatility Forecast - Hugging Face Gradio App | |
| LΓ€uft direkt auf Hugging Face Spaces | |
| """ | |
| import gradio as gr | |
| import yfinance as yf | |
| import pandas as pd | |
| import numpy as np | |
| from datetime import datetime | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # FΓΌr Plots | |
| import matplotlib.pyplot as plt | |
| import io | |
| from PIL import Image | |
| class VolatilityForecaster: | |
| def __init__(self, ticker, interval='5m', period='60d'): | |
| self.ticker = ticker | |
| self.interval = interval | |
| self.period = period | |
| self.data = None | |
| self.returns = None | |
| self.volatility = None | |
| def fetch_data(self): | |
| """Fetch data from Yahoo Finance""" | |
| stock = yf.Ticker(self.ticker) | |
| self.data = stock.history(period=self.period, interval=self.interval) | |
| if self.data.empty: | |
| raise ValueError(f"No data found for {self.ticker}") | |
| return self.data | |
| def calculate_volatility(self, window=20): | |
| """Calculate rolling volatility from returns""" | |
| self.returns = np.log(self.data['Close'] / self.data['Close'].shift(1)) | |
| periods_per_day = 78 | |
| periods_per_year = periods_per_day * 252 | |
| self.volatility = self.returns.rolling(window=window).std() * np.sqrt(periods_per_year) | |
| self.volatility = self.volatility.dropna() | |
| return self.volatility | |
| def prepare_forecast_data(self, forecast_horizon=12): | |
| """Prepare data for forecasting""" | |
| train_size = int(len(self.volatility) * 0.8) | |
| train_data = self.volatility.iloc[:train_size].values | |
| test_data = self.volatility.iloc[train_size:train_size+forecast_horizon].values | |
| test_dates = self.volatility.index[train_size:train_size+forecast_horizon] | |
| return train_data, test_data, test_dates | |
| class ModelComparison: | |
| def __init__(self, train_data, test_data, test_dates, forecast_horizon=12): | |
| self.train_data = train_data | |
| self.test_data = test_data | |
| self.test_dates = test_dates | |
| self.forecast_horizon = forecast_horizon | |
| self.results = {} | |
| def forecast_chronos(self): | |
| """Chronos-Modell von Amazon""" | |
| try: | |
| from chronos import ChronosPipeline | |
| import torch | |
| pipeline = ChronosPipeline.from_pretrained( | |
| "amazon/chronos-t5-small", | |
| device_map="cpu", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| context = torch.tensor(self.train_data[-100:]) | |
| forecast = pipeline.predict( | |
| context=context, | |
| prediction_length=self.forecast_horizon, | |
| num_samples=20 | |
| ) | |
| forecast_median = np.median(forecast[0].numpy(), axis=0) | |
| self.results['Chronos'] = { | |
| 'forecast': forecast_median, | |
| 'actual': self.test_data, | |
| 'dates': self.test_dates | |
| } | |
| return True | |
| except Exception as e: | |
| print(f"Chronos failed: {str(e)}") | |
| return False | |
| def forecast_moirai(self): | |
| """Moirai-Modell""" | |
| try: | |
| from uni2ts.model.moirai import MoiraiForecast | |
| model = MoiraiForecast.load_from_checkpoint( | |
| checkpoint_path="Salesforce/moirai-1.0-R-small", | |
| map_location="cpu" | |
| ) | |
| forecast = model.forecast( | |
| past_data=self.train_data[-512:], | |
| prediction_length=self.forecast_horizon | |
| ) | |
| self.results['Moirai'] = { | |
| 'forecast': forecast.mean().numpy(), | |
| 'actual': self.test_data, | |
| 'dates': self.test_dates | |
| } | |
| return True | |
| except Exception as e: | |
| print(f"Moirai failed: {str(e)}") | |
| return False | |
| def forecast_moment(self): | |
| """MOMENT-Modell""" | |
| try: | |
| from momentfm import MOMENTPipeline | |
| model = MOMENTPipeline.from_pretrained( | |
| "AutonLab/MOMENT-1-large", | |
| model_kwargs={'task_name': 'forecasting'} | |
| ) | |
| context = self.train_data[-512:].reshape(1, -1) | |
| forecast = model(context, output_length=self.forecast_horizon) | |
| self.results['MOMENT'] = { | |
| 'forecast': forecast[0], | |
| 'actual': self.test_data, | |
| 'dates': self.test_dates | |
| } | |
| return True | |
| except Exception as e: | |
| print(f"MOMENT failed: {str(e)}") | |
| return False | |
| def forecast_timesfm(self): | |
| """TimesFM-Modell""" | |
| try: | |
| import timesfm | |
| tfm = timesfm.TimesFm( | |
| context_len=512, | |
| horizon_len=self.forecast_horizon, | |
| input_patch_len=32, | |
| output_patch_len=128, | |
| ) | |
| tfm.load_from_checkpoint() | |
| forecast = tfm.forecast( | |
| inputs=[self.train_data[-512:]], | |
| freq=[0] | |
| ) | |
| self.results['TimesFM'] = { | |
| 'forecast': forecast[0], | |
| 'actual': self.test_data, | |
| 'dates': self.test_dates | |
| } | |
| return True | |
| except Exception as e: | |
| print(f"TimesFM failed: {str(e)}") | |
| return False | |
| def calculate_metrics(self): | |
| """Calculate comprehensive performance metrics""" | |
| metrics_df = [] | |
| for model_name, result in self.results.items(): | |
| if result is None: | |
| continue | |
| forecast = result['forecast'] | |
| actual = result['actual'] | |
| mae = np.mean(np.abs(forecast - actual)) | |
| rmse = np.sqrt(np.mean((forecast - actual)**2)) | |
| mape = np.mean(np.abs((actual - forecast) / (actual + 1e-10))) * 100 | |
| if len(actual) > 1: | |
| actual_direction = np.sign(np.diff(actual)) | |
| forecast_direction = np.sign(np.diff(forecast)) | |
| directional_accuracy = np.mean(actual_direction == forecast_direction) * 100 | |
| else: | |
| directional_accuracy = 0 | |
| ss_res = np.sum((actual - forecast)**2) | |
| ss_tot = np.sum((actual - np.mean(actual))**2) | |
| r2 = 1 - (ss_res / (ss_tot + 1e-10)) | |
| metrics_df.append({ | |
| 'Model': model_name, | |
| 'MAE': mae, | |
| 'RMSE': rmse, | |
| 'MAPE (%)': mape, | |
| 'RΒ²': r2, | |
| 'Dir. Acc. (%)': directional_accuracy | |
| }) | |
| return pd.DataFrame(metrics_df) | |
| def run_all_forecasts(self): | |
| """Run all model forecasts""" | |
| success_count = 0 | |
| if self.forecast_chronos(): | |
| success_count += 1 | |
| if self.forecast_moirai(): | |
| success_count += 1 | |
| if self.forecast_moment(): | |
| success_count += 1 | |
| if self.forecast_timesfm(): | |
| success_count += 1 | |
| return self.calculate_metrics(), success_count | |
| def create_plot(comparison, stock_name): | |
| """Create visualization""" | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) | |
| colors = {'Chronos': 'red', 'Moirai': 'blue', 'MOMENT': 'green', 'TimesFM': 'orange'} | |
| # Plot 1: Forecasts | |
| for model_name, result in comparison.results.items(): | |
| if result is not None: | |
| ax1.plot(result['dates'], result['actual'], 'k-', | |
| linewidth=2.5, label='Actual', marker='o') | |
| break | |
| for model_name, result in comparison.results.items(): | |
| if result is not None: | |
| ax1.plot(result['dates'], result['forecast'], | |
| color=colors.get(model_name, 'gray'), | |
| linestyle='--', linewidth=2, | |
| label=f'{model_name}', marker='x') | |
| ax1.set_xlabel('Time') | |
| ax1.set_ylabel('Volatility (annualized)') | |
| ax1.set_title(f'{stock_name} - Volatility Forecast') | |
| ax1.legend() | |
| ax1.grid(True, alpha=0.3) | |
| ax1.tick_params(axis='x', rotation=45) | |
| # Plot 2: Metrics | |
| metrics_df = comparison.calculate_metrics() | |
| if not metrics_df.empty: | |
| models = metrics_df['Model'].tolist() | |
| mae_values = metrics_df['MAE'].tolist() | |
| rmse_values = metrics_df['RMSE'].tolist() | |
| x = np.arange(len(models)) | |
| width = 0.35 | |
| ax2.bar(x - width/2, mae_values, width, label='MAE', alpha=0.8) | |
| ax2.bar(x + width/2, rmse_values, width, label='RMSE', alpha=0.8) | |
| ax2.set_xlabel('Model') | |
| ax2.set_ylabel('Error') | |
| ax2.set_title(f'{stock_name} - MAE & RMSE Comparison') | |
| ax2.set_xticks(x) | |
| ax2.set_xticklabels(models, rotation=45) | |
| ax2.legend() | |
| ax2.grid(True, alpha=0.3, axis='y') | |
| plt.tight_layout() | |
| # Convert to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close() | |
| return img | |
| def run_forecast(stock_ticker, forecast_minutes): | |
| """Main function for Gradio interface""" | |
| try: | |
| forecast_horizon = forecast_minutes // 5 # Convert to 5-min periods | |
| status = f"π Starting forecast for {stock_ticker}...\n\n" | |
| # Fetch data | |
| status += "π₯ Fetching data from Yahoo Finance...\n" | |
| forecaster = VolatilityForecaster(ticker=stock_ticker, interval='5m', period='60d') | |
| forecaster.fetch_data() | |
| status += f"β Downloaded {len(forecaster.data)} data points\n" | |
| status += f"π Date range: {forecaster.data.index[0]} to {forecaster.data.index[-1]}\n\n" | |
| # Calculate volatility | |
| status += "π Calculating volatility...\n" | |
| forecaster.calculate_volatility(window=20) | |
| train_data, test_data, test_dates = forecaster.prepare_forecast_data( | |
| forecast_horizon=forecast_horizon | |
| ) | |
| status += f"π Training data points: {len(train_data)}\n" | |
| status += f"π Test data points: {len(test_data)}\n" | |
| status += f"π Test period: {test_dates[0]} to {test_dates[-1]}\n\n" | |
| # Run forecasts | |
| status += "π€ Running model forecasts...\n\n" | |
| comparison = ModelComparison(train_data, test_data, test_dates, forecast_horizon) | |
| metrics_df, success_count = comparison.run_all_forecasts() | |
| status += f"β Successfully ran {success_count}/4 models\n\n" | |
| # Create plot | |
| plot_img = create_plot(comparison, stock_ticker) | |
| # Format results | |
| if not metrics_df.empty: | |
| metrics_str = metrics_df.to_string(index=False) | |
| best_rmse = metrics_df.loc[metrics_df['RMSE'].idxmin(), 'Model'] | |
| best_r2 = metrics_df.loc[metrics_df['RΒ²'].idxmax(), 'Model'] | |
| status += "="*60 + "\n" | |
| status += "π RESULTS\n" | |
| status += "="*60 + "\n\n" | |
| status += metrics_str + "\n\n" | |
| status += f"π Best Model (RMSE): {best_rmse}\n" | |
| status += f"π Best Model (RΒ²): {best_r2}\n" | |
| else: | |
| status += "β No models completed successfully\n" | |
| plot_img = None | |
| return status, plot_img, metrics_df | |
| except Exception as e: | |
| return f"β Error: {str(e)}", None, None | |
| # Gradio Interface | |
| with gr.Blocks(title="SMI Volatility Forecast") as demo: | |
| gr.Markdown(""" | |
| # π SMI Volatility Forecast - Model Comparison | |
| Compare **Chronos, Moirai, MOMENT, and TimesFM** foundation models for volatility forecasting. | |
| This app uses 5-minute data from Yahoo Finance (max 60 days) to predict volatility. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| stock_input = gr.Dropdown( | |
| choices=['NESN.SW', 'NOVN.SW', 'ROG.SW', 'UBSG.SW', 'ABBN.SW'], | |
| value='NESN.SW', | |
| label="π Select SMI Stock" | |
| ) | |
| forecast_input = gr.Slider( | |
| minimum=30, | |
| maximum=120, | |
| value=60, | |
| step=30, | |
| label="β±οΈ Forecast Horizon (minutes)" | |
| ) | |
| run_button = gr.Button("π Run Forecast", variant="primary") | |
| with gr.Column(scale=2): | |
| status_output = gr.Textbox( | |
| label="π Status & Results", | |
| lines=20, | |
| max_lines=30 | |
| ) | |
| with gr.Row(): | |
| plot_output = gr.Image(label="π Visualization") | |
| with gr.Row(): | |
| metrics_output = gr.Dataframe( | |
| label="π Detailed Metrics", | |
| headers=["Model", "MAE", "RMSE", "MAPE (%)", "RΒ²", "Dir. Acc. (%)"] | |
| ) | |
| run_button.click( | |
| fn=run_forecast, | |
| inputs=[stock_input, forecast_input], | |
| outputs=[status_output, plot_output, metrics_output] | |
| ) | |
| gr.Markdown(""" | |
| ## π How it works | |
| 1. **Data Collection**: Fetches 5-minute historical data (60 days max from Yahoo Finance) | |
| 2. **Volatility Calculation**: Computes rolling volatility from log returns | |
| 3. **Train/Test Split**: 80% training, 20% testing (out-of-sample validation) | |
| 4. **Model Forecasting**: Runs 4 foundation models in parallel | |
| 5. **Evaluation**: Compares models using MAE, RMSE, MAPE, RΒ², and Directional Accuracy | |
| ### π Metrics Explained | |
| - **MAE/RMSE**: Error measures (lower is better) | |
| - **MAPE**: Percentage error (lower is better) | |
| - **RΒ²**: Explained variance 0-1 (higher is better, >0.5 is good) | |
| - **Directional Accuracy**: Trend prediction accuracy (>50% beats random) | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |