Spaces:
Runtime error
Runtime error
| ''' | |
| Author: Qiguang Chen | |
| LastEditors: Qiguang Chen | |
| Date: 2023-02-12 22:23:58 | |
| LastEditTime: 2023-02-19 14:14:56 | |
| Description: | |
| ''' | |
| import json | |
| import os | |
| import queue | |
| import shutil | |
| import torch | |
| import dill | |
| from common import utils | |
| class Saver(): | |
| def __init__(self, config, start_time=None) -> None: | |
| self.config = config | |
| if self.config.get("save_dir"): | |
| self.model_save_dir = self.config["save_dir"] | |
| else: | |
| if not os.path.exists("save/"): | |
| os.mkdir("save/") | |
| self.model_save_dir = "save/" + start_time | |
| if not os.path.exists(self.model_save_dir): | |
| os.mkdir(self.model_save_dir) | |
| save_mode = config.get("save_mode") | |
| self.save_mode = save_mode if save_mode is not None else "save-by-eval" | |
| max_save_num = self.config.get("max_save_num") | |
| self.max_save_num = max_save_num if max_save_num is not None else 1 | |
| self.save_pool = queue.Queue(maxsize=max_save_num) | |
| def save_tokenizer(self, tokenizer): | |
| with open(os.path.join(self.model_save_dir, "tokenizer.pkl"), 'wb') as f: | |
| dill.dump(tokenizer, f) | |
| def save_label(self, intent_list, slot_list): | |
| utils.save_json(os.path.join(self.model_save_dir, "label.json"), {"intent": intent_list, "slot": slot_list}) | |
| def save_model(self, model, train_state, accelerator=None): | |
| step = train_state["step"] | |
| if self.max_save_num != 1: | |
| model_save_dir =os.path.join(self.model_save_dir, str(step)) | |
| if self.save_pool.full(): | |
| delete_dir = self.save_pool.get() | |
| shutil.rmtree(delete_dir) | |
| self.save_pool.put(model_save_dir) | |
| else: | |
| self.save_pool.put(model_save_dir) | |
| if not os.path.exists(model_save_dir): | |
| os.mkdir(model_save_dir) | |
| else: | |
| model_save_dir = self.model_save_dir | |
| if not os.path.exists(model_save_dir): | |
| os.mkdir(model_save_dir) | |
| if accelerator is None: | |
| torch.save(model, os.path.join(model_save_dir, "model.pkl")) | |
| torch.save(train_state, os.path.join(model_save_dir, "train_state.pkl"), pickle_module=dill) | |
| else: | |
| accelerator.wait_for_everyone() | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| accelerator.save(unwrapped_model, os.path.join(model_save_dir, "model.pkl")) | |
| accelerator.save_state(output_dir=model_save_dir) | |
| def auto_save_step(self, model, train_state, accelerator=None): | |
| step = train_state["step"] | |
| if self.save_mode == "save-by-step" and step % self.config.get("save_step")==0 and step != 0: | |
| self.save_model(model, train_state, accelerator) | |
| return True | |
| else: | |
| return False | |
| def save_output(self, outputs, dataset): | |
| outputs.save(self.model_save_dir, dataset) |