Spaces:
Configuration error
Configuration error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.checkpoint import checkpoint, checkpoint_sequential | |
| from collections.abc import Iterable | |
| from itertools import repeat | |
| def _ntuple(n): | |
| def parse(x): | |
| if isinstance(x, Iterable) and not isinstance(x, str): | |
| return x | |
| return tuple(repeat(x, n)) | |
| return parse | |
| to_1tuple = _ntuple(1) | |
| to_2tuple = _ntuple(2) | |
| def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1): | |
| assert isinstance(model, nn.Module) | |
| def set_attr(module): | |
| module.grad_checkpointing = True | |
| module.fp32_attention = use_fp32_attention | |
| module.grad_checkpointing_step = gc_step | |
| model.apply(set_attr) | |
| def auto_grad_checkpoint(module, *args, **kwargs): | |
| if getattr(module, 'grad_checkpointing', False): | |
| if isinstance(module, Iterable): | |
| gc_step = module[0].grad_checkpointing_step | |
| return checkpoint_sequential(module, gc_step, *args, **kwargs) | |
| else: | |
| return checkpoint(module, *args, **kwargs) | |
| return module(*args, **kwargs) | |
| def checkpoint_sequential(functions, step, input, *args, **kwargs): | |
| # Hack for keyword-only parameter in a python 2.7-compliant way | |
| preserve = kwargs.pop('preserve_rng_state', True) | |
| if kwargs: | |
| raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) | |
| def run_function(start, end, functions): | |
| def forward(input): | |
| for j in range(start, end + 1): | |
| input = functions[j](input, *args) | |
| return input | |
| return forward | |
| if isinstance(functions, torch.nn.Sequential): | |
| functions = list(functions.children()) | |
| # the last chunk has to be non-volatile | |
| end = -1 | |
| segment = len(functions) // step | |
| for start in range(0, step * (segment - 1), step): | |
| end = start + step - 1 | |
| input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve) | |
| return run_function(end + 1, len(functions) - 1, functions)(input) | |
| def get_rel_pos(q_size, k_size, rel_pos): | |
| """ | |
| Get relative positional embeddings according to the relative positions of | |
| query and key sizes. | |
| Args: | |
| q_size (int): size of query q. | |
| k_size (int): size of key k. | |
| rel_pos (Tensor): relative position embeddings (L, C). | |
| Returns: | |
| Extracted positional embeddings according to relative positions. | |
| """ | |
| max_rel_dist = int(2 * max(q_size, k_size) - 1) | |
| # Interpolate rel pos if needed. | |
| if rel_pos.shape[0] != max_rel_dist: | |
| # Interpolate rel pos. | |
| rel_pos_resized = F.interpolate( | |
| rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), | |
| size=max_rel_dist, | |
| mode="linear", | |
| ) | |
| rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) | |
| else: | |
| rel_pos_resized = rel_pos | |
| # Scale the coords with short length if shapes for q and k are different. | |
| q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) | |
| k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) | |
| relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) | |
| return rel_pos_resized[relative_coords.long()] | |
| def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): | |
| """ | |
| Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. | |
| https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 | |
| Args: | |
| attn (Tensor): attention map. | |
| q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). | |
| rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. | |
| rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. | |
| q_size (Tuple): spatial sequence size of query q with (q_h, q_w). | |
| k_size (Tuple): spatial sequence size of key k with (k_h, k_w). | |
| Returns: | |
| attn (Tensor): attention map with added relative positional embeddings. | |
| """ | |
| q_h, q_w = q_size | |
| k_h, k_w = k_size | |
| Rh = get_rel_pos(q_h, k_h, rel_pos_h) | |
| Rw = get_rel_pos(q_w, k_w, rel_pos_w) | |
| B, _, dim = q.shape | |
| r_q = q.reshape(B, q_h, q_w, dim) | |
| rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) | |
| rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) | |
| attn = ( | |
| attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] | |
| ).view(B, q_h * q_w, k_h * k_w) | |
| return attn | |