From 2dc1b205f609e83712438631c048988cf5ffce42 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 9 Jul 2022 23:17:14 -0700 Subject: [PATCH] Fix Illegal Memory Access bug in fwd when d=16 --- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 44eb27a..c8dcee8 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -514,13 +514,22 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG; } - softmax.reduce_max_after_sync_(p_max_o, rows); + // When d = 16, O only has 16 x 16 = 256 elements, and each of the 128 threads wants + // to write 4 elements, so only half of the thread should deal with O. + bool o_rows_are_valid = + (Kernel_traits::THREADS <= Gmem_tile_o::THREADS_PER_ROW * Gmem_tile_o::ROWS) + || (tidx / Gmem_tile_o::THREADS_PER_ROW < Gmem_tile_o::ROWS); + if (o_rows_are_valid) { + softmax.reduce_max_after_sync_(p_max_o, rows); + } static_assert(Mma_tile_o::MMAS_M == 1); for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { p_max_o[jj][0] *= params.scale_bmm1f; } float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP]; - if (!Is_first) { smem_softmax_lse.load(p_prev_scale_o, rows, l % 2); } + if ((!Is_first) && o_rows_are_valid) { + smem_softmax_lse.load(p_prev_scale_o, rows, l % 2); + } // if (!Is_first) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { // printf("p_prev_scale_o=%.6f\n", p_prev_scale_o[0]); @@ -537,7 +546,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i static_assert(Mma_tile_o::MMAS_M == 1); float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; - softmax.reduce_sum_after_sync_(p_sum_o, rows); + if (o_rows_are_valid) { + softmax.reduce_sum_after_sync_(p_sum_o, rows); + } if (!Is_first) { for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]); @@ -558,7 +569,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // printf("p_sum_log=%.6f\n", p_sum_log[jj][0]); // } // } - if ((tidx % Gmem_tile_o::THREADS_PER_ROW == 0) && (tidx / Gmem_tile_o::THREADS_PER_ROW < Gmem_tile_o::ROWS)) { + if ((tidx % Gmem_tile_o::THREADS_PER_ROW == 0) && o_rows_are_valid) { gmem_softmax_lse.store_row( reinterpret_cast(p_sum_log[jj]), rows[jj]); }