Get rid of o_rows_are_valid since we don't have headdim=16 anymore
This commit is contained in:
parent
46fd2a20b2
commit
c422fee377
@ -554,20 +554,13 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
|||||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||||
rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG;
|
rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG;
|
||||||
}
|
}
|
||||||
// When d = 16, O only has 16 x 16 = 256 elements, and each of the 128 threads wants
|
|
||||||
// to write 4 elements, so only half of the thread should deal with O.
|
|
||||||
bool o_rows_are_valid =
|
|
||||||
(Kernel_traits::THREADS <= Gmem_tile_o::THREADS_PER_ROW * Gmem_tile_o::ROWS)
|
|
||||||
|| (tidx / Gmem_tile_o::THREADS_PER_ROW < Gmem_tile_o::ROWS);
|
|
||||||
if (o_rows_are_valid) {
|
|
||||||
softmax.reduce_max_after_sync_(p_max_o, rows);
|
softmax.reduce_max_after_sync_(p_max_o, rows);
|
||||||
}
|
|
||||||
static_assert(Mma_tile_o::MMAS_M == 1);
|
static_assert(Mma_tile_o::MMAS_M == 1);
|
||||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||||
p_max_o[jj][0] *= params.scale_bmm1f;
|
p_max_o[jj][0] *= params.scale_bmm1f;
|
||||||
}
|
}
|
||||||
float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP];
|
float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP];
|
||||||
if ((!Is_first) && o_rows_are_valid) {
|
if (!Is_first) {
|
||||||
smem_softmax_lse.load(p_prev_scale_o, rows);
|
smem_softmax_lse.load(p_prev_scale_o, rows);
|
||||||
}
|
}
|
||||||
// if (!Is_first) {
|
// if (!Is_first) {
|
||||||
@ -586,9 +579,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
|||||||
|
|
||||||
static_assert(Mma_tile_o::MMAS_M == 1);
|
static_assert(Mma_tile_o::MMAS_M == 1);
|
||||||
float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
|
float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
|
||||||
if (o_rows_are_valid) {
|
|
||||||
softmax.reduce_sum_after_sync_(p_sum_o, rows);
|
softmax.reduce_sum_after_sync_(p_sum_o, rows);
|
||||||
}
|
|
||||||
if (!Is_first) {
|
if (!Is_first) {
|
||||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||||
p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]);
|
p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]);
|
||||||
@ -609,7 +600,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
|||||||
// printf("p_sum_log=%.6f\n", p_sum_log[jj][0]);
|
// printf("p_sum_log=%.6f\n", p_sum_log[jj][0]);
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
if ((tidx % Gmem_tile_o::THREADS_PER_ROW == 0) && o_rows_are_valid) {
|
if (tidx % Gmem_tile_o::THREADS_PER_ROW == 0) {
|
||||||
gmem_softmax_lse.store_row(
|
gmem_softmax_lse.store_row(
|
||||||
reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M]>(p_sum_log[jj]), rows[jj]);
|
reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M]>(p_sum_log[jj]), rows[jj]);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user