From c413c41cda0f9359e7a12bb674c0f87bf41798c5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 18 Feb 2023 19:22:57 +0000 Subject: [PATCH] Add reshape_and_cache op --- csrc/cache.cpp | 11 +++++++ csrc/cache_kernels.cu | 72 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/csrc/cache.cpp b/csrc/cache.cpp index ab33ee12..e20d1c3f 100644 --- a/csrc/cache.cpp +++ b/csrc/cache.cpp @@ -5,9 +5,20 @@ void copy_blocks( torch::Tensor& dst, const std::map& block_mapping); +void reshape_and_cache( + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "copy_cache_blocks", ©_blocks, "Copy the cache blocks from src to dst"); + m.def( + "reshape_and_cache", + &reshape_and_cache, + "Reshape the key and value tensors and cache them"); } diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 7a8befc0..1cce1bc7 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -2,6 +2,7 @@ #include +#include #include #include @@ -41,3 +42,74 @@ void copy_blocks( stream); } } + +template +__global__ void reshape_and_cache_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, block_size, head_size] + const int* __restrict__ slot_mapping, // [num_tokens] + const int num_heads, + const int head_size, + const int block_size, + const int x) { + const int token_idx = blockIdx.x; + const int slot_idx = slot_mapping[token_idx]; + const int block_idx = slot_idx / block_size; + const int block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int src_idx = token_idx * n + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int x_idx = head_offset / x; + const int x_offset = head_offset % x; + + const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + + x_idx * block_size * x + + block_offset * x + + x_offset; + const int tgt_value_idx = block_idx * num_heads * block_size * head_size + + head_idx * block_size * head_size + + block_offset * head_size + + head_offset; + key_cache[tgt_key_idx] = __ldg(&key[src_idx]); + value_cache[tgt_value_idx] = __ldg(&value[src_idx]); + } +} + +void reshape_and_cache( + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping) { + int num_tokens = key.size(0); + int head_num = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + dim3 grid(num_tokens); + dim3 block(std::min(head_num * head_size, 512)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + key.scalar_type(), + "reshape_and_cache_kernel", + [&] { + reshape_and_cache_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), + head_num, + head_size, + block_size, + x); + }); +}