Image-to-Video
Diffusers
GGUF
SteadyDancer-GGUF / small_archs.py
jiamingZ's picture
Upload small_archs.py with huggingface_hub
a807fde verified
import torch
import torch.nn as nn
class FactorConv3d(nn.Module):
"""
(2+1)D decomposition of 3D convolution: 1xHxW spatial convolution → Swish → Tx1x1 temporal convolution
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size,
stride: int = 1,
dilation: int = 1):
super().__init__()
if isinstance(kernel_size, int):
k_t, k_h, k_w = kernel_size, kernel_size, kernel_size
else:
k_t, k_h, k_w = kernel_size
pad_t = (k_t - 1) * dilation // 2
pad_hw = (k_h - 1) * dilation // 2
self.spatial = nn.Conv3d(
in_channels, in_channels,
kernel_size=(1, k_h, k_w),
stride=(1, stride, stride),
padding=(0, pad_hw, pad_hw),
dilation=(1, dilation, dilation),
groups=in_channels,
bias=False
)
self.temporal = nn.Conv3d(
in_channels, out_channels,
kernel_size=(k_t, 1, 1),
stride=(stride, 1, 1),
padding=(pad_t, 0, 0),
dilation=(dilation, 1, 1),
bias=True
)
self.act = nn.SiLU()
def forward(self, x):
print(f"x before factorconv3d: {x.dtype}, spatial weight: {self.spatial.weight.dtype}, temporal weight: {self.temporal.weight.dtype}") # x before factorconv3d: torch.float32, spatial weight: torch.float32, temporal weight: torch.float32
x_type = x.dtype
x = x.to(self.spatial.weight.dtype)
x = self.spatial(x)
x = x.to(x_type)
x = self.act(x)
x = x.to(self.temporal.weight.dtype)
x = self.temporal(x)
x = x.to(x_type)
return x
class LayerNorm2D(nn.Module):
"""
LayerNorm over C for a 4-D tensor (B, C, H, W)
"""
def __init__(self, num_channels, eps=1e-5, affine=True):
super().__init__()
self.num_channels = num_channels
self.eps = eps
self.affine = affine
if affine:
self.weight = nn.Parameter(torch.ones(1, num_channels, 1, 1))
self.bias = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
def forward(self, x):
# x: (B, C, H, W)
mean = x.mean(dim=1, keepdim=True) # (B, 1, H, W)
var = x.var (dim=1, keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + self.eps)
if self.affine:
x = x * self.weight + self.bias
return x
class PoseRefNetNoBNV3(nn.Module):
def __init__(self,
in_channels_c: int,
in_channels_x: int,
hidden_dim: int = 256,
num_heads: int = 8,
dropout: float = 0.1):
super().__init__()
self.d_model = hidden_dim
self.nhead = num_heads
self.proj_p = nn.Conv2d(in_channels_c, hidden_dim, kernel_size=1)
self.proj_r = nn.Conv2d(in_channels_x, hidden_dim, kernel_size=1)
self.proj_p_back = nn.Conv2d(hidden_dim, in_channels_c, kernel_size=1)
self.cross_attn = nn.MultiheadAttention(hidden_dim,
num_heads=num_heads,
dropout=dropout)
self.ffn_pose = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1),
nn.SiLU(),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1)
)
self.norm1 = LayerNorm2D(hidden_dim)
self.norm2 = LayerNorm2D(hidden_dim)
def forward(self, pose, ref, mask=None):
"""
pose : (B, C1, T, H, W)
ref : (B, C2, T, H, W)
mask : (B, T*H*W) optional key_padding_mask
return: (B, d_model, T, H, W)
"""
pose_type, ref_type = pose.dtype, ref.dtype
B, _, T, H, W = pose.shape
L = H * W
p_trans = pose.permute(0, 2, 1, 3, 4).contiguous().flatten(0, 1)
r_trans = ref.permute(0, 2, 1, 3, 4).contiguous().flatten(0, 1)
p_trans, r_trans = p_trans.to(self.proj_p.weight.dtype), r_trans.to(self.proj_r.weight.dtype)
p_trans = self.proj_p(p_trans)
r_trans = self.proj_r(r_trans)
p_trans, r_trans = p_trans.to(pose_type), r_trans.to(ref_type)
p_trans = p_trans.flatten(2).transpose(1, 2)
r_trans = r_trans.flatten(2).transpose(1, 2)
p_trans, r_trans = p_trans.to(self.cross_attn.in_proj_weight.dtype), r_trans.to(self.cross_attn.in_proj_weight.dtype)
out = self.cross_attn(query=r_trans,
key=p_trans,
value=p_trans,
key_padding_mask=mask)[0]
out = out.transpose(1, 2).contiguous().view(B*T, -1, H, W)
out = self.norm1(out)
out_type = out.dtype
out = out.to(self.ffn_pose[0].weight.dtype)
ffn_out = self.ffn_pose(out)
ffn_out = ffn_out.to(out_type)
out = out + ffn_out
out = self.norm2(out)
out_type = out.dtype
out = out.to(self.proj_p_back.weight.dtype)
out = self.proj_p_back(out)
out = out.to(out_type)
out = out.view(B, T, -1, H, W).contiguous().transpose(1, 2)
return out