Refactor the test code for attention kernels (#13)

This commit is contained in:
Woosuk Kwon 2023-03-29 18:59:27 -07:00 committed by GitHub
parent 64e0e38314
commit a1b3de86cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,5 @@
import random
from typing import Optional
from typing import List, Optional
from flash_attn.flash_attention import FlashAttention
import torch
@ -64,6 +64,39 @@ def ref_single_query_cached_kv_attention(
output[i].copy_(out, non_blocking=True)
def ref_multi_query_kv_attention(
cu_seq_lens: List[int],
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
head_size = query.shape[-1]
scale = 1.0 / (head_size ** 0.5)
num_seqs = len(cu_seq_lens) - 1
ref_outputs = []
for i in range(num_seqs):
start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1]
seq_len = end_idx - start_idx
# Create attention mask
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
ref_output = ref_masked_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
scale,
attn_mask=attn_mask,
)
ref_outputs.append(ref_output)
ref_output = torch.cat(ref_outputs, dim=0)
return ref_output
def test_single_query_cached_kv_attention(
num_tokens: int,
num_heads: int,
@ -156,30 +189,29 @@ def test_multi_query_kv_attention(
causal=True,
)[0]
ref_outputs = []
for i, seq_len in enumerate(seq_lens):
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1]
ref_output = ref_masked_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
scale,
attn_mask=attn_mask,
)
ref_outputs.append(ref_output)
ref_output = torch.cat(ref_outputs, dim=0)
cu_seq_lens = cu_seq_lens.cpu().tolist()
ref_output = ref_multi_query_kv_attention(
cu_seq_lens,
query,
key,
value,
dtype,
)
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
@torch.inference_mode()
def test_attention() -> None:
def test_attention(seed: int) -> None:
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
# the test fails due to the precision issue. Re-run the test if it fails.
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
for dtype in [torch.half, torch.float]:
for block_size in [8, 16]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '
f'head_size={head_size}')
test_single_query_cached_kv_attention(
num_tokens=37,
num_heads=3,
@ -193,6 +225,8 @@ def test_attention() -> None:
for dtype in [torch.half]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.
for head_size in [64, 80, 96, 128]:
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
f'head_size={head_size}')
test_multi_query_kv_attention(
num_seqs=11,
num_heads=3,
@ -202,4 +236,4 @@ def test_attention() -> None:
if __name__ == '__main__':
test_attention()
test_attention(seed=0)