Spaces:
Runtime error
Runtime error
| """ Utility file for trainers """ | |
| import os | |
| import shutil | |
| from glob import glob | |
| import torch | |
| import torch.distributed as dist | |
| ''' checkpoint functions ''' | |
| # saves checkpoint | |
| def save_checkpoint(model, \ | |
| optimizer, \ | |
| scheduler, \ | |
| epoch, \ | |
| checkpoint_dir, \ | |
| name, \ | |
| model_name): | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| checkpoint_state = { | |
| "model": model.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "scheduler": scheduler.state_dict(), | |
| "epoch": epoch | |
| } | |
| checkpoint_path = os.path.join(checkpoint_dir,'{}_{}_{}.pt'.format(name, model_name, epoch)) | |
| torch.save(checkpoint_state, checkpoint_path) | |
| print("Saved checkpoint: {}".format(checkpoint_path)) | |
| # reload model weights from checkpoint file | |
| def reload_ckpt(args, \ | |
| network, \ | |
| optimizer, \ | |
| scheduler, \ | |
| gpu, \ | |
| model_name, \ | |
| manual_reload_name=None, \ | |
| manual_reload=False, \ | |
| manual_reload_dir=None, \ | |
| epoch=None, \ | |
| fit_sefa=False): | |
| if manual_reload: | |
| reload_name = manual_reload_name | |
| else: | |
| reload_name = args.name | |
| if manual_reload_dir: | |
| ckpt_dir = manual_reload_dir + reload_name + "/ckpt/" | |
| else: | |
| ckpt_dir = args.output_dir + reload_name + "/ckpt/" | |
| temp_ckpt_dir = f'{args.output_dir}{reload_name}/ckpt_temp/' | |
| reload_epoch = epoch | |
| # find best or latest epoch | |
| if epoch==None: | |
| reload_epoch_temp = 0 | |
| reload_epoch_ckpt = 0 | |
| if len(os.listdir(temp_ckpt_dir))!=0: | |
| reload_epoch_temp = find_best_epoch(temp_ckpt_dir) | |
| if len(os.listdir(ckpt_dir))!=0: | |
| reload_epoch_ckpt = find_best_epoch(ckpt_dir) | |
| if reload_epoch_ckpt >= reload_epoch_temp: | |
| reload_epoch = reload_epoch_ckpt | |
| else: | |
| reload_epoch = reload_epoch_temp | |
| ckpt_dir = temp_ckpt_dir | |
| else: | |
| if os.path.isfile(f"{temp_ckpt_dir}{reload_epoch}/{reload_name}_{model_name}_{reload_epoch}.pt"): | |
| ckpt_dir = temp_ckpt_dir | |
| # reloading weight | |
| if model_name==None: | |
| resuming_path = f"{ckpt_dir}{reload_epoch}/{reload_name}_{reload_epoch}.pt" | |
| else: | |
| resuming_path = f"{ckpt_dir}{reload_epoch}/{reload_name}_{model_name}_{reload_epoch}.pt" | |
| if gpu==0: | |
| print("===Resume checkpoint from: {}===".format(resuming_path)) | |
| loc = 'cuda:{}'.format(gpu) | |
| checkpoint = torch.load(resuming_path, map_location=loc) | |
| start_epoch = 0 if manual_reload and not fit_sefa else checkpoint["epoch"] | |
| if manual_reload_dir is not None and 'parameter_estimation' in manual_reload_dir: | |
| from collections import OrderedDict | |
| new_state_dict = OrderedDict() | |
| for k, v in checkpoint["model"].items(): | |
| name = 'module.' + k # add `module.` | |
| new_state_dict[name] = v | |
| network.load_state_dict(new_state_dict) | |
| else: | |
| network.load_state_dict(checkpoint["model"]) | |
| if not manual_reload: | |
| optimizer.load_state_dict(checkpoint["optimizer"]) | |
| scheduler.load_state_dict(checkpoint["scheduler"]) | |
| if gpu==0: | |
| # print("=> loaded checkpoint '{}' (epoch {})".format(resuming_path, checkpoint['epoch'])) | |
| print("=> loaded checkpoint '{}' (epoch {})".format(resuming_path, epoch)) | |
| return start_epoch | |
| # find best epoch for reloading current model | |
| def find_best_epoch(input_dir): | |
| cur_epochs = glob("{}*".format(input_dir)) | |
| return find_by_name(cur_epochs) | |
| # sort string epoch names by integers | |
| def find_by_name(epochs): | |
| int_epochs = [] | |
| for e in epochs: | |
| int_epochs.append(int(os.path.basename(e))) | |
| int_epochs.sort() | |
| return (int_epochs[-1]) | |
| # remove ckpt files | |
| def remove_ckpt(cur_ckpt_path_dir, leave=2): | |
| ckpt_nums = [int(i) for i in os.listdir(cur_ckpt_path_dir)] | |
| ckpt_nums.sort() | |
| del_num = len(ckpt_nums) - leave | |
| cur_del_num = 0 | |
| while del_num > 0: | |
| shutil.rmtree("{}{}".format(cur_ckpt_path_dir, ckpt_nums[cur_del_num])) | |
| del_num -= 1 | |
| cur_del_num += 1 | |
| ''' multi-GPU functions ''' | |
| # gather function implemented from DirectCLR | |
| class GatherLayer_Direct(torch.autograd.Function): | |
| """ | |
| Gather tensors from all workers with support for backward propagation: | |
| This implementation does not cut the gradients as torch.distributed.all_gather does. | |
| """ | |
| def forward(ctx, x): | |
| output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] | |
| dist.all_gather(output, x) | |
| return tuple(output) | |
| def backward(ctx, *grads): | |
| all_gradients = torch.stack(grads) | |
| dist.all_reduce(all_gradients) | |
| return all_gradients[dist.get_rank()] | |
| from classy_vision.generic.distributed_util import ( | |
| convert_to_distributed_tensor, | |
| convert_to_normal_tensor, | |
| is_distributed_training_run, | |
| ) | |
| def gather_from_all(tensor: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Similar to classy_vision.generic.distributed_util.gather_from_all | |
| except that it does not cut the gradients | |
| """ | |
| if tensor.ndim == 0: | |
| # 0 dim tensors cannot be gathered. so unsqueeze | |
| tensor = tensor.unsqueeze(0) | |
| if is_distributed_training_run(): | |
| tensor, orig_device = convert_to_distributed_tensor(tensor) | |
| gathered_tensors = GatherLayer_Direct.apply(tensor) | |
| gathered_tensors = [ | |
| convert_to_normal_tensor(_tensor, orig_device) | |
| for _tensor in gathered_tensors | |
| ] | |
| else: | |
| gathered_tensors = [tensor] | |
| gathered_tensor = torch.cat(gathered_tensors, 0) | |
| return gathered_tensor | |