From 050873327e83f0f84d4f5a37912ed5140d7e9d9f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 2 Jun 2022 14:09:46 -0700 Subject: [PATCH] Remove softmax fp16 max --- csrc/flash_attn/src/fmha/softmax.h | 125 ++--------------------------- csrc/flash_attn/src/fmha/utils.h | 5 -- 2 files changed, 5 insertions(+), 125 deletions(-) diff --git a/csrc/flash_attn/src/fmha/softmax.h b/csrc/flash_attn/src/fmha/softmax.h index a125b69..2de6761 100644 --- a/csrc/flash_attn/src/fmha/softmax.h +++ b/csrc/flash_attn/src/fmha/softmax.h @@ -58,12 +58,6 @@ inline __device__ float apply_exp_(float x, float max) { //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ __half2 apply_exp_(__half2 x, __half2 max) { - return h2exp(__hsub2(x, max)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - inline __device__ float apply_exp2_(float x, float max) { return exp2f(x - max); // With fast-math, this produces the same PTX instruction as the assembly below @@ -75,17 +69,9 @@ inline __device__ float apply_exp2_(float x, float max) { //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ __half2 apply_exp2_(__half2 x, __half2 max) { - return h2exp2(__hsub2(x, max)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template struct ReadType {}; -template<> struct ReadType<4, false> { using T = float;}; -template<> struct ReadType<8, false> { using T = float2;}; -template<> struct ReadType<4, true> { using T = __half2;}; -template<> struct ReadType<8, true> { using T = float2;}; +template struct ReadType {}; +template<> struct ReadType<4> { using T = float;}; +template<> struct ReadType<8> { using T = float2;}; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -118,8 +104,7 @@ struct Smem_tile_reduce { static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS; static_assert(LOOPS == 1); - using read_t = typename ReadType::T; - using read_half_t = typename ReadType::T; + using read_t = typename ReadType::T; __device__ inline Smem_tile_reduce(float *smem_, const int tidx) { @@ -152,17 +137,6 @@ struct Smem_tile_reduce { } } - __device__ inline void store(__half2 (&frag)[MMAS_M]) { - __half2 *smem_write_half_ = reinterpret_cast<__half2 *>(smem_write_); - if( qid_ == 0 ) { - #pragma unroll - for( int mi = 0; mi < MMAS_M; mi++ ) { - int offset = mi * 16 * WARPS_N; - smem_write_half_[offset + 0 * 8 * WARPS_N] = frag[mi]; - } - } - } - __device__ inline void load(read_t (&frag)[2 * MMAS_M]) { #pragma unroll for( int mi = 0; mi < MMAS_M; mi++ ) { @@ -172,15 +146,6 @@ struct Smem_tile_reduce { } } - __device__ inline void load(read_half_t (&frag)[MMAS_M]) { - read_half_t *smem_read_half_ = reinterpret_cast(smem_read_); - #pragma unroll - for( int mi = 0; mi < MMAS_M; mi++ ) { - int offset = mi * 16 * 4; - frag[mi] = smem_read_half_[offset + 0 * 8 * 4]; - } - } - __device__ inline void load_row(read_t (&frag)[MMAS_M], int row) { #pragma unroll for( int mi = 0; mi < MMAS_M; mi++ ) { @@ -304,29 +269,6 @@ struct Softmax_base { } } - // Apply the exp to all the elements. - inline __device__ void apply_exp(const __half2 (&max)[MMAS_M]) { - #pragma unroll - for (int mi = 0; mi < MMAS_M; ++mi) { - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - // max * log_2(e)) This allows the compiler to use the ffma - // instruction instead of fadd and fmul separately. - constexpr float kLog2e = M_LOG2E; - const float2 max_f = __half22float2(max[mi]); - const float max0_log2e = max_f.x * kLog2e, max1_log2e = max_f.y * kLog2e; - #pragma unroll - for (int ni = 0; ni < MMAS_N * 4; ++ni) { - float2 elt = __half22float2(elt_half_[mi][ni]); - elt_[mi * 2 + 0][ni] = apply_exp2_(elt.x * kLog2e, max0_log2e); - elt_[mi * 2 + 1][ni] = apply_exp2_(elt.y * kLog2e, max1_log2e); - // __half2 out = apply_exp_(elt_half_[mi][ni], max[mi]); - // float2 outf = __half22float2(out); - // elt_[mi * 2 + 0][ni] = outf.x; - // elt_[mi * 2 + 1][ni] = outf.y; - } - } - } - // Apply the exp to all the elements. template inline __device__ void apply_exp_col(const float (&max)[MMAS_N * 4]) { @@ -527,7 +469,6 @@ struct Softmax_base { int tidx_; // The elements. float elt_[MMAS_M * 2][MMAS_N * 4]; - __half2 elt_half_[MMAS_M][MMAS_N * 4]; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -638,34 +579,6 @@ struct Softmax : public Softmax_base { } } - // Scale FP32 fragments - template - inline __device__ void unpack_noscale_half_and_apply_mask(const Accumulator (&acc)[MMAS_M][MMAS_N], - const Mask &mask) { - - #pragma unroll - for( int mi = 0; mi < MMAS_M; ++mi ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N; ++ni ) { - float tmp[2][4]; - // 1st row - 4 elements per row. - tmp[0][0] = mask.is_valid(mi, ni, 0, 0) ? acc[mi][ni].elt(0) : -INFINITY; - tmp[0][1] = mask.is_valid(mi, ni, 0, 1) ? acc[mi][ni].elt(1) : -INFINITY; - tmp[0][2] = mask.is_valid(mi, ni, 0, 2) ? acc[mi][ni].elt(4) : -INFINITY; - tmp[0][3] = mask.is_valid(mi, ni, 0, 3) ? acc[mi][ni].elt(5) : -INFINITY; - // 2nd row - 4 elements per row. - tmp[1][0] = mask.is_valid(mi, ni, 1, 0) ? acc[mi][ni].elt(2) : -INFINITY; - tmp[1][1] = mask.is_valid(mi, ni, 1, 1) ? acc[mi][ni].elt(3) : -INFINITY; - tmp[1][2] = mask.is_valid(mi, ni, 1, 2) ? acc[mi][ni].elt(6) : -INFINITY; - tmp[1][3] = mask.is_valid(mi, ni, 1, 3) ? acc[mi][ni].elt(7) : -INFINITY; - this->elt_half_[mi][4 * ni + 0] = __floats2half2_rn(tmp[0][0], tmp[1][0]); - this->elt_half_[mi][4 * ni + 1] = __floats2half2_rn(tmp[0][1], tmp[1][1]); - this->elt_half_[mi][4 * ni + 2] = __floats2half2_rn(tmp[0][2], tmp[1][2]); - this->elt_half_[mi][4 * ni + 3] = __floats2half2_rn(tmp[0][3], tmp[1][3]); - } - } - } - template __device__ inline void thread_reduce_(float (&frag)[2 * MMAS_M], Operator &op) { #pragma unroll @@ -678,18 +591,6 @@ struct Softmax : public Softmax_base { } } - template - __device__ inline void thread_reduce_(__half2 (&frag)[MMAS_M], Operator &op) { - #pragma unroll - for( int mi = 0; mi < MMAS_M; mi++ ) { - frag[mi] = this->elt_half_[mi][0]; - #pragma unroll - for( int ni = 1; ni < 4 * MMAS_N; ni++ ) { - frag[mi] = op(frag[mi], this->elt_half_[mi][ni]); - } - } - } - template __device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) { thread_reduce_(frag, op); @@ -701,29 +602,13 @@ struct Softmax : public Softmax_base { quad_allreduce(frag, tmp, op); } - template - __device__ inline void reduce_(__half2 (&frag)[MMAS_M], Operator &op, Smem_tile_red & smem_red) { - thread_reduce_(frag, op); - quad_reduce(frag, frag, op); - smem_red.store(frag); - __syncthreads(); - typename Smem_tile_red::read_half_t tmp[MMAS_M]; - smem_red.load(tmp); - quad_allreduce(frag, tmp, op); - } - template __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){ MaxOp max; reduce_(frag, max, smem_max_); } - __device__ inline void reduce_max(__half2 (&frag)[MMAS_M]){ - MaxOp<__half2> max; - reduce_(frag, max, smem_max_); - } - - __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){ + __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){ SumOp sum; reduce_(frag, sum, smem_sum_); } diff --git a/csrc/flash_attn/src/fmha/utils.h b/csrc/flash_attn/src/fmha/utils.h index 0e55940..1087f80 100644 --- a/csrc/flash_attn/src/fmha/utils.h +++ b/csrc/flash_attn/src/fmha/utils.h @@ -1024,11 +1024,6 @@ struct MaxOp { __device__ inline float operator()(float const &x, float const &y) { return max(x, y); } }; -template <> -struct MaxOp<__half2> { -__device__ inline __half2 operator()(__half2 const &x, __half2 const &y) { return __hmax2(x, y); } -}; - //////////////////////////////////////////////////////////////////////////////////////////////////// template