Fix typo with lse_max == -INFINITY
This commit is contained in:
parent
8a326bbc9e
commit
31920dda5f
@ -1118,7 +1118,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
|
|||||||
for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); }
|
for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); }
|
||||||
MaxOp<float> max_op;
|
MaxOp<float> max_op;
|
||||||
lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
|
lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
|
||||||
lse_max == lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
|
lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
|
||||||
float lse_sum = expf(lse_accum(0) - lse_max);
|
float lse_sum = expf(lse_accum(0) - lse_max);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }
|
for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user