diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index ec493940..b6766e1e 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -1,5 +1,5 @@ import random -from typing import Optional +from typing import List, Optional from flash_attn.flash_attention import FlashAttention import torch @@ -64,6 +64,39 @@ def ref_single_query_cached_kv_attention( output[i].copy_(out, non_blocking=True) +def ref_multi_query_kv_attention( + cu_seq_lens: List[int], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + head_size = query.shape[-1] + scale = 1.0 / (head_size ** 0.5) + + num_seqs = len(cu_seq_lens) - 1 + ref_outputs = [] + for i in range(num_seqs): + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + seq_len = end_idx - start_idx + + # Create attention mask + attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5 + attn_mask = attn_mask.to(dtype=dtype, device='cuda') + + ref_output = ref_masked_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + def test_single_query_cached_kv_attention( num_tokens: int, num_heads: int, @@ -156,30 +189,29 @@ def test_multi_query_kv_attention( causal=True, )[0] - ref_outputs = [] - for i, seq_len in enumerate(seq_lens): - attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5 - attn_mask = attn_mask.to(dtype=dtype, device='cuda') - start_idx = cu_seq_lens[i] - end_idx = cu_seq_lens[i + 1] - ref_output = ref_masked_attention( - query[start_idx:end_idx], - key[start_idx:end_idx], - value[start_idx:end_idx], - scale, - attn_mask=attn_mask, - ) - ref_outputs.append(ref_output) - ref_output = torch.cat(ref_outputs, dim=0) - + cu_seq_lens = cu_seq_lens.cpu().tolist() + ref_output = ref_multi_query_kv_attention( + cu_seq_lens, + query, + key, + value, + dtype, + ) assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) @torch.inference_mode() -def test_attention() -> None: +def test_attention(seed: int) -> None: + # NOTE(woosuk): Even when the seed is fixed, there is a chance that + # the test fails due to the precision issue. Re-run the test if it fails. + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) for dtype in [torch.half, torch.float]: for block_size in [8, 16]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: + print(f'Testing single_query_cached_kv_attention with ' + f'dtype={dtype}, block_size={block_size}, ' + f'head_size={head_size}') test_single_query_cached_kv_attention( num_tokens=37, num_heads=3, @@ -193,6 +225,8 @@ def test_attention() -> None: for dtype in [torch.half]: # NOTE(woosuk): FlashAttention does not support head_size > 128. for head_size in [64, 80, 96, 128]: + print(f'Testing multi_query_kv_attention with dtype={dtype}, ' + f'head_size={head_size}') test_multi_query_kv_attention( num_seqs=11, num_heads=3, @@ -202,4 +236,4 @@ def test_attention() -> None: if __name__ == '__main__': - test_attention() + test_attention(seed=0)