Spaces:
Runtime error
Runtime error
model path changed update
Browse files- src/model.py +4 -6
src/model.py
CHANGED
|
@@ -8,7 +8,7 @@ import torch
|
|
| 8 |
from prophet.serialize import model_to_json, model_from_json
|
| 9 |
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
|
| 10 |
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
|
| 11 |
-
|
| 12 |
# at beginning of the script
|
| 13 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 14 |
|
|
@@ -38,11 +38,9 @@ class Model_Load:
|
|
| 38 |
def store_model_load(self,model_option):
|
| 39 |
if model_option=='TFT':
|
| 40 |
# best_model_path="models/store_item_10_lead_1_v2/lightning_logs/lightning_logs/version_2/checkpoints/epoch=7-step=4472.ckpt"
|
| 41 |
-
best_model_path="models/store_item_10_lead_1_v3/lightning_logs/lightning_logs/version_0/checkpoints/epoch=7-step=4472.ckpt"
|
| 42 |
-
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
|
| 43 |
-
|
| 44 |
-
# best_tft.load_state_dict(torch.load(best_model_path,map_location=torch.device('cpu')))
|
| 45 |
-
# best_tft.to('cpu')
|
| 46 |
print('Model Load Sucessfully.')
|
| 47 |
return best_tft
|
| 48 |
elif model_option=='Prophet':
|
|
|
|
| 8 |
from prophet.serialize import model_to_json, model_from_json
|
| 9 |
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
|
| 10 |
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
|
| 11 |
+
import pickle
|
| 12 |
# at beginning of the script
|
| 13 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 14 |
|
|
|
|
| 38 |
def store_model_load(self,model_option):
|
| 39 |
if model_option=='TFT':
|
| 40 |
# best_model_path="models/store_item_10_lead_1_v2/lightning_logs/lightning_logs/version_2/checkpoints/epoch=7-step=4472.ckpt"
|
| 41 |
+
# best_model_path="models/store_item_10_lead_1_v3/lightning_logs/lightning_logs/version_0/checkpoints/epoch=7-step=4472.ckpt"
|
| 42 |
+
# best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
|
| 43 |
+
best_tft=pickle.load(open("models/cpu_finalized_model_v1.sav", 'rb'))
|
|
|
|
|
|
|
| 44 |
print('Model Load Sucessfully.')
|
| 45 |
return best_tft
|
| 46 |
elif model_option=='Prophet':
|