| |
| |
| |
| import math |
| from enum import Enum |
|
|
| import einops |
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
| |
| from rscd.models.decoderheads.vision_lstm_util import interpolate_sincos, to_ntuple, VitPatchEmbed, VitPosEmbed2d, DropPath |
|
|
| class SequenceTraversal(Enum): |
| ROWWISE_FROM_TOP_LEFT = "rowwise_from_top_left" |
| ROWWISE_FROM_BOT_RIGHT = "rowwise_from_bot_right" |
|
|
|
|
| def bias_linspace_init_(param: torch.Tensor, start: float = 3.4, end: float = 6.0) -> torch.Tensor: |
| """Linearly spaced bias init across dimensions.""" |
| assert param.dim() == 1, f"param must be 1-dimensional (typically a bias), got {param.dim()}" |
| n_dims = param.shape[0] |
| init_vals = torch.linspace(start, end, n_dims) |
| with torch.no_grad(): |
| param.copy_(init_vals) |
| return param |
|
|
|
|
| def small_init_(param: torch.Tensor, dim: int) -> torch.Tensor: |
| """ |
| Fills the input Tensor with values according to the method described in Transformers without Tears: Improving |
| the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution. |
| Adopted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py. |
| """ |
| std = math.sqrt(2 / (5 * dim)) |
| torch.nn.init.normal_(param, mean=0.0, std=std) |
| return param |
|
|
|
|
| def wang_init_(param: torch.Tensor, dim: int, num_blocks: int): |
| """ Adopted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py. """ |
| std = 2 / num_blocks / math.sqrt(dim) |
| torch.nn.init.normal_(param, mean=0.0, std=std) |
| return param |
|
|
|
|
| def parallel_stabilized_simple( |
| queries: torch.Tensor, |
| keys: torch.Tensor, |
| values: torch.Tensor, |
| igate_preact: torch.Tensor, |
| fgate_preact: torch.Tensor, |
| lower_triangular_matrix: torch.Tensor = None, |
| stabilize_rowwise: bool = True, |
| eps: float = 1e-6, |
| ) -> torch.Tensor: |
| """ |
| This is the mLSTM cell in parallel form. |
| This version is stabilized. We control the range of exp() arguments by |
| ensuring that they are always smaller than 0.0 by subtracting the maximum. |
| |
| Args: |
| :param queries: (torch.Tensor) (B, NH, S, DH) |
| :param keys: (torch.Tensor) (B, NH, S, DH) |
| :param values: (torch.Tensor) (B, NH, S, DH) |
| :param igate_preact: (torch.Tensor) (B, NH, S, 1) |
| :param fgate_preact: (torch.Tensor) (B, NH, S, 1) |
| :param lower_triangular_matrix: (torch.Tensor) (S,S). Defaults to None. |
| :param stabilize_rowwise: (bool) Wether to stabilize the combination matrix C rowwise (take maximum per row). |
| Alternative: Subtract the maximum over all rows. Defaults to True. |
| :param eps: (float) small constant to avoid division by 0. Defaults to 1e-6. |
| |
| Returns: |
| torch.Tensor: (B, NH, S, DH), h_tilde_state |
| """ |
|
|
| B, NH, S, DH = queries.shape |
| _dtype, _device = queries.dtype, queries.device |
|
|
| |
| log_fgates = torch.nn.functional.logsigmoid(fgate_preact) |
| if lower_triangular_matrix is None or S < lower_triangular_matrix.size(-1): |
| ltr = torch.tril(torch.ones((S, S), dtype=torch.bool, device=_device)) |
| else: |
| ltr = lower_triangular_matrix |
| assert ltr.dtype == torch.bool, f"lower_triangular_matrix must be of dtype bool, got {ltr.dtype}" |
|
|
| log_fgates_cumsum = torch.cat( |
| [ |
| torch.zeros((B, NH, 1, 1), dtype=_dtype, device=_device), |
| torch.cumsum(log_fgates, dim=-2), |
| ], |
| dim=-2, |
| ) |
| |
| |
| |
| rep_log_fgates_cumsum = log_fgates_cumsum.repeat(1, 1, 1, S + 1) |
| |
| |
| _log_fg_matrix = rep_log_fgates_cumsum - rep_log_fgates_cumsum.transpose(-2, -1) |
| |
| |
| log_fg_matrix = torch.where(ltr, _log_fg_matrix[:, :, 1:, 1:], -float("inf")) |
|
|
| |
| log_D_matrix = log_fg_matrix + igate_preact.transpose(-2, -1) |
| |
| if stabilize_rowwise: |
| max_log_D, _ = torch.max(log_D_matrix, dim=-1, keepdim=True) |
| else: |
| max_log_D = torch.max(log_D_matrix.view(B, NH, -1), dim=-1, keepdim=True)[0].unsqueeze(-1) |
| |
| log_D_matrix_stabilized = log_D_matrix - max_log_D |
| D_matrix = torch.exp(log_D_matrix_stabilized) |
|
|
| keys_scaled = keys / math.sqrt(DH) |
|
|
| |
| qk_matrix = queries @ keys_scaled.transpose(-2, -1) |
| C_matrix = qk_matrix * D_matrix |
| normalizer = torch.maximum(C_matrix.sum(dim=-1, keepdim=True).abs(), torch.exp(-max_log_D)) |
| |
| C_matrix_normalized = C_matrix / (normalizer + eps) |
|
|
| |
| h_tilde_state = C_matrix_normalized @ values |
|
|
| return h_tilde_state |
|
|
|
|
| class LinearHeadwiseExpand(nn.Module): |
| """ |
| This is a structured projection layer that projects the input to a higher dimension. |
| It only allows integer up-projection factors, i.e. the output dimension is a multiple of the input dimension. |
| """ |
|
|
| def __init__(self, dim, num_heads, bias=False): |
| super().__init__() |
| assert dim % num_heads == 0 |
| self.dim = dim |
| self.num_heads = num_heads |
|
|
| dim_per_head = dim // num_heads |
| self.weight = nn.Parameter(torch.empty(num_heads, dim_per_head, dim_per_head)) |
| if bias: |
| self.bias = nn.Parameter(torch.empty(dim)) |
| else: |
| self.bias = None |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| nn.init.normal_(self.weight.data, mean=0.0, std=math.sqrt(2 / 5 / self.weight.shape[-1])) |
| if self.bias is not None: |
| nn.init.zeros_(self.bias.data) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = einops.rearrange(x, "... (nh d) -> ... nh d", nh=self.num_heads) |
| x = einops.einsum( |
| x, |
| self.weight, |
| "... nh d, nh out_d d -> ... nh out_d", |
| ) |
| x = einops.rearrange(x, "... nh out_d -> ... (nh out_d)") |
| if self.bias is not None: |
| x = x + self.bias |
| return x |
|
|
| def extra_repr(self): |
| return ( |
| f"dim={self.dim}, " |
| f"num_heads={self.num_heads}, " |
| f"bias={self.bias is not None}, " |
| ) |
|
|
|
|
| class CausalConv1d(nn.Module): |
| """ |
| Implements causal depthwise convolution of a time series tensor. |
| Input: Tensor of shape (B,T,F), i.e. (batch, time, feature) |
| Output: Tensor of shape (B,T,F) |
| |
| Args: |
| feature_dim: number of features in the input tensor |
| kernel_size: size of the kernel for the depthwise convolution |
| causal_conv_bias: whether to use bias in the depthwise convolution |
| channel_mixing: whether to use channel mixing (i.e. groups=1) or not (i.e. groups=feature_dim) |
| If True, it mixes the convolved features across channels. |
| If False, all the features are convolved independently. |
| """ |
|
|
| def __init__(self, dim, kernel_size=4, bias=True): |
| super().__init__() |
| self.dim = dim |
| self.kernel_size = kernel_size |
| self.bias = bias |
| |
| self.pad = kernel_size - 1 |
| self.conv = nn.Conv1d( |
| in_channels=dim, |
| out_channels=dim, |
| kernel_size=kernel_size, |
| padding=self.pad, |
| groups=dim, |
| bias=bias, |
| ) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| self.conv.reset_parameters() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x = einops.rearrange(x, "b l d -> b d l") |
| |
| x = self.conv(x) |
| x = x[:, :, :-self.pad] |
| |
| x = einops.rearrange(x, "b d l -> b l d") |
| return x |
|
|
|
|
| class LayerNorm(nn.Module): |
| """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False. """ |
|
|
| def __init__( |
| self, |
| ndim: int = -1, |
| weight: bool = True, |
| bias: bool = False, |
| eps: float = 1e-5, |
| residual_weight: bool = True, |
| ): |
| super().__init__() |
| self.weight = nn.Parameter(torch.zeros(ndim)) if weight else None |
| self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
| self.eps = eps |
| self.residual_weight = residual_weight |
| self.ndim = ndim |
| self.reset_parameters() |
|
|
| @property |
| def weight_proxy(self) -> torch.Tensor: |
| if self.weight is None: |
| return None |
| if self.residual_weight: |
| return 1.0 + self.weight |
| else: |
| return self.weight |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return F.layer_norm( |
| x, |
| normalized_shape=(self.ndim,), |
| weight=self.weight_proxy, |
| bias=self.bias, |
| eps=self.eps, |
| ) |
|
|
| def reset_parameters(self): |
| if self.weight_proxy is not None: |
| if self.residual_weight: |
| nn.init.zeros_(self.weight) |
| else: |
| nn.init.ones_(self.weight) |
| if self.bias is not None: |
| nn.init.zeros_(self.bias) |
|
|
|
|
| class MultiHeadLayerNorm(LayerNorm): |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| assert x.ndim == 4, "Input must be 4D tensor (B, NH, S, DH)" |
| B, NH, S, DH = x.shape |
|
|
| gn_in_1 = x.transpose(1, 2) |
| gn_in_2 = gn_in_1.reshape(B * S, NH * DH) |
| out = F.group_norm( |
| gn_in_2, |
| num_groups=NH, |
| weight=self.weight_proxy, |
| bias=self.bias, |
| eps=self.eps, |
| ) |
| |
| out = out.view(B, S, NH, DH).transpose(1, 2) |
| return out |
|
|
|
|
| class MatrixLSTMCell(nn.Module): |
| def __init__(self, dim, num_heads): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
|
|
| self.igate = nn.Linear(3 * dim, num_heads) |
| self.fgate = nn.Linear(3 * dim, num_heads) |
| self.outnorm = MultiHeadLayerNorm(ndim=dim, weight=True, bias=False) |
| self.causal_mask_cache = {} |
| self.reset_parameters() |
|
|
| def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: |
| B, S, _ = q.shape |
|
|
| if_gate_input = torch.cat([q, k, v], dim=-1) |
| q = q.view(B, S, self.num_heads, -1) |
| k = k.view(B, S, self.num_heads, -1) |
| v = v.view(B, S, self.num_heads, -1) |
|
|
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| |
| igate_preact = self.igate(if_gate_input) |
| igate_preact = igate_preact.transpose(-1, -2).unsqueeze(-1) |
| fgate_preact = self.fgate(if_gate_input) |
| fgate_preact = fgate_preact.transpose(-1, -2).unsqueeze(-1) |
|
|
| |
| if S in self.causal_mask_cache: |
| causal_mask = self.causal_mask_cache[(S, str(q.device))] |
| else: |
| causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device=q.device)) |
| self.causal_mask_cache[(S, str(q.device))] = causal_mask |
|
|
| h_state = parallel_stabilized_simple( |
| queries=q, |
| keys=k, |
| values=v, |
| igate_preact=igate_preact, |
| fgate_preact=fgate_preact, |
| lower_triangular_matrix=causal_mask, |
| ) |
|
|
| h_state_norm = self.outnorm(h_state) |
| h_state_norm = h_state_norm.transpose(1, 2).reshape(B, S, -1) |
|
|
| return h_state_norm |
|
|
| def reset_parameters(self): |
| self.outnorm.reset_parameters() |
| |
| torch.nn.init.zeros_(self.fgate.weight) |
| bias_linspace_init_(self.fgate.bias, start=3.0, end=6.0) |
| |
| torch.nn.init.zeros_(self.igate.weight) |
| torch.nn.init.normal_(self.igate.bias, mean=0.0, std=0.1) |
|
|
|
|
| class ViLLayer(nn.Module): |
| def __init__( |
| self, |
| dim, |
| direction, |
| expansion=2, |
| qkv_block_size=4, |
| proj_bias=False, |
| conv_bias=True, |
| kernel_size=4, |
| ): |
| super().__init__() |
| if dim % qkv_block_size != 0: |
| qkv_block_size=2 |
| |
| self.dim = dim |
| self.direction = direction |
| self.expansion = expansion |
| self.qkv_block_size = qkv_block_size |
| self.proj_bias = proj_bias |
| self.conv_bias = conv_bias |
| self.kernel_size = kernel_size |
|
|
| inner_dim = expansion * dim |
| num_heads = inner_dim // qkv_block_size |
| self.proj_up = nn.Linear( |
| in_features=dim, |
| out_features=2 * inner_dim, |
| bias=proj_bias, |
| ) |
| self.q_proj = LinearHeadwiseExpand( |
| dim=inner_dim, |
| num_heads=num_heads, |
| bias=proj_bias, |
| ) |
| self.k_proj = LinearHeadwiseExpand( |
| dim=inner_dim, |
| num_heads=num_heads, |
| bias=proj_bias, |
| ) |
| self.v_proj = LinearHeadwiseExpand( |
| dim=inner_dim, |
| num_heads=num_heads, |
| bias=proj_bias, |
| ) |
|
|
| self.conv1d = CausalConv1d( |
| dim=inner_dim, |
| kernel_size=kernel_size, |
| bias=conv_bias, |
| ) |
| self.mlstm_cell = MatrixLSTMCell( |
| dim=inner_dim, |
| num_heads=qkv_block_size, |
| ) |
| self.learnable_skip = nn.Parameter(torch.ones(inner_dim)) |
|
|
| self.proj_down = nn.Linear( |
| in_features=inner_dim, |
| out_features=dim, |
| bias=proj_bias, |
| ) |
| self.reset_parameters() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, S, _ = x.shape |
|
|
| |
| if self.direction == SequenceTraversal.ROWWISE_FROM_TOP_LEFT: |
| pass |
| elif self.direction == SequenceTraversal.ROWWISE_FROM_BOT_RIGHT: |
| x = x.flip(dims=[1]) |
| else: |
| raise NotImplementedError |
|
|
| |
| x_inner = self.proj_up(x) |
| x_mlstm, z = torch.chunk(x_inner, chunks=2, dim=-1) |
|
|
| |
| x_mlstm_conv = self.conv1d(x_mlstm) |
| x_mlstm_conv_act = F.silu(x_mlstm_conv) |
| q = self.q_proj(x_mlstm_conv_act) |
| k = self.k_proj(x_mlstm_conv_act) |
| v = self.v_proj(x_mlstm) |
| h_tilde_state = self.mlstm_cell(q=q, k=k, v=v) |
| h_tilde_state_skip = h_tilde_state + (self.learnable_skip * x_mlstm_conv_act) |
|
|
| |
| h_state = h_tilde_state_skip * F.silu(z) |
|
|
| |
| x = self.proj_down(h_state) |
|
|
| |
| if self.direction == SequenceTraversal.ROWWISE_FROM_TOP_LEFT: |
| pass |
| elif self.direction == SequenceTraversal.ROWWISE_FROM_BOT_RIGHT: |
| x = x.flip(dims=[1]) |
| else: |
| raise NotImplementedError |
|
|
| return x |
|
|
| def reset_parameters(self): |
| |
| small_init_(self.proj_up.weight, dim=self.dim) |
| if self.proj_up.bias is not None: |
| nn.init.zeros_(self.proj_up.bias) |
| |
| wang_init_(self.proj_down.weight, dim=self.dim, num_blocks=1) |
| if self.proj_down.bias is not None: |
| nn.init.zeros_(self.proj_down.bias) |
|
|
| nn.init.ones_(self.learnable_skip) |
|
|
| def _init_qkv_proj(qkv_proj: LinearHeadwiseExpand): |
| |
| small_init_(qkv_proj.weight, dim=self.dim) |
| if qkv_proj.bias is not None: |
| nn.init.zeros_(qkv_proj.bias) |
|
|
| _init_qkv_proj(self.q_proj) |
| _init_qkv_proj(self.k_proj) |
| _init_qkv_proj(self.v_proj) |
|
|
| self.mlstm_cell.reset_parameters() |
|
|
|
|
| class ViLBlock(nn.Module): |
| def __init__(self, dim, direction, drop_path=0.0, norm_bias=False): |
| super().__init__() |
| self.dim = dim |
| self.direction = direction |
| self.drop_path = drop_path |
| self.norm_bias = norm_bias |
|
|
| self.drop_path = DropPath(drop_prob=drop_path) |
| self.norm = LayerNorm(ndim=dim, weight=True, bias=norm_bias) |
| self.layer = ViLLayer(dim=dim, direction=direction) |
|
|
| self.reset_parameters() |
|
|
| def _forward_path(self, x): |
| x = self.norm(x) |
| x = self.layer(x) |
| return x |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.drop_path(x, self._forward_path) |
| |
| return x |
|
|
| def reset_parameters(self): |
| self.layer.reset_parameters() |
| self.norm.reset_parameters() |
|
|
|
|
| class VisionLSTM(nn.Module): |
| def __init__( |
| self, |
| dim=192, |
| input_shape=(3, 224, 224), |
| patch_size=16, |
| depth=24, |
| output_shape=(1000,), |
| mode="classifier", |
| pooling="bilateral_avg", |
| drop_path_rate=0.0, |
| stride=None, |
| alternation="bidirectional", |
| drop_path_decay=False, |
| legacy_norm=False, |
| ): |
| super().__init__() |
| self.input_shape = input_shape |
| self.output_shape = output_shape |
| ndim = len(self.input_shape) - 1 |
| self.patch_size = to_ntuple(patch_size, n=ndim) |
| self.dim = dim |
| self.depth = depth |
| self.stride = stride |
| self.mode = mode |
| self.pooling = pooling |
| self.alternation = alternation |
| self.drop_path_rate = drop_path_rate |
| self.drop_path_decay = drop_path_decay |
|
|
| |
| self.patch_embed = VitPatchEmbed( |
| dim=dim, |
| stride=stride, |
| num_channels=self.input_shape[0], |
| resolution=self.input_shape[1:], |
| patch_size=self.patch_size, |
| ) |
|
|
| |
| self.pos_embed = VitPosEmbed2d(seqlens=self.patch_embed.seqlens, dim=dim) |
|
|
| |
| if drop_path_decay and drop_path_rate > 0.: |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
| else: |
| dpr = [drop_path_rate] * depth |
|
|
| |
| directions = [] |
| if alternation == "bidirectional": |
| for i in range(depth): |
| if i % 2 == 0: |
| directions.append(SequenceTraversal.ROWWISE_FROM_TOP_LEFT) |
| else: |
| directions.append(SequenceTraversal.ROWWISE_FROM_BOT_RIGHT) |
| else: |
| raise NotImplementedError(f"invalid alternation '{alternation}'") |
|
|
| |
| self.blocks = nn.ModuleList( |
| [ |
| ViLBlock( |
| dim=dim, |
| drop_path=dpr[i], |
| direction=directions[i], |
| ) |
| for i in range(depth) |
| ] |
| ) |
| |
| if legacy_norm: |
| self.legacy_norm = LayerNorm(dim, bias=False) |
| else: |
| self.legacy_norm = nn.Identity() |
| self.norm = nn.LayerNorm(dim, eps=1e-6) |
|
|
| |
| if mode is None: |
| |
| assert self.output_shape is None |
| assert self.pooling is None |
| self.head = None |
| self.output_shape = (self.patch_embed.num_patches, dim) |
| elif mode == "classifier": |
| |
| assert self.output_shape is not None and len(self.output_shape) == 1, \ |
| f"define number of classes via output_shape=(num_classes,) (e.g. output_shape=(1000,) for ImageNet-1K" |
| self.head = nn.Linear(dim, self.output_shape[0]) |
| |
| nn.init.trunc_normal_(self.head.weight, std=2e-5) |
| nn.init.zeros_(self.head.bias) |
| else: |
| raise NotImplementedError |
|
|
| def load_state_dict(self, state_dict, strict=True): |
| |
| old_pos_embed = state_dict["pos_embed.embed"] |
| if old_pos_embed.shape != self.pos_embed.embed.shape: |
| state_dict["pos_embed.embed"] = interpolate_sincos(embed=old_pos_embed, seqlens=self.pos_embed.seqlens) |
| return super().load_state_dict(state_dict=state_dict, strict=strict) |
|
|
| @torch.jit.ignore |
| def no_weight_decay(self): |
| return {"pos_embed.embed"} |
|
|
| def forward(self, x): |
| |
| x = self.patch_embed(x) |
| |
| x = self.pos_embed(x) |
|
|
| |
| x = einops.rearrange(x, "b ... d -> b (...) d") |
|
|
| |
| for block in self.blocks: |
| x = block(x) |
| x = self.legacy_norm(x) |
|
|
| |
| if self.pooling is None: |
| x = self.norm(x) |
| elif self.pooling == "bilateral_avg": |
| |
| x = (x[:, 0] + x[:, -1]) / 2 |
| x = self.norm(x) |
| else: |
| raise NotImplementedError(f"pooling '{self.pooling}' is not implemented") |
|
|
| |
| if self.head is not None: |
| x = self.head(x) |
|
|
| return x |