AshenH commited on
Commit
4cad9bd
·
verified ·
1 Parent(s): 6c6d38f

Update tools/ts_forecast_tool.py

Browse files
Files changed (1) hide show
  1. tools/ts_forecast_tool.py +349 -82
tools/ts_forecast_tool.py CHANGED
@@ -1,115 +1,382 @@
1
  # space/tools/ts_forecast_tool.py
2
  import os
 
3
  from typing import Optional, Dict
4
 
5
  import torch
6
  import pandas as pd
 
7
 
8
  from utils.tracing import Tracer
9
  from utils.config import AppConfig
10
-
11
- # We avoid unavailable task-specific heads.
12
- # Use a generic AutoModel and attempt capability-based calls.
13
  from transformers import AutoModel, AutoConfig
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  class TimeseriesForecastTool:
17
  """
18
- Lightweight wrapper around 'ibm-granite/granite-timeseries-ttm-r1' for zero-shot forecasting.
19
 
20
  This wrapper:
21
- - loads the model with `AutoModel.from_pretrained`
22
- - checks for a `.predict(...)` method first
23
- - else tries calling the model with `prediction_length=horizon`
24
- - returns a Pandas DataFrame with a single 'forecast' column
 
25
 
26
  Expected input:
27
- - series: pd.Series with a DatetimeIndex (regular frequency recommended)
28
- - horizon: int, number of future steps
29
-
30
- NOTE:
31
- Different library versions expose different APIs. If your environment/model
32
- lacks a compatible inference method, we raise a clear RuntimeError with
33
- guidance rather than failing at import time.
34
  """
35
 
36
  def __init__(
37
  self,
38
  cfg: Optional[AppConfig],
39
  tracer: Optional[Tracer],
40
- model_id: str = "ibm-granite/granite-timeseries-ttm-r1",
41
  device: Optional[str] = None,
42
  ):
43
  self.cfg = cfg
44
  self.tracer = tracer
45
  self.model_id = model_id
46
-
 
 
 
47
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
48
- # Load config + model generically
49
- self.config = AutoConfig.from_pretrained(self.model_id)
50
- self.model = AutoModel.from_pretrained(self.model_id)
51
- self.model.to(self.device)
52
- self.model.eval()
53
-
54
- def zeroshot_forecast(self, series: pd.Series, horizon: int = 96) -> pd.DataFrame:
55
- if not isinstance(series, pd.Series):
56
- raise ValueError("series must be a pandas Series")
57
- if series.empty:
58
- return pd.DataFrame(columns=["forecast"])
59
-
60
- # Ensure numeric tensor
61
- values = series.astype("float32").to_numpy()
62
- x = torch.tensor(values, dtype=torch.float32, device=self.device).unsqueeze(0)
63
-
64
- with torch.no_grad():
65
- # 1) Preferred: explicit .predict API
66
- if hasattr(self.model, "predict"):
67
- try:
68
- preds = self.model.predict(x, prediction_length=horizon)
69
- yhat = preds if isinstance(preds, torch.Tensor) else torch.tensor(preds)
70
- out = yhat.squeeze().detach().cpu().numpy()
71
- return pd.DataFrame({"forecast": out})
72
- except Exception as e:
73
- raise RuntimeError(
74
- f"Granite model has a 'predict' method but it failed at runtime: {e}"
75
- )
76
-
77
- # 2) Fallback: call forward with a 'prediction_length' kwarg if supported
78
  try:
