From c267b1a02c952b68a897c96201f32ad57e0b955e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 8 Apr 2023 13:36:09 -0700 Subject: [PATCH] Add query stride to multi_query_cached_kv_attention & Add kernel benchmark script (#27) * Add query stride to multi_query_cached_kv_attention * Add kernel benchmark script --- benchmark/benchmark_attention.py | 165 +++++++++++++++++++++++++++++++ csrc/attention_kernels.cu | 16 ++- tests/kernels/attention.py | 8 +- 3 files changed, 181 insertions(+), 8 deletions(-) create mode 100644 benchmark/benchmark_attention.py diff --git a/benchmark/benchmark_attention.py b/benchmark/benchmark_attention.py new file mode 100644 index 00000000..ac43ddb3 --- /dev/null +++ b/benchmark/benchmark_attention.py @@ -0,0 +1,165 @@ +import functools +import random +import time +from typing import List + +from flash_attn.flash_attn_interface import _flash_attn_forward +import torch + +from cacheflow import attention_ops + + +def benchmark(name, f, num_warmup = 10, num_iters = 100): + for _ in range(num_warmup): + f() + torch.cuda.synchronize() + + start = time.time() + for _ in range(num_iters): + f() + torch.cuda.synchronize() + end = time.time() + print(f'{name}: {(end - start) / num_iters * 1000:.3f} ms') + + +@torch.inference_mode() +def benchmark_multi_query_cached_kv_attention( + query_lens: List[int], + context_lens: List[int], + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, +) -> None: + print(f'query_lens: {query_lens}, context_lens: {context_lens}, ' + f'num_heads: {num_heads}, head_size: {head_size}, block_size: ' + f'{block_size}, num_blocks: {num_blocks}, dtype: {dtype}') + # Create query tensor. + num_queries = len(query_lens) + cu_query_lens = [0] + for query_len in query_lens: + cu_query_lens.append(cu_query_lens[-1] + query_len) + num_total_tokens = cu_query_lens[-1] + qkv = torch.randn( + num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') + query, _, _ = qkv.unbind(dim=1) + + # Create key and value cache. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_block_shape = (num_heads, head_size // x, block_size, x) + key_cache = torch.randn( + size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda') + value_block_shape = (num_heads, head_size, block_size) + value_cache = torch.randn( + size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda') + + # Create block tables. + max_context_len = max(context_lens) + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_queries): + block_table = [ + random.randint(0, num_blocks - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda') + + # Create input and output data structures. + cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda') + context_len_tensor = torch.tensor(context_lens, dtype=torch.int, device='cuda') + scale = float(1.0 / (head_size ** 0.5)) + output = torch.empty( + num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda') + + # Run our implementation. + def run_ours(): + attention_ops.multi_query_cached_kv_attention( + cu_query_lens, + output, + query, + key_cache, + value_cache, + scale, + block_tables, + context_len_tensor, + block_size, + max_context_len, + ) + benchmark('Ours', run_ours) + + # Upper bound: Flash attention. + # Becuase Flash attention cannot read our own cache, + # we make key and value tensors contiguous. + num_kv_tokens = sum(context_lens) + cu_context_lens = [0] + for context_len in context_lens: + cu_context_lens.append(cu_context_lens[-1] + context_len) + cu_context_lens = torch.tensor(cu_context_lens, dtype=torch.int, device='cuda') + qkv = torch.randn( + num_kv_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') + _, key, value = qkv.unbind(dim=1) + ref_output = torch.empty_like(output) + + # Run Flash attention. + def run_flash_attn(): + _flash_attn_forward( + query, + key, + value, + ref_output, + cu_query_lens, + cu_context_lens, + max(query_lens), + max_context_len, + dropout_p=0.0, + softmax_scale=scale, + causal=True, + return_softmax=False, + ) + benchmark('Flash attention', run_flash_attn) + + +if __name__ == '__main__': + BLOCK_SIZE = 8 + NUM_BLOCKS = 1024 + DTYPE = torch.half + + # LLaMA-13B and OPT-13B + NUM_HEADS = 40 + HEAD_SIZE = 128 + + run_benchmark = functools.partial( + benchmark_multi_query_cached_kv_attention, + num_heads=NUM_HEADS, + head_size=HEAD_SIZE, + block_size=BLOCK_SIZE, + num_blocks=NUM_BLOCKS, + dtype=DTYPE, + ) + + run_benchmark( + query_lens=[64] * 1, + context_lens=[64] * 1, + ) + run_benchmark( + query_lens=[128] * 1, + context_lens=[128] * 1, + ) + run_benchmark( + query_lens=[64] * 8, + context_lens=[64] * 8, + ) + run_benchmark( + query_lens=[128] * 8, + context_lens=[128] * 8, + ) + run_benchmark( + query_lens=[64, 32, 16], + context_lens=[128, 256, 64], + ) + run_benchmark( + query_lens=[1024], + context_lens=[1024], + ) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 73c29b74..fc3a2717 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -271,7 +271,8 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_( const float scale, const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq] const int context_len, - const int max_num_blocks_per_seq) { + const int max_num_blocks_per_seq, + const int q_stride) { constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; @@ -302,7 +303,8 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_( // For example, if the the thread group size is 4, then the first thread in the group // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... // th vectors of the query, and so on. - const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; Q_vec q_vecs[NUM_VECS_PER_THREAD]; #pragma unroll for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { @@ -514,7 +516,8 @@ __global__ void multi_query_cached_kv_attention_kernel( const float scale, const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_prompts] - const int max_num_blocks_per_seq) { + const int max_num_blocks_per_seq, + const int q_stride) { const int seq_idx = blockIdx.y; const int prompt_idx = seq_prompt_mapping[seq_idx]; const int seq_start_idx = cu_query_lens[prompt_idx]; @@ -532,7 +535,8 @@ __global__ void multi_query_cached_kv_attention_kernel( scale, block_table, context_len, - max_num_blocks_per_seq); + max_num_blocks_per_seq, + q_stride); } } // namespace cacheflow @@ -696,7 +700,8 @@ void single_query_cached_kv_attention( scale, \ block_tables_ptr, \ context_lens_ptr, \ - max_num_blocks_per_seq); + max_num_blocks_per_seq, \ + query_stride); // TODO(woosuk): Tune NUM_THREADS. @@ -719,6 +724,7 @@ void multi_query_cached_kv_attention_launcher( int num_heads = query.size(1); int head_size = query.size(2); int max_num_blocks_per_seq = block_tables.size(1); + int query_stride = query.stride(0); int* cu_query_lens_ptr = cu_query_lens.data_ptr(); int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr(); diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index 0747566e..a66f2c3d 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -285,8 +285,9 @@ def test_multi_query_cached_kv_attention( cu_query_lens.append(cu_query_lens[-1] + query_len) num_total_tokens = cu_query_lens[-1] - query = torch.randn( - num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda') + qkv = torch.randn( + num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') + query, _, _ = qkv.unbind(dim=1) x = 16 // torch.tensor([], dtype=dtype).element_size() key_block_shape = (num_heads, head_size // x, block_size, x) key_cache = torch.randn( @@ -314,7 +315,8 @@ def test_multi_query_cached_kv_attention( block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda') scale = float(1.0 / (head_size ** 0.5)) - output = torch.empty_like(query) + output = torch.empty( + num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda') attention_ops.multi_query_cached_kv_attention( cu_query_lens,