import torch from typing import Tuple, Union import torch from einops import rearrange from torch import nn def make_triple(value: Union[int, Tuple[int, int, int]]) -> Tuple[int, int, int]: value = (value,) * 3 if isinstance(value, int) else value assert len(value) == 3 return value class AudioPack(nn.Module): def __init__( self, in_channels: int, patch_size: Union[int, Tuple[int, int, int]], dim: int, layernorm=False, ): super().__init__() t, h, w = make_triple(patch_size) self.patch_size = t, h, w self.proj = nn.Linear(in_channels * t * h * w, dim) if layernorm: self.norm_out = nn.LayerNorm(dim) else: self.norm_out = None def forward( self, vid: torch.Tensor, ) -> torch.Tensor: t, h, w = self.patch_size vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) vid = self.proj(vid) if self.norm_out is not None: vid = self.norm_out(vid) return vid