Refactor the test code for attention kernels (#13)
This commit is contained in:
parent
64e0e38314
commit
a1b3de86cd
@ -1,5 +1,5 @@
|
||||
import random
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
import torch
|
||||
@ -64,6 +64,39 @@ def ref_single_query_cached_kv_attention(
|
||||
output[i].copy_(out, non_blocking=True)
|
||||
|
||||
|
||||
def ref_multi_query_kv_attention(
|
||||
cu_seq_lens: List[int],
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
head_size = query.shape[-1]
|
||||
scale = 1.0 / (head_size ** 0.5)
|
||||
|
||||
num_seqs = len(cu_seq_lens) - 1
|
||||
ref_outputs = []
|
||||
for i in range(num_seqs):
|
||||
start_idx = cu_seq_lens[i]
|
||||
end_idx = cu_seq_lens[i + 1]
|
||||
seq_len = end_idx - start_idx
|
||||
|
||||
# Create attention mask
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
|
||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
||||
|
||||
ref_output = ref_masked_attention(
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
scale,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
ref_outputs.append(ref_output)
|
||||
ref_output = torch.cat(ref_outputs, dim=0)
|
||||
return ref_output
|
||||
|
||||
|
||||
def test_single_query_cached_kv_attention(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
@ -156,30 +189,29 @@ def test_multi_query_kv_attention(
|
||||
causal=True,
|
||||
)[0]
|
||||
|
||||
ref_outputs = []
|
||||
for i, seq_len in enumerate(seq_lens):
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
|
||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
||||
start_idx = cu_seq_lens[i]
|
||||
end_idx = cu_seq_lens[i + 1]
|
||||
ref_output = ref_masked_attention(
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
scale,
|
||||
attn_mask=attn_mask,
|
||||
cu_seq_lens = cu_seq_lens.cpu().tolist()
|
||||
ref_output = ref_multi_query_kv_attention(
|
||||
cu_seq_lens,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dtype,
|
||||
)
|
||||
ref_outputs.append(ref_output)
|
||||
ref_output = torch.cat(ref_outputs, dim=0)
|
||||
|
||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_attention() -> None:
|
||||
def test_attention(seed: int) -> None:
|
||||
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
|
||||
# the test fails due to the precision issue. Re-run the test if it fails.
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for block_size in [8, 16]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Testing single_query_cached_kv_attention with '
|
||||
f'dtype={dtype}, block_size={block_size}, '
|
||||
f'head_size={head_size}')
|
||||
test_single_query_cached_kv_attention(
|
||||
num_tokens=37,
|
||||
num_heads=3,
|
||||
@ -193,6 +225,8 @@ def test_attention() -> None:
|
||||
for dtype in [torch.half]:
|
||||
# NOTE(woosuk): FlashAttention does not support head_size > 128.
|
||||
for head_size in [64, 80, 96, 128]:
|
||||
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
|
||||
f'head_size={head_size}')
|
||||
test_multi_query_kv_attention(
|
||||
num_seqs=11,
|
||||
num_heads=3,
|
||||
@ -202,4 +236,4 @@ def test_attention() -> None:
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_attention()
|
||||
test_attention(seed=0)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user