diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 4c02f33c..d8199c8e 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -164,6 +164,7 @@ def run_single_query_cached_kv_attention( block_size: int, num_blocks: int, dtype: torch.dtype, + num_kv_heads: int = None, ) -> None: qkv = torch.empty(num_tokens, 3, @@ -202,6 +203,14 @@ def run_single_query_cached_kv_attention( head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda") scale = float(1.0 / (head_size**0.5)) + + num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + assert num_heads % num_kv_heads == 0 + num_queries_per_kv = num_heads // num_kv_heads + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), + num_queries_per_kv) + output = torch.empty(num_tokens, num_heads, head_size,