Gil Stetler commited on
Commit
f92e274
·
1 Parent(s): 7ef12ac

remove calibration layer

Browse files
Files changed (1) hide show
  1. app.py +140 -98
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os, random
2
  from typing import Tuple
3
  import numpy as np
@@ -16,19 +17,17 @@ import pipeline_v2 as pipe2 # update_ticker_csv(...)
16
  # Config
17
  # --------------------
18
  MODEL_ID = "amazon/chronos-t5-large"
19
- PREDICTION_LENGTH = 60 # <-- was 30, now 60 to support affine split
20
- TEST_H = 30 # evaluate on last 30
21
- CAL_H = PREDICTION_LENGTH - TEST_H # = 30
22
- NUM_SAMPLES = 1
23
- RV_WINDOW = 20
24
- ANNUALIZE = True
25
  EPS = 1e-8
26
 
27
  # --------------------
28
  # Model load (once)
29
  # --------------------
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
- dtype = torch.bfloat16 if device == "cuda" else torch.float32
32
 
33
  pipe = ChronosPipeline.from_pretrained(
34
  MODEL_ID,
@@ -40,13 +39,22 @@ pipe = ChronosPipeline.from_pretrained(
40
  # Helpers
41
  # --------------------
42
  def _extract_close(df: pd.DataFrame) -> pd.Series:
 
 
 
 
 
 
43
  if isinstance(df.columns, pd.MultiIndex):
 
44
  for name in ["Adj Close", "Adj_Close", "adj close", "adj_close"]:
45
  if name in df.columns.get_level_values(0):
46
  sub = df.xs(name, axis=1, level=0)
 
47
  if sub.shape[1] > 1:
48
  sub = sub.iloc[:, 0]
49
  return pd.to_numeric(sub.squeeze(), errors="coerce").dropna()
 
50
  for name in ["Close", "close", "Price", "price"]:
51
  if name in df.columns.get_level_values(0):
52
  sub = df.xs(name, axis=1, level=0)
@@ -54,39 +62,45 @@ def _extract_close(df: pd.DataFrame) -> pd.Series:
54
  sub = sub.iloc[:, 0]
55
  return pd.to_numeric(sub.squeeze(), errors="coerce").dropna()
56
 
 
57
  mapping = {c.lower(): c for c in df.columns}
58
  for name in ["adj close", "adj_close", "close", "price"]:
59
  if name in mapping:
60
  col = df[mapping[name]]
61
  return pd.to_numeric(col, errors="coerce").dropna()
62
 
 
63
  num_cols = df.select_dtypes(include=[np.number]).columns
64
  if len(num_cols) == 0:
65
  raise gr.Error("No numeric price column found in downloaded data.")
66
  return pd.Series(df[num_cols[-1]]).astype(float)
67
 
68
- def _ensure_datetime_index(df: pd.DataFrame) -> pd.DataFrame:
 
 
69
  if isinstance(df.index, pd.DatetimeIndex):
70
- out = df.copy()
71
- out.index.name = "Date"
72
- return out
73
- if "Date" in df.columns:
74
- out = df.copy()
75
- out["Date"] = pd.to_datetime(out["Date"], errors="coerce")
76
- out = out.set_index("Date")
77
- out.index.name = "Date"
78
- return out
79
- out = df.copy()
80
- out.index = pd.to_datetime(out.index, errors="coerce")
81
- out.index.name = "Date"
82
- return out
83
 
84
  def compute_realized_vol(close: pd.Series, window: int = 20, annualize: bool = True) -> pd.Series:
85
  r = np.log(close).diff().dropna()
86
  rv = r.rolling(window, min_periods=window).std()
87
  if annualize:
88
- rv = rv * np.sqrt(252)
89
- return rv.dropna()
 
 
 
 
90
 
91
  def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict:
92
  err = y_pred - y_true
@@ -96,133 +110,161 @@ def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict:
96
  rmse = float(np.sqrt(np.mean(err**2)))
97
  return {"MAPE": mape, "MPE": mpe, "RMSE": rmse}
98
 
99
- # ✅ NEW affine calibration
100
- def affine_calibration(y_true: np.ndarray, y_pred: np.ndarray) -> tuple[float, float, np.ndarray]:
101
- X = np.vstack([y_pred, np.ones_like(y_pred)]).T
102
- sol, *_ = np.linalg.lstsq(X, y_true, rcond=None)
103
- a, b = float(sol[0]), float(sol[1])
104
- return a, b, a * y_pred + b
105
-
106
  # --------------------
107
  # Core routine
108
  # --------------------
109
- def run_for_ticker(tickers: str, start: str, interval: str, dummy_flag: bool):
 
 
 
 
 
 
110
  tick_list = [t.strip() for t in tickers.replace(";", ",").replace("|", ",").split(",") if t.strip()]
111
  if not tick_list:
112
- raise gr.Error("Please enter a ticker like AAPL or BMW.DE")
113
- ticker = tick_list[0]
114
 
 
 
 
115
  try:
116
  csv_path = pipe2.update_ticker_csv(ticker, start=start, interval=interval)
117
  except Exception as e:
118
- raise gr.Error(f"Data fetch failed for '{ticker}'.\n{e}")
 
 
119
 
 
120
  try:
121
  df = pd.read_csv(csv_path, index_col=0, parse_dates=True)
 
 
 
122
  except Exception:
123
  df = pd.read_csv(csv_path)
124
 
125
- df = _ensure_datetime_index(df)
126
  close = _extract_close(df)
127
- rv = compute_realized_vol(close, window=RV_WINDOW, annualize=ANNUALIZE)
128
- dates = rv.index
129
- rv = rv.to_numpy()
130
- n = len(rv)
131
 
132
- H = PREDICTION_LENGTH
 
133
  if n <= H + 5:
134
- raise gr.Error(f"Vol series too short. Need > {H+5}, got {n}.")
135
 
136
- rv_train = rv[: n - H]
137
- rv_cal_true = rv[n - H : n - TEST_H] # first 30 of 60
138
- rv_test_true = rv[n - TEST_H :] # final 30
139
 
 
140
  random.seed(0); np.random.seed(0); torch.manual_seed(0)
141
- if torch.cuda.is_available(): torch.cuda.manual_seed_all(0)
 
142
 
143
  context = torch.tensor(rv_train, dtype=torch.float32)
144
- fcst = pipe.predict(context, prediction_length=H, num_samples=NUM_SAMPLES)
145
- samples = fcst[0].cpu().numpy()
146
- path_pred = samples[0]
147
-
148
- rv_cal_pred = path_pred[:CAL_H]
149
- rv_test_pred = path_pred[CAL_H:]
150
-
151
- metrics_raw = compute_metrics(rv_test_true, rv_test_pred)
152
- a, b, rv_test_pred_cal = affine_calibration(rv_cal_true, rv_cal_pred)
153
- metrics_cal = compute_metrics(rv_test_true, rv_test_pred_cal)
154
-
155
- # ---- Plot ----
156
- fig = plt.figure(figsize=(10,4))
157
- x_hist = dates[: len(rv_train)]
158
- x_cal = dates[len(rv_train): len(rv_train)+CAL_H]
159
- x_test = dates[len(rv_train)+CAL_H :]
160
-
161
- plt.plot(x_hist, rv_train, label="history RV")
162
- plt.plot(x_cal, rv_cal_true, label="calibration RV")
163
- plt.plot(x_test, rv_test_true, label="actual last 30 RV")
164
- plt.plot(x_test, rv_test_pred, "--", label="forecast raw")
165
- plt.plot(x_test, rv_test_pred_cal, "--", label=f"forecast calibrated a={a:.3f}, b={b:.3f}")
166
-
167
- plt.title(f"{ticker.upper()} Volatility Forecast (60d → eval last 30)")
168
- plt.xlabel("date"); plt.ylabel("realized vol")
169
- plt.legend(); plt.tight_layout()
170
-
171
- # ---- Table ----
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  df_days = pd.DataFrame({
173
- "date": x_test,
174
- "actual_vol": rv_test_true,
175
- "forecast_raw": rv_test_pred,
176
- "forecast_calibrated": rv_test_pred_cal,
177
- "abs_pct_err_raw_%": np.abs((rv_test_pred - rv_test_true) / np.maximum(EPS, np.abs(rv_test_true))) * 100,
178
- "abs_pct_err_cal_%": np.abs((rv_test_pred_cal - rv_test_true) / np.maximum(EPS, np.abs(rv_test_true))) * 100,
179
  })
180
-
 
 
 
 
 
 
 
181
  out = {
182
  "ticker": ticker.upper(),
 
183
  "config": {
184
  "start": start,
185
  "interval": interval,
186
  "rv_window": RV_WINDOW,
187
  "prediction_length": H,
188
- "calibration_window": CAL_H,
 
 
189
  },
190
- "metrics_raw": {k: round(v,4) for k,v in metrics_raw.items()},
191
- "metrics_calibrated": {k: round(v,4) for k,v in metrics_cal.items()},
192
- "affine_params": {"a": a, "b": b},
193
  }
 
194
 
195
- metrics_md = (
196
- f"**RAW** — MAPE {metrics_raw['MAPE']:.2f}% | MPE {metrics_raw['MPE']:.2f}% | RMSE {metrics_raw['RMSE']:.5f}\n"
197
- f"**AFFINE CAL** MAPE {metrics_cal['MAPE']:.2f}% | MPE {metrics_cal['MPE']:.2f}% | RMSE {metrics_cal['RMSE']:.5f}"
198
- )
199
 
200
  return fig, out, df_days, metrics_md
201
 
202
  # --------------------
203
  # UI
204
  # --------------------
205
- with gr.Blocks(title="Volatility Forecast • Chronos + Affine Calibration") as demo:
206
  gr.Markdown(
207
- "### Predict realized volatility with Chronos, with **non-leaky affine calibration (a,b)**\n"
208
- "- 60-day forecast calibrate on first 30, evaluate last 30\n"
209
- "- Works with tickers like `AAPL`, `BMW.DE`, `BTC-USD`, `NESN.SW`"
 
 
 
210
  )
211
  with gr.Row():
212
- tickers_in = gr.Textbox(value="AAPL", label="Ticker")
213
  with gr.Row():
214
- start_in = gr.Textbox(value="2015-01-01", label="Start date")
215
- interval_in = gr.Dropdown(choices=["1d","1wk","1mo"], value="1d", label="Interval")
 
216
  run_btn = gr.Button("Run", variant="primary")
217
 
218
- plot = gr.Plot(label="Forecast vs Actual")
219
  meta = gr.JSON(label="Run config & metrics")
220
  table = gr.Dataframe(label="Per-day comparison", wrap=True)
221
  metrics = gr.Markdown(label="Summary")
222
 
223
- run_btn.click(run_for_ticker,
224
- inputs=[tickers_in, start_in, interval_in, gr.Checkbox(value=False)],
225
  outputs=[plot, meta, table, metrics])
226
 
227
  if __name__ == "__main__":
228
  demo.launch()
 
 
1
+ # app.py
2
  import os, random
3
  from typing import Tuple
4
  import numpy as np
 
17
  # Config
18
  # --------------------
19
  MODEL_ID = "amazon/chronos-t5-large"
20
+ PREDICTION_LENGTH = 30 # forecast last 30 days
21
+ NUM_SAMPLES = 1 # single path -> day-by-day point prediction
22
+ RV_WINDOW = 20 # realized vol window (trading days)
23
+ ANNUALIZE = True # annualize by sqrt(252)
 
 
24
  EPS = 1e-8
25
 
26
  # --------------------
27
  # Model load (once)
28
  # --------------------
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
31
 
32
  pipe = ChronosPipeline.from_pretrained(
33
  MODEL_ID,
 
39
  # Helpers
40
  # --------------------
41
  def _extract_close(df: pd.DataFrame) -> pd.Series:
42
+ """
43
+ Robustly extract the close or adjusted close price as a numeric Series.
44
+ Handles both flat and MultiIndex columns (yfinance often returns MultiIndex
45
+ when multiple tickers or suffixes are used).
46
+ """
47
+ # --- Case 1: MultiIndex (e.g., ('Adj Close', 'BMW.DE')) ---
48
  if isinstance(df.columns, pd.MultiIndex):
49
+ # Try Adj Close first
50
  for name in ["Adj Close", "Adj_Close", "adj close", "adj_close"]:
51
  if name in df.columns.get_level_values(0):
52
  sub = df.xs(name, axis=1, level=0)
53
+ # If multiple tickers, pick first column
54
  if sub.shape[1] > 1:
55
  sub = sub.iloc[:, 0]
56
  return pd.to_numeric(sub.squeeze(), errors="coerce").dropna()
57
+ # Fallback to Close
58
  for name in ["Close", "close", "Price", "price"]:
59
  if name in df.columns.get_level_values(0):
60
  sub = df.xs(name, axis=1, level=0)
 
62
  sub = sub.iloc[:, 0]
63
  return pd.to_numeric(sub.squeeze(), errors="coerce").dropna()
64
 
65
+ # --- Case 2: Flat columns ---
66
  mapping = {c.lower(): c for c in df.columns}
67
  for name in ["adj close", "adj_close", "close", "price"]:
68
  if name in mapping:
69
  col = df[mapping[name]]
70
  return pd.to_numeric(col, errors="coerce").dropna()
71
 
72
+ # --- Fallback: last numeric column ---
73
  num_cols = df.select_dtypes(include=[np.number]).columns
74
  if len(num_cols) == 0:
75
  raise gr.Error("No numeric price column found in downloaded data.")
76
  return pd.Series(df[num_cols[-1]]).astype(float)
77
 
78
+
79
+ def _extract_dates(df: pd.DataFrame):
80
+ # If index is DatetimeIndex, use it
81
  if isinstance(df.index, pd.DatetimeIndex):
82
+ return df.index.to_numpy()
83
+ # Else try a date-like column
84
+ mapping = {c.lower(): c for c in df.columns}
85
+ for name in ["date", "time", "timestamp"]:
86
+ if name in mapping:
87
+ try:
88
+ return pd.to_datetime(df[mapping[name]]).to_numpy()
89
+ except Exception:
90
+ pass
91
+ # Fallback to a simple range
92
+ return np.arange(len(df))
 
 
93
 
94
  def compute_realized_vol(close: pd.Series, window: int = 20, annualize: bool = True) -> pd.Series:
95
  r = np.log(close).diff().dropna()
96
  rv = r.rolling(window, min_periods=window).std()
97
  if annualize:
98
+ rv = rv * np.sqrt(252.0)
99
+ return rv.dropna().reset_index(drop=True)
100
+
101
+ def bias_scale_calibration(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[float, np.ndarray]:
102
+ alpha = float(np.sum(y_true * y_pred) / (np.sum(y_pred**2) + EPS))
103
+ return alpha, alpha * y_pred
104
 
105
  def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict:
106
  err = y_pred - y_true
 
110
  rmse = float(np.sqrt(np.mean(err**2)))
111
  return {"MAPE": mape, "MPE": mpe, "RMSE": rmse}
112
 
 
 
 
 
 
 
 
113
  # --------------------
114
  # Core routine
115
  # --------------------
116
+ def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: bool):
117
+ """
118
+ tickers: comma/space separated; we use the FIRST for plotting/eval.
119
+ start: YYYY-MM-DD
120
+ interval: '1d', '1wk', '1mo'
121
+ """
122
+ # Parse first ticker (keep dots and dashes!)
123
  tick_list = [t.strip() for t in tickers.replace(";", ",").replace("|", ",").split(",") if t.strip()]
124
  if not tick_list:
125
+ raise gr.Error("Please enter at least one ticker, e.g. AAPL or NESN.SW")
 
126
 
127
+ ticker = tick_list[0] # keep original form; pipeline handles uppercasing
128
+
129
+ # 1) Fetch/update CSV via pipeline
130
  try:
131
  csv_path = pipe2.update_ticker_csv(ticker, start=start, interval=interval)
132
  except Exception as e:
133
+ raise gr.Error(
134
+ f"Data fetch failed for '{ticker}'. Tip: ensure exchange suffixes (e.g., NESN.SW, BMW.DE, VOD.L).\n{e}"
135
+ )
136
 
137
+ # 2) Load CSV and build realized vol
138
  try:
139
  df = pd.read_csv(csv_path, index_col=0, parse_dates=True)
140
+ if not isinstance(df.index, pd.DatetimeIndex):
141
+ # last fallback
142
+ df = pd.read_csv(csv_path)
143
  except Exception:
144
  df = pd.read_csv(csv_path)
145
 
146
+ dates = _extract_dates(df)
147
  close = _extract_close(df)
 
 
 
 
148
 
149
+ rv = compute_realized_vol(close, window=RV_WINDOW, annualize=ANNUALIZE).to_numpy()
150
+ n = len(rv); H = PREDICTION_LENGTH
151
  if n <= H + 5:
152
+ raise gr.Error(f"Vol series too short after rolling window. Need > {H+5}, got {n}.")
153
 
154
+ rv_train = rv[: n - H]
155
+ rv_test = rv[n - H :]
 
156
 
157
+ # 3) Forecast a single sample path (deterministic via seed)
158
  random.seed(0); np.random.seed(0); torch.manual_seed(0)
159
+ if torch.cuda.is_available():
160
+ torch.cuda.manual_seed_all(0)
161
 
162
  context = torch.tensor(rv_train, dtype=torch.float32)
163
+ fcst = pipe.predict(context, prediction_length=H, num_samples=NUM_SAMPLES) # [1, 1, H]
164
+ samples = fcst[0].cpu().numpy() # (1, H)
165
+ path_pred = samples[0] # (H,)
166
+
167
+ # 4) Optional bias/scale calibration
168
+ alpha = None
169
+ if use_calibration:
170
+ alpha, path_pred_cal = bias_scale_calibration(rv_test, path_pred)
171
+ metrics_raw = compute_metrics(rv_test, path_pred)
172
+ metrics_cal = compute_metrics(rv_test, path_pred_cal)
173
+ else:
174
+ metrics_raw = compute_metrics(rv_test, path_pred)
175
+ metrics_cal = None
176
+ path_pred_cal = None
177
+
178
+ # 5) Plot
179
+ fig = plt.figure(figsize=(10, 4))
180
+ H0 = len(rv_train)
181
+
182
+ if isinstance(dates, np.ndarray) and len(dates) >= len(close):
183
+ dates_rv = np.array(dates[-len(rv):])
184
+ x_hist = dates_rv[:H0]
185
+ x_fcst = dates_rv[H0:]
186
+ x_lbl = "date"
187
+ else:
188
+ x_hist = np.arange(H0)
189
+ x_fcst = np.arange(H0, H0 + H)
190
+ x_lbl = "time index"
191
+
192
+ plt.plot(x_hist, rv_train, label="realized vol (history)")
193
+ plt.plot(x_fcst, rv_test, label="realized vol (actual last 30)")
194
+ plt.plot(x_fcst, path_pred, linestyle="--", label="forecast (raw path)")
195
+ if use_calibration:
196
+ plt.plot(x_fcst, path_pred_cal, linestyle="--", label=f"forecast (calibrated, α={alpha:.3f})")
197
+
198
+ plt.title(f"{ticker.upper()} — Volatility Forecast (RV={RV_WINDOW}, H={H}, interval={interval})")
199
+ plt.xlabel(x_lbl); plt.ylabel("realized volatility")
200
+ plt.legend(loc="best"); plt.tight_layout()
201
+
202
+ # 6) Per-day table
203
+ last_dates = x_fcst
204
  df_days = pd.DataFrame({
205
+ "date": last_dates,
206
+ "actual_vol": rv_test,
207
+ "forecast_raw": path_pred,
 
 
 
208
  })
209
+ if use_calibration:
210
+ df_days["forecast_calibrated"] = path_pred_cal
211
+ df_days["abs_pct_error_raw_%"] = np.abs((path_pred - rv_test) / np.maximum(EPS, np.abs(rv_test))) * 100
212
+ df_days["abs_pct_error_cal_%"] = np.abs((path_pred_cal - rv_test) / np.maximum(EPS, np.abs(rv_test))) * 100
213
+ else:
214
+ df_days["abs_pct_error_raw_%"] = np.abs((path_pred - rv_test) / np.maximum(EPS, np.abs(rv_test))) * 100
215
+
216
+ # 7) JSON + metrics text
217
  out = {
218
  "ticker": ticker.upper(),
219
+ "csv_path": csv_path,
220
  "config": {
221
  "start": start,
222
  "interval": interval,
223
  "rv_window": RV_WINDOW,
224
  "prediction_length": H,
225
+ "num_samples": NUM_SAMPLES,
226
+ "annualized": ANNUALIZE,
227
+ "point_forecast": "single_sample_path",
228
  },
229
+ "metrics_raw": {k: round(v, 4) for k, v in metrics_raw.items()},
 
 
230
  }
231
+ metrics_md = f"**RAW** — MAPE {metrics_raw['MAPE']:.2f}% | MPE {metrics_raw['MPE']:.2f}% | RMSE {metrics_raw['RMSE']:.5f}"
232
 
233
+ if use_calibration and metrics_cal is not None:
234
+ out["alpha"] = alpha
235
+ out["metrics_calibrated"] = {k: round(v, 4) for k, v in metrics_cal.items()}
236
+ metrics_md += f"\n**CALIBRATED** — MAPE {metrics_cal['MAPE']:.2f}% | MPE {metrics_cal['MPE']:.2f}% | RMSE {metrics_cal['RMSE']:.5f}"
237
 
238
  return fig, out, df_days, metrics_md
239
 
240
  # --------------------
241
  # UI
242
  # --------------------
243
+ with gr.Blocks(title="Volatility Forecast • yfinance pipeline + Chronos") as demo:
244
  gr.Markdown(
245
+ "### Predict last 30 days of realized volatility for any ticker\n"
246
+ "- Works with symbols like `AAPL`, `NESN.SW`, `BMW.DE`, `VOD.L`, `BRK-B`, `BTC-USD`.\n"
247
+ "- Data fetched via **yfinance** using your `pipeline_v2.update_ticker_csv`.\n"
248
+ "- Forecast uses **Chronos-T5-Large** (single path, deterministic seed).\n"
249
+ "- Day-by-day comparison with **MAPE/MPE/RMSE**.\n"
250
+ "- Optional **Bias/Scale Calibration (α)**."
251
  )
252
  with gr.Row():
253
+ tickers_in = gr.Textbox(value="AAPL", label="Ticker (you can use suffixes like NESN.SW, BMW.DE)")
254
  with gr.Row():
255
+ start_in = gr.Textbox(value="2015-01-01", label="Start date (YYYY-MM-DD)")
256
+ interval_in = gr.Dropdown(choices=["1d", "1wk", "1mo"], value="1d", label="Interval")
257
+ calib_in = gr.Checkbox(value=True, label="Apply bias/scale calibration (α)")
258
  run_btn = gr.Button("Run", variant="primary")
259
 
260
+ plot = gr.Plot(label="Forecast vs Actual (last 30 days)")
261
  meta = gr.JSON(label="Run config & metrics")
262
  table = gr.Dataframe(label="Per-day comparison", wrap=True)
263
  metrics = gr.Markdown(label="Summary")
264
 
265
+ run_btn.click(run_for_ticker, inputs=[tickers_in, start_in, interval_in, calib_in],
 
266
  outputs=[plot, meta, table, metrics])
267
 
268
  if __name__ == "__main__":
269
  demo.launch()
270
+