tmp-kernel-123 / benchmarks /benchmark.py
drbh's picture
drbh HF Staff
Update benchmarks/benchmark.py
4a05275 verified
import math
import torch
from kernels.benchmark import Benchmark
def _cdiv(a, b):
return (a + b - 1) // b
def _extract_output(result):
if isinstance(result, tuple):
return result[0]
return result
def _reference_mla_decode(q, blocked_k, block_table, cache_seqlens, head_dim_v, causal=False):
b, s_q, h_q, d = q.size()
block_size = blocked_k.size(1)
h_kv = blocked_k.size(2)
out = torch.empty(b, s_q, h_q, head_dim_v, dtype=torch.float32, device=q.device)
for i in range(b):
cur_len = int(cache_seqlens[i].item())
num_blocks = _cdiv(cur_len, block_size)
cur_blocks = block_table[i][:num_blocks]
kv = blocked_k[cur_blocks].reshape(-1, h_kv, d)[:cur_len]
query = q[i].transpose(0, 1).float() # [h_q, s_q, d]
key_val = kv.transpose(0, 1).float() # [h_kv, s_k, d]
if h_kv != h_q:
key_val = key_val.repeat_interleave(h_q // h_kv, dim=0)
attn = query @ key_val.transpose(-2, -1) / math.sqrt(d)
s_k = key_val.size(1)
if causal and s_q > 1:
mask = torch.ones(s_q, s_k, dtype=torch.bool, device=q.device).tril(
diagonal=s_k - s_q
)
attn.masked_fill_(~mask, float("-inf"))
attn = torch.softmax(attn, dim=-1)
output = attn @ key_val[..., :head_dim_v]
out[i] = output.transpose(0, 1)
return out.to(q.dtype)
def _varlen_reference_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, causal=False):
batch_size = cu_seqlens_q.shape[0] - 1
total_tokens_q = q.shape[0]
num_heads = q.shape[1]
head_dim_v = v.shape[2]
scale = q.shape[-1] ** (-0.5)
out = torch.zeros(
(total_tokens_q, num_heads, head_dim_v), device=q.device, dtype=q.dtype
)
for b in range(batch_size):
start_q, end_q = cu_seqlens_q[b], cu_seqlens_q[b + 1]
start_k, end_k = cu_seqlens_k[b], cu_seqlens_k[b + 1]
q_b = q[start_q:end_q].transpose(0, 1).float() # [H, seq_q, D_qk]
k_b = k[start_k:end_k].transpose(0, 1).float() # [H, seq_k, D_qk]
v_b = v[start_k:end_k].transpose(0, 1).float() # [H, seq_k, D_v]
attn = q_b @ k_b.transpose(-2, -1) * scale
if causal:
seq_q, seq_k = q_b.size(1), k_b.size(1)
mask = torch.ones(seq_q, seq_k, dtype=torch.bool, device=q.device).tril(
diagonal=seq_k - seq_q
)
attn.masked_fill_(~mask, float("-inf"))
attn = torch.softmax(attn, dim=-1)
result = attn @ v_b # [H, seq_q, D_v]
out[start_q:end_q] = result.transpose(0, 1).to(q.dtype)
return out
# MLA decode constants (DeepSeek V3 architecture)
_HEAD_DIM = 576 # Q/K head dimension
_HEAD_DIM_V = 512 # V head dimension
_NUM_HEADS_K = 1 # MLA uses single KV head
_PAGE_BLOCK_SIZE = 64 # Page block size
def _setup_mla_decode(bench, batch_size, seq_k, num_heads_q):
max_num_blocks = _cdiv(seq_k, _PAGE_BLOCK_SIZE)
total_blocks = batch_size * max_num_blocks
bench.q = (
torch.randn(
batch_size, 1, num_heads_q, _HEAD_DIM, device="cuda", dtype=torch.bfloat16
)
/ 10
)
bench.blocked_k = (
torch.randn(
total_blocks,
_PAGE_BLOCK_SIZE,
_NUM_HEADS_K,
_HEAD_DIM,
device="cuda",
dtype=torch.bfloat16,
)
/ 10
)
bench.block_table = torch.arange(
total_blocks, device="cuda", dtype=torch.int32
).view(batch_size, max_num_blocks)
bench.cache_seqlens = torch.full(
(batch_size,), seq_k, device="cuda", dtype=torch.int32
)
bench.tile_scheduler_metadata, _ = bench.kernel.get_mla_metadata()
bench.out = torch.empty(
batch_size, 1, num_heads_q, _HEAD_DIM_V, device="cuda", dtype=torch.bfloat16
)
def _run_mla_decode(bench, causal=False):
out, lse = bench.kernel.flash_mla_with_kvcache(
q=bench.q,
k_cache=bench.blocked_k,
block_table=bench.block_table,
cache_seqlens=bench.cache_seqlens,
head_dim_v=_HEAD_DIM_V,
tile_scheduler_metadata=bench.tile_scheduler_metadata,
causal=causal,
)
bench.out = out
def _verify_mla_decode(bench, causal=False):
return _reference_mla_decode(
bench.q,
bench.blocked_k,
bench.block_table,
bench.cache_seqlens,
_HEAD_DIM_V,
causal=causal,
)
class FlashMLABenchmark(Benchmark):
seed: int = 42
# Workload: small (B=2, S_k=256, H_q=64)
def setup_small(self):
_setup_mla_decode(self, batch_size=2, seq_k=256, num_heads_q=64)
def benchmark_small(self):
_run_mla_decode(self, causal=False)
def verify_small(self) -> torch.Tensor:
return _verify_mla_decode(self, causal=False)
# Workload: medium (B=4, S_k=1024, H_q=64)
def setup_medium(self):
_setup_mla_decode(self, batch_size=4, seq_k=1024, num_heads_q=64)
def benchmark_medium(self):
_run_mla_decode(self, causal=False)
def verify_medium(self) -> torch.Tensor:
return _verify_mla_decode(self, causal=False)
# Workload: large (B=8, S_k=4096, H_q=128)
def setup_large(self):
_setup_mla_decode(self, batch_size=8, seq_k=4096, num_heads_q=128)
def benchmark_large(self):
_run_mla_decode(self, causal=False)
def verify_large(self) -> torch.Tensor:
return _verify_mla_decode(self, causal=False)
class FlashMLACausalBenchmark(Benchmark):
seed: int = 42
# Workload: small (B=2, S_k=256, H_q=64)
def setup_small(self):
_setup_mla_decode(self, batch_size=2, seq_k=256, num_heads_q=64)
def benchmark_small(self):
_run_mla_decode(self, causal=True)
def verify_small(self) -> torch.Tensor:
return _verify_mla_decode(self, causal=True)
# Workload: medium (B=4, S_k=1024, H_q=64)
def setup_medium(self):
_setup_mla_decode(self, batch_size=4, seq_k=1024, num_heads_q=64)
def benchmark_medium(self):
_run_mla_decode(self, causal=True)
def verify_medium(self) -> torch.Tensor:
return _verify_mla_decode(self, causal=True)
# Workload: large (B=8, S_k=4096, H_q=128)
def setup_large(self):
_setup_mla_decode(self, batch_size=8, seq_k=4096, num_heads_q=128)
def benchmark_large(self):
_run_mla_decode(self, causal=True)
def verify_large(self) -> torch.Tensor:
return _verify_mla_decode(self, causal=True)
# class FlashMLAVarlenBenchmark(Benchmark):
# seed: int = 42
# # Workload: small (3 sequences, max_seqlen=64)
# def setup_small(self):
# H, D = 8, 64
# seqlens = [32, 48, 64]
# total = sum(seqlens)
# self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
# self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
# self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
# self.cu_seqlens = torch.tensor(
# [0] + list(torch.cumsum(torch.tensor(seqlens), 0)),
# device="cuda",
# dtype=torch.int32,
# )
# self.max_seqlen = max(seqlens)
# self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16)
# def benchmark_small(self):
# self.out = _extract_output(
# self.kernel.flash_attn_varlen_func(
# self.q,
# self.k,
# self.v,
# self.cu_seqlens,
# self.cu_seqlens,
# self.max_seqlen,
# self.max_seqlen,
# )
# )
# def verify_small(self) -> torch.Tensor:
# return _varlen_reference_attention(
# self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
# )
# # Workload: medium (5 sequences, max_seqlen=256)
# def setup_medium(self):
# H, D = 16, 64
# seqlens = [128, 192, 256, 200, 150]
# total = sum(seqlens)
# self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
# self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
# self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
# self.cu_seqlens = torch.tensor(
# [0] + list(torch.cumsum(torch.tensor(seqlens), 0)),
# device="cuda",
# dtype=torch.int32,
# )
# self.max_seqlen = max(seqlens)
# self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16)
# def benchmark_medium(self):
# self.out = _extract_output(
# self.kernel.flash_attn_varlen_func(
# self.q,
# self.k,
# self.v,
# self.cu_seqlens,
# self.cu_seqlens,
# self.max_seqlen,
# self.max_seqlen,
# )
# )
# def verify_medium(self) -> torch.Tensor:
# return _varlen_reference_attention(
# self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
# )
# # Workload: large (8 sequences, max_seqlen=512)
# def setup_large(self):
# H, D = 32, 128
# seqlens = [256, 384, 512, 448, 320, 480, 400, 512]
# total = sum(seqlens)
# self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
# self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
# self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
# self.cu_seqlens = torch.tensor(
# [0] + list(torch.cumsum(torch.tensor(seqlens), 0)),
# device="cuda",
# dtype=torch.int32,
# )
# self.max_seqlen = max(seqlens)
# self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16)
# def benchmark_large(self):
# self.out = _extract_output(
# self.kernel.flash_attn_varlen_func(
# self.q,
# self.k,
# self.v,
# self.cu_seqlens,
# self.cu_seqlens,
# self.max_seqlen,
# self.max_seqlen,
# )
# )
# def verify_large(self) -> torch.Tensor:
# return _varlen_reference_attention(
# self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
# )