From 31920dda5fe34864adf582db6eec7ce722b11c7e Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 29 Aug 2023 21:48:46 -0700 Subject: [PATCH] Fix typo with lse_max == -INFINITY --- csrc/flash_attn/src/flash_fwd_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index b8dce9e..5a2a314 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -1118,7 +1118,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); } MaxOp max_op; lse_max = Allreduce::run(lse_max, max_op); - lse_max == lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf float lse_sum = expf(lse_accum(0) - lse_max); #pragma unroll for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }