diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 3975c97..ee5d68d 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -667,6 +667,8 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { const int bidb = blockIdx.x; // The block index for the head. const int bidh = blockIdx.y; + // The block index. + const int bidx = gridDim.x * bidh + bidb; // The thread index. const int tidx = threadIdx.x; @@ -678,8 +680,10 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern. auto seeds = at::cuda::philox::unpack(params.philox_args); - params.rng_state[0] = std::get<0>(seeds); - params.rng_state[1] = std::get<1>(seeds); + if (bidx == 0 && tidx == 0) { + params.rng_state[0] = std::get<0>(seeds); + params.rng_state[1] = std::get<1>(seeds); + } Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32); constexpr int M = Kernel_traits::Cta_tile_p::M; const int STEPS = (params.seqlen_q + M - 1) / M;