| | import math |
| | import torch |
| |
|
| | from kernels.benchmark import Benchmark |
| |
|
| |
|
| | def _cdiv(a, b): |
| | return (a + b - 1) // b |
| |
|
| |
|
| | def _extract_output(result): |
| | if isinstance(result, tuple): |
| | return result[0] |
| | return result |
| |
|
| |
|
| | def _reference_mla_decode(q, blocked_k, block_table, cache_seqlens, head_dim_v, causal=False): |
| | b, s_q, h_q, d = q.size() |
| | block_size = blocked_k.size(1) |
| | h_kv = blocked_k.size(2) |
| |
|
| | out = torch.empty(b, s_q, h_q, head_dim_v, dtype=torch.float32, device=q.device) |
| |
|
| | for i in range(b): |
| | cur_len = int(cache_seqlens[i].item()) |
| | num_blocks = _cdiv(cur_len, block_size) |
| | cur_blocks = block_table[i][:num_blocks] |
| | kv = blocked_k[cur_blocks].reshape(-1, h_kv, d)[:cur_len] |
| |
|
| | query = q[i].transpose(0, 1).float() |
| | key_val = kv.transpose(0, 1).float() |
| |
|
| | if h_kv != h_q: |
| | key_val = key_val.repeat_interleave(h_q // h_kv, dim=0) |
| |
|
| | attn = query @ key_val.transpose(-2, -1) / math.sqrt(d) |
| |
|
| | s_k = key_val.size(1) |
| | if causal and s_q > 1: |
| | mask = torch.ones(s_q, s_k, dtype=torch.bool, device=q.device).tril( |
| | diagonal=s_k - s_q |
| | ) |
| | attn.masked_fill_(~mask, float("-inf")) |
| |
|
| | attn = torch.softmax(attn, dim=-1) |
| | output = attn @ key_val[..., :head_dim_v] |
| | out[i] = output.transpose(0, 1) |
| |
|
| | return out.to(q.dtype) |
| |
|
| |
|
| | def _varlen_reference_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, causal=False): |
| | batch_size = cu_seqlens_q.shape[0] - 1 |
| | total_tokens_q = q.shape[0] |
| | num_heads = q.shape[1] |
| | head_dim_v = v.shape[2] |
| | scale = q.shape[-1] ** (-0.5) |
| |
|
| | out = torch.zeros( |
| | (total_tokens_q, num_heads, head_dim_v), device=q.device, dtype=q.dtype |
| | ) |
| |
|
| | for b in range(batch_size): |
| | start_q, end_q = cu_seqlens_q[b], cu_seqlens_q[b + 1] |
| | start_k, end_k = cu_seqlens_k[b], cu_seqlens_k[b + 1] |
| |
|
| | q_b = q[start_q:end_q].transpose(0, 1).float() |
| | k_b = k[start_k:end_k].transpose(0, 1).float() |
| | v_b = v[start_k:end_k].transpose(0, 1).float() |
| |
|
| | attn = q_b @ k_b.transpose(-2, -1) * scale |
| |
|
| | if causal: |
| | seq_q, seq_k = q_b.size(1), k_b.size(1) |
| | mask = torch.ones(seq_q, seq_k, dtype=torch.bool, device=q.device).tril( |
| | diagonal=seq_k - seq_q |
| | ) |
| | attn.masked_fill_(~mask, float("-inf")) |
| |
|
| | attn = torch.softmax(attn, dim=-1) |
| | result = attn @ v_b |
| | out[start_q:end_q] = result.transpose(0, 1).to(q.dtype) |
| |
|
| | return out |
| |
|
| |
|
| | |
| | _HEAD_DIM = 576 |
| | _HEAD_DIM_V = 512 |
| | _NUM_HEADS_K = 1 |
| | _PAGE_BLOCK_SIZE = 64 |
| |
|
| |
|
| | def _setup_mla_decode(bench, batch_size, seq_k, num_heads_q): |
| | max_num_blocks = _cdiv(seq_k, _PAGE_BLOCK_SIZE) |
| | total_blocks = batch_size * max_num_blocks |
| |
|
| | bench.q = ( |
| | torch.randn( |
| | batch_size, 1, num_heads_q, _HEAD_DIM, device="cuda", dtype=torch.bfloat16 |
| | ) |
| | / 10 |
| | ) |
| | bench.blocked_k = ( |
| | torch.randn( |
| | total_blocks, |
| | _PAGE_BLOCK_SIZE, |
| | _NUM_HEADS_K, |
| | _HEAD_DIM, |
| | device="cuda", |
| | dtype=torch.bfloat16, |
| | ) |
| | / 10 |
| | ) |
| | bench.block_table = torch.arange( |
| | total_blocks, device="cuda", dtype=torch.int32 |
| | ).view(batch_size, max_num_blocks) |
| | bench.cache_seqlens = torch.full( |
| | (batch_size,), seq_k, device="cuda", dtype=torch.int32 |
| | ) |
| | bench.tile_scheduler_metadata, _ = bench.kernel.get_mla_metadata() |
| | bench.out = torch.empty( |
| | batch_size, 1, num_heads_q, _HEAD_DIM_V, device="cuda", dtype=torch.bfloat16 |
| | ) |
| |
|
| |
|
| | def _run_mla_decode(bench, causal=False): |
| | out, lse = bench.kernel.flash_mla_with_kvcache( |
| | q=bench.q, |
| | k_cache=bench.blocked_k, |
| | block_table=bench.block_table, |
| | cache_seqlens=bench.cache_seqlens, |
| | head_dim_v=_HEAD_DIM_V, |
| | tile_scheduler_metadata=bench.tile_scheduler_metadata, |
| | causal=causal, |
| | ) |
| | bench.out = out |
| |
|
| |
|
| | def _verify_mla_decode(bench, causal=False): |
| | return _reference_mla_decode( |
| | bench.q, |
| | bench.blocked_k, |
| | bench.block_table, |
| | bench.cache_seqlens, |
| | _HEAD_DIM_V, |
| | causal=causal, |
| | ) |
| |
|
| |
|
| | class FlashMLABenchmark(Benchmark): |
| | seed: int = 42 |
| |
|
| | |
| | def setup_small(self): |
| | _setup_mla_decode(self, batch_size=2, seq_k=256, num_heads_q=64) |
| |
|
| | def benchmark_small(self): |
| | _run_mla_decode(self, causal=False) |
| |
|
| | def verify_small(self) -> torch.Tensor: |
| | return _verify_mla_decode(self, causal=False) |
| |
|
| | |
| | def setup_medium(self): |
| | _setup_mla_decode(self, batch_size=4, seq_k=1024, num_heads_q=64) |
| |
|
| | def benchmark_medium(self): |
| | _run_mla_decode(self, causal=False) |
| |
|
| | def verify_medium(self) -> torch.Tensor: |
| | return _verify_mla_decode(self, causal=False) |
| |
|
| | |
| | def setup_large(self): |
| | _setup_mla_decode(self, batch_size=8, seq_k=4096, num_heads_q=128) |
| |
|
| | def benchmark_large(self): |
| | _run_mla_decode(self, causal=False) |
| |
|
| | def verify_large(self) -> torch.Tensor: |
| | return _verify_mla_decode(self, causal=False) |
| |
|
| |
|
| | class FlashMLACausalBenchmark(Benchmark): |
| | seed: int = 42 |
| |
|
| | |
| | def setup_small(self): |
| | _setup_mla_decode(self, batch_size=2, seq_k=256, num_heads_q=64) |
| |
|
| | def benchmark_small(self): |
| | _run_mla_decode(self, causal=True) |
| |
|
| | def verify_small(self) -> torch.Tensor: |
| | return _verify_mla_decode(self, causal=True) |
| |
|
| | |
| | def setup_medium(self): |
| | _setup_mla_decode(self, batch_size=4, seq_k=1024, num_heads_q=64) |
| |
|
| | def benchmark_medium(self): |
| | _run_mla_decode(self, causal=True) |
| |
|
| | def verify_medium(self) -> torch.Tensor: |
| | return _verify_mla_decode(self, causal=True) |
| |
|
| | |
| | def setup_large(self): |
| | _setup_mla_decode(self, batch_size=8, seq_k=4096, num_heads_q=128) |
| |
|
| | def benchmark_large(self): |
| | _run_mla_decode(self, causal=True) |
| |
|
| | def verify_large(self) -> torch.Tensor: |
| | return _verify_mla_decode(self, causal=True) |
| |
|
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|