From 45567a25a2f74cabe17f9d59f88169ad06d4c874 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Sat, 15 Apr 2023 06:09:41 +0000 Subject: [PATCH] only 1 thread writes to global mem in fprop Signed-off-by: Kirthi Shankar Sivamani --- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 3975c97..ee5d68d 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -667,6 +667,8 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { const int bidb = blockIdx.x; // The block index for the head. const int bidh = blockIdx.y; + // The block index. + const int bidx = gridDim.x * bidh + bidb; // The thread index. const int tidx = threadIdx.x; @@ -678,8 +680,10 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern. auto seeds = at::cuda::philox::unpack(params.philox_args); - params.rng_state[0] = std::get<0>(seeds); - params.rng_state[1] = std::get<1>(seeds); + if (bidx == 0 && tidx == 0) { + params.rng_state[0] = std::get<0>(seeds); + params.rng_state[1] = std::get<1>(seeds); + } Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32); constexpr int M = Kernel_traits::Cta_tile_p::M; const int STEPS = (params.seqlen_q + M - 1) / M;