Spaces:
Runtime error
Runtime error
| import copy | |
| from pathlib import Path | |
| import warnings | |
| import lightning.pytorch as pl | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from prophet.serialize import model_to_json, model_from_json | |
| from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet | |
| from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters | |
| import pickle | |
| # at beginning of the script | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| class Model_Load: | |
| def __init__(self): | |
| pass | |
| def energy_model_load(self,model_option): | |
| if model_option=='TFT': | |
| # best_model_path='models/consumer_final_10/lightning_logs/lightning_logs/version_0/checkpoints/epoch=5-step=49260.ckpt' | |
| # best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path) | |
| filename="models/cpu_energy_tft_model_v1.sav" | |
| best_tft=pickle.load(open(filename, 'rb')) | |
| print('Model Load Sucessfully.') | |
| return best_tft | |
| elif model_option=='Prophet': | |
| best_model_path='models/fb_energy_model.json' | |
| with open(best_model_path, 'r') as fin: | |
| model = model_from_json(fin.read()) | |
| return model | |
| # elif model_option=='ten consumer': | |
| # best_model_path='consumer_10/lightning_logs/lightning_logs/version_0/checkpoints/epoch=11-step=98544.ckpt' | |
| # best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path) | |
| # print('Model Load Sucessfully.') | |
| # elif model_option=='fifty consumer': | |
| # raise Exception('Model not present') | |
| def store_model_load(self,model_option): | |
| if model_option=='TFT': | |
| # best_model_path="models/store_item_10_lead_1_v2/lightning_logs/lightning_logs/version_2/checkpoints/epoch=7-step=4472.ckpt" | |
| # best_model_path="models/store_item_10_lead_1_v3/lightning_logs/lightning_logs/version_0/checkpoints/epoch=7-step=4472.ckpt" | |
| # best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path) | |
| best_tft=pickle.load(open("models/cpu_finalized_model_v1.sav", 'rb')) | |
| print('Model Load Sucessfully.') | |
| return best_tft | |
| elif model_option=='Prophet': | |
| best_model_path='models/fb_store_model_new.json' | |
| with open(best_model_path, 'r') as fin: | |
| model = model_from_json(fin.read()) | |
| return model | |
| # elif model_option=='Item 50 TFT': | |
| # raise Exception('Model not present') | |
| # elif model_option=='FB Prophet': | |
| # raise Exception('Model not present') | |
| if __name__=='__main__': | |
| obj=Model_Load() | |
| obj.load() | |