From 1274ec3e7ee9fb6c4cd727c7858991d4e11044ff Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 14 Jan 2024 12:19:17 -0800 Subject: [PATCH] Move dropout to a separate file (dropout.h) --- csrc/flash_attn/src/dropout.h | 91 ++++++++++++++++++++++++++ csrc/flash_attn/src/flash_bwd_kernel.h | 19 +++--- csrc/flash_attn/src/flash_fwd_kernel.h | 27 ++++---- csrc/flash_attn/src/softmax.h | 70 -------------------- 4 files changed, 112 insertions(+), 95 deletions(-) create mode 100644 csrc/flash_attn/src/dropout.h diff --git a/csrc/flash_attn/src/dropout.h b/csrc/flash_attn/src/dropout.h new file mode 100644 index 0000000..43c35c2 --- /dev/null +++ b/csrc/flash_attn/src/dropout.h @@ -0,0 +1,91 @@ +#pragma once + +#include "philox.cuh" +#include "utils.h" + +namespace flash { + +struct Dropout { + + const unsigned long long seed, offset; + const uint8_t p_dropout_in_uint8_t; + + inline __device__ Dropout(const unsigned long long seed, const unsigned long long offset, + const uint8_t p_dropout_in_uint8_t, + const int bid, const int hid, const int tid, const int nheads) + : seed(seed) + , offset(offset + (bid * nheads + hid) * 32 + tid % 32) + , p_dropout_in_uint8_t(p_dropout_in_uint8_t) { + } + + template + inline __device__ void apply_dropout(Tensor &tensor_, + int block_row_start, int block_col_start, int block_row_stride) { + // tensor_ has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_rowcol_dropout(tensor_.layout())); + // tensor has shape (8, MMA_M, MMA_N / 2) + using T = typename Engine::value_type; + auto encode_dropout = [](bool keep, T val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); + }; + static_assert(decltype(size<2>(tensor))::value % 2 == 0); + const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); + const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); + // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } + #pragma unroll + for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { + uint2 rowcol = make_uint2(block_row_start, block_col_start); + #pragma unroll + for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { + // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} + uint4 random_uint4 = flash::philox(seed, reinterpret_cast(rowcol), offset); + // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} + uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); + // Special implementation for 16-bit types: we duplicate the threshold to the + // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction + // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, + // and the high 16 bits will be either 0xffff or 0x0000, depending on whether + // the random value is less than the threshold. + // We then do a bit-wise AND between the mask and the original value (in 32-bit). + // We're exploiting the fact that floating point comparison is equivalent to integer + // comparison, since we're comparing unsigned integers whose top 8-bits are zero. + if (!encode_dropout_in_sign_bit + && (std::is_same::value || std::is_same::value)) { + uint16_t rnd_16[16]; + #pragma unroll + for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } + uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); + #pragma unroll + for (int j = 0; j < 2; j++) { + Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + #pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t mask; + asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); + tensor_uint32(i) &= mask; + } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } else { + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < 8; i++) { + tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); + } + Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } + // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); + // // } + } + } + } + +}; + +} // namespace flash diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index df0296b..ac9da62 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -14,6 +14,7 @@ #include "kernel_traits.h" #include "utils.h" #include "softmax.h" +#include "dropout.h" #include "alibi.h" @@ -796,8 +797,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view); } - auto seed = params.rng_state[0]; - auto offset = params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32; + flash::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t, + bidb, bidh, tidx, params.h); clear(acc_dv); clear(acc_dk); @@ -886,9 +887,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 static_assert(MMA_N_SdP % 2 == 0); int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); - flash::apply_dropout( - scores, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, AtomLayoutMS + dropout.template apply_dropout( + scores, block_row_idx, block_col_idx, AtomLayoutMS ); } // Convert scores from fp32 to fp16/bf16 @@ -1395,8 +1395,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in #pragma unroll for (int mi = 0; mi < size(dP_sum); ++mi) { dP_sum(mi) = sdPsum(get<0>(taccScS_row(mi))); } - auto seed = params.rng_state[0]; - auto offset = params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32; + flash::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t, + bidb, bidh, tidx, params.h); clear(acc_dq); @@ -1445,9 +1445,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 static_assert(MMA_N_SdP % 2 == 0); int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); - flash::apply_dropout( - scores, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, AtomLayoutMS + dropout.template apply_dropout( + scores, block_row_idx, block_col_idx, AtomLayoutMS ); } // Convert scores from fp32 to fp16/bf16 diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index ca58e66..0ea195b 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -14,6 +14,7 @@ #include "kernel_traits.h" #include "utils.h" #include "softmax.h" +#include "dropout.h" #include "alibi.h" @@ -75,15 +76,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; - auto seeds = at::cuda::philox::unpack(params.philox_args); - unsigned long long seed = std::get<0>(seeds); - unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; + auto seed_offset = at::cuda::philox::unpack(params.philox_args); + flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, + bidb, bidh, tidx, params.h); // Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might // exit early and no one saves the rng states. if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { - params.rng_state[0] = seed; - params.rng_state[1] = std::get<1>(seeds); + params.rng_state[0] = std::get<0>(seed_offset); + params.rng_state[1] = std::get<1>(seed_offset); } const BlockInfo binfo(params, bidb); @@ -404,16 +405,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi if (Return_softmax) { Tensor acc_s_f16 = flash::convert_type(acc_s); Tensor acc_s_f16_drop = make_tensor(acc_s_f16.data(), rP.layout()); - flash::apply_dropout( - acc_s_f16_drop, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, kNWarps + dropout.template apply_dropout( + acc_s_f16_drop, block_row_idx, block_col_idx, kNWarps ); cute::copy(acc_s_f16, tSgS); tSgS.data() = tSgS.data() + (-kBlockN); } if (Is_dropout) { - flash::apply_dropout(rP, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, kNWarps); + dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); } // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) @@ -489,16 +488,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi if (Return_softmax) { Tensor acc_s_f16 = flash::convert_type(acc_s); Tensor acc_s_f16_drop = make_tensor(acc_s_f16.data(), rP.layout()); - flash::apply_dropout( - acc_s_f16_drop, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, kNWarps + dropout.template apply_dropout( + acc_s_f16_drop, block_row_idx, block_col_idx, kNWarps ); cute::copy(acc_s_f16, tSgS); tSgS.data() = tSgS.data() + (-kBlockN); } if (Is_dropout) { - flash::apply_dropout(rP, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, kNWarps); + dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); } // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index df449aa..5d6609b 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -212,74 +212,4 @@ inline __device__ void apply_mask_causal_w_idx( } } -template -inline __device__ void apply_dropout(Tensor &tensor_, uint8_t p_dropout_in_uint8_t, - unsigned long long seed, unsigned long long offset, - int block_row_start, int block_col_start, - int block_row_stride) { - // tensor_ has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_rowcol_dropout(tensor_.layout())); - // tensor has shape (8, MMA_M, MMA_N / 2) - using T = typename Engine::value_type; - auto encode_dropout = [](bool keep, T val) { - return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); - }; - static_assert(decltype(size<2>(tensor))::value % 2 == 0); - const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); - const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); - // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } - #pragma unroll - for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { - uint2 rowcol = make_uint2(block_row_start, block_col_start); - #pragma unroll - for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { - // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} - uint4 random_uint4 = flash::philox(seed, reinterpret_cast(rowcol), offset); - // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} - uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); - // Special implementation for 16-bit types: we duplicate the threshold to the - // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction - // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, - // and the high 16 bits will be either 0xffff or 0x0000, depending on whether - // the random value is less than the threshold. - // We then do a bit-wise AND between the mask and the original value (in 32-bit). - // We're exploiting the fact that floating point comparison is equivalent to integer - // comparison, since we're comparing unsigned integers whose top 8-bits are zero. - if (!encode_dropout_in_sign_bit - && (std::is_same::value || std::is_same::value)) { - uint16_t rnd_16[16]; - #pragma unroll - for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } - uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); - #pragma unroll - for (int j = 0; j < 2; j++) { - Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); - // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - #pragma unroll - for (int i = 0; i < 4; i++) { - uint32_t mask; - asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); - tensor_uint32(i) &= mask; - } - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - } - } else { - #pragma unroll - for (int j = 0; j < 2; j++) { - #pragma unroll - for (int i = 0; i < 8; i++) { - tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); - } - Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - } - } - // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); - // // } - } - } -} - } // namespace flash