|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import math
|
|
|
from typing import Tuple, Optional
|
|
|
from einops import rearrange
|
|
|
from ..utils.io_utils import hash_state_dict_keys
|
|
|
from .audio_pack import AudioPack
|
|
|
from ..utils.args_config import args
|
|
|
|
|
|
if args.sp_size > 1:
|
|
|
|
|
|
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
|
|
get_sequence_parallel_world_size,
|
|
|
get_sp_group)
|
|
|
|
|
|
|
|
|
try:
|
|
|
import flash_attn_interface
|
|
|
print('using flash_attn_interface')
|
|
|
FLASH_ATTN_3_AVAILABLE = True
|
|
|
except ModuleNotFoundError:
|
|
|
FLASH_ATTN_3_AVAILABLE = False
|
|
|
|
|
|
try:
|
|
|
import flash_attn
|
|
|
print('using flash_attn')
|
|
|
FLASH_ATTN_2_AVAILABLE = True
|
|
|
except ModuleNotFoundError:
|
|
|
FLASH_ATTN_2_AVAILABLE = False
|
|
|
|
|
|
try:
|
|
|
from sageattention import sageattn
|
|
|
print('using sageattention')
|
|
|
SAGE_ATTN_AVAILABLE = True
|
|
|
except ModuleNotFoundError:
|
|
|
SAGE_ATTN_AVAILABLE = False
|
|
|
|
|
|
|
|
|
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False):
|
|
|
if compatibility_mode:
|
|
|
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
|
|
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
|
|
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
|
|
x = F.scaled_dot_product_attention(q, k, v)
|
|
|
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
|
|
elif FLASH_ATTN_3_AVAILABLE:
|
|
|
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
|
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
|
x = flash_attn_interface.flash_attn_func(q, k, v)
|
|
|
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
|
|
elif FLASH_ATTN_2_AVAILABLE:
|
|
|
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
|
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
|
x = flash_attn.flash_attn_func(q, k, v)
|
|
|
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
|
|
elif SAGE_ATTN_AVAILABLE:
|
|
|
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
|
|
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
|
|
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
|
|
x = sageattn(q, k, v)
|
|
|
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
|
|
else:
|
|
|
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
|
|
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
|
|
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
|
|
x = F.scaled_dot_product_attention(q, k, v)
|
|
|
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
|
|
return x
|
|
|
|
|
|
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
|
|
return (x * (1 + scale) + shift)
|
|
|
|
|
|
|
|
|
def sinusoidal_embedding_1d(dim, position):
|
|
|
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
|
|
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
|
|
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
|
|
return x.to(position.dtype)
|
|
|
|
|
|
def precompute_freqs_cos_sin(dim: int, end: int = 1024, theta: float = 10000.0):
|
|
|
|
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float64)[:(dim//2)] / dim))
|
|
|
angles = torch.outer(torch.arange(end, dtype=torch.float64, device=freqs.device), freqs)
|
|
|
return angles.cos().to(torch.float32), angles.sin().to(torch.float32)
|
|
|
|
|
|
def precompute_freqs_cos_sin_3d(dim: int, end: int = 1024, theta: float = 10000.0):
|
|
|
fdim = dim - 2 * (dim // 3)
|
|
|
hdim = dim // 3
|
|
|
wdim = dim // 3
|
|
|
fcos, fsin = precompute_freqs_cos_sin(fdim, end, theta)
|
|
|
hcos, hsin = precompute_freqs_cos_sin(hdim, end, theta)
|
|
|
wcos, wsin = precompute_freqs_cos_sin(wdim, end, theta)
|
|
|
return (fcos, hcos, wcos), (fsin, hsin, wsin)
|
|
|
|
|
|
def rope_apply_real(x, cos, sin, num_heads):
|
|
|
|
|
|
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
|
|
|
|
|
d2 = x.shape[-1] // 2
|
|
|
x = x.reshape(*x.shape[:-1], d2, 2)
|
|
|
x1, x2 = x[..., 0], x[..., 1]
|
|
|
|
|
|
|
|
|
rot_x1 = x1 * cos - x2 * sin
|
|
|
rot_x2 = x1 * sin + x2 * cos
|
|
|
out = torch.stack((rot_x1, rot_x2), dim=-1).reshape(*x.shape[:-2], -1)
|
|
|
|
|
|
return rearrange(out, "b s n d -> b s (n d)")
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
|
def __init__(self, dim, eps=1e-5):
|
|
|
super().__init__()
|
|
|
self.eps = eps
|
|
|
self.weight = nn.Parameter(torch.ones(dim))
|
|
|
|
|
|
def norm(self, x):
|
|
|
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
|
|
|
|
|
def forward(self, x):
|
|
|
dtype = x.dtype
|
|
|
return self.norm(x.float()).to(dtype) * self.weight
|
|
|
|
|
|
|
|
|
class AttentionModule(nn.Module):
|
|
|
def __init__(self, num_heads):
|
|
|
super().__init__()
|
|
|
self.num_heads = num_heads
|
|
|
|
|
|
|
|
|
def forward(self, q, k, v):
|
|
|
|
|
|
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
|
|
|
|
|
|
return x
|
|
|
|
|
|
class SelfAttention(nn.Module):
|
|
|
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
|
|
super().__init__()
|
|
|
self.dim = dim
|
|
|
self.num_heads = num_heads
|
|
|
self.head_dim = dim // num_heads
|
|
|
|
|
|
self.q = nn.Linear(dim, dim)
|
|
|
self.k = nn.Linear(dim, dim)
|
|
|
self.v = nn.Linear(dim, dim)
|
|
|
self.o = nn.Linear(dim, dim)
|
|
|
self.norm_q = RMSNorm(dim, eps=eps)
|
|
|
self.norm_k = RMSNorm(dim, eps=eps)
|
|
|
|
|
|
self.attn = AttentionModule(self.num_heads)
|
|
|
|
|
|
def forward(self, x, freqs):
|
|
|
|
|
|
cos, sin = freqs
|
|
|
|
|
|
q = self.norm_q(self.q(x))
|
|
|
k = self.norm_k(self.k(x))
|
|
|
v = self.v(x)
|
|
|
|
|
|
|
|
|
|
|
|
q = rope_apply_real(q, cos, sin, self.num_heads)
|
|
|
k = rope_apply_real(k, cos, sin, self.num_heads)
|
|
|
x = self.attn(q, k, v)
|
|
|
return self.o(x)
|
|
|
|
|
|
|
|
|
class CrossAttention(nn.Module):
|
|
|
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False):
|
|
|
super().__init__()
|
|
|
self.dim = dim
|
|
|
self.num_heads = num_heads
|
|
|
self.head_dim = dim // num_heads
|
|
|
|
|
|
self.q = nn.Linear(dim, dim)
|
|
|
self.k = nn.Linear(dim, dim)
|
|
|
self.v = nn.Linear(dim, dim)
|
|
|
self.o = nn.Linear(dim, dim)
|
|
|
self.norm_q = RMSNorm(dim, eps=eps)
|
|
|
self.norm_k = RMSNorm(dim, eps=eps)
|
|
|
self.has_image_input = has_image_input
|
|
|
if has_image_input:
|
|
|
self.k_img = nn.Linear(dim, dim)
|
|
|
self.v_img = nn.Linear(dim, dim)
|
|
|
self.norm_k_img = RMSNorm(dim, eps=eps)
|
|
|
|
|
|
self.attn = AttentionModule(self.num_heads)
|
|
|
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
|
if self.has_image_input:
|
|
|
img = y[:, :257]
|
|
|
ctx = y[:, 257:]
|
|
|
else:
|
|
|
ctx = y
|
|
|
q = self.norm_q(self.q(x))
|
|
|
k = self.norm_k(self.k(ctx))
|
|
|
v = self.v(ctx)
|
|
|
x = self.attn(q, k, v)
|
|
|
if self.has_image_input:
|
|
|
k_img = self.norm_k_img(self.k_img(img))
|
|
|
v_img = self.v_img(img)
|
|
|
y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
|
|
|
x = x + y
|
|
|
return self.o(x)
|
|
|
|
|
|
|
|
|
class GateModule(nn.Module):
|
|
|
def __init__(self,):
|
|
|
super().__init__()
|
|
|
|
|
|
def forward(self, x, gate, residual):
|
|
|
return x + gate * residual
|
|
|
|
|
|
|
|
|
class DiTBlock(nn.Module):
|
|
|
def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
|
|
|
super().__init__()
|
|
|
self.dim = dim
|
|
|
self.num_heads = num_heads
|
|
|
self.ffn_dim = ffn_dim
|
|
|
|
|
|
self.self_attn = SelfAttention(dim, num_heads, eps)
|
|
|
self.cross_attn = CrossAttention(
|
|
|
dim, num_heads, eps, has_image_input=has_image_input)
|
|
|
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
|
|
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
|
|
self.norm3 = nn.LayerNorm(dim, eps=eps)
|
|
|
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
|
|
|
approximate='tanh'), nn.Linear(ffn_dim, dim))
|
|
|
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
|
|
self.gate = GateModule()
|
|
|
|
|
|
def forward(self, x, context, t_mod, freqs):
|
|
|
|
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
|
|
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
|
|
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
|
|
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
|
|
x = x + self.cross_attn(self.norm3(x), context)
|
|
|
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
|
|
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
|
|
return x
|
|
|
|
|
|
class MLP(nn.Module):
|
|
|
def __init__(self, in_dim, out_dim):
|
|
|
super().__init__()
|
|
|
|
|
|
self.ln_in = nn.LayerNorm(in_dim)
|
|
|
self.fc1 = nn.Linear(in_dim, in_dim)
|
|
|
|
|
|
self.activation = nn.GELU()
|
|
|
self.fc2 = nn.Linear(in_dim, out_dim)
|
|
|
self.ln_out = nn.LayerNorm(out_dim)
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = self.ln_in(x)
|
|
|
x = self.fc2(self.activation(self.fc1(x)))
|
|
|
x = self.ln_out(x)
|
|
|
return x
|
|
|
|
|
|
class Head(nn.Module):
|
|
|
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
|
|
|
super().__init__()
|
|
|
self.dim = dim
|
|
|
self.patch_size = patch_size
|
|
|
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
|
|
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
|
|
|
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
|
|
|
|
|
def forward(self, x, t_mod):
|
|
|
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
|
|
|
x = (self.head(self.norm(x) * (1 + scale) + shift))
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class WanModel(torch.nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
dim: int,
|
|
|
in_dim: int,
|
|
|
ffn_dim: int,
|
|
|
out_dim: int,
|
|
|
text_dim: int,
|
|
|
freq_dim: int,
|
|
|
eps: float,
|
|
|
patch_size: Tuple[int, int, int],
|
|
|
num_heads: int,
|
|
|
num_layers: int,
|
|
|
has_image_input: bool,
|
|
|
audio_hidden_size: int=32,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.dim = dim
|
|
|
self.freq_dim = freq_dim
|
|
|
self.has_image_input = has_image_input
|
|
|
self.patch_size = patch_size
|
|
|
|
|
|
self.patch_embedding = nn.Conv3d(
|
|
|
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
|
|
|
|
|
self.text_embedding = nn.Sequential(
|
|
|
nn.Linear(text_dim, dim),
|
|
|
nn.GELU(approximate='tanh'),
|
|
|
nn.Linear(dim, dim)
|
|
|
)
|
|
|
self.time_embedding = nn.Sequential(
|
|
|
nn.Linear(freq_dim, dim),
|
|
|
nn.SiLU(),
|
|
|
nn.Linear(dim, dim)
|
|
|
)
|
|
|
self.time_projection = nn.Sequential(
|
|
|
nn.SiLU(), nn.Linear(dim, dim * 6))
|
|
|
self.blocks = nn.ModuleList([
|
|
|
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
|
|
|
for _ in range(num_layers)
|
|
|
])
|
|
|
self.head = Head(dim, out_dim, patch_size, eps)
|
|
|
head_dim = dim // num_heads
|
|
|
self.freqs = precompute_freqs_cos_sin_3d(head_dim)
|
|
|
|
|
|
if has_image_input:
|
|
|
self.img_emb = MLP(1280, dim)
|
|
|
|
|
|
if 'use_audio' in args:
|
|
|
self.use_audio = args.use_audio
|
|
|
else:
|
|
|
self.use_audio = False
|
|
|
if self.use_audio:
|
|
|
audio_input_dim = 10752
|
|
|
audio_out_dim = dim
|
|
|
self.audio_proj = AudioPack(audio_input_dim, [4, 1, 1], audio_hidden_size, layernorm=True)
|
|
|
self.audio_cond_projs = nn.ModuleList()
|
|
|
for d in range(num_layers // 2 - 1):
|
|
|
l = nn.Linear(audio_hidden_size, audio_out_dim)
|
|
|
self.audio_cond_projs.append(l)
|
|
|
|
|
|
def patchify(self, x: torch.Tensor):
|
|
|
grid_size = x.shape[2:]
|
|
|
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
|
|
return x, grid_size
|
|
|
|
|
|
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
|
|
return rearrange(
|
|
|
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
|
|
|
f=grid_size[0], h=grid_size[1], w=grid_size[2],
|
|
|
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
|
|
|
)
|
|
|
|
|
|
def forward(self,
|
|
|
x: torch.Tensor,
|
|
|
timestep: torch.Tensor,
|
|
|
context: torch.Tensor,
|
|
|
clip_feature: Optional[torch.Tensor] = None,
|
|
|
y: Optional[torch.Tensor] = None,
|
|
|
use_gradient_checkpointing: bool = False,
|
|
|
audio_emb: Optional[torch.Tensor] = None,
|
|
|
use_gradient_checkpointing_offload: bool = False,
|
|
|
tea_cache = None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
|
|
|
t = self.time_embedding(
|
|
|
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
|
|
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
|
|
context = self.text_embedding(context)
|
|
|
lat_h, lat_w = x.shape[-2], x.shape[-1]
|
|
|
|
|
|
if audio_emb != None and self.use_audio:
|
|
|
audio_emb = audio_emb.permute(0, 2, 1)[:, :, :, None, None]
|
|
|
audio_emb = torch.cat([audio_emb[:, :, :1].repeat(1, 1, 3, 1, 1), audio_emb], 2)
|
|
|
audio_emb = self.audio_proj(audio_emb)
|
|
|
|
|
|
audio_emb = torch.concat([audio_cond_proj(audio_emb) for audio_cond_proj in self.audio_cond_projs], 0)
|
|
|
|
|
|
x = torch.cat([x, y], dim=1)
|
|
|
x = self.patch_embedding(x)
|
|
|
x, (f, h, w) = self.patchify(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
(fcos, hcos, wcos), (fsin, hsin, wsin) = self.freqs
|
|
|
cos = torch.cat([
|
|
|
fcos[:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
|
|
hcos[:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
|
|
wcos[:w].view(1, 1, w, -1).expand(f, h, w, -1),
|
|
|
], dim=-1).reshape(f*h*w, 1, -1).to(x.device, dtype=x.dtype)
|
|
|
sin = torch.cat([
|
|
|
fsin[:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
|
|
hsin[:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
|
|
wsin[:w].view(1, 1, w, -1).expand(f, h, w, -1),
|
|
|
], dim=-1).reshape(f*h*w, 1, -1).to(x.device, dtype=x.dtype)
|
|
|
freqs = (cos, sin)
|
|
|
|
|
|
def create_custom_forward(module):
|
|
|
def custom_forward(*inputs):
|
|
|
return module(*inputs)
|
|
|
return custom_forward
|
|
|
|
|
|
if tea_cache is not None:
|
|
|
tea_cache_update = tea_cache.check(self, x, t_mod)
|
|
|
else:
|
|
|
tea_cache_update = False
|
|
|
ori_x_len = x.shape[1]
|
|
|
if tea_cache_update:
|
|
|
x = tea_cache.update(x)
|
|
|
else:
|
|
|
if args.sp_size > 1:
|
|
|
|
|
|
sp_size = get_sequence_parallel_world_size()
|
|
|
pad_size = 0
|
|
|
if ori_x_len % sp_size != 0:
|
|
|
pad_size = sp_size - ori_x_len % sp_size
|
|
|
x = torch.cat([x, torch.zeros_like(x[:, -1:]).repeat(1, pad_size, 1)], 1)
|
|
|
x = torch.chunk(x, sp_size, dim=1)[get_sequence_parallel_rank()]
|
|
|
|
|
|
if self.use_audio:
|
|
|
audio_emb = audio_emb.reshape(x.shape[0], audio_emb.shape[0] // x.shape[0], -1, *audio_emb.shape[2:])
|
|
|
|
|
|
for layer_i, block in enumerate(self.blocks):
|
|
|
|
|
|
if self.use_audio:
|
|
|
au_idx = None
|
|
|
if (layer_i <= len(self.blocks) // 2 and layer_i > 1):
|
|
|
au_idx = layer_i - 2
|
|
|
audio_emb_tmp = audio_emb[:, au_idx].repeat(1, 1, lat_h // 2, lat_w // 2, 1)
|
|
|
audio_cond_tmp = self.patchify(audio_emb_tmp.permute(0, 4, 1, 2, 3))[0]
|
|
|
if args.sp_size > 1:
|
|
|
if pad_size > 0:
|
|
|
audio_cond_tmp = torch.cat([audio_cond_tmp, torch.zeros_like(audio_cond_tmp[:, -1:]).repeat(1, pad_size, 1)], 1)
|
|
|
audio_cond_tmp = torch.chunk(audio_cond_tmp, sp_size, dim=1)[get_sequence_parallel_rank()]
|
|
|
x = audio_cond_tmp + x
|
|
|
|
|
|
if self.training and use_gradient_checkpointing:
|
|
|
if use_gradient_checkpointing_offload:
|
|
|
with torch.autograd.graph.save_on_cpu():
|
|
|
x = torch.utils.checkpoint.checkpoint(
|
|
|
create_custom_forward(block),
|
|
|
x, context, t_mod, freqs,
|
|
|
use_reentrant=False,
|
|
|
)
|
|
|
else:
|
|
|
x = torch.utils.checkpoint.checkpoint(
|
|
|
create_custom_forward(block),
|
|
|
x, context, t_mod, freqs,
|
|
|
use_reentrant=False,
|
|
|
)
|
|
|
else:
|
|
|
x = block(x, context, t_mod, freqs)
|
|
|
if tea_cache is not None:
|
|
|
x_cache = get_sp_group().all_gather(x, dim=1)
|
|
|
x_cache = x_cache[:, :ori_x_len]
|
|
|
tea_cache.store(x_cache)
|
|
|
|
|
|
x = self.head(x, t)
|
|
|
if args.sp_size > 1:
|
|
|
|
|
|
x = get_sp_group().all_gather(x, dim=1)
|
|
|
x = x[:, :ori_x_len]
|
|
|
|
|
|
x = self.unpatchify(x, (f, h, w))
|
|
|
return x
|
|
|
|
|
|
@staticmethod
|
|
|
def state_dict_converter():
|
|
|
return WanModelStateDictConverter()
|
|
|
|
|
|
|
|
|
class WanModelStateDictConverter:
|
|
|
def __init__(self):
|
|
|
pass
|
|
|
|
|
|
def from_diffusers(self, state_dict):
|
|
|
rename_dict = {
|
|
|
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
|
|
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
|
|
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
|
|
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
|
|
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
|
|
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
|
|
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
|
|
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
|
|
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
|
|
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
|
|
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
|
|
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
|
|
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
|
|
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
|
|
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
|
|
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
|
|
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
|
|
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
|
|
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
|
|
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
|
|
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
|
|
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
|
|
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
|
|
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
|
|
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
|
|
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
|
|
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
|
|
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
|
|
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
|
|
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
|
|
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
|
|
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
|
|
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
|
|
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
|
|
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
|
|
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
|
|
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
|
|
"patch_embedding.bias": "patch_embedding.bias",
|
|
|
"patch_embedding.weight": "patch_embedding.weight",
|
|
|
"scale_shift_table": "head.modulation",
|
|
|
"proj_out.bias": "head.head.bias",
|
|
|
"proj_out.weight": "head.head.weight",
|
|
|
}
|
|
|
state_dict_ = {}
|
|
|
for name, param in state_dict.items():
|
|
|
if name in rename_dict:
|
|
|
state_dict_[rename_dict[name]] = param
|
|
|
else:
|
|
|
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
|
|
if name_ in rename_dict:
|
|
|
name_ = rename_dict[name_]
|
|
|
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
|
|
state_dict_[name_] = param
|
|
|
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
|
|
|
config = {
|
|
|
"model_type": "t2v",
|
|
|
"patch_size": (1, 2, 2),
|
|
|
"text_len": 512,
|
|
|
"in_dim": 16,
|
|
|
"dim": 5120,
|
|
|
"ffn_dim": 13824,
|
|
|
"freq_dim": 256,
|
|
|
"text_dim": 4096,
|
|
|
"out_dim": 16,
|
|
|
"num_heads": 40,
|
|
|
"num_layers": 40,
|
|
|
"window_size": (-1, -1),
|
|
|
"qk_norm": True,
|
|
|
"cross_attn_norm": True,
|
|
|
"eps": 1e-6,
|
|
|
}
|
|
|
else:
|
|
|
config = {}
|
|
|
return state_dict_, config
|
|
|
|
|
|
def from_civitai(self, state_dict):
|
|
|
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
|
|
config = {
|
|
|
"has_image_input": False,
|
|
|
"patch_size": [1, 2, 2],
|
|
|
"in_dim": 16,
|
|
|
"dim": 1536,
|
|
|
"ffn_dim": 8960,
|
|
|
"freq_dim": 256,
|
|
|
"text_dim": 4096,
|
|
|
"out_dim": 16,
|
|
|
"num_heads": 12,
|
|
|
"num_layers": 30,
|
|
|
"eps": 1e-6
|
|
|
}
|
|
|
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
|
|
|
config = {
|
|
|
"has_image_input": False,
|
|
|
"patch_size": [1, 2, 2],
|
|
|
"in_dim": 16,
|
|
|
"dim": 5120,
|
|
|
"ffn_dim": 13824,
|
|
|
"freq_dim": 256,
|
|
|
"text_dim": 4096,
|
|
|
"out_dim": 16,
|
|
|
"num_heads": 40,
|
|
|
"num_layers": 40,
|
|
|
"eps": 1e-6
|
|
|
}
|
|
|
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
|
|
config = {
|
|
|
"has_image_input": True,
|
|
|
"patch_size": [1, 2, 2],
|
|
|
"in_dim": 36,
|
|
|
"dim": 5120,
|
|
|
"ffn_dim": 13824,
|
|
|
"freq_dim": 256,
|
|
|
"text_dim": 4096,
|
|
|
"out_dim": 16,
|
|
|
"num_heads": 40,
|
|
|
"num_layers": 40,
|
|
|
"eps": 1e-6
|
|
|
}
|
|
|
else:
|
|
|
config = {}
|
|
|
if hasattr(args, "model_config"):
|
|
|
model_config = args.model_config
|
|
|
if model_config is not None:
|
|
|
config.update(model_config)
|
|
|
return state_dict, config |