79
- outputs = self.model(x, prediction_length=horizon)
80
- # Try common attribute names
81
- for k in ("predictions", "prediction", "logits", "output"):
82
- if hasattr(outputs, k):
83
- tensor = getattr(outputs, k)
84
- if isinstance(tensor, (tuple, list)):
85
- tensor = tensor[0]
86
- if not isinstance(tensor, torch.Tensor):
87
- tensor = torch.tensor(tensor)
88
- out = tensor.squeeze().detach().cpu().numpy()
89
- # If multi-dim, take last dimension as forecast
90
- if out.ndim > 1:
91
- out = out[-1] if out.shape[0] == horizon else out.reshape(-1)
92
- return pd.DataFrame({"forecast": out})
93
- # If outputs is a raw tensor
94
- if isinstance(outputs, torch.Tensor):
95
- out = outputs.squeeze().detach().cpu().numpy()
96
- if out.ndim > 1:
97
- out = out[-1] if out.shape[0] == horizon else out.reshape(-1)
98
- return pd.DataFrame({"forecast": out})
99
- except TypeError:
100
- # Some builds may not accept prediction_length at all
101
- pass
102
  except Exception as e:
103
- raise RuntimeError(
104
- f"Calling the model forward for forecasting failed: {e}"
 
 
 
 
 
 
105
  )
