Fix splitKV combine function when local LSEs are all -inf

This commit is contained in:
Tri Dao 2023-09-03 11:39:09 -07:00
parent de2949f37d
commit 26d7d92f3d

View File

@ -1124,7 +1124,9 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }
SumOp<float> sum_op;
lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);
ElementAccum lse_logsum = logf(lse_sum) + lse_max;
// For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; }
// Store the scales exp(lse - lse_logsum) in shared memory.