only 1 thread writes to global mem in fprop

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
This commit is contained in:
Kirthi Shankar Sivamani 2023-04-15 06:09:41 +00:00
parent a0997bc77c
commit 45567a25a2

View File

@ -667,6 +667,8 @@ inline __device__ void device_1xN_loop(const Params &params) {
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.y;
// The block index.
const int bidx = gridDim.x * bidh + bidb;
// The thread index.
const int tidx = threadIdx.x;
@ -678,8 +680,10 @@ inline __device__ void device_1xN_loop(const Params &params) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern.
auto seeds = at::cuda::philox::unpack(params.philox_args);
params.rng_state[0] = std::get<0>(seeds);
params.rng_state[1] = std::get<1>(seeds);
if (bidx == 0 && tidx == 0) {
params.rng_state[0] = std::get<0>(seeds);
params.rng_state[1] = std::get<1>(seeds);
}
Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32);
constexpr int M = Kernel_traits::Cta_tile_p::M;
const int STEPS = (params.seqlen_q + M - 1) / M;