Fix race condition in backward pass (smem_dq)
This commit is contained in:
parent
eeca63a72a
commit
ea38d3d261
@ -475,6 +475,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
static_assert(Gmem_tile_dq::LOOPS == 1);
|
||||
|
||||
// Swizzle the elements and do the final reduction.
|
||||
// Need to syncthreads here, otherwise the smem_dq reads from the previous iteration
|
||||
// might happen after the smem_dq writes in this iteration.
|
||||
__syncthreads();
|
||||
smem_dq.store(acc_dq, 0);
|
||||
|
||||
typename Smem_tile_dot::Fragment frag_dot[2][Mma_tile_dkv::MMAS_N];
|
||||
|
||||
Loading…
Reference in New Issue
Block a user