Fix Illegal Memory Access bug in fwd when d=16

This commit is contained in:
Tri Dao 2022-07-09 23:17:14 -07:00
parent 5b838a8bef
commit 2dc1b205f6

View File

@ -514,13 +514,22 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG;
}
softmax.reduce_max_after_sync_(p_max_o, rows);
// When d = 16, O only has 16 x 16 = 256 elements, and each of the 128 threads wants
// to write 4 elements, so only half of the thread should deal with O.
bool o_rows_are_valid =
(Kernel_traits::THREADS <= Gmem_tile_o::THREADS_PER_ROW * Gmem_tile_o::ROWS)
|| (tidx / Gmem_tile_o::THREADS_PER_ROW < Gmem_tile_o::ROWS);
if (o_rows_are_valid) {
softmax.reduce_max_after_sync_(p_max_o, rows);
}
static_assert(Mma_tile_o::MMAS_M == 1);
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
p_max_o[jj][0] *= params.scale_bmm1f;
}
float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP];
if (!Is_first) { smem_softmax_lse.load(p_prev_scale_o, rows, l % 2); }
if ((!Is_first) && o_rows_are_valid) {
smem_softmax_lse.load(p_prev_scale_o, rows, l % 2);
}
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("p_prev_scale_o=%.6f\n", p_prev_scale_o[0]);
@ -537,7 +546,9 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
static_assert(Mma_tile_o::MMAS_M == 1);
float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
softmax.reduce_sum_after_sync_(p_sum_o, rows);
if (o_rows_are_valid) {
softmax.reduce_sum_after_sync_(p_sum_o, rows);
}
if (!Is_first) {
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]);
@ -558,7 +569,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// printf("p_sum_log=%.6f\n", p_sum_log[jj][0]);
// }
// }
if ((tidx % Gmem_tile_o::THREADS_PER_ROW == 0) && (tidx / Gmem_tile_o::THREADS_PER_ROW < Gmem_tile_o::ROWS)) {
if ((tidx % Gmem_tile_o::THREADS_PER_ROW == 0) && o_rows_are_valid) {
gmem_softmax_lse.store_row(
reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M]>(p_sum_log[jj]), rows[jj]);
}