lz211 commited on
Commit
a25d0f6
Β·
verified Β·
1 Parent(s): a7cf454

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +423 -0
app.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SMI Volatility Forecast - Hugging Face Gradio App
3
+ LΓ€uft direkt auf Hugging Face Spaces
4
+ """
5
+
6
+ import gradio as gr
7
+ import yfinance as yf
8
+ import pandas as pd
9
+ import numpy as np
10
+ from datetime import datetime
11
+ import warnings
12
+ warnings.filterwarnings('ignore')
13
+
14
+ # FΓΌr Plots
15
+ import matplotlib.pyplot as plt
16
+ import io
17
+ from PIL import Image
18
+
19
+ class VolatilityForecaster:
20
+ def __init__(self, ticker, interval='5m', period='60d'):
21
+ self.ticker = ticker
22
+ self.interval = interval
23
+ self.period = period
24
+ self.data = None
25
+ self.returns = None
26
+ self.volatility = None
27
+
28
+ def fetch_data(self):
29
+ """Fetch data from Yahoo Finance"""
30
+ stock = yf.Ticker(self.ticker)
31
+ self.data = stock.history(period=self.period, interval=self.interval)
32
+
33
+ if self.data.empty:
34
+ raise ValueError(f"No data found for {self.ticker}")
35
+
36
+ return self.data
37
+
38
+ def calculate_volatility(self, window=20):
39
+ """Calculate rolling volatility from returns"""
40
+ self.returns = np.log(self.data['Close'] / self.data['Close'].shift(1))
41
+ periods_per_day = 78
42
+ periods_per_year = periods_per_day * 252
43
+
44
+ self.volatility = self.returns.rolling(window=window).std() * np.sqrt(periods_per_year)
45
+ self.volatility = self.volatility.dropna()
46
+
47
+ return self.volatility
48
+
49
+ def prepare_forecast_data(self, forecast_horizon=12):
50
+ """Prepare data for forecasting"""
51
+ train_size = int(len(self.volatility) * 0.8)
52
+
53
+ train_data = self.volatility.iloc[:train_size].values
54
+ test_data = self.volatility.iloc[train_size:train_size+forecast_horizon].values
55
+ test_dates = self.volatility.index[train_size:train_size+forecast_horizon]
56
+
57
+ return train_data, test_data, test_dates
58
+
59
+
60
+ class ModelComparison:
61
+ def __init__(self, train_data, test_data, test_dates, forecast_horizon=12):
62
+ self.train_data = train_data
63
+ self.test_data = test_data
64
+ self.test_dates = test_dates
65
+ self.forecast_horizon = forecast_horizon
66
+ self.results = {}
67
+
68
+ def forecast_chronos(self):
69
+ """Chronos-Modell von Amazon"""
70
+ try:
71
+ from chronos import ChronosPipeline
72
+ import torch
73
+
74
+ pipeline = ChronosPipeline.from_pretrained(
75
+ "amazon/chronos-t5-small",
76
+ device_map="cpu",
77
+ torch_dtype=torch.bfloat16,
78
+ )
79
+
80
+ context = torch.tensor(self.train_data[-100:])
81
+ forecast = pipeline.predict(
82
+ context=context,
83
+ prediction_length=self.forecast_horizon,
84
+ num_samples=20
85
+ )
86
+
87
+ forecast_median = np.median(forecast[0].numpy(), axis=0)
88
+
89
+ self.results['Chronos'] = {
90
+ 'forecast': forecast_median,
91
+ 'actual': self.test_data,
92
+ 'dates': self.test_dates
93
+ }
94
+
95
+ return True
96
+
97
+ except Exception as e:
98
+ print(f"Chronos failed: {str(e)}")
99
+ return False
100
+
101
+ def forecast_moirai(self):
102
+ """Moirai-Modell"""
103
+ try:
104
+ from uni2ts.model.moirai import MoiraiForecast
105
+
106
+ model = MoiraiForecast.load_from_checkpoint(
107
+ checkpoint_path="Salesforce/moirai-1.0-R-small",
108
+ map_location="cpu"
109
+ )
110
+
111
+ forecast = model.forecast(
112
+ past_data=self.train_data[-512:],
113
+ prediction_length=self.forecast_horizon
114
+ )
115
+
116
+ self.results['Moirai'] = {
117
+ 'forecast': forecast.mean().numpy(),
118
+ 'actual': self.test_data,
119
+ 'dates': self.test_dates
120
+ }
121
+
122
+ return True
123
+
124
+ except Exception as e:
125
+ print(f"Moirai failed: {str(e)}")
126
+ return False
127
+
128
+ def forecast_moment(self):
129
+ """MOMENT-Modell"""
130
+ try:
131
+ from momentfm import MOMENTPipeline
132
+
133
+ model = MOMENTPipeline.from_pretrained(
134
+ "AutonLab/MOMENT-1-large",
135
+ model_kwargs={'task_name': 'forecasting'}
136
+ )
137
+
138
+ context = self.train_data[-512:].reshape(1, -1)
139
+ forecast = model(context, output_length=self.forecast_horizon)
140
+
141
+ self.results['MOMENT'] = {
142
+ 'forecast': forecast[0],
143
+ 'actual': self.test_data,
144
+ 'dates': self.test_dates
145
+ }
146
+
147
+ return True
148
+
149
+ except Exception as e:
150
+ print(f"MOMENT failed: {str(e)}")
151
+ return False
152
+
153
+ def forecast_timesfm(self):
154
+ """TimesFM-Modell"""
155
+ try:
156
+ import timesfm
157
+
158
+ tfm = timesfm.TimesFm(
159
+ context_len=512,
160
+ horizon_len=self.forecast_horizon,
161
+ input_patch_len=32,
162
+ output_patch_len=128,
163
+ )
164
+ tfm.load_from_checkpoint()
165
+
166
+ forecast = tfm.forecast(
167
+ inputs=[self.train_data[-512:]],
168
+ freq=[0]
169
+ )
170
+
171
+ self.results['TimesFM'] = {
172
+ 'forecast': forecast[0],
173
+ 'actual': self.test_data,
174
+ 'dates': self.test_dates
175
+ }
176
+
177
+ return True
178
+
179
+ except Exception as e:
180
+ print(f"TimesFM failed: {str(e)}")
181
+ return False
182
+
183
+ def calculate_metrics(self):
184
+ """Calculate comprehensive performance metrics"""
185
+ metrics_df = []
186
+
187
+ for model_name, result in self.results.items():
188
+ if result is None:
189
+ continue
190
+
191
+ forecast = result['forecast']
192
+ actual = result['actual']
193
+
194
+ mae = np.mean(np.abs(forecast - actual))
195
+ rmse = np.sqrt(np.mean((forecast - actual)**2))
196
+ mape = np.mean(np.abs((actual - forecast) / (actual + 1e-10))) * 100
197
+
198
+ if len(actual) > 1:
199
+ actual_direction = np.sign(np.diff(actual))
200
+ forecast_direction = np.sign(np.diff(forecast))
201
+ directional_accuracy = np.mean(actual_direction == forecast_direction) * 100
202
+ else:
203
+ directional_accuracy = 0
204
+
205
+ ss_res = np.sum((actual - forecast)**2)
206
+ ss_tot = np.sum((actual - np.mean(actual))**2)
207
+ r2 = 1 - (ss_res / (ss_tot + 1e-10))
208
+
209
+ metrics_df.append({
210
+ 'Model': model_name,
211
+ 'MAE': mae,
212
+ 'RMSE': rmse,
213
+ 'MAPE (%)': mape,
214
+ 'RΒ²': r2,
215
+ 'Dir. Acc. (%)': directional_accuracy
216
+ })
217
+
218
+ return pd.DataFrame(metrics_df)
219
+
220
+ def run_all_forecasts(self):
221
+ """Run all model forecasts"""
222
+ success_count = 0
223
+
224
+ if self.forecast_chronos():
225
+ success_count += 1
226
+ if self.forecast_moirai():
227
+ success_count += 1
228
+ if self.forecast_moment():
229
+ success_count += 1
230
+ if self.forecast_timesfm():
231
+ success_count += 1
232
+
233
+ return self.calculate_metrics(), success_count
234
+
235
+
236
+ def create_plot(comparison, stock_name):
237
+ """Create visualization"""
238
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
239
+
240
+ colors = {'Chronos': 'red', 'Moirai': 'blue', 'MOMENT': 'green', 'TimesFM': 'orange'}
241
+
242
+ # Plot 1: Forecasts
243
+ for model_name, result in comparison.results.items():
244
+ if result is not None:
245
+ ax1.plot(result['dates'], result['actual'], 'k-',
246
+ linewidth=2.5, label='Actual', marker='o')
247
+ break
248
+
249
+ for model_name, result in comparison.results.items():
250
+ if result is not None:
251
+ ax1.plot(result['dates'], result['forecast'],
252
+ color=colors.get(model_name, 'gray'),
253
+ linestyle='--', linewidth=2,
254
+ label=f'{model_name}', marker='x')
255
+
256
+ ax1.set_xlabel('Time')
257
+ ax1.set_ylabel('Volatility (annualized)')
258
+ ax1.set_title(f'{stock_name} - Volatility Forecast')
259
+ ax1.legend()
260
+ ax1.grid(True, alpha=0.3)
261
+ ax1.tick_params(axis='x', rotation=45)
262
+
263
+ # Plot 2: Metrics
264
+ metrics_df = comparison.calculate_metrics()
265
+ if not metrics_df.empty:
266
+ models = metrics_df['Model'].tolist()
267
+ mae_values = metrics_df['MAE'].tolist()
268
+ rmse_values = metrics_df['RMSE'].tolist()
269
+
270
+ x = np.arange(len(models))
271
+ width = 0.35
272
+
273
+ ax2.bar(x - width/2, mae_values, width, label='MAE', alpha=0.8)
274
+ ax2.bar(x + width/2, rmse_values, width, label='RMSE', alpha=0.8)
275
+
276
+ ax2.set_xlabel('Model')
277
+ ax2.set_ylabel('Error')
278
+ ax2.set_title(f'{stock_name} - MAE & RMSE Comparison')
279
+ ax2.set_xticks(x)
280
+ ax2.set_xticklabels(models, rotation=45)
281
+ ax2.legend()
282
+ ax2.grid(True, alpha=0.3, axis='y')
283
+
284
+ plt.tight_layout()
285
+
286
+ # Convert to image
287
+ buf = io.BytesIO()
288
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
289
+ buf.seek(0)
290
+ img = Image.open(buf)
291
+ plt.close()
292
+
293
+ return img
294
+
295
+
296
+ def run_forecast(stock_ticker, forecast_minutes):
297
+ """Main function for Gradio interface"""
298
+ try:
299
+ forecast_horizon = forecast_minutes // 5 # Convert to 5-min periods
300
+
301
+ status = f"πŸš€ Starting forecast for {stock_ticker}...\n\n"
302
+
303
+ # Fetch data
304
+ status += "πŸ“₯ Fetching data from Yahoo Finance...\n"
305
+ forecaster = VolatilityForecaster(ticker=stock_ticker, interval='5m', period='60d')
306
+ forecaster.fetch_data()
307
+
308
+ status += f"βœ… Downloaded {len(forecaster.data)} data points\n"
309
+ status += f"πŸ“… Date range: {forecaster.data.index[0]} to {forecaster.data.index[-1]}\n\n"
310
+
311
+ # Calculate volatility
312
+ status += "πŸ“Š Calculating volatility...\n"
313
+ forecaster.calculate_volatility(window=20)
314
+
315
+ train_data, test_data, test_dates = forecaster.prepare_forecast_data(
316
+ forecast_horizon=forecast_horizon
317
+ )
318
+
319
+ status += f"πŸ“Š Training data points: {len(train_data)}\n"
320
+ status += f"πŸ“Š Test data points: {len(test_data)}\n"
321
+ status += f"πŸ“… Test period: {test_dates[0]} to {test_dates[-1]}\n\n"
322
+
323
+ # Run forecasts
324
+ status += "πŸ€– Running model forecasts...\n\n"
325
+ comparison = ModelComparison(train_data, test_data, test_dates, forecast_horizon)
326
+ metrics_df, success_count = comparison.run_all_forecasts()
327
+
328
+ status += f"βœ… Successfully ran {success_count}/4 models\n\n"
329
+
330
+ # Create plot
331
+ plot_img = create_plot(comparison, stock_ticker)
332
+
333
+ # Format results
334
+ if not metrics_df.empty:
335
+ metrics_str = metrics_df.to_string(index=False)
336
+
337
+ best_rmse = metrics_df.loc[metrics_df['RMSE'].idxmin(), 'Model']
338
+ best_r2 = metrics_df.loc[metrics_df['RΒ²'].idxmax(), 'Model']
339
+
340
+ status += "="*60 + "\n"
341
+ status += "πŸ“Š RESULTS\n"
342
+ status += "="*60 + "\n\n"
343
+ status += metrics_str + "\n\n"
344
+ status += f"πŸ† Best Model (RMSE): {best_rmse}\n"
345
+ status += f"πŸ† Best Model (RΒ²): {best_r2}\n"
346
+ else:
347
+ status += "❌ No models completed successfully\n"
348
+ plot_img = None
349
+
350
+ return status, plot_img, metrics_df
351
+
352
+ except Exception as e:
353
+ return f"❌ Error: {str(e)}", None, None
354
+
355
+
356
+ # Gradio Interface
357
+ with gr.Blocks(title="SMI Volatility Forecast") as demo:
358
+ gr.Markdown("""
359
+ # πŸ“Š SMI Volatility Forecast - Model Comparison
360
+
361
+ Compare **Chronos, Moirai, MOMENT, and TimesFM** foundation models for volatility forecasting.
362
+
363
+ This app uses 5-minute data from Yahoo Finance (max 60 days) to predict volatility.
364
+ """)
365
+
366
+ with gr.Row():
367
+ with gr.Column(scale=1):
368
+ stock_input = gr.Dropdown(
369
+ choices=['NESN.SW', 'NOVN.SW', 'ROG.SW', 'UBSG.SW', 'ABBN.SW'],
370
+ value='NESN.SW',
371
+ label="πŸ“ˆ Select SMI Stock"
372
+ )
373
+
374
+ forecast_input = gr.Slider(
375
+ minimum=30,
376
+ maximum=120,
377
+ value=60,
378
+ step=30,
379
+ label="⏱️ Forecast Horizon (minutes)"
380
+ )
381
+
382
+ run_button = gr.Button("πŸš€ Run Forecast", variant="primary")
383
+
384
+ with gr.Column(scale=2):
385
+ status_output = gr.Textbox(
386
+ label="πŸ“‹ Status & Results",
387
+ lines=20,
388
+ max_lines=30
389
+ )
390
+
391
+ with gr.Row():
392
+ plot_output = gr.Image(label="πŸ“Š Visualization")
393
+
394
+ with gr.Row():
395
+ metrics_output = gr.Dataframe(
396
+ label="πŸ“ˆ Detailed Metrics",
397
+ headers=["Model", "MAE", "RMSE", "MAPE (%)", "RΒ²", "Dir. Acc. (%)"]
398
+ )
399
+
400
+ run_button.click(
401
+ fn=run_forecast,
402
+ inputs=[stock_input, forecast_input],
403
+ outputs=[status_output, plot_output, metrics_output]
404
+ )
405
+
406
+ gr.Markdown("""
407
+ ## πŸ“– How it works
408
+
409
+ 1. **Data Collection**: Fetches 5-minute historical data (60 days max from Yahoo Finance)
410
+ 2. **Volatility Calculation**: Computes rolling volatility from log returns
411
+ 3. **Train/Test Split**: 80% training, 20% testing (out-of-sample validation)
412
+ 4. **Model Forecasting**: Runs 4 foundation models in parallel
413
+ 5. **Evaluation**: Compares models using MAE, RMSE, MAPE, RΒ², and Directional Accuracy
414
+
415
+ ### πŸ† Metrics Explained
416
+ - **MAE/RMSE**: Error measures (lower is better)
417
+ - **MAPE**: Percentage error (lower is better)
418
+ - **RΒ²**: Explained variance 0-1 (higher is better, >0.5 is good)
419
+ - **Directional Accuracy**: Trend prediction accuracy (>50% beats random)
420
+ """)
421
+
422
+ if __name__ == "__main__":
423
+ demo.launch()