Spaces:
Runtime error
Runtime error
| from functools import wraps | |
| from time import time | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange, reduce, repeat | |
| from scipy import interpolate | |
| def max_stack(tensors: list[torch.Tensor]) -> torch.Tensor: | |
| if len(tensors) == 1: | |
| return tensors[0] | |
| return torch.stack(tensors, dim=-1).max(dim=-1).values | |
| def last_stack(tensors: list[torch.Tensor]) -> torch.Tensor: | |
| return tensors[-1] | |
| def first_stack(tensors: list[torch.Tensor]) -> torch.Tensor: | |
| return tensors[0] | |
| def softmax_stack( | |
| tensors: list[torch.Tensor], temperature: float = 1.0 | |
| ) -> torch.Tensor: | |
| if len(tensors) == 1: | |
| return tensors[0] | |
| return F.softmax(torch.stack(tensors, dim=-1) / temperature, dim=-1).sum(dim=-1) | |
| def mean_stack(tensors: list[torch.Tensor]) -> torch.Tensor: | |
| if len(tensors) == 1: | |
| return tensors[0] | |
| return torch.stack(tensors, dim=-1).mean(dim=-1) | |
| def sum_stack(tensors: list[torch.Tensor]) -> torch.Tensor: | |
| if len(tensors) == 1: | |
| return tensors[0] | |
| return torch.stack(tensors, dim=-1).sum(dim=-1) | |
| def convert_module_to_f16(l): | |
| """ | |
| Convert primitive modules to float16. | |
| """ | |
| if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): | |
| l.weight.data = l.weight.data.half() | |
| if l.bias is not None: | |
| l.bias.data = l.bias.data.half() | |
| def convert_module_to_f32(l): | |
| """ | |
| Convert primitive modules to float32, undoing convert_module_to_f16(). | |
| """ | |
| if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): | |
| l.weight.data = l.weight.data.float() | |
| if l.bias is not None: | |
| l.bias.data = l.bias.data.float() | |
| def format_seconds(seconds): | |
| minutes, seconds = divmod(seconds, 60) | |
| hours, minutes = divmod(minutes, 60) | |
| return f"{hours:d}:{minutes:02d}:{seconds:02d}" | |
| def get_params(module, lr, wd): | |
| skip_list = {} | |
| skip_keywords = {} | |
| if hasattr(module, "no_weight_decay"): | |
| skip_list = module.no_weight_decay() | |
| if hasattr(module, "no_weight_decay_keywords"): | |
| skip_keywords = module.no_weight_decay_keywords() | |
| has_decay = [] | |
| no_decay = [] | |
| for name, param in module.named_parameters(): | |
| if not param.requires_grad: | |
| continue # frozen weights | |
| if ( | |
| (name in skip_list) | |
| or any((kw in name for kw in skip_keywords)) | |
| or len(param.shape) == 1 | |
| or name.endswith(".gamma") | |
| or name.endswith(".beta") | |
| or name.endswith(".bias") | |
| ): | |
| no_decay.append(param) | |
| else: | |
| has_decay.append(param) | |
| group1 = { | |
| "params": has_decay, | |
| "weight_decay": wd, | |
| "lr": lr, | |
| "weight_decay_init": wd, | |
| "weight_decay_base": wd, | |
| "lr_base": lr, | |
| } | |
| group2 = { | |
| "params": no_decay, | |
| "weight_decay": 0.0, | |
| "lr": lr, | |
| "weight_decay_init": 0.0, | |
| "weight_decay_base": 0.0, | |
| "weight_decay_final": 0.0, | |
| "lr_base": lr, | |
| } | |
| return [group1, group2], [lr, lr] | |
| def get_num_layer_for_swin(var_name, num_max_layer, layers_per_stage): | |
| if var_name in ("cls_token", "mask_token", "pos_embed", "absolute_pos_embed"): | |
| return 0 | |
| elif var_name.startswith("patch_embed"): | |
| return 0 | |
| elif var_name.startswith("layers"): | |
| if var_name.split(".")[2] == "blocks": | |
| stage_id = int(var_name.split(".")[1]) | |
| layer_id = int(var_name.split(".")[3]) + sum(layers_per_stage[:stage_id]) | |
| return layer_id + 1 | |
| elif var_name.split(".")[2] == "downsample": | |
| stage_id = int(var_name.split(".")[1]) | |
| layer_id = sum(layers_per_stage[: stage_id + 1]) | |
| return layer_id | |
| else: | |
| return num_max_layer - 1 | |
| def get_params_layerdecayswin(module, lr, wd, ld): | |
| skip_list = {} | |
| skip_keywords = {} | |
| if hasattr(module, "no_weight_decay"): | |
| skip_list = module.no_weight_decay() | |
| if hasattr(module, "no_weight_decay_keywords"): | |
| skip_keywords = module.no_weight_decay_keywords() | |
| layers_per_stage = module.depths | |
| num_layers = sum(layers_per_stage) + 1 | |
| lrs = [] | |
| params = [] | |
| for name, param in module.named_parameters(): | |
| if not param.requires_grad: | |
| print(f"{name} frozen") | |
| continue # frozen weights | |
| layer_id = get_num_layer_for_swin(name, num_layers, layers_per_stage) | |
| lr_cur = lr * ld ** (num_layers - layer_id - 1) | |
| # if (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1 or name.endswith(".bias"): | |
| if (name in skip_list) or any((kw in name for kw in skip_keywords)): | |
| wd_cur = 0.0 | |
| else: | |
| wd_cur = wd | |
| params.append({"params": param, "weight_decay": wd_cur, "lr": lr_cur}) | |
| lrs.append(lr_cur) | |
| return params, lrs | |
| def log(t, eps: float = 1e-5): | |
| return torch.log(t.clamp(min=eps)) | |
| def l2norm(t): | |
| return F.normalize(t, dim=-1) | |
| def exists(val): | |
| return val is not None | |
| def identity(t, *args, **kwargs): | |
| return t | |
| def divisible_by(numer, denom): | |
| return (numer % denom) == 0 | |
| def first(arr, d=None): | |
| if len(arr) == 0: | |
| return d | |
| return arr[0] | |
| def default(val, d): | |
| if exists(val): | |
| return val | |
| return d() if callable(d) else d | |
| def maybe(fn): | |
| def inner(x): | |
| if not exists(x): | |
| return x | |
| return fn(x) | |
| return inner | |
| def once(fn): | |
| called = False | |
| def inner(x): | |
| nonlocal called | |
| if called: | |
| return | |
| called = True | |
| return fn(x) | |
| return inner | |
| def _many(fn): | |
| def inner(tensors, pattern, **kwargs): | |
| return (fn(tensor, pattern, **kwargs) for tensor in tensors) | |
| return inner | |
| rearrange_many = _many(rearrange) | |
| repeat_many = _many(repeat) | |
| reduce_many = _many(reduce) | |
| def load_pretrained(state_dict, checkpoint): | |
| checkpoint_model = checkpoint["model"] | |
| if any([True if "encoder." in k else False for k in checkpoint_model.keys()]): | |
| checkpoint_model = { | |
| k.replace("encoder.", ""): v | |
| for k, v in checkpoint_model.items() | |
| if k.startswith("encoder.") | |
| } | |
| print("Detect pre-trained model, remove [encoder.] prefix.") | |
| else: | |
| print("Detect non-pre-trained model, pass without doing anything.") | |
| print(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") | |
| checkpoint = load_checkpoint_swin(state_dict, checkpoint_model) | |
| def load_checkpoint_swin(model, checkpoint_model): | |
| state_dict = model.state_dict() | |
| # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size | |
| all_keys = list(checkpoint_model.keys()) | |
| for key in all_keys: | |
| if "relative_position_bias_table" in key: | |
| relative_position_bias_table_pretrained = checkpoint_model[key] | |
| relative_position_bias_table_current = state_dict[key] | |
| L1, nH1 = relative_position_bias_table_pretrained.size() | |
| L2, nH2 = relative_position_bias_table_current.size() | |
| if nH1 != nH2: | |
| print(f"Error in loading {key}, passing......") | |
| else: | |
| if L1 != L2: | |
| print(f"{key}: Interpolate relative_position_bias_table using geo.") | |
| src_size = int(L1**0.5) | |
| dst_size = int(L2**0.5) | |
| def geometric_progression(a, r, n): | |
| return a * (1.0 - r**n) / (1.0 - r) | |
| left, right = 1.01, 1.5 | |
| while right - left > 1e-6: | |
| q = (left + right) / 2.0 | |
| gp = geometric_progression(1, q, src_size // 2) | |
| if gp > dst_size // 2: | |
| right = q | |
| else: | |
| left = q | |
| # if q > 1.090307: | |
| # q = 1.090307 | |
| dis = [] | |
| cur = 1 | |
| for i in range(src_size // 2): | |
| dis.append(cur) | |
| cur += q ** (i + 1) | |
| r_ids = [-_ for _ in reversed(dis)] | |
| x = r_ids + [0] + dis | |
| y = r_ids + [0] + dis | |
| t = dst_size // 2.0 | |
| dx = np.arange(-t, t + 0.1, 1.0) | |
| dy = np.arange(-t, t + 0.1, 1.0) | |
| print("Original positions = %s" % str(x)) | |
| print("Target positions = %s" % str(dx)) | |
| all_rel_pos_bias = [] | |
| for i in range(nH1): | |
| z = ( | |
| relative_position_bias_table_pretrained[:, i] | |
| .view(src_size, src_size) | |
| .float() | |
| .numpy() | |
| ) | |
| f_cubic = interpolate.interp2d(x, y, z, kind="cubic") | |
| all_rel_pos_bias.append( | |
| torch.Tensor(f_cubic(dx, dy)) | |
| .contiguous() | |
| .view(-1, 1) | |
| .to(relative_position_bias_table_pretrained.device) | |
| ) | |
| new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) | |
| checkpoint_model[key] = new_rel_pos_bias | |
| # delete relative_position_index since we always re-init it | |
| relative_position_index_keys = [ | |
| k for k in checkpoint_model.keys() if "relative_position_index" in k | |
| ] | |
| for k in relative_position_index_keys: | |
| del checkpoint_model[k] | |
| # delete relative_coords_table since we always re-init it | |
| relative_coords_table_keys = [ | |
| k for k in checkpoint_model.keys() if "relative_coords_table" in k | |
| ] | |
| for k in relative_coords_table_keys: | |
| del checkpoint_model[k] | |
| # # re-map keys due to name change | |
| rpe_mlp_keys = [k for k in checkpoint_model.keys() if "cpb_mlp" in k] | |
| for k in rpe_mlp_keys: | |
| checkpoint_model[k.replace("cpb_mlp", "rpe_mlp")] = checkpoint_model.pop(k) | |
| # delete attn_mask since we always re-init it | |
| attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k] | |
| for k in attn_mask_keys: | |
| del checkpoint_model[k] | |
| encoder_keys = [k for k in checkpoint_model.keys() if k.startswith("encoder.")] | |
| for k in encoder_keys: | |
| checkpoint_model[k.replace("encoder.", "")] = checkpoint_model.pop(k) | |
| return checkpoint_model | |
| def add_padding_metas(out, image_metas): | |
| device = out.device | |
| # left, right, top, bottom | |
| paddings = [img_meta.get("paddings", [0] * 4) for img_meta in image_metas] | |
| paddings = torch.stack(paddings).to(device) | |
| outs = [F.pad(o, padding, value=0.0) for padding, o in zip(paddings, out)] | |
| return torch.stack(outs) | |
| # left, right, top, bottom | |
| def remove_padding(out, paddings): | |
| H, W = out.shape[-2:] | |
| outs = [ | |
| o[..., padding[2] : H - padding[3], padding[0] : W - padding[1]] | |
| for padding, o in zip(paddings, out) | |
| ] | |
| return torch.stack(outs) | |
| def remove_padding_metas(out, image_metas): | |
| B, C, H, W = out.shape | |
| device = out.device | |
| # left, right, top, bottom | |
| paddings = [ | |
| torch.tensor(img_meta.get("paddings", [0] * 4)) for img_meta in image_metas | |
| ] | |
| return remove_padding(out, paddings) | |
| def ssi_helper(tensor1, tensor2): | |
| stability_mat = 1e-4 * torch.eye(2, device=tensor1.device) | |
| tensor2_one = torch.stack([tensor2, torch.ones_like(tensor2)], dim=1) | |
| scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ ( | |
| tensor2_one.T @ tensor1.unsqueeze(1) | |
| ) | |
| scale, shift = scale_shift.squeeze().chunk(2, dim=0) | |
| return scale, shift | |
| def calculate_mean_values(names, values): | |
| # Create a defaultdict to store sum and count for each name | |
| name_values = {name: {} for name in names} | |
| # Iterate through the lists and accumulate values for each name | |
| for name, value in zip(names, values): | |
| name_values[name]["sum"] = name_values[name].get("sum", 0.0) + value | |
| name_values[name]["count"] = name_values[name].get("count", 0.0) + 1 | |
| # Calculate mean values and create the output dictionary | |
| output_dict = { | |
| name: name_values[name]["sum"] / name_values[name]["count"] | |
| for name in name_values | |
| } | |
| return output_dict | |
| def remove_leading_dim(infos): | |
| if isinstance(infos, dict): | |
| return {k: remove_leading_dim(v) for k, v in infos.items()} | |
| elif isinstance(infos, torch.Tensor): | |
| return infos.squeeze(0) | |
| else: | |
| return infos | |
| def recursive_index(infos, index): | |
| if isinstance(infos, dict): | |
| return {k: recursive_index(v, index) for k, v in infos.items()} | |
| elif isinstance(infos, torch.Tensor): | |
| return infos[index] | |
| else: | |
| return infos | |
| def to_cpu(infos): | |
| if isinstance(infos, dict): | |
| return {k: to_cpu(v) for k, v in infos.items()} | |
| elif isinstance(infos, torch.Tensor): | |
| return infos.detach() | |
| else: | |
| return infos | |
| def masked_mean( | |
| data: torch.Tensor, | |
| mask: torch.Tensor | None = None, | |
| dim: list[int] | None = None, | |
| keepdim: bool = False, | |
| ) -> torch.Tensor: | |
| dim = dim if dim is not None else list(range(data.dim())) | |
| if mask is None: | |
| return data.mean(dim=dim, keepdim=keepdim) | |
| mask = mask.float() | |
| mask_sum = torch.sum(mask, dim=dim, keepdim=True) | |
| mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( | |
| mask_sum, min=1.0 | |
| ) | |
| return mask_mean.squeeze(dim) if not keepdim else mask_mean | |
| class ProfileMethod: | |
| def __init__(self, model, func_name, track_statistics=True, verbose=False): | |
| self.model = model | |
| self.func_name = func_name | |
| self.verbose = verbose | |
| self.track_statistics = track_statistics | |
| self.timings = [] | |
| def __enter__(self): | |
| # Start timing | |
| if self.verbose: | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| self.start_time = time() | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| if self.verbose: | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| self.end_time = time() | |
| elapsed_time = self.end_time - self.start_time | |
| self.timings.append(elapsed_time) | |
| if self.track_statistics and len(self.timings) > 25: | |
| # Compute statistics if tracking | |
| timings_array = np.array(self.timings) | |
| mean_time = np.mean(timings_array) | |
| std_time = np.std(timings_array) | |
| quantiles = np.percentile(timings_array, [0, 25, 50, 75, 100]) | |
| print( | |
| f"{self.model.__class__.__name__}.{self.func_name} took {elapsed_time:.4f} seconds" | |
| ) | |
| print(f"Mean Time: {mean_time:.4f} seconds") | |
| print(f"Std Time: {std_time:.4f} seconds") | |
| print( | |
| f"Quantiles: Min={quantiles[0]:.4f}, 25%={quantiles[1]:.4f}, Median={quantiles[2]:.4f}, 75%={quantiles[3]:.4f}, Max={quantiles[4]:.4f}" | |
| ) | |
| else: | |
| print( | |
| f"{self.model.__class__.__name__}.{self.func_name} took {elapsed_time:.4f} seconds" | |
| ) | |
| def profile_method(track_statistics=True, verbose=False): | |
| def decorator(func): | |
| def wrapper(self, *args, **kwargs): | |
| with ProfileMethod(self, func.__name__, track_statistics, verbose): | |
| return func(self, *args, **kwargs) | |
| return wrapper | |
| return decorator | |
| class ProfileFunction: | |
| def __init__(self, func_name, track_statistics=True, verbose=False): | |
| self.func_name = func_name | |
| self.verbose = verbose | |
| self.track_statistics = track_statistics | |
| self.timings = [] | |
| def __enter__(self): | |
| # Start timing | |
| if self.verbose: | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| self.start_time = time() | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| if self.verbose: | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| self.end_time = time() | |
| elapsed_time = self.end_time - self.start_time | |
| self.timings.append(elapsed_time) | |
| if self.track_statistics and len(self.timings) > 25: | |
| # Compute statistics if tracking | |
| timings_array = np.array(self.timings) | |
| mean_time = np.mean(timings_array) | |
| std_time = np.std(timings_array) | |
| quantiles = np.percentile(timings_array, [0, 25, 50, 75, 100]) | |
| print(f"{self.func_name} took {elapsed_time:.4f} seconds") | |
| print(f"Mean Time: {mean_time:.4f} seconds") | |
| print(f"Std Time: {std_time:.4f} seconds") | |
| print( | |
| f"Quantiles: Min={quantiles[0]:.4f}, 25%={quantiles[1]:.4f}, Median={quantiles[2]:.4f}, 75%={quantiles[3]:.4f}, Max={quantiles[4]:.4f}" | |
| ) | |
| else: | |
| print(f"{self.func_name} took {elapsed_time:.4f} seconds") | |
| def profile_function(track_statistics=True, verbose=False): | |
| def decorator(func): | |
| def wrapper(self, *args, **kwargs): | |
| with ProfileFunction(func.__name__, track_statistics, verbose): | |
| return func(self, *args, **kwargs) | |
| return wrapper | |
| return decorator | |
| def recursive_apply(inputs, func): | |
| if isinstance(inputs, list): | |
| return [recursive_apply(camera, func) for camera in inputs] | |
| else: | |
| return func(inputs) | |
| def squeeze_list(nested_list, dim, current_dim=0): | |
| # If the current dimension is in the list of indices to squeeze | |
| if isinstance(nested_list, list) and len(nested_list) == 1 and current_dim == dim: | |
| return squeeze_list(nested_list[0], dim, current_dim + 1) | |
| elif isinstance(nested_list, list): | |
| return [squeeze_list(item, dim, current_dim + 1) for item in nested_list] | |
| else: | |
| return nested_list | |
| def match_gt(tensor1, tensor2, padding1, padding2, mode: str = "bilinear"): | |
| """ | |
| Transform each item in tensor1 batch to match tensor2's dimensions and padding. | |
| Args: | |
| tensor1 (torch.Tensor): The input tensor to transform, with shape (batch_size, channels, height, width). | |
| tensor2 (torch.Tensor): The target tensor to match, with shape (batch_size, channels, height, width). | |
| padding1 (tuple): Padding applied to tensor1 (pad_left, pad_right, pad_top, pad_bottom). | |
| padding2 (tuple): Desired padding to be applied to match tensor2 (pad_left, pad_right, pad_top, pad_bottom). | |
| Returns: | |
| torch.Tensor: The batch of transformed tensors matching tensor2's size and padding. | |
| """ | |
| # Get batch size | |
| batch_size = len(tensor1) | |
| src_dtype = tensor1[0].dtype | |
| tgt_dtype = tensor2[0].dtype | |
| # List to store transformed tensors | |
| transformed_tensors = [] | |
| for i in range(batch_size): | |
| item1 = tensor1[i] | |
| item2 = tensor2[i] | |
| h1, w1 = item1.shape[1], item1.shape[2] | |
| pad1_l, pad1_r, pad1_t, pad1_b = ( | |
| padding1[i] if padding1 is not None else (0, 0, 0, 0) | |
| ) | |
| pad2_l, pad2_r, pad2_t, pad2_b = ( | |
| padding2[i] if padding2 is not None else (0, 0, 0, 0) | |
| ) | |
| item1_unpadded = item1[:, pad1_t : h1 - pad1_b, pad1_l : w1 - pad1_r] | |
| h2, w2 = ( | |
| item2.shape[1] - pad2_t - pad2_b, | |
| item2.shape[2] - pad2_l - pad2_r, | |
| ) | |
| item1_resized = F.interpolate( | |
| item1_unpadded.unsqueeze(0).to(tgt_dtype), size=(h2, w2), mode=mode | |
| ) | |
| item1_padded = F.pad(item1_resized, (pad2_l, pad2_r, pad2_t, pad2_b)) | |
| transformed_tensors.append(item1_padded) | |
| transformed_batch = torch.cat(transformed_tensors) | |
| return transformed_batch.to(src_dtype) | |