Refactor the test code for attention kernels (#13)
This commit is contained in:
parent
64e0e38314
commit
a1b3de86cd
@ -1,5 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from flash_attn.flash_attention import FlashAttention
|
from flash_attn.flash_attention import FlashAttention
|
||||||
import torch
|
import torch
|
||||||
@ -64,6 +64,39 @@ def ref_single_query_cached_kv_attention(
|
|||||||
output[i].copy_(out, non_blocking=True)
|
output[i].copy_(out, non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
|
def ref_multi_query_kv_attention(
|
||||||
|
cu_seq_lens: List[int],
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
head_size = query.shape[-1]
|
||||||
|
scale = 1.0 / (head_size ** 0.5)
|
||||||
|
|
||||||
|
num_seqs = len(cu_seq_lens) - 1
|
||||||
|
ref_outputs = []
|
||||||
|
for i in range(num_seqs):
|
||||||
|
start_idx = cu_seq_lens[i]
|
||||||
|
end_idx = cu_seq_lens[i + 1]
|
||||||
|
seq_len = end_idx - start_idx
|
||||||
|
|
||||||
|
# Create attention mask
|
||||||
|
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
|
||||||
|
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
||||||
|
|
||||||
|
ref_output = ref_masked_attention(
|
||||||
|
query[start_idx:end_idx],
|
||||||
|
key[start_idx:end_idx],
|
||||||
|
value[start_idx:end_idx],
|
||||||
|
scale,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
)
|
||||||
|
ref_outputs.append(ref_output)
|
||||||
|
ref_output = torch.cat(ref_outputs, dim=0)
|
||||||
|
return ref_output
|
||||||
|
|
||||||
|
|
||||||
def test_single_query_cached_kv_attention(
|
def test_single_query_cached_kv_attention(
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@ -156,30 +189,29 @@ def test_multi_query_kv_attention(
|
|||||||
causal=True,
|
causal=True,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
ref_outputs = []
|
cu_seq_lens = cu_seq_lens.cpu().tolist()
|
||||||
for i, seq_len in enumerate(seq_lens):
|
ref_output = ref_multi_query_kv_attention(
|
||||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
|
cu_seq_lens,
|
||||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
query,
|
||||||
start_idx = cu_seq_lens[i]
|
key,
|
||||||
end_idx = cu_seq_lens[i + 1]
|
value,
|
||||||
ref_output = ref_masked_attention(
|
dtype,
|
||||||
query[start_idx:end_idx],
|
|
||||||
key[start_idx:end_idx],
|
|
||||||
value[start_idx:end_idx],
|
|
||||||
scale,
|
|
||||||
attn_mask=attn_mask,
|
|
||||||
)
|
)
|
||||||
ref_outputs.append(ref_output)
|
|
||||||
ref_output = torch.cat(ref_outputs, dim=0)
|
|
||||||
|
|
||||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_attention() -> None:
|
def test_attention(seed: int) -> None:
|
||||||
|
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
|
||||||
|
# the test fails due to the precision issue. Re-run the test if it fails.
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
for dtype in [torch.half, torch.float]:
|
for dtype in [torch.half, torch.float]:
|
||||||
for block_size in [8, 16]:
|
for block_size in [8, 16]:
|
||||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||||
|
print(f'Testing single_query_cached_kv_attention with '
|
||||||
|
f'dtype={dtype}, block_size={block_size}, '
|
||||||
|
f'head_size={head_size}')
|
||||||
test_single_query_cached_kv_attention(
|
test_single_query_cached_kv_attention(
|
||||||
num_tokens=37,
|
num_tokens=37,
|
||||||
num_heads=3,
|
num_heads=3,
|
||||||
@ -193,6 +225,8 @@ def test_attention() -> None:
|
|||||||
for dtype in [torch.half]:
|
for dtype in [torch.half]:
|
||||||
# NOTE(woosuk): FlashAttention does not support head_size > 128.
|
# NOTE(woosuk): FlashAttention does not support head_size > 128.
|
||||||
for head_size in [64, 80, 96, 128]:
|
for head_size in [64, 80, 96, 128]:
|
||||||
|
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
|
||||||
|
f'head_size={head_size}')
|
||||||
test_multi_query_kv_attention(
|
test_multi_query_kv_attention(
|
||||||
num_seqs=11,
|
num_seqs=11,
|
||||||
num_heads=3,
|
num_heads=3,
|
||||||
@ -202,4 +236,4 @@ def test_attention() -> None:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_attention()
|
test_attention(seed=0)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user