106
-
107
- # If we get here, the installed combo doesn't expose an inference entrypoint we can use.
108
- raise RuntimeError(
109
- "The installed transformers/model combo does not expose a usable zero-shot "
110
- "forecasting interface (no `.predict` and forward(...) didn't accept "
111
- "`prediction_length`). Consider:\n"
112
- " Upgrading transformers/torch versions\n"
113
- " Using the 'granite-tsfm-public' PyPI if available in your region\n"
114
- " • Switching to a classic forecaster for now (e.g., ARIMA/XGBoost)\n"
115
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # space/tools/ts_forecast_tool.py
2
  import os
3
+ import logging
4
  from typing import Optional, Dict
5
 
6
  import torch
7
  import pandas as pd
8
+ import numpy as np
9
 
10
  from utils.tracing import Tracer
11
  from utils.config import AppConfig
 
 
 
12
  from transformers import AutoModel, AutoConfig
13
 
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Constants
17
+ MIN_SERIES_LENGTH = 2
18
+ MAX_SERIES_LENGTH = 10000
19
+ MIN_HORIZON = 1
20
+ MAX_HORIZON = 365
21
+ DEFAULT_MODEL_ID = "ibm-granite/granite-timeseries-ttm-r1"
22
+
23
+
24
+ class ForecastToolError(Exception):
25
+ """Custom exception for forecast tool errors."""
26
+ pass
27
+
28
 
29
  class TimeseriesForecastTool:
30
  """
31
+ Lightweight wrapper around Granite Time Series models for zero-shot forecasting.
32
 
33
  This wrapper:
34
+ - Loads the model with AutoModel.from_pretrained
35
+ - Validates input series and horizon
36
+ - Attempts multiple inference methods (predict, forward with prediction_length)
37
+ - Returns a Pandas DataFrame with forecast column
38
+ - Provides comprehensive error handling and logging
39
 
40
  Expected input:
41
+ - series: pd.Series with DatetimeIndex (regular frequency recommended)
42
+ - horizon: int, number of future steps to forecast
 
 
 
 
 
43
  """
44
 
45
  def __init__(
46
  self,
47
  cfg: Optional[AppConfig],
48
  tracer: Optional[Tracer],
49
+ model_id: str = DEFAULT_MODEL_ID,
50
  device: Optional[str] = None,
51
  ):
52
  self.cfg = cfg
53
  self.tracer = tracer
54
  self.model_id = model_id
55
+ self.model = None
56
+ self.config = None
57
+
58
+ # Determine device
59
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
60
+ logger.info(f"TimeseriesForecastTool initialized with device: {self.device}")
61
+
62
+ # Lazy loading - model loaded on first use
63
+ self._initialized = False
64
+
65
+ def _ensure_loaded(self):
66
+ """Lazy load the model and configuration."""
67
+ if self._initialized:
68
+ return
69
+
70
+ try:
71
+ logger.info(f"Loading Granite time series model: {self.model_id}")
72
+
73
+ # Load configuration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  try:
75
+ self.config = AutoConfig.from_pretrained(self.model_id)
76
+ logger.info(f"Model config loaded: {type(self.config).__name__}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  except Exception as e:
78
+ logger.warning(f"Could not load model config: {e}")
79
+ self.config = None
80
+
81
+ # Load model
82
+ try:
83
+ self.model = AutoModel.from_pretrained(
84
+ self.model_id,
85
+ trust_remote_code=True # Required for some custom models
86
  )
87
+ self.model.to(self.device)
88
+ self.model.eval()
89
+ logger.info(f"Model loaded successfully: {type(self.model).__name__}")
90
+
91
+ except Exception as e:
92
+ raise ForecastToolError(
93
+ f"Failed to load model '{self.model_id}': {e}\n"
94
+ "Ensure the model is available and transformers is up to date."
95
+ ) from e
96
+
97
+ self._initialized = True
98
+
99
+ except ForecastToolError:
100
+ raise
101
+ except Exception as e:
102
+ raise ForecastToolError(f"Model initialization failed: {e}") from e
103
+
104
+ def _validate_series(self, series: pd.Series) -> tuple[bool, str]:
105
+ """
106
+ Validate input time series.
107
+ Returns (is_valid, error_message).
108
+ """
109
+ if not isinstance(series, pd.Series):
110
+ return False, "Input must be a pandas Series"
111
+
112
+ if series.empty:
113
+ return False, "Series is empty"
114
+
115
+ if len(series) < MIN_SERIES_LENGTH:
116
+ return False, f"Series too short (min {MIN_SERIES_LENGTH} points required)"
117
+
118
+ if len(series) > MAX_SERIES_LENGTH:
119
+ return False, f"Series too long (max {MAX_SERIES_LENGTH} points allowed)"
120
+
121
+ # Check for nulls
122
+ if series.isnull().any():
123
+ null_count = series.isnull().sum()
124
+ return False, f"Series contains {null_count} null values. Please handle missing data first."
125
+
126
+ # Check for infinite values
127
+ if not np.isfinite(series).all():
128
+ return False, "Series contains infinite values"
129
+
130
+ # Warn if not numeric
131
+ if not pd.api.types.is_numeric_dtype(series):
132
+ return False, f"Series must be numeric, got dtype: {series.dtype}"
133
+
134
+ return True, ""
135
+
136
+ def _validate_horizon(self, horizon: int) -> tuple[bool, str]:
137
+ """
138
+ Validate forecast horizon.
139
+ Returns (is_valid, error_message).
140
+ """
141
+ try:
142
+ h = int(horizon)
143
+ except (TypeError, ValueError):
144
+ return False, f"Horizon must be an integer, got: {horizon}"
145
+
146
+ if h < MIN_HORIZON:
147
+ return False, f"Horizon too small (min {MIN_HORIZON})"
148
+
149
+ if h > MAX_HORIZON:
150
+ return False, f"Horizon too large (max {MAX_HORIZON})"
151
+
152
+ return True, ""
153
+
154
+ def _prepare_input_tensor(self, series: pd.Series) -> torch.Tensor:
155
+ """
156
+ Convert pandas Series to PyTorch tensor.
157
+ Handles type conversion and device placement.
158
+ """
159
+ try:
160
+ # Convert to float32 numpy array
161
+ values = series.astype("float32").to_numpy()
162
+
163
+ # Create tensor and move to device
164
+ tensor = torch.tensor(values, dtype=torch.float32, device=self.device)
165
+
166
+ # Add batch dimension [1, seq_len]
167
+ tensor = tensor.unsqueeze(0)
168
+
169
+ logger.debug(f"Input tensor shape: {tensor.shape}, device: {tensor.device}")
170
+
171
+ return tensor
172
+
173
+ except Exception as e:
174
+ raise ForecastToolError(f"Failed to prepare input tensor: {e}") from e
175
+
176
+ def _try_predict_method(self, x: torch.Tensor, horizon: int) -> Optional[np.ndarray]:
177
+ """
178
+ Try using the model's .predict() method.
179
+ Returns None if method doesn't exist or fails.
180
+ """
181
+ if not hasattr(self.model, "predict"):
182
+ logger.debug("Model has no 'predict' method")
183
+ return None
184
+
185
+ try:
186
+ logger.info("Attempting forecast with .predict() method")
187
+ preds = self.model.predict(x, prediction_length=horizon)
188
+
189
+ # Convert to tensor if needed
190
+ if not isinstance(preds, torch.Tensor):
191
+ preds = torch.tensor(preds, device=self.device)
192
+
193
+ # Extract numpy array
194
+ output = preds.squeeze().detach().cpu().numpy()
195
+
196
+ # Validate output shape
197
+ if output.shape[-1] != horizon:
198
+ logger.warning(
199
+ f"Prediction length mismatch: expected {horizon}, got {output.shape[-1]}"
200
+ )
201
+
202
+ logger.info(f"Forecast successful via .predict(): {output.shape}")
203
+ return output
204
+
205
+ except Exception as e:
206
+ logger.warning(f"predict() method failed: {e}")
207
+ return None
208
+
209
+ def _try_forward_method(self, x: torch.Tensor, horizon: int) -> Optional[np.ndarray]:
210
+ """
211
+ Try using the model's forward() method with prediction_length parameter.
212
+ Returns None if method fails.
213
+ """
214
+ try:
215
+ logger.info("Attempting forecast with forward(prediction_length=...)")
216
+ outputs = self.model(x, prediction_length=horizon)
217
+
218
+ # Try to extract predictions from various possible output formats
219
+ prediction_tensor = None
220
+
221
+ # Check common attribute names
222
+ for attr in ("predictions", "prediction", "logits", "forecast", "output"):
223
+ if hasattr(outputs, attr):
224
+ candidate = getattr(outputs, attr)
225
+
226
+ # Handle tuple/list outputs
227
+ if isinstance(candidate, (tuple, list)):
228
+ candidate = candidate[0]
229
+
230
+ # Convert to tensor if needed
231
+ if not isinstance(candidate, torch.Tensor):
232
+ candidate = torch.tensor(candidate, device=self.device)
233
+
234
+ prediction_tensor = candidate
235
+ logger.debug(f"Found predictions in attribute: {attr}")
236
+ break
237
+
238
+ # If outputs is directly a tensor
239
+ if prediction_tensor is None and isinstance(outputs, torch.Tensor):
240
+ prediction_tensor = outputs
241
+ logger.debug("Using raw tensor output")
242
+
243
+ if prediction_tensor is None:
244
+ logger.warning("Could not extract predictions from forward() output")
245
+ return None
246
+
247
+ # Convert to numpy
248
+ output = prediction_tensor.squeeze().detach().cpu().numpy()
249
+
250
+ # Handle multi-dimensional outputs
251
+ if output.ndim > 1:
252
+ # Take the last row or flatten based on shape
253
+ if output.shape[0] == horizon:
254
+ output = output.flatten()
255
+ else:
256
+ output = output[-1] if output.shape[0] < output.shape[1] else output.flatten()
257
+
258
+ # Ensure correct length
259
+ if len(output) != horizon:
260
+ logger.warning(
261
+ f"Output length {len(output)} doesn't match horizon {horizon}. Truncating/padding."
262
+ )
263
+ if len(output) > horizon:
264
+ output = output[:horizon]
265
+ else:
266
+ # Pad with last value
267
+ output = np.pad(output, (0, horizon - len(output)), mode='edge')
268
+
269
+ logger.info(f"Forecast successful via forward(): {output.shape}")
270
+ return output
271
+
272
+ except TypeError as e:
273
+ logger.warning(f"forward() doesn't accept prediction_length: {e}")
274
+ return None
275
+ except Exception as e:
276
+ logger.warning(f"forward() method failed: {e}")
277
+ return None
278
+
279
+ def zeroshot_forecast(self, series: pd.Series, horizon: int = 96) -> pd.DataFrame:
280
+ """
281
+ Generate zero-shot forecast for input time series.
282
+
283
+ Args:
284
+ series: Input time series (pd.Series with numeric values)
285
+ horizon: Number of periods to forecast (default: 96)
286
+
287
+ Returns:
288
+ DataFrame with 'forecast' column containing predictions
289
+
290
+ Raises:
291
+ ForecastToolError: If forecasting fails
292
+ """
293
+ try:
294
+ # Validate inputs
295
+ is_valid, error_msg = self._validate_series(series)
296
+ if not is_valid:
297
+ raise ForecastToolError(f"Invalid series: {error_msg}")
298
+
299
+ is_valid, error_msg = self._validate_horizon(horizon)
300
+ if not is_valid:
301
+ raise ForecastToolError(f"Invalid horizon: {error_msg}")
302
+
303
+ # Ensure model is loaded
304
+ self._ensure_loaded()
305
+
306
+ # Log input statistics
307
+ logger.info(
308
+ f"Forecasting: series_length={len(series)}, "
309
+ f"horizon={horizon}, "
310
+ f"series_mean={series.mean():.2f}, "
311
+ f"series_std={series.std():.2f}"
312
+ )
313
+
314
+ # Prepare input tensor
315
+ x = self._prepare_input_tensor(series)
316
+
317
+ # Try prediction methods in order of preference
318
+ output = None
319
+
320
+ with torch.no_grad():
321
+ # Method 1: Try .predict()
322
+ output = self._try_predict_method(x, horizon)
323
+
324
+ # Method 2: Try forward with prediction_length
325
+ if output is None:
326
+ output = self._try_forward_method(x, horizon)
327
+
328
+ # If all methods failed
329
+ if output is None:
330
+ raise ForecastToolError(
331
+ "Could not generate forecast using available model methods.\n"
332
+ "The model may not support zero-shot forecasting with this interface.\n"
333
+ "Suggestions:\n"
334
+ " • Check model documentation for correct usage\n"
335
+ " • Ensure transformers library is up to date\n"
336
+ " • Try a different model or use traditional forecasting (ARIMA, Prophet)\n"
337
+ f" • Model type: {type(self.model).__name__}"
338
+ )
339
+
340
+ # Create output DataFrame
341
+ result_df = pd.DataFrame({"forecast": output})
342
+
343
+ # Log output statistics
344
+ logger.info(
345
+ f"Forecast complete: "
346
+ f"mean={output.mean():.2f}, "
347
+ f"std={output.std():.2f}, "
348
+ f"min={output.min():.2f}, "
349
+ f"max={output.max():.2f}"
350
+ )
351
+
352
+ # Trace event
353
+ if self.tracer:
354
+ self.tracer.trace_event("forecast", {
355
+ "series_length": len(series),
356
+ "horizon": horizon,
357
+ "forecast_mean": float(output.mean()),
358
+ "forecast_std": float(output.std())
359
+ })
360
+
361
+ return result_df
362
+
363
+ except ForecastToolError:
364
+ raise
365
+ except Exception as e:
366
+ error_msg = f"Forecasting failed unexpectedly: {str(e)}"
367
+ logger.error(error_msg)
368
+ if self.tracer:
369
+ self.tracer.trace_event("forecast_error", {"error": error_msg})
370
+ raise ForecastToolError(error_msg) from e
371
+
372
+ def get_model_info(self) -> Dict[str, any]:
373
+ """Get information about the loaded model."""
374
+ self._ensure_loaded()
375
+
376
+ return {
377
+ "model_id": self.model_id,
378
+ "model_type": type(self.model).__name__,
379
+ "device": str(self.device),
380
+ "has_predict": hasattr(self.model, "predict"),
381
+ "config": str(self.config) if self.config else None
382
+ }