Fix typo with lse_max == -INFINITY

This commit is contained in:
Tri Dao 2023-08-29 21:48:46 -07:00
parent 8a326bbc9e
commit 31920dda5f

View File

@ -1118,7 +1118,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); }
MaxOp<float> max_op;
lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
lse_max == lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
float lse_sum = expf(lse_accum(0) - lse_max);
#pragma unroll
for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }