| | from diffsynth.models.wan_video_dit import flash_attention, WanModel |
| | import torch.nn.functional as F |
| | import torch.nn as nn |
| | import torch |
| | import os |
| | from safetensors import safe_open |
| |
|
| |
|
| | class AudioProjModel(nn.Module): |
| | def __init__(self, audio_in_dim=1024, cross_attention_dim=1024): |
| | super().__init__() |
| | self.cross_attention_dim = cross_attention_dim |
| | self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False) |
| | self.norm = torch.nn.LayerNorm(cross_attention_dim) |
| |
|
| | def forward(self, audio_embeds): |
| | context_tokens = self.proj(audio_embeds) |
| | context_tokens = self.norm(context_tokens) |
| | return context_tokens |
| |
|
| |
|
| | class WanCrossAttentionProcessor(nn.Module): |
| | def __init__(self, context_dim, hidden_dim): |
| | super().__init__() |
| |
|
| | self.context_dim = context_dim |
| | self.hidden_dim = hidden_dim |
| |
|
| | self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False) |
| | self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False) |
| |
|
| | nn.init.zeros_(self.k_proj.weight) |
| | nn.init.zeros_(self.v_proj.weight) |
| |
|
| | def __call__( |
| | self, |
| | attn: nn.Module, |
| | x: torch.Tensor, |
| | context: torch.Tensor, |
| | context_lens: torch.Tensor, |
| | audio_proj: torch.Tensor, |
| | audio_context_lens: torch.Tensor, |
| | latents_num_frames: int = 21, |
| | audio_scale: float = 1.0, |
| | ) -> torch.Tensor: |
| | """ |
| | x: [B, L1, C]. |
| | context: [B, L2, C]. |
| | context_lens: [B]. |
| | audio_proj: [B, 21, L3, C] |
| | audio_context_lens: [B*21]. |
| | """ |
| | context_img = context[:, :257] |
| | context = context[:, 257:] |
| | b, n, d = x.size(0), attn.num_heads, attn.head_dim |
| |
|
| | |
| | q = attn.norm_q(attn.q(x)).view(b, -1, n, d) |
| | k = attn.norm_k(attn.k(context)).view(b, -1, n, d) |
| | v = attn.v(context).view(b, -1, n, d) |
| | k_img = attn.norm_k_img(attn.k_img(context_img)).view(b, -1, n, d) |
| | v_img = attn.v_img(context_img).view(b, -1, n, d) |
| | img_x = flash_attention(q, k_img, v_img, k_lens=None) |
| | |
| | x = flash_attention(q, k, v, k_lens=context_lens) |
| | x = x.flatten(2) |
| | img_x = img_x.flatten(2) |
| |
|
| | if len(audio_proj.shape) == 4: |
| | audio_q = q.view(b * latents_num_frames, -1, n, d) |
| | ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d) |
| | ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d) |
| | audio_x = flash_attention( |
| | audio_q, ip_key, ip_value, k_lens=audio_context_lens |
| | ) |
| | audio_x = audio_x.view(b, q.size(1), n, d) |
| | audio_x = audio_x.flatten(2) |
| | elif len(audio_proj.shape) == 3: |
| | ip_key = self.k_proj(audio_proj).view(b, -1, n, d) |
| | ip_value = self.v_proj(audio_proj).view(b, -1, n, d) |
| | audio_x = flash_attention(q, ip_key, ip_value, k_lens=audio_context_lens) |
| | audio_x = audio_x.flatten(2) |
| | |
| | x = x + img_x + audio_x * audio_scale |
| | x = attn.o(x) |
| | return x |
| |
|
| |
|
| | class FantasyTalkingAudioConditionModel(nn.Module): |
| | def __init__(self, wan_dit: WanModel, audio_in_dim: int, audio_proj_dim: int): |
| | super().__init__() |
| |
|
| | self.audio_in_dim = audio_in_dim |
| | self.audio_proj_dim = audio_proj_dim |
| |
|
| | |
| | self.proj_model = self.init_proj(self.audio_proj_dim) |
| | self.set_audio_processor(wan_dit) |
| |
|
| | def init_proj(self, cross_attention_dim=5120): |
| | proj_model = AudioProjModel( |
| | audio_in_dim=self.audio_in_dim, cross_attention_dim=cross_attention_dim |
| | ) |
| | return proj_model |
| |
|
| | def set_audio_processor(self, wan_dit): |
| | attn_procs = {} |
| | for name in wan_dit.attn_processors.keys(): |
| | attn_procs[name] = WanCrossAttentionProcessor( |
| | context_dim=self.audio_proj_dim, hidden_dim=wan_dit.dim |
| | ) |
| | wan_dit.set_attn_processor(attn_procs) |
| |
|
| | def load_audio_processor(self, ip_ckpt: str, wan_dit): |
| | if os.path.splitext(ip_ckpt)[-1] == ".safetensors": |
| | state_dict = {"proj_model": {}, "audio_processor": {}} |
| | with safe_open(ip_ckpt, framework="pt", device="cpu") as f: |
| | for key in f.keys(): |
| | if key.startswith("proj_model."): |
| | state_dict["proj_model"][key.replace("proj_model.", "")] = ( |
| | f.get_tensor(key) |
| | ) |
| | elif key.startswith("audio_processor."): |
| | state_dict["audio_processor"][ |
| | key.replace("audio_processor.", "") |
| | ] = f.get_tensor(key) |
| | else: |
| | state_dict = torch.load(ip_ckpt, map_location="cpu") |
| | self.proj_model.load_state_dict(state_dict["proj_model"]) |
| | wan_dit.load_state_dict(state_dict["audio_processor"], strict=False) |
| |
|
| | def get_proj_fea(self, audio_fea=None): |
| |
|
| | return self.proj_model(audio_fea) if audio_fea is not None else None |
| |
|
| | def split_audio_sequence(self, audio_proj_length, num_frames=81): |
| | """ |
| | Map the audio feature sequence to corresponding latent frame slices. |
| | |
| | Args: |
| | audio_proj_length (int): The total length of the audio feature sequence |
| | (e.g., 173 in audio_proj[1, 173, 768]). |
| | num_frames (int): The number of video frames in the training data (default: 81). |
| | |
| | Returns: |
| | list: A list of [start_idx, end_idx] pairs. Each pair represents the index range |
| | (within the audio feature sequence) corresponding to a latent frame. |
| | """ |
| | |
| | tokens_per_frame = audio_proj_length / num_frames |
| |
|
| | |
| | tokens_per_latent_frame = tokens_per_frame * 4 |
| | half_tokens = int(tokens_per_latent_frame / 2) |
| |
|
| | pos_indices = [] |
| | for i in range(int((num_frames - 1) / 4) + 1): |
| | if i == 0: |
| | pos_indices.append(0) |
| | else: |
| | start_token = tokens_per_frame * ((i - 1) * 4 + 1) |
| | end_token = tokens_per_frame * (i * 4 + 1) |
| | center_token = int((start_token + end_token) / 2) - 1 |
| | pos_indices.append(center_token) |
| |
|
| | |
| | pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices] |
| |
|
| | |
| | pos_idx_ranges[0] = [ |
| | -(half_tokens * 2 - pos_idx_ranges[1][0]), |
| | pos_idx_ranges[1][0], |
| | ] |
| |
|
| | return pos_idx_ranges |
| |
|
| | def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0): |
| | """ |
| | Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding |
| | if the range exceeds the input boundaries. |
| | |
| | Args: |
| | input_tensor (Tensor): Input audio tensor of shape [1, L, 768]. |
| | pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]]. |
| | expand_length (int): Number of tokens to expand on both sides of each subsequence. |
| | |
| | Returns: |
| | sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding. |
| | Each element is a padded subsequence. |
| | k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence. |
| | Useful for ignoring padding tokens in attention masks. |
| | """ |
| | pos_idx_ranges = [ |
| | [idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges |
| | ] |
| | sub_sequences = [] |
| | seq_len = input_tensor.size(1) |
| | max_valid_idx = seq_len - 1 |
| | k_lens_list = [] |
| | for start, end in pos_idx_ranges: |
| | |
| | pad_front = max(-start, 0) |
| | pad_back = max(end - max_valid_idx, 0) |
| |
|
| | |
| | valid_start = max(start, 0) |
| | valid_end = min(end, max_valid_idx) |
| |
|
| | |
| | if valid_start <= valid_end: |
| | valid_part = input_tensor[:, valid_start : valid_end + 1, :] |
| | else: |
| | valid_part = input_tensor.new_zeros( |
| | (1, 0, input_tensor.size(2)) |
| | ) |
| |
|
| | |
| | padded_subseq = F.pad( |
| | valid_part, |
| | (0, 0, 0, pad_back + pad_front, 0, 0), |
| | mode="constant", |
| | value=0, |
| | ) |
| | k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front) |
| |
|
| | sub_sequences.append(padded_subseq) |
| | return torch.stack(sub_sequences, dim=1), torch.tensor( |
| | k_lens_list, dtype=torch.long |
| | ) |
| |
|