Fix race condition in backward pass (smem_dq)

This commit is contained in:
Tri Dao 2022-06-25 18:02:30 -07:00
parent eeca63a72a
commit ea38d3d261

View File

@ -475,6 +475,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
static_assert(Gmem_tile_dq::LOOPS == 1);
// Swizzle the elements and do the final reduction.
// Need to syncthreads here, otherwise the smem_dq reads from the previous iteration
// might happen after the smem_dq writes in this iteration.
__syncthreads();
smem_dq.store(acc_dq, 0);
typename Smem_tile_dot::Fragment frag_dot[2][Mma_tile_dkv::MMAS_N];