diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index f6f325b..2185dc1 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -475,6 +475,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng static_assert(Gmem_tile_dq::LOOPS == 1); // Swizzle the elements and do the final reduction. + // Need to syncthreads here, otherwise the smem_dq reads from the previous iteration + // might happen after the smem_dq writes in this iteration. + __syncthreads(); smem_dq.store(acc_dq, 0); typename Smem_tile_dot::Fragment frag_dot[2][Mma_tile_dkv::MMAS_N];