abhaypratapsingh111 commited on
Commit
33ccadb
Β·
verified Β·
1 Parent(s): 15b68db

Upload folder using huggingface_hub

Browse files
services/__init__.py ADDED
File without changes
services/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (164 Bytes). View file
 
services/__pycache__/cache_manager.cpython-311.pyc ADDED
Binary file (7.57 kB). View file
 
services/__pycache__/data_processor.cpython-311.pyc ADDED
Binary file (18 kB). View file
 
services/__pycache__/model_service.cpython-311.pyc ADDED
Binary file (19.8 kB). View file
 
services/cache_manager.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cache manager for storing predictions and uploaded data
3
+ """
4
+
5
+ import logging
6
+ from typing import Dict, Optional
7
+ from datetime import datetime, timedelta
8
+ import pandas as pd
9
+
10
+ from config.constants import MAX_PREDICTION_HISTORY
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class CacheManager:
16
+ """
17
+ Manages caching of predictions and data to improve performance
18
+ """
19
+
20
+ def __init__(self):
21
+ self.predictions = [] # List of prediction results
22
+ self.uploaded_data = {} # Dict of uploaded datasets
23
+ self.max_predictions = MAX_PREDICTION_HISTORY
24
+
25
+ def store_prediction(
26
+ self,
27
+ data_hash: str,
28
+ horizon: int,
29
+ confidence_levels: list,
30
+ result: Dict
31
+ ):
32
+ """
33
+ Store a prediction result
34
+
35
+ Args:
36
+ data_hash: Hash of the input data
37
+ horizon: Forecast horizon used
38
+ confidence_levels: Confidence levels used
39
+ result: Prediction result dictionary
40
+ """
41
+ prediction_entry = {
42
+ 'data_hash': data_hash,
43
+ 'horizon': horizon,
44
+ 'confidence_levels': confidence_levels,
45
+ 'result': result,
46
+ 'timestamp': datetime.now()
47
+ }
48
+
49
+ self.predictions.append(prediction_entry)
50
+
51
+ # Keep only the most recent predictions
52
+ if len(self.predictions) > self.max_predictions:
53
+ self.predictions = self.predictions[-self.max_predictions:]
54
+
55
+ logger.debug(f"Stored prediction, cache size: {len(self.predictions)}")
56
+
57
+ def get_prediction(
58
+ self,
59
+ data_hash: str,
60
+ horizon: int,
61
+ confidence_levels: list
62
+ ) -> Optional[Dict]:
63
+ """
64
+ Retrieve a cached prediction if available
65
+
66
+ Args:
67
+ data_hash: Hash of the input data
68
+ horizon: Forecast horizon
69
+ confidence_levels: Confidence levels
70
+
71
+ Returns:
72
+ Cached prediction result or None
73
+ """
74
+ for entry in reversed(self.predictions):
75
+ if (entry['data_hash'] == data_hash and
76
+ entry['horizon'] == horizon and
77
+ entry['confidence_levels'] == confidence_levels):
78
+
79
+ logger.info("Cache hit for prediction")
80
+ return entry['result']
81
+
82
+ logger.debug("Cache miss for prediction")
83
+ return None
84
+
85
+ def store_data(self, filename: str, data: pd.DataFrame):
86
+ """
87
+ Store uploaded data
88
+
89
+ Args:
90
+ filename: Name of the uploaded file
91
+ data: DataFrame containing the data
92
+ """
93
+ self.uploaded_data[filename] = {
94
+ 'data': data,
95
+ 'timestamp': datetime.now()
96
+ }
97
+
98
+ logger.info(f"Stored data for {filename}")
99
+
100
+ def get_data(self, filename: str) -> Optional[pd.DataFrame]:
101
+ """
102
+ Retrieve uploaded data
103
+
104
+ Args:
105
+ filename: Name of the file
106
+
107
+ Returns:
108
+ DataFrame or None
109
+ """
110
+ if filename in self.uploaded_data:
111
+ return self.uploaded_data[filename]['data']
112
+ return None
113
+
114
+ def clear_old_data(self, max_age_hours: int = 24):
115
+ """
116
+ Clear data older than specified hours
117
+
118
+ Args:
119
+ max_age_hours: Maximum age in hours
120
+ """
121
+ cutoff = datetime.now() - timedelta(hours=max_age_hours)
122
+
123
+ # Clear old uploaded data
124
+ old_files = [
125
+ filename for filename, entry in self.uploaded_data.items()
126
+ if entry['timestamp'] < cutoff
127
+ ]
128
+
129
+ for filename in old_files:
130
+ del self.uploaded_data[filename]
131
+
132
+ if old_files:
133
+ logger.info(f"Cleared {len(old_files)} old data entries")
134
+
135
+ def clear_all(self):
136
+ """Clear all cached data"""
137
+ self.predictions.clear()
138
+ self.uploaded_data.clear()
139
+ logger.info("Cleared all cache")
140
+
141
+ def get_stats(self) -> Dict:
142
+ """Get cache statistics"""
143
+ return {
144
+ 'num_predictions': len(self.predictions),
145
+ 'num_datasets': len(self.uploaded_data),
146
+ 'total_memory_mb': self._estimate_memory()
147
+ }
148
+
149
+ def _estimate_memory(self) -> float:
150
+ """Estimate memory usage in MB (rough estimate)"""
151
+ try:
152
+ total_bytes = 0
153
+
154
+ # Estimate prediction cache size
155
+ for entry in self.predictions:
156
+ if 'forecast' in entry['result']:
157
+ total_bytes += entry['result']['forecast'].memory_usage(deep=True).sum()
158
+
159
+ # Estimate data cache size
160
+ for entry in self.uploaded_data.values():
161
+ total_bytes += entry['data'].memory_usage(deep=True).sum()
162
+
163
+ return total_bytes / (1024 * 1024)
164
+ except Exception as e:
165
+ logger.warning(f"Failed to estimate memory: {str(e)}")
166
+ return 0.0
167
+
168
+
169
+ # Global cache instance
170
+ cache_manager = CacheManager()
services/data_processor.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data preprocessing pipeline for time series data
3
+ """
4
+
5
+ import logging
6
+ from typing import Dict, List, Optional, Tuple, Any
7
+ import pandas as pd
8
+ import numpy as np
9
+ from io import BytesIO
10
+
11
+ from config.constants import (
12
+ DATE_FORMATS,
13
+ MAX_MISSING_PERCENT,
14
+ MIN_DATA_POINTS_MULTIPLIER,
15
+ ALLOWED_EXTENSIONS
16
+ )
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class DataProcessor:
22
+ """
23
+ Handles all data preprocessing tasks for time series forecasting
24
+ """
25
+
26
+ def __init__(self):
27
+ self.data = None
28
+ self.original_data = None
29
+ self.metadata = {}
30
+
31
+ def _timedelta_to_freq_string(self, td: pd.Timedelta) -> str:
32
+ """
33
+ Convert a Timedelta to a pandas frequency string
34
+
35
+ Args:
36
+ td: Timedelta object
37
+
38
+ Returns:
39
+ Frequency string (e.g., 'H', 'D', '5min', etc.)
40
+ """
41
+ total_seconds = td.total_seconds()
42
+
43
+ # Common time frequencies
44
+ if total_seconds == 0:
45
+ return 'D' # Default to daily if zero
46
+ elif total_seconds % 604800 == 0: # Weekly (7 days)
47
+ weeks = int(total_seconds / 604800)
48
+ return f'{weeks}W' if weeks > 1 else 'W'
49
+ elif total_seconds % 86400 == 0: # Daily (24 hours)
50
+ days = int(total_seconds / 86400)
51
+ return f'{days}D' if days > 1 else 'D'
52
+ elif total_seconds % 3600 == 0: # Hourly
53
+ hours = int(total_seconds / 3600)
54
+ return f'{hours}H' if hours > 1 else 'H'
55
+ elif total_seconds % 60 == 0: # Minutes
56
+ minutes = int(total_seconds / 60)
57
+ return f'{minutes}min' if minutes > 1 else 'min'
58
+ elif total_seconds % 1 == 0: # Seconds
59
+ seconds = int(total_seconds)
60
+ return f'{seconds}s' if seconds > 1 else 's'
61
+ else:
62
+ # For irregular frequencies, default to daily
63
+ logger.warning(f"Irregular frequency detected ({td}), defaulting to Daily")
64
+ return 'D'
65
+
66
+ def load_file(self, contents: bytes, filename: str) -> Dict[str, Any]:
67
+ """
68
+ Load data from uploaded file
69
+
70
+ Args:
71
+ contents: File contents as bytes
72
+ filename: Original filename
73
+
74
+ Returns:
75
+ Dictionary with status and data/error
76
+ """
77
+ try:
78
+ # Determine file type
79
+ extension = filename.split('.')[-1].lower()
80
+
81
+ if extension not in ALLOWED_EXTENSIONS:
82
+ return {
83
+ 'status': 'error',
84
+ 'error': f'Invalid file type. Allowed: {", ".join(ALLOWED_EXTENSIONS)}'
85
+ }
86
+
87
+ # Load data based on file type
88
+ if extension == 'csv':
89
+ self.data = pd.read_csv(BytesIO(contents))
90
+ elif extension in ['xlsx', 'xls']:
91
+ self.data = pd.read_excel(BytesIO(contents))
92
+
93
+ self.original_data = self.data.copy()
94
+
95
+ logger.info(f"Loaded file {filename} with shape {self.data.shape}")
96
+
97
+ # Generate initial metadata
98
+ self.metadata = {
99
+ 'filename': filename,
100
+ 'rows': len(self.data),
101
+ 'columns': list(self.data.columns),
102
+ 'dtypes': {col: str(dtype) for col, dtype in self.data.dtypes.items()}
103
+ }
104
+
105
+ return {
106
+ 'status': 'success',
107
+ 'data': self.data,
108
+ 'metadata': self.metadata
109
+ }
110
+
111
+ except Exception as e:
112
+ logger.error(f"Failed to load file {filename}: {str(e)}", exc_info=True)
113
+ return {
114
+ 'status': 'error',
115
+ 'error': f'Failed to load file: {str(e)}'
116
+ }
117
+
118
+ def validate_data(
119
+ self,
120
+ date_column: str,
121
+ target_column: str,
122
+ id_column: Optional[str] = None
123
+ ) -> Dict[str, Any]:
124
+ """
125
+ Validate the selected columns and data quality
126
+
127
+ Args:
128
+ date_column: Name of the date/time column
129
+ target_column: Name of the target variable column
130
+ id_column: Optional ID column for multivariate series
131
+
132
+ Returns:
133
+ Validation result dictionary
134
+ """
135
+ try:
136
+ issues = []
137
+ warnings = []
138
+
139
+ # Check if columns exist
140
+ if date_column not in self.data.columns:
141
+ issues.append(f"Date column '{date_column}' not found")
142
+ if target_column not in self.data.columns:
143
+ issues.append(f"Target column '{target_column}' not found")
144
+ if id_column and id_column not in self.data.columns:
145
+ issues.append(f"ID column '{id_column}' not found")
146
+
147
+ if issues:
148
+ return {'status': 'error', 'issues': issues}
149
+
150
+ # Check for missing values
151
+ missing_pct = (self.data[target_column].isna().sum() / len(self.data)) * 100
152
+ if missing_pct > MAX_MISSING_PERCENT:
153
+ warnings.append(
154
+ f"Target column has {missing_pct:.1f}% missing values (>{MAX_MISSING_PERCENT}%)"
155
+ )
156
+
157
+ # Check data type of target
158
+ if not pd.api.types.is_numeric_dtype(self.data[target_column]):
159
+ issues.append(f"Target column must be numeric, found {self.data[target_column].dtype}")
160
+
161
+ # Try to parse date column
162
+ try:
163
+ _ = pd.to_datetime(self.data[date_column])
164
+ except Exception as e:
165
+ issues.append(f"Cannot parse date column: {str(e)}")
166
+
167
+ if issues:
168
+ return {'status': 'error', 'issues': issues, 'warnings': warnings}
169
+
170
+ return {
171
+ 'status': 'success',
172
+ 'warnings': warnings,
173
+ 'missing_pct': missing_pct
174
+ }
175
+
176
+ except Exception as e:
177
+ logger.error(f"Validation failed: {str(e)}", exc_info=True)
178
+ return {'status': 'error', 'issues': [str(e)]}
179
+
180
+ def preprocess(
181
+ self,
182
+ date_column: str,
183
+ target_column: any, # Can be string or list of strings
184
+ id_column: Optional[str] = None,
185
+ forecast_horizon: int = 30,
186
+ max_rows: int = 100000
187
+ ) -> Dict[str, Any]:
188
+ """
189
+ Complete preprocessing pipeline
190
+
191
+ Args:
192
+ date_column: Name of the date column
193
+ target_column: Name of the target column (string) or list of target columns for multivariate
194
+ id_column: Optional ID column
195
+ forecast_horizon: Number of periods to forecast
196
+
197
+ Returns:
198
+ Processed data and metadata
199
+ """
200
+ try:
201
+ logger.info("Starting preprocessing pipeline")
202
+
203
+ # Step 0: Handle very large datasets
204
+ original_row_count = len(self.data)
205
+ if original_row_count > max_rows:
206
+ logger.warning(f"Dataset has {original_row_count} rows, sampling to {max_rows} for performance")
207
+ # Keep the most recent data for forecasting
208
+ self.data = self.data.tail(max_rows).reset_index(drop=True)
209
+
210
+ # Step 1: Parse dates
211
+ logger.info("Parsing dates...")
212
+ self.data[date_column] = pd.to_datetime(self.data[date_column])
213
+
214
+ # Step 2: Sort by date and remove duplicate timestamps
215
+ self.data = self.data.sort_values(date_column).reset_index(drop=True)
216
+
217
+ # Check for and handle duplicate timestamps
218
+ duplicate_count = self.data[date_column].duplicated().sum()
219
+ if duplicate_count > 0:
220
+ logger.warning(f"Found {duplicate_count} duplicate timestamps, keeping first occurrence")
221
+ self.data = self.data.drop_duplicates(subset=[date_column], keep='first').reset_index(drop=True)
222
+
223
+ # Step 3: Detect frequency
224
+ logger.info("Detecting frequency...")
225
+ freq = pd.infer_freq(self.data[date_column])
226
+ if freq is None:
227
+ # Try to infer from differences
228
+ diffs = self.data[date_column].diff().dropna()
229
+ if len(diffs) > 0:
230
+ # Get the most common time difference
231
+ mode_diff = diffs.mode()
232
+ if len(mode_diff) > 0 and mode_diff[0] != pd.Timedelta(0):
233
+ # Convert Timedelta to frequency string
234
+ td = mode_diff[0]
235
+ freq = self._timedelta_to_freq_string(td)
236
+ logger.warning(f"Could not auto-detect frequency, inferred from mode: {freq}")
237
+ else:
238
+ freq = 'D'
239
+ logger.warning("Using default frequency: Daily")
240
+ else:
241
+ freq = 'D'
242
+ logger.warning("Using default frequency: Daily")
243
+
244
+ # Step 4: Handle missing values in target(s)
245
+ # Normalize target_column to list
246
+ target_columns = [target_column] if isinstance(target_column, str) else target_column
247
+ logger.info(f"Processing {len(target_columns)} target column(s): {target_columns}")
248
+
249
+ logger.info("Handling missing values...")
250
+ total_missing_count = 0
251
+
252
+ for tcol in target_columns:
253
+ missing_count = self.data[tcol].isna().sum()
254
+ total_missing_count += missing_count
255
+
256
+ if missing_count > 0:
257
+ # Forward fill for small gaps
258
+ self.data[tcol] = self.data[tcol].ffill(limit=5)
259
+
260
+ # Linear interpolation for remaining
261
+ self.data[tcol] = self.data[tcol].interpolate(method='linear')
262
+
263
+ # Final fallback: backward fill
264
+ self.data[tcol] = self.data[tcol].bfill()
265
+
266
+ logger.info(f"Filled {missing_count} missing values in '{tcol}'")
267
+
268
+ # Step 5: Detect outliers (IQR method) - only for primary target
269
+ logger.info("Detecting outliers...")
270
+ primary_target = target_columns[0]
271
+ Q1 = self.data[primary_target].quantile(0.25)
272
+ Q3 = self.data[primary_target].quantile(0.75)
273
+ IQR = Q3 - Q1
274
+ outlier_mask = (
275
+ (self.data[primary_target] < (Q1 - 3 * IQR)) |
276
+ (self.data[primary_target] > (Q3 + 3 * IQR))
277
+ )
278
+ outlier_count = outlier_mask.sum()
279
+
280
+ # Step 6: Check if sufficient data
281
+ min_required = forecast_horizon * MIN_DATA_POINTS_MULTIPLIER
282
+ if len(self.data) < min_required:
283
+ return {
284
+ 'status': 'error',
285
+ 'error': f'Insufficient data. Need at least {min_required} points for {forecast_horizon}-period forecast.'
286
+ }
287
+
288
+ # Step 7: Prepare for Chronos 2 format
289
+ # Chronos 2 expects columns: ['id', 'timestamp', 'target']
290
+ # For multivariate: ['id', 'timestamp', 'target', 'covariate1', 'covariate2', ...]
291
+ processed_df = pd.DataFrame({
292
+ 'id': self.data[id_column] if id_column else 'series_1',
293
+ 'timestamp': self.data[date_column],
294
+ 'target': self.data[target_columns[0]].astype(float)
295
+ })
296
+
297
+ # Add additional target columns as covariates
298
+ if len(target_columns) > 1:
299
+ logger.info(f"Adding {len(target_columns)-1} additional target column(s) as covariates")
300
+ for tcol in target_columns[1:]:
301
+ processed_df[tcol] = self.data[tcol].astype(float)
302
+
303
+ # Generate quality report
304
+ quality_report = {
305
+ 'total_points': len(processed_df),
306
+ 'original_points': original_row_count,
307
+ 'sampled': original_row_count > max_rows,
308
+ 'date_range': {
309
+ 'start': processed_df['timestamp'].min().strftime('%Y-%m-%d'),
310
+ 'end': processed_df['timestamp'].max().strftime('%Y-%m-%d')
311
+ },
312
+ 'frequency': str(freq),
313
+ 'missing_filled': total_missing_count,
314
+ 'outliers_detected': outlier_count,
315
+ 'duplicates_removed': duplicate_count if duplicate_count > 0 else 0,
316
+ 'target_columns': target_columns,
317
+ 'statistics': {
318
+ 'mean': float(processed_df['target'].mean()),
319
+ 'std': float(processed_df['target'].std()),
320
+ 'min': float(processed_df['target'].min()),
321
+ 'max': float(processed_df['target'].max())
322
+ }
323
+ }
324
+
325
+ logger.info("Preprocessing completed successfully")
326
+
327
+ return {
328
+ 'status': 'success',
329
+ 'data': processed_df,
330
+ 'quality_report': quality_report,
331
+ 'frequency': freq
332
+ }
333
+
334
+ except Exception as e:
335
+ logger.error(f"Preprocessing failed: {str(e)}", exc_info=True)
336
+ return {
337
+ 'status': 'error',
338
+ 'error': str(e)
339
+ }
340
+
341
+ def get_column_info(self) -> Dict[str, List[str]]:
342
+ """
343
+ Get information about columns for UI dropdowns
344
+
345
+ Returns:
346
+ Dictionary with potential date and numeric columns
347
+ """
348
+ if self.data is None:
349
+ return {'date_columns': [], 'numeric_columns': [], 'all_columns': []}
350
+
351
+ date_columns = []
352
+ numeric_columns = []
353
+
354
+ for col in self.data.columns:
355
+ # Check if column could be a date
356
+ if self.data[col].dtype == 'object':
357
+ # Try to parse a sample
358
+ try:
359
+ pd.to_datetime(self.data[col].iloc[:5])
360
+ date_columns.append(col)
361
+ except:
362
+ pass
363
+ elif pd.api.types.is_datetime64_any_dtype(self.data[col]):
364
+ date_columns.append(col)
365
+
366
+ # Check if column is numeric
367
+ if pd.api.types.is_numeric_dtype(self.data[col]):
368
+ numeric_columns.append(col)
369
+
370
+ return {
371
+ 'date_columns': date_columns,
372
+ 'numeric_columns': numeric_columns,
373
+ 'all_columns': list(self.data.columns)
374
+ }
375
+
376
+ def get_preview(self, n_rows: int = 10) -> pd.DataFrame:
377
+ """
378
+ Get a preview of the data
379
+
380
+ Args:
381
+ n_rows: Number of rows to return
382
+
383
+ Returns:
384
+ DataFrame preview
385
+ """
386
+ if self.data is None:
387
+ return pd.DataFrame()
388
+
389
+ return self.data.head(n_rows)
390
+
391
+
392
+ # Global data processor instance
393
+ data_processor = DataProcessor()
services/model_service.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chronos 2 Model Service
3
+ Handles model loading, caching, and inference using Chronos2Pipeline
4
+ """
5
+
6
+ import logging
7
+ import time
8
+ from typing import Dict, List, Optional, Tuple, Any
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+ from chronos import ChronosPipeline, Chronos2Pipeline
13
+
14
+ from config.constants import CHRONOS2_MODEL, CONFIDENCE_LEVELS
15
+ from config.settings import CONFIG, DEVICE, MODEL_CONFIG
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class ChronosModelService:
21
+ """
22
+ Service for managing Chronos 2 model lifecycle and inference
23
+ Uses Chronos2Pipeline with DataFrame-based API
24
+ """
25
+
26
+ def __init__(self):
27
+ self.model = None
28
+ self.device = None
29
+ self.model_variant = None
30
+ self.is_loaded = False
31
+ self.load_time = None
32
+ self.is_chronos2 = False # Track which pipeline type is loaded
33
+
34
+ def _get_device(self) -> str:
35
+ """Determine the best available device"""
36
+ if DEVICE == 'cuda':
37
+ if not torch.cuda.is_available():
38
+ logger.warning("CUDA requested but not available, falling back to CPU")
39
+ return 'cpu'
40
+ return 'cuda'
41
+ elif DEVICE == 'cpu':
42
+ return 'cpu'
43
+ else: # auto
44
+ return 'cuda' if torch.cuda.is_available() else 'cpu'
45
+
46
+ def load_model(self) -> Dict[str, Any]:
47
+ """
48
+ Load the Chronos 2 model at startup
49
+
50
+ Returns:
51
+ Dictionary with loading status and metadata
52
+ """
53
+ try:
54
+ start_time = time.time()
55
+ logger.info("Loading Chronos 2 model from HuggingFace paper 2510.15821")
56
+
57
+ # Use the single Chronos-2 model
58
+ model_path = CHRONOS2_MODEL
59
+ self.model_variant = 'chronos-2'
60
+
61
+ # Determine device
62
+ self.device = self._get_device()
63
+ logger.info(f"Using device: {self.device}")
64
+
65
+ # Load model using Chronos2Pipeline
66
+ self.model = Chronos2Pipeline.from_pretrained(
67
+ model_path,
68
+ device_map=self.device,
69
+ torch_dtype=torch.bfloat16 if self.device == 'cuda' else torch.float32,
70
+ )
71
+ self.is_chronos2 = True
72
+
73
+ self.load_time = time.time() - start_time
74
+ self.is_loaded = True
75
+
76
+ logger.info(f"Model loaded successfully in {self.load_time:.2f}s")
77
+
78
+ # Warmup prediction
79
+ if MODEL_CONFIG['warmup_enabled']:
80
+ self._warmup()
81
+
82
+ return {
83
+ 'status': 'success',
84
+ 'model': 'chronos-2',
85
+ 'device': self.device,
86
+ 'load_time': self.load_time,
87
+ 'model_name': model_path
88
+ }
89
+
90
+ except Exception as e:
91
+ logger.error(f"Failed to load model: {str(e)}", exc_info=True)
92
+ self.is_loaded = False
93
+ return {
94
+ 'status': 'error',
95
+ 'error': str(e)
96
+ }
97
+
98
+ def _warmup(self):
99
+ """Run a warmup prediction to initialize the model"""
100
+ try:
101
+ logger.info("Running warmup prediction")
102
+
103
+ # Create warmup DataFrame in Chronos 2 format
104
+ warmup_data = pd.DataFrame({
105
+ 'id': ['warmup'] * MODEL_CONFIG['warmup_length'],
106
+ 'timestamp': pd.date_range('2020-01-01', periods=MODEL_CONFIG['warmup_length'], freq='D'),
107
+ 'target': np.random.randn(MODEL_CONFIG['warmup_length'])
108
+ })
109
+
110
+ self.predict(
111
+ warmup_data,
112
+ horizon=MODEL_CONFIG['warmup_horizon'],
113
+ confidence_levels=[80]
114
+ )
115
+ logger.info("Warmup completed successfully")
116
+
117
+ except Exception as e:
118
+ logger.warning(f"Warmup failed: {str(e)}")
119
+
120
+ def predict(
121
+ self,
122
+ data: pd.DataFrame,
123
+ horizon: int,
124
+ confidence_levels: List[int] = None,
125
+ future_df: Optional[pd.DataFrame] = None
126
+ ) -> Dict[str, Any]:
127
+ """
128
+ Generate forecasts using Chronos 2 model with DataFrame API
129
+
130
+ Args:
131
+ data: DataFrame with columns ['id', 'timestamp', 'target']
132
+ Can also include covariates for multivariate forecasting
133
+ horizon: Number of periods to forecast
134
+ confidence_levels: List of confidence levels (e.g., [80, 90, 95])
135
+ future_df: Optional DataFrame with future covariate values
136
+
137
+ Returns:
138
+ Dictionary with predictions and metadata
139
+ """
140
+ logger.info("=" * 80)
141
+ logger.info("MODEL SERVICE: predict() - ENTRY")
142
+ logger.info(f"Data shape: {data.shape}")
143
+ logger.info(f"Data columns: {data.columns.tolist()}")
144
+ logger.info(f"Horizon: {horizon}")
145
+ logger.info(f"Confidence levels: {confidence_levels}")
146
+ logger.info(f"Is loaded: {self.is_loaded}")
147
+ logger.info("=" * 80)
148
+
149
+ if not self.is_loaded:
150
+ logger.error("βœ— Model not loaded!")
151
+ raise RuntimeError("Model not loaded. Call load_model() first.")
152
+
153
+ try:
154
+ start_time = time.time()
155
+ logger.info("Starting prediction...")
156
+
157
+ # Use default confidence levels if not provided
158
+ if confidence_levels is None:
159
+ confidence_levels = CONFIDENCE_LEVELS
160
+
161
+ # Calculate quantile levels from confidence intervals
162
+ quantile_levels = []
163
+ for cl in sorted(confidence_levels):
164
+ lower = (100 - cl) / 200 # e.g., 80% -> 0.10
165
+ upper = 1 - lower # e.g., 80% -> 0.90
166
+ quantile_levels.extend([lower, upper])
167
+
168
+ # Add median
169
+ quantile_levels.append(0.5)
170
+ quantile_levels = sorted(set(quantile_levels))
171
+
172
+ logger.info(f"Generating forecast for horizon={horizon}, quantiles={quantile_levels}")
173
+
174
+ # Ensure required columns exist
175
+ required_cols = ['id', 'timestamp', 'target']
176
+ logger.info(f"Checking for required columns: {required_cols}")
177
+ if not all(col in data.columns for col in required_cols):
178
+ error_msg = f"Data must contain columns: {required_cols}, but got: {data.columns.tolist()}"
179
+ logger.error(f"βœ— {error_msg}")
180
+ raise ValueError(error_msg)
181
+ logger.info("βœ“ All required columns present")
182
+
183
+ # Generate forecast using appropriate API
184
+ if self.is_chronos2:
185
+ logger.info("Using Chronos2Pipeline.predict_df() method")
186
+ logger.info(f"Calling predict_df with prediction_length={horizon}, quantile_levels={quantile_levels}")
187
+ # Use Chronos 2 DataFrame API
188
+ pred_df = self.model.predict_df(
189
+ df=data,
190
+ future_df=future_df,
191
+ prediction_length=horizon,
192
+ quantile_levels=quantile_levels,
193
+ id_column='id',
194
+ timestamp_column='timestamp',
195
+ target='target'
196
+ )
197
+ logger.info(f"βœ“ predict_df completed - result shape: {pred_df.shape}")
198
+ else:
199
+ # Use original Chronos tensor API
200
+ # Convert DataFrame to tensor
201
+ context_tensor = torch.tensor(data['target'].values, dtype=torch.float32).unsqueeze(0)
202
+
203
+ # Generate forecast
204
+ forecast_tensors = self.model.predict(
205
+ context=context_tensor,
206
+ prediction_length=horizon,
207
+ num_samples=20, # Number of sample paths
208
+ limit_prediction_length=False
209
+ )
210
+
211
+ # Convert tensor output to DataFrame format
212
+ # forecast_tensors shape: [batch, num_samples, prediction_length]
213
+ quantiles_np = np.quantile(
214
+ forecast_tensors.squeeze(0).numpy(),
215
+ q=quantile_levels,
216
+ axis=0
217
+ )
218
+
219
+ # Create prediction DataFrame in Chronos 2 format
220
+ last_timestamp = pd.to_datetime(data['timestamp'].iloc[-1])
221
+ freq = pd.infer_freq(pd.to_datetime(data['timestamp']))
222
+ if freq is None:
223
+ freq = 'D' # Default to daily
224
+
225
+ future_timestamps = pd.date_range(
226
+ start=last_timestamp,
227
+ periods=horizon + 1,
228
+ freq=freq
229
+ )[1:] # Exclude the last historical point
230
+
231
+ pred_df = pd.DataFrame({
232
+ 'id': [data['id'].iloc[0]] * horizon,
233
+ 'timestamp': future_timestamps
234
+ })
235
+
236
+ # Add quantile columns
237
+ for i, q in enumerate(quantile_levels):
238
+ pred_df[f'{q:.2f}'] = quantiles_np[i, :]
239
+
240
+ # Process forecast results
241
+ # pred_df contains columns: id, timestamp, and quantile columns
242
+
243
+ # Extract forecast for the first series (if multiple)
244
+ series_ids = pred_df['id'].unique()
245
+ if len(series_ids) > 0:
246
+ series_pred = pred_df[pred_df['id'] == series_ids[0]].copy()
247
+ else:
248
+ series_pred = pred_df.copy()
249
+
250
+ # Create forecast dataframe with confidence intervals
251
+ forecast_df = pd.DataFrame({
252
+ 'ds': series_pred['timestamp'],
253
+ 'forecast': series_pred['0.5'] # Median forecast
254
+ })
255
+
256
+ # Add confidence intervals
257
+ for cl in confidence_levels:
258
+ lower = (100 - cl) / 200
259
+ upper = 1 - lower
260
+
261
+ lower_col = f'{lower:.2f}'
262
+ upper_col = f'{upper:.2f}'
263
+
264
+ if lower_col in series_pred.columns:
265
+ forecast_df[f'lower_{cl}'] = series_pred[lower_col].values
266
+ if upper_col in series_pred.columns:
267
+ forecast_df[f'upper_{cl}'] = series_pred[upper_col].values
268
+
269
+ inference_time = time.time() - start_time
270
+
271
+ logger.info(f"βœ“ Forecast generated successfully in {inference_time:.2f}s")
272
+ logger.info(f"Returning forecast DataFrame with {len(forecast_df)} rows")
273
+ logger.info("MODEL SERVICE: predict() - EXIT (success)")
274
+ logger.info("=" * 80)
275
+
276
+ return {
277
+ 'status': 'success',
278
+ 'forecast': forecast_df,
279
+ 'inference_time': inference_time,
280
+ 'horizon': horizon,
281
+ 'confidence_levels': confidence_levels,
282
+ 'full_prediction': pred_df # Include full prediction for multivariate
283
+ }
284
+
285
+ except Exception as e:
286
+ logger.error(f"βœ— EXCEPTION in predict(): {str(e)}", exc_info=True)
287
+ logger.info("MODEL SERVICE: predict() - EXIT (exception)")
288
+ logger.info("=" * 80)
289
+ return {
290
+ 'status': 'error',
291
+ 'error': str(e)
292
+ }
293
+
294
+ def backtest(
295
+ self,
296
+ data: pd.DataFrame,
297
+ test_size: int,
298
+ forecast_horizon: int,
299
+ confidence_levels: List[int] = None
300
+ ) -> Dict[str, Any]:
301
+ """
302
+ Perform backtesting on historical data to evaluate model performance
303
+
304
+ Args:
305
+ data: DataFrame with columns ['id', 'timestamp', 'target']
306
+ test_size: Number of periods to use for testing
307
+ forecast_horizon: Forecast horizon for each prediction
308
+ confidence_levels: List of confidence levels
309
+
310
+ Returns:
311
+ Dictionary with backtest results including predictions vs actuals
312
+ """
313
+ logger.info("=" * 80)
314
+ logger.info("MODEL SERVICE: backtest() - ENTRY")
315
+ logger.info(f"Data shape: {data.shape}")
316
+ logger.info(f"Test size: {test_size}")
317
+ logger.info(f"Forecast horizon: {forecast_horizon}")
318
+ logger.info("=" * 80)
319
+
320
+ if not self.is_loaded:
321
+ raise RuntimeError("Model not loaded. Call load_model() first.")
322
+
323
+ try:
324
+ start_time = time.time()
325
+
326
+ # Split data into train and test
327
+ train_size = len(data) - test_size
328
+ if train_size < forecast_horizon * 2:
329
+ raise ValueError(f"Insufficient training data. Need at least {forecast_horizon * 2} points.")
330
+
331
+ # Use rolling window approach
332
+ # We'll make predictions for the test period using the training data
333
+ train_data = data.iloc[:train_size].copy()
334
+ test_data = data.iloc[train_size:].copy()
335
+
336
+ logger.info(f"Train size: {len(train_data)}, Test size: {len(test_data)}")
337
+
338
+ # Make prediction on test period
339
+ forecast_result = self.predict(
340
+ data=train_data,
341
+ horizon=test_size,
342
+ confidence_levels=confidence_levels
343
+ )
344
+
345
+ if forecast_result['status'] == 'error':
346
+ return forecast_result
347
+
348
+ forecast_df = forecast_result['forecast']
349
+
350
+ # Align forecast with actual values
351
+ backtest_df = pd.DataFrame({
352
+ 'timestamp': test_data['timestamp'].values,
353
+ 'actual': test_data['target'].values,
354
+ 'predicted': forecast_df['forecast'].values[:len(test_data)]
355
+ })
356
+
357
+ # Add confidence intervals if available
358
+ for cl in (confidence_levels or []):
359
+ lower_col = f'lower_{cl}'
360
+ upper_col = f'upper_{cl}'
361
+ if lower_col in forecast_df.columns:
362
+ backtest_df[lower_col] = forecast_df[lower_col].values[:len(test_data)]
363
+ if upper_col in forecast_df.columns:
364
+ backtest_df[upper_col] = forecast_df[upper_col].values[:len(test_data)]
365
+
366
+ # Calculate metrics
367
+ actual = backtest_df['actual'].values
368
+ predicted = backtest_df['predicted'].values
369
+
370
+ # Remove any NaN values
371
+ mask = ~(np.isnan(actual) | np.isnan(predicted))
372
+ actual = actual[mask]
373
+ predicted = predicted[mask]
374
+
375
+ if len(actual) == 0:
376
+ raise ValueError("No valid data points for metric calculation")
377
+
378
+ mae = np.mean(np.abs(actual - predicted))
379
+ rmse = np.sqrt(np.mean((actual - predicted) ** 2))
380
+ mape = np.mean(np.abs((actual - predicted) / (actual + 1e-10))) * 100
381
+
382
+ # R-squared
383
+ ss_res = np.sum((actual - predicted) ** 2)
384
+ ss_tot = np.sum((actual - np.mean(actual)) ** 2)
385
+ r2 = 1 - (ss_res / (ss_tot + 1e-10))
386
+
387
+ metrics = {
388
+ 'MAE': float(mae),
389
+ 'RMSE': float(rmse),
390
+ 'MAPE': float(mape),
391
+ 'R2': float(r2)
392
+ }
393
+
394
+ inference_time = time.time() - start_time
395
+
396
+ logger.info(f"βœ“ Backtest completed in {inference_time:.2f}s")
397
+ logger.info(f"Metrics: MAE={mae:.2f}, RMSE={rmse:.2f}, MAPE={mape:.2f}%, R2={r2:.4f}")
398
+ logger.info("MODEL SERVICE: backtest() - EXIT (success)")
399
+ logger.info("=" * 80)
400
+
401
+ return {
402
+ 'status': 'success',
403
+ 'backtest_data': backtest_df,
404
+ 'metrics': metrics,
405
+ 'inference_time': inference_time,
406
+ 'train_size': train_size,
407
+ 'test_size': test_size
408
+ }
409
+
410
+ except Exception as e:
411
+ logger.error(f"βœ— EXCEPTION in backtest(): {str(e)}", exc_info=True)
412
+ logger.info("MODEL SERVICE: backtest() - EXIT (exception)")
413
+ logger.info("=" * 80)
414
+ return {
415
+ 'status': 'error',
416
+ 'error': str(e)
417
+ }
418
+
419
+ def get_status(self) -> Dict[str, Any]:
420
+ """Get current model status"""
421
+ return {
422
+ 'is_loaded': self.is_loaded,
423
+ 'variant': self.model_variant,
424
+ 'device': self.device,
425
+ 'load_time': self.load_time
426
+ }
427
+
428
+
429
+ # Global model service instance
430
+ model_service = ChronosModelService()