From bf87484efac9812b5854266025f86bbb21894e4b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 4 Sep 2023 09:20:06 +0900 Subject: [PATCH] [BugFix] Fix NaN errors in paged attention kernel (#936) --- csrc/attention/attention_kernels.cu | 12 ++++++++++++ csrc/attention/dtype_bfloat16.cuh | 10 ++++++++++ csrc/attention/dtype_float16.cuh | 10 +++++----- csrc/attention/dtype_float32.cuh | 5 +++++ 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 568d1fb1..d603f8e4 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -246,6 +246,8 @@ __global__ void single_query_cached_kv_attention_kernel( accs[i] = 0.f; } + scalar_t zero_value; + zero(zero_value); for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { const int physical_block_number = block_table[block_idx]; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; @@ -261,6 +263,16 @@ __global__ void single_query_cached_kv_attention_kernel( if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec = *reinterpret_cast(v_ptr + offset); + if (block_idx == num_blocks - 1) { + // NOTE(woosuk): When v_vec contains the tokens that are out of the context, + // we should explicitly zero out the values since they may contain NaNs. + // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j <= V_VEC_SIZE; j++) { + v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + } + } accs[i] += dot(logits_vec, v_vec); } } diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 70fd064b..2154bfcf 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -420,4 +420,14 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { #endif } +// Zero-out a variable. +inline __device__ void zero(__nv_bfloat16& dst) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + // Same as CUDART_ZERO_BF16 introduced in CUDA 12.2. + dst = __ushort_as_bfloat16((unsigned short)0x0000U); +#endif +} + } // namespace vllm diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index 6ffc30cd..e6792112 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -390,11 +390,6 @@ inline __device__ float sum(uint4 v) { return sum(c); } -// Zero-out a vector. -inline __device__ void zero(uint16_t& dst) { - dst = uint16_t(0); -} - // From float32 to float16. inline __device__ void from_float(uint16_t& dst, float src) { dst = float_to_half(src); @@ -441,4 +436,9 @@ inline __device__ Float8_ to_float(uint4 u) { return tmp; } +// Zero-out a variable. +inline __device__ void zero(uint16_t& dst) { + dst = uint16_t(0); +} + } // namespace vllm diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh index 960cf48e..b200d2d2 100644 --- a/csrc/attention/dtype_float32.cuh +++ b/csrc/attention/dtype_float32.cuh @@ -265,4 +265,9 @@ inline __device__ Float8_ to_float(Float8_ u) { return u; } +// Zero-out a variable. +inline __device__ void zero(float& dst) { + dst = 0.f; +} + } // namespace vllm