| |
| import numbers |
| from mamba_ssm.modules.mamba_simple import Mamba |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| from timm.models.layers import DropPath, to_2tuple |
|
|
| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
|
|
| from einops import rearrange, repeat |
|
|
| try: |
| from causal_conv1d import causal_conv1d_fn, causal_conv1d_update |
| except ImportError: |
| causal_conv1d_fn, causal_conv1d_update = None |
|
|
| try: |
| from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj |
| except ImportError: |
| selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None |
|
|
| try: |
| from mamba_ssm.ops.triton.selective_state_update import selective_state_update |
| except ImportError: |
| selective_state_update = None |
|
|
| try: |
| from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn |
| except ImportError: |
| RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None |
|
|
| class LightweightModel(nn.Module): |
| def __init__(self, in_channels, out_channels): |
| super(LightweightModel, self).__init__() |
| self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels) |
| self.pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) |
|
|
| def forward(self, x): |
| x = self.depthwise_conv(x) |
| x = self.pointwise_conv(x) |
| return x |
|
|
|
|
| class ConvMamba(nn.Module): |
| def __init__( |
| self, |
| d_model, |
| d_state=16, |
| d_conv=4, |
| expand=2, |
| dt_rank="auto", |
| dt_min=0.001, |
| dt_max=0.1, |
| dt_init="random", |
| dt_scale=1.0, |
| dt_init_floor=1e-4, |
| conv_bias=True, |
| bias=False, |
| use_fast_path=True, |
| layer_idx=None, |
| device=None, |
| dtype=None, |
| bimamba_type="none", |
| conv_mode = "deepwise" |
| ): |
| factory_kwargs = {"device": device, "dtype": dtype} |
| super().__init__() |
| self.conv_mode = conv_mode |
| self.d_model = d_model |
| self.d_state = d_state |
| self.d_conv = d_conv |
| self.expand = expand |
| self.d_inner = int(self.expand * self.d_model) |
| self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank |
| self.use_fast_path = use_fast_path |
| self.layer_idx = layer_idx |
| self.bimamba_type = bimamba_type |
|
|
| if self.conv_mode == "orignal": |
| self.local_relation = nn.Sequential( |
| nn.Conv2d(in_channels=self.d_model, out_channels=self.d_model, kernel_size=3, stride=1, padding=1), |
| nn.SiLU(), |
| nn.Conv2d(in_channels=self.d_model, out_channels=self.d_inner, kernel_size=3, stride=1, padding=1), |
| ) |
| elif self.conv_mode == "orignal_1_5_dmodel": |
| self.local_relation = nn.Sequential( |
| nn.Conv2d(in_channels=self.d_model, out_channels=int(1.5*self.d_model), kernel_size=3, stride=1, padding=1), |
| nn.SiLU(), |
| nn.Conv2d(in_channels=int(1.5*self.d_model), out_channels=self.d_inner, kernel_size=3, stride=1, padding=1), |
| ) |
| elif self.conv_mode == "orignal_dinner": |
| self.local_relation = nn.Sequential( |
| nn.Conv2d(in_channels=self.d_model, out_channels=self.d_inner, kernel_size=3, stride=1, padding=1), |
| nn.SiLU(), |
| nn.Conv2d(in_channels=self.d_inner, out_channels=self.d_inner, kernel_size=3, stride=1, padding=1), |
| ) |
| elif self.conv_mode == "deepwise": |
| self.local_relation = nn.Sequential( |
| LightweightModel(in_channels=self.d_model, out_channels=self.d_model), |
| nn.SiLU(), |
| LightweightModel(in_channels=self.d_model, out_channels=self.d_inner), |
| ) |
| elif self.conv_mode == "deepwise_dinner": |
| self.local_relation = nn.Sequential( |
| LightweightModel(in_channels=self.d_model, out_channels=self.d_inner), |
| nn.SiLU(), |
| LightweightModel(in_channels=self.d_inner, out_channels=self.d_inner), |
| ) |
|
|
| self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) |
|
|
| self.conv1d = nn.Conv1d( |
| in_channels=self.d_inner, |
| out_channels=self.d_inner, |
| bias=conv_bias, |
| kernel_size=d_conv, |
| groups=self.d_inner, |
| padding=d_conv - 1, |
| **factory_kwargs, |
| ) |
|
|
| self.activation = "silu" |
| self.act = nn.SiLU() |
|
|
| self.x_proj = nn.Linear( |
| self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs |
| ) |
| self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) |
|
|
| |
| dt_init_std = self.dt_rank**-0.5 * dt_scale |
| if dt_init == "constant": |
| nn.init.constant_(self.dt_proj.weight, dt_init_std) |
| elif dt_init == "random": |
| nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) |
| else: |
| raise NotImplementedError |
|
|
| |
| dt = torch.exp( |
| torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) |
| + math.log(dt_min) |
| ).clamp(min=dt_init_floor) |
| |
| inv_dt = dt + torch.log(-torch.expm1(-dt)) |
| with torch.no_grad(): |
| self.dt_proj.bias.copy_(inv_dt) |
| |
| self.dt_proj.bias._no_reinit = True |
|
|
| |
| A = repeat( |
| torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), |
| "n -> d n", |
| d=self.d_inner, |
| ).contiguous() |
| A_log = torch.log(A) |
| self.A_log = nn.Parameter(A_log) |
| self.A_log._no_weight_decay = True |
|
|
| |
| self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) |
| self.D._no_weight_decay = True |
|
|
| |
| assert bimamba_type == "v2" |
|
|
| A_b = repeat( |
| torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), |
| "n -> d n", |
| d=self.d_inner, |
| ).contiguous() |
| A_b_log = torch.log(A_b) |
| self.A_b_log = nn.Parameter(A_b_log) |
| self.A_b_log._no_weight_decay = True |
|
|
| self.conv1d_b = nn.Conv1d( |
| in_channels=self.d_inner, |
| out_channels=self.d_inner, |
| bias=conv_bias, |
| kernel_size=d_conv, |
| groups=self.d_inner, |
| padding=d_conv - 1, |
| **factory_kwargs, |
| ) |
|
|
| self.x_proj_b = nn.Linear( |
| self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs |
| ) |
| self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) |
|
|
| self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) |
| self.D_b._no_weight_decay = True |
|
|
| self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) |
|
|
| def forward(self, hidden_states, inference_params=None): |
| """ |
| hidden_states: (B, L, D) |
| Returns: same shape as hidden_states |
| """ |
| batch, seqlen, dim = hidden_states.shape |
| h = int(math.sqrt(seqlen)) |
|
|
| local_relation = self.local_relation(rearrange(hidden_states, "b (h w) d -> b d h w", h=h)) |
| local_relation = rearrange(local_relation, "b d h w -> b d (h w)") |
|
|
| conv_state, ssm_state = None, None |
| if inference_params is not None: |
| conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) |
| if inference_params.seqlen_offset > 0: |
| |
| out, _, _ = self.step(hidden_states, conv_state, ssm_state) |
| return out |
|
|
| |
| xz = rearrange( |
| self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), |
| "d (b l) -> b d l", |
| l=seqlen, |
| ) |
| if self.in_proj.bias is not None: |
| xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") |
|
|
| A = -torch.exp(self.A_log.float()) |
| |
| if self.use_fast_path and inference_params is None: |
| if self.bimamba_type == "v2": |
| A_b = -torch.exp(self.A_b_log.float()) |
| out = mamba_inner_fn_no_out_proj( |
| xz, |
| self.conv1d.weight, |
| self.conv1d.bias, |
| self.x_proj.weight, |
| self.dt_proj.weight, |
| A, |
| None, |
| None, |
| self.D.float(), |
| delta_bias=self.dt_proj.bias.float(), |
| delta_softplus=True, |
| ) |
| out_b = mamba_inner_fn_no_out_proj( |
| xz.flip([-1]), |
| self.conv1d_b.weight, |
| self.conv1d_b.bias, |
| self.x_proj_b.weight, |
| self.dt_proj_b.weight, |
| A_b, |
| None, |
| None, |
| self.D_b.float(), |
| delta_bias=self.dt_proj_b.bias.float(), |
| delta_softplus=True, |
| ) |
| |
| out = F.linear(rearrange(out + out_b.flip([-1]) + local_relation, "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) |
| else: |
| out = mamba_inner_fn( |
| xz, |
| self.conv1d.weight, |
| self.conv1d.bias, |
| self.x_proj.weight, |
| self.dt_proj.weight, |
| self.out_proj.weight, |
| self.out_proj.bias, |
| A, |
| None, |
| None, |
| self.D.float(), |
| delta_bias=self.dt_proj.bias.float(), |
| delta_softplus=True, |
| ) |
| else: |
| x, z = xz.chunk(2, dim=1) |
| |
| if conv_state is not None: |
| conv_state.copy_(x[:, :, -self.d_conv :]) |
| if causal_conv1d_fn is None: |
| x = self.act(self.conv1d(x)[..., :seqlen]) |
| else: |
| assert self.activation in ["silu", "swish"] |
| x = causal_conv1d_fn( |
| x, |
| rearrange(self.conv1d.weight, "d 1 w -> d w"), |
| self.conv1d.bias, |
| self.activation, |
| ) |
|
|
| |
| |
| |
| x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) |
| dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) |
| dt = self.dt_proj.weight @ dt.t() |
| dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) |
| B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() |
| C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() |
| assert self.activation in ["silu", "swish"] |
| y = selective_scan_fn( |
| x, |
| dt, |
| A, |
| B, |
| C, |
| self.D.float(), |
| z=z, |
| delta_bias=self.dt_proj.bias.float(), |
| delta_softplus=True, |
| return_last_state=ssm_state is not None, |
| ) |
| if ssm_state is not None: |
| y, last_state = y |
| ssm_state.copy_(last_state) |
| y = rearrange(y, "b d l -> b l d") |
| out = self.out_proj(y) |
| return out |
|
|
| def step(self, hidden_states, conv_state, ssm_state): |
| dtype = hidden_states.dtype |
| assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" |
| xz = self.in_proj(hidden_states.squeeze(1)) |
| x, z = xz.chunk(2, dim=-1) |
|
|
| |
| if causal_conv1d_update is None: |
| conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) |
| conv_state[:, :, -1] = x |
| x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) |
| if self.conv1d.bias is not None: |
| x = x + self.conv1d.bias |
| x = self.act(x).to(dtype=dtype) |
| else: |
| x = causal_conv1d_update( |
| x, |
| conv_state, |
| rearrange(self.conv1d.weight, "d 1 w -> d w"), |
| self.conv1d.bias, |
| self.activation, |
| ) |
|
|
| x_db = self.x_proj(x) |
| dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) |
| |
| dt = F.linear(dt, self.dt_proj.weight) |
| A = -torch.exp(self.A_log.float()) |
|
|
| |
| if selective_state_update is None: |
| |
| dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) |
| dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) |
| dB = torch.einsum("bd,bn->bdn", dt, B) |
| ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) |
| y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) |
| y = y + self.D.to(dtype) * x |
| y = y * self.act(z) |
| else: |
| y = selective_state_update( |
| ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True |
| ) |
|
|
| out = self.out_proj(y) |
| return out.unsqueeze(1), conv_state, ssm_state |
|
|
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
| device = self.out_proj.weight.device |
| conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype |
| conv_state = torch.zeros( |
| batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype |
| ) |
| ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype |
| |
| ssm_state = torch.zeros( |
| batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype |
| ) |
| return conv_state, ssm_state |
|
|
| def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): |
| assert self.layer_idx is not None |
| if self.layer_idx not in inference_params.key_value_memory_dict: |
| batch_shape = (batch_size,) |
| conv_state = torch.zeros( |
| batch_size, |
| self.d_model * self.expand, |
| self.d_conv, |
| device=self.conv1d.weight.device, |
| dtype=self.conv1d.weight.dtype, |
| ) |
| ssm_state = torch.zeros( |
| batch_size, |
| self.d_model * self.expand, |
| self.d_state, |
| device=self.dt_proj.weight.device, |
| dtype=self.dt_proj.weight.dtype, |
| |
| ) |
| inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) |
| else: |
| conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] |
| |
| if initialize_states: |
| conv_state.zero_() |
| ssm_state.zero_() |
| return conv_state, ssm_state |
|
|
| def to_3d(x): |
| return rearrange(x, 'b c h w -> b (h w) c') |
|
|
| def to_4d(x, h, w): |
| return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) |
|
|
| class WithBias_LayerNorm(nn.Module): |
| def __init__(self, normalized_shape): |
| super(WithBias_LayerNorm, self).__init__() |
| if isinstance(normalized_shape, numbers.Integral): |
| normalized_shape = (normalized_shape,) |
| normalized_shape = torch.Size(normalized_shape) |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) |
| self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
| def forward(self, x): |
| mu = x.mean(-1, keepdim=True) |
| sigma = x.var(-1, keepdim=True, unbiased=False) |
| return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias |
|
|
| class BiasFree_LayerNorm(nn.Module): |
| def __init__(self, normalized_shape): |
| super(BiasFree_LayerNorm, self).__init__() |
| if isinstance(normalized_shape, numbers.Integral): |
| normalized_shape = (normalized_shape,) |
| normalized_shape = torch.Size(normalized_shape) |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) |
| def forward(self, x): |
| sigma = x.var(-1, keepdim=True, unbiased=False) |
| return x / torch.sqrt(sigma + 1e-5) * self.weight |
|
|
| class LayerNorm(nn.Module): |
| def __init__(self, dim, norm_type='with_bias'): |
| super(LayerNorm, self).__init__() |
| if norm_type == 'BiasFree': |
| self.body = BiasFree_LayerNorm(dim) |
| else: |
| self.body = WithBias_LayerNorm(dim) |
| def forward(self, x): |
| if len(x.shape) == 4: |
| h, w = x.shape[-2:] |
| return to_4d(self.body(to_3d(x)), h, w) |
| else: |
| return self.body(x) |
|
|
| class M3(nn.Module): |
| def __init__(self, dim): |
| super(M3, self).__init__() |
| self.multi_modal_mamba_block = Mamba(dim, bimamba_type="m3") |
| self.norm1 = LayerNorm(dim, 'with_bias') |
| self.norm2 = LayerNorm(dim, 'with_bias') |
| self.norm3 = LayerNorm(dim, 'with_bias') |
| self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) |
|
|
| def forward(self, I1, I2, fusion, test_h, test_w): |
| fusion = self.norm1(fusion) |
| I2 = self.norm2(I2) |
| I1 = self.norm3(I1) |
| global_f = self.multi_modal_mamba_block(fusion, extra_emb1=I2, extra_emb2=I1) |
| B, HW, C = global_f.shape |
| fusion = global_f.transpose(1, 2).view(B, C, test_h, test_w) |
| fusion = (self.dwconv(fusion) + fusion).flatten(2).transpose(1, 2) |
| return fusion, None |
|
|
| class PatchEmbed(nn.Module): |
| def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): |
| super(PatchEmbed, self).__init__() |
| img_size = to_2tuple(img_size) |
| patch_size = to_2tuple(patch_size) |
| self.img_size = img_size |
| self.patch_size = patch_size |
| self.patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] |
| self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] |
| self.in_chans = in_chans |
| self.embed_dim = embed_dim |
| self.norm = norm_layer(embed_dim) if norm_layer is not None else None |
| def forward(self, x): |
| |
| x = x.flatten(2).transpose(1, 2) |
| if self.norm is not None: |
| x = self.norm(x) |
| return x |
|
|
| class PatchUnEmbed(nn.Module): |
| def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): |
| super(PatchUnEmbed, self).__init__() |
| img_size = to_2tuple(img_size) |
| patch_size = to_2tuple(patch_size) |
| self.img_size = img_size |
| self.patch_size = patch_size |
| self.patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] |
| self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] |
| self.in_chans = in_chans |
| self.embed_dim = embed_dim |
| def forward(self, x, x_size): |
| B, HW, C = x.shape |
| x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) |
| return x |
|
|
|
|
| class Block(nn.Module): |
| def __init__( |
| self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False |
| ): |
| """ |
| Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" |
| |
| This Block has a slightly different structure compared to a regular |
| prenorm Transformer block. |
| The standard block is: LN -> MHA/MLP -> Add. |
| [Ref: https://arxiv.org/abs/2002.04745] |
| Here we have: Add -> LN -> Mixer, returning both |
| the hidden_states (output of the mixer) and the residual. |
| This is purely for performance reasons, as we can fuse add and LayerNorm. |
| The residual needs to be provided (except for the very first block). |
| """ |
| super().__init__() |
| self.residual_in_fp32 = residual_in_fp32 |
| self.fused_add_norm = fused_add_norm |
| self.mixer = mixer_cls(dim) |
| self.norm = norm_cls(dim) |
| if self.fused_add_norm: |
| assert RMSNorm is not None, "RMSNorm import fails" |
| assert isinstance( |
| self.norm, (nn.LayerNorm, RMSNorm) |
| ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" |
|
|
| def forward( |
| self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None |
| ): |
| r"""Pass the input through the encoder layer. |
| |
| Args: |
| hidden_states: the sequence to the encoder layer (required). |
| residual: hidden_states = Mixer(LN(residual)) |
| """ |
| if not self.fused_add_norm: |
| residual = (hidden_states + residual) if residual is not None else hidden_states |
| hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) |
| if self.residual_in_fp32: |
| residual = residual.to(torch.float32) |
| else: |
| fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn |
| hidden_states, residual = fused_add_norm_fn( |
| hidden_states, |
| self.norm.weight, |
| self.norm.bias, |
| residual=residual, |
| prenorm=True, |
| residual_in_fp32=self.residual_in_fp32, |
| eps=self.norm.eps, |
| ) |
| hidden_states = self.mixer(hidden_states, inference_params=inference_params) |
| return hidden_states, residual |
|
|
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
| return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) |
|
|