|
|
|
|
|
import functools
|
|
|
import math
|
|
|
from dataclasses import dataclass
|
|
|
import torch
|
|
|
from vsa import video_sparse_attn
|
|
|
from typing import Any
|
|
|
|
|
|
VSA_TILE_SIZE = (4, 4, 4)
|
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=10)
|
|
|
def get_tile_partition_indices(
|
|
|
dit_seq_shape: tuple[int, int, int],
|
|
|
tile_size: tuple[int, int, int],
|
|
|
device: torch.device,
|
|
|
) -> torch.LongTensor:
|
|
|
T, H, W = dit_seq_shape
|
|
|
ts, hs, ws = tile_size
|
|
|
indices = torch.arange(T * H * W, device=device,
|
|
|
dtype=torch.long).reshape(T, H, W)
|
|
|
ls = []
|
|
|
for t in range(math.ceil(T / ts)):
|
|
|
for h in range(math.ceil(H / hs)):
|
|
|
for w in range(math.ceil(W / ws)):
|
|
|
ls.append(indices[t * ts:min(t * ts + ts, T),
|
|
|
h * hs:min(h * hs + hs, H),
|
|
|
w * ws:min(w * ws + ws, W)].flatten())
|
|
|
index = torch.cat(ls, dim=0)
|
|
|
return index
|
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=10)
|
|
|
def get_reverse_tile_partition_indices(
|
|
|
dit_seq_shape: tuple[int, int, int],
|
|
|
tile_size: tuple[int, int, int],
|
|
|
device: torch.device,
|
|
|
) -> torch.LongTensor:
|
|
|
return torch.argsort(
|
|
|
get_tile_partition_indices(dit_seq_shape, tile_size, device))
|
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=10)
|
|
|
def construct_variable_block_sizes(
|
|
|
dit_seq_shape: tuple[int, int, int],
|
|
|
num_tiles: tuple[int, int, int],
|
|
|
device: torch.device,
|
|
|
) -> torch.LongTensor:
|
|
|
"""
|
|
|
Compute the number of valid (non‑padded) tokens inside every
|
|
|
(ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order
|
|
|
(t‑tile, h‑tile, w‑tile) that `rearrange` uses.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
torch.LongTensor # shape: [∏ full_window_size]
|
|
|
"""
|
|
|
|
|
|
t, h, w = dit_seq_shape
|
|
|
ts_t, ts_h, ts_w = VSA_TILE_SIZE
|
|
|
n_t, n_h, n_w = num_tiles
|
|
|
|
|
|
def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:
|
|
|
"""Vector with the size of each tile along one dimension."""
|
|
|
sizes = torch.full((n_tiles, ), tile, dtype=torch.int, device=device)
|
|
|
|
|
|
remainder = dim_len - (n_tiles - 1) * tile
|
|
|
sizes[-1] = remainder if remainder > 0 else tile
|
|
|
return sizes
|
|
|
|
|
|
t_sizes = _sizes(t, ts_t, n_t)
|
|
|
h_sizes = _sizes(h, ts_h, n_h)
|
|
|
w_sizes = _sizes(w, ts_w, n_w)
|
|
|
|
|
|
|
|
|
block_sizes = (
|
|
|
t_sizes[:, None, None]
|
|
|
* h_sizes[None, :, None]
|
|
|
* w_sizes[None, None, :]
|
|
|
).reshape(-1)
|
|
|
|
|
|
return block_sizes
|
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=10)
|
|
|
def get_non_pad_index(
|
|
|
variable_block_sizes: torch.LongTensor,
|
|
|
max_block_size: int,
|
|
|
):
|
|
|
n_win = variable_block_sizes.shape[0]
|
|
|
device = variable_block_sizes.device
|
|
|
starts_pad = torch.arange(n_win, device=device) * max_block_size
|
|
|
index_pad = starts_pad[:, None] + torch.arange(max_block_size,
|
|
|
device=device)[None, :]
|
|
|
index_mask = torch.arange(
|
|
|
max_block_size, device=device)[None, :] < variable_block_sizes[:, None]
|
|
|
return index_pad[index_mask]
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class VideoSparseAttentionMetadata():
|
|
|
current_timestep: int
|
|
|
dit_seq_shape: list[int]
|
|
|
VSA_sparsity: float
|
|
|
num_tiles: list[int]
|
|
|
total_seq_length: int
|
|
|
tile_partition_indices: torch.LongTensor
|
|
|
reverse_tile_partition_indices: torch.LongTensor
|
|
|
variable_block_sizes: torch.LongTensor
|
|
|
non_pad_index: torch.LongTensor
|
|
|
|
|
|
|
|
|
def build(
|
|
|
current_timestep: int,
|
|
|
raw_latent_shape: tuple[int, int, int],
|
|
|
patch_size: tuple[int, int, int],
|
|
|
VSA_sparsity: float,
|
|
|
device: torch.device,
|
|
|
**kwargs: dict[str, Any],
|
|
|
) -> VideoSparseAttentionMetadata:
|
|
|
patch_size = patch_size
|
|
|
dit_seq_shape = (raw_latent_shape[0] // patch_size[0],
|
|
|
raw_latent_shape[1] // patch_size[1],
|
|
|
raw_latent_shape[2] // patch_size[2])
|
|
|
|
|
|
num_tiles = (math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]),
|
|
|
math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]),
|
|
|
math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]))
|
|
|
total_seq_length = math.prod(dit_seq_shape)
|
|
|
|
|
|
tile_partition_indices = get_tile_partition_indices(
|
|
|
dit_seq_shape, VSA_TILE_SIZE, device)
|
|
|
reverse_tile_partition_indices = get_reverse_tile_partition_indices(
|
|
|
dit_seq_shape, VSA_TILE_SIZE, device)
|
|
|
variable_block_sizes = construct_variable_block_sizes(
|
|
|
dit_seq_shape, num_tiles, device)
|
|
|
non_pad_index = get_non_pad_index(variable_block_sizes,
|
|
|
math.prod(VSA_TILE_SIZE))
|
|
|
|
|
|
return VideoSparseAttentionMetadata(
|
|
|
current_timestep=current_timestep,
|
|
|
dit_seq_shape=dit_seq_shape,
|
|
|
VSA_sparsity=VSA_sparsity,
|
|
|
num_tiles=num_tiles,
|
|
|
total_seq_length=total_seq_length,
|
|
|
tile_partition_indices=tile_partition_indices,
|
|
|
reverse_tile_partition_indices=reverse_tile_partition_indices,
|
|
|
variable_block_sizes=variable_block_sizes,
|
|
|
non_pad_index=non_pad_index)
|
|
|
|
|
|
|
|
|
|
|
|
class VideoSparseAttentionImpl():
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
num_heads: int,
|
|
|
head_size: int,
|
|
|
causal: bool,
|
|
|
softmax_scale: float,
|
|
|
num_kv_heads: int | None = None,
|
|
|
prefix: str = "",
|
|
|
**extra_impl_args,
|
|
|
) -> None:
|
|
|
self.prefix = prefix
|
|
|
|
|
|
def tile(self, x: torch.Tensor, num_tiles: list[int],
|
|
|
tile_partition_indices: torch.LongTensor,
|
|
|
non_pad_index: torch.LongTensor) -> torch.Tensor:
|
|
|
t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
|
|
|
h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
|
|
|
w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]
|
|
|
|
|
|
x_padded = torch.zeros(
|
|
|
(x.shape[0], t_padded_size * h_padded_size * w_padded_size,
|
|
|
x.shape[-2], x.shape[-1]),
|
|
|
device=x.device,
|
|
|
dtype=x.dtype)
|
|
|
x_padded[:, non_pad_index] = x[:, tile_partition_indices]
|
|
|
return x_padded
|
|
|
|
|
|
def untile(self, x: torch.Tensor,
|
|
|
reverse_tile_partition_indices: torch.LongTensor,
|
|
|
non_pad_index: torch.LongTensor) -> torch.Tensor:
|
|
|
x = x[:, non_pad_index][:, reverse_tile_partition_indices]
|
|
|
return x
|
|
|
|
|
|
def preprocess_qkv(
|
|
|
self,
|
|
|
qkv: torch.Tensor,
|
|
|
attn_metadata: VideoSparseAttentionMetadata,
|
|
|
) -> torch.Tensor:
|
|
|
return self.tile(qkv, attn_metadata.num_tiles,
|
|
|
attn_metadata.tile_partition_indices,
|
|
|
attn_metadata.non_pad_index)
|
|
|
|
|
|
def postprocess_output(
|
|
|
self,
|
|
|
output: torch.Tensor,
|
|
|
attn_metadata: VideoSparseAttentionMetadata,
|
|
|
) -> torch.Tensor:
|
|
|
return self.untile(output, attn_metadata.reverse_tile_partition_indices,
|
|
|
attn_metadata.non_pad_index)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
query: torch.Tensor,
|
|
|
key: torch.Tensor,
|
|
|
value: torch.Tensor,
|
|
|
attn_metadata: VideoSparseAttentionMetadata,
|
|
|
) -> torch.Tensor:
|
|
|
query = query.transpose(1, 2).contiguous()
|
|
|
key = key.transpose(1, 2).contiguous()
|
|
|
value = value.transpose(1, 2).contiguous()
|
|
|
|
|
|
VSA_sparsity = attn_metadata.VSA_sparsity
|
|
|
|
|
|
cur_topk = math.ceil(
|
|
|
(1 - VSA_sparsity) *
|
|
|
(attn_metadata.total_seq_length / math.prod(VSA_TILE_SIZE)))
|
|
|
|
|
|
hidden_states = video_sparse_attn(
|
|
|
query,
|
|
|
key,
|
|
|
value,
|
|
|
variable_block_sizes=attn_metadata.variable_block_sizes,
|
|
|
topk=cur_topk,
|
|
|
block_size=VSA_TILE_SIZE,
|
|
|
compress_attn_weight=None).transpose(1, 2)
|
|
|
|
|
|
return hidden_states
|
|
|
|