From 26d7d92f3db1c31bacc8d425ee45833b7ec7c16e Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 3 Sep 2023 11:39:09 -0700 Subject: [PATCH] Fix splitKV combine function when local LSEs are all -inf --- csrc/flash_attn/src/flash_fwd_kernel.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 5a2a314..3c6d800 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -1124,7 +1124,9 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } SumOp sum_op; lse_sum = Allreduce::run(lse_sum, sum_op); - ElementAccum lse_logsum = logf(lse_sum) + lse_max; + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; } // Store the scales exp(lse - lse_logsum) in shared memory.