Fix splitKV combine function when local LSEs are all -inf
This commit is contained in:
parent
de2949f37d
commit
26d7d92f3d
@ -1124,7 +1124,9 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
|
||||
for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }
|
||||
SumOp<float> sum_op;
|
||||
lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);
|
||||
ElementAccum lse_logsum = logf(lse_sum) + lse_max;
|
||||
// For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise
|
||||
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
|
||||
ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;
|
||||
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
|
||||
if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; }
|
||||
// Store the scales exp(lse - lse_logsum) in shared memory.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user