Performance Issue (Transformers)

#36
by THEMi11 - opened

Running this model with transformers as in example is extremely slow. ~30 seconds per image. Batching does not help. Running on h100.

I encountered the same issue. In my case it was Patch convolution that took minutes to perform. I looked inside of the code and found out that each of the ~22k patches produces exactly one output voxel (1024, 1, 1, 1). So, basically, the convolution easily may be replaced with one matrix multiplication but I think cudnn spends a lot of time dispatching thousands of kernels.

Some monkey patching like this should work:

patch_embed = base_model.visual.patch_embed
proj = patch_embed.proj  # Conv3d(3, 1024, (2,14,14), stride=(2,14,14))

in_features = (
    patch_embed.in_channels
    * patch_embed.temporal_patch_size
    * patch_embed.patch_size ** 2
)  # 3 * 2 * 14 * 14 = 1176
embed_dim = patch_embed.embed_dim  # 1024
weight = proj.weight  # (1024, 3, 2, 14, 14)
bias = proj.bias      # (1024,)

def _fast_forward(hidden_states: torch.Tensor) -> torch.Tensor:
    target_dtype = weight.dtype
    hidden_states = hidden_states.reshape(-1, in_features).to(dtype=target_dtype)
    return F.linear(hidden_states, weight.reshape(embed_dim, -1), bias)

patch_embed.forward = _fast_forward

Sign up or log in to comment