From d380e87fb6e784fa2e92a3f596581255450bfc45 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 4 Jun 2022 16:01:36 -0700 Subject: [PATCH] Don't use Smem_dp_sum in backward pass To reduce smem usage for SM75 --- .../src/fmha_dgrad_fp16_kernel_loop.sm80.cu | 29 ++++++- .../src/fmha_dgrad_kernel_1xN_loop.h | 85 +++++++------------ .../src/fmha_fprop_fp16_kernel.sm80.cu | 25 ++++-- 3 files changed, 76 insertions(+), 63 deletions(-) diff --git a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu index 778482f..825898d 100644 --- a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu @@ -15,15 +15,13 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; - constexpr int smem_size_dp_sum = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; using Smem_tile_s = fmha::Smem_tile_mma_transposed; constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2); static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N); - static_assert(smem_size_dp_sum == 16 * 4 * 2); - constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2 + smem_size_dp_sum; + constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2; bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" bool is_causal = params.is_causal; @@ -41,6 +39,7 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params : (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel); } + // printf("N = %d, WARPS_N = %d, Smem size = %d\n", N, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv); if( smem_size_dq_dk_dv >= 48 * 1024 ) { FMHA_CHECK_CUDA(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); @@ -97,4 +96,28 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶ using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>; run_fmha_dgrad_fp16_sm80_loop_(params, stream); } + // if (params.d == 64) { + // auto dprops = at::cuda::getCurrentDeviceProperties(); + // if (dprops->major == 7 && dprops->minor == 5) { + // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>; + // run_fmha_dgrad_fp16_sm80_loop_(params, stream); + // } else { + // if( params.s == 128 ) { + // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>; + // run_fmha_dgrad_fp16_sm80_loop_(params, stream); + // } else if( params.s >= 256 ) { + // if (dprops->major == 8 && dprops->minor == 0) { + // // Don't share smem for K & V, and don't keep V in registers + // // This speeds things up by 2-3% by avoiding register spills, but it + // // uses more shared memory, which is fine on A100 but not other GPUs. + // // For other GPUs, we keep V in registers. + // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>; + // run_fmha_dgrad_fp16_sm80_loop_(params, stream); + // } else if (dprops->major == 8 && dprops->minor > 0) { + // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>; + // run_fmha_dgrad_fp16_sm80_loop_(params, stream); + // } + // } + // } + // } } \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 45732a1..c41452c 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -12,16 +12,19 @@ namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void dot_do_o(float (&sum)[M], const uint4 (&do_)[M], const uint4 (&o)[M], - Smem_dp_sum smem, const int buffer_idx) { +template +inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], + Gmem_softmax_sum gmem_softmax_d, int tidx) { + float sum[M]; + fmha::SumOp sum_op; #pragma unroll for (int mi = 0; mi < M; ++mi) { - sum[mi] = smem.reduce_warp(fmha::hmulsum8(do_[mi], o[mi])); + sum[mi] = fmha::Allreduce::run(fmha::hmulsum8(do_[mi], o[mi]), sum_op); + } + const int dp_sum_row = tidx / THREADS_PER_ROW; + if ((dp_sum_row < ROWS) && (tidx % THREADS_PER_ROW == 0)) { + gmem_softmax_d.store_row(reinterpret_cast(sum), dp_sum_row); } - static_assert(M == 1); - smem.store(sum[0], buffer_idx); - // smem.store(sum, buffer_idx); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -101,8 +104,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum; - using Smem_dp_sum = typename Kernel_traits::Smem_dp_sum; - // using Gemm1 = Gemm_Q_K; using Gemm1 = Gemm_Q_K; @@ -208,26 +209,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng gmem_softmax_lse.load(reinterpret_cast(p_lse)); gmem_softmax_lse.move(); - float dp_sum[Mma_tile_p::MMAS_M * 2]; - if (!Is_first) { - gmem_softmax_d.load(reinterpret_cast(dp_sum)); - gmem_softmax_d.move(); - } - - float dp_sum_regs[Gmem_tile_do::LDGS]; - Smem_dp_sum smem_dp_sum(reinterpret_cast(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE * 2]), tidx); - if (!Is_first) { __syncthreads(); } // Commit the data for Q, dO, and V to shared memory. gmem_q.commit(gemm_q_k.smem_q); gmem_do.commit(smem_do); if (Is_first) { - dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, 0); - const int dp_sum_row = tidx / Smem_dp_sum::THREADS_PER_ROW; - if ((dp_sum_row < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) { - gmem_softmax_d.store_row(reinterpret_cast(dp_sum_regs), dp_sum_row); - } - gmem_softmax_d.move(); + dot_do_o( + gmem_do.fetch_, gmem_o.fetch_, gmem_softmax_d, tidx + ); } // Instead of scaling dP by rp_dropout, we scale V instead @@ -266,6 +255,10 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng } } + float dp_sum[Mma_tile_p::MMAS_M * 2]; + gmem_softmax_d.load(reinterpret_cast(dp_sum)); + gmem_softmax_d.move(); + // Commit the data for V to shared memory if it has not been done already. if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { // Make sure we are done loading the fragments for K. @@ -357,21 +350,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // __syncthreads(); // } - // TD [2022-04-24]: if Is_first, then it's faster to set acc_dp to zero then subtract by - // dp_sum later. If !Is_first, then it's faster to set acc_dp to -dp_sum and don't subtract - // later. This is because loading dp_sum earlier uses more registers. fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; - if (Is_first) { - fmha::Clear_accumulator::apply(acc_dp); - } else { + #pragma unroll + for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) { #pragma unroll - for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) { + for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) { #pragma unroll - for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) { - #pragma unroll - for (int ii = 0; ii < 8; ++ii) { - acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)]; - } + for (int ii = 0; ii < 8; ++ii) { + acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)]; } } } @@ -409,12 +395,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N]; smem_kt.load(frag_kt[0], 0); - if (Is_first) { - const int quad = (tidx % Cta_tile_p::THREADS_PER_WARP) / 4; - const int row[2] = {quad, quad + 8}; - smem_dp_sum.load(dp_sum, row, l % 2); - } - // Trigger the load for the next dO values. if( l < steps - 1) { smem_do.move_to_next_write_buffer(); @@ -430,7 +410,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax // // will be zero. // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum[mi] *= params.p_dropout; } - if (Is_first) { softmax.subtract_dp_sum(dp_sum); } Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; softmax.pack(frag_dp); @@ -547,21 +526,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng if(l < steps - 1) { gmem_do.commit(smem_do); if (Is_first) { - // dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum); - // smem_dp_sum.move_to_next_write_buffer(); - dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, (l + 1) % 2); - const int dp_sum_row_1 = tidx / Smem_dp_sum::THREADS_PER_ROW; - if ((dp_sum_row_1 < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) { - gmem_softmax_d.store_row(reinterpret_cast(dp_sum_regs), dp_sum_row_1); - } - gmem_softmax_d.move(); + dot_do_o( + gmem_do.fetch_, gmem_o.fetch_, gmem_softmax_d, tidx + ); } gmem_softmax_lse.load(reinterpret_cast(p_lse)); gmem_softmax_lse.move(); - if (!Is_first) { - gmem_softmax_d.load(reinterpret_cast(dp_sum)); - gmem_softmax_d.move(); - } } typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M]; @@ -591,6 +561,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Make sure dQ is in shared memory. __syncthreads(); + if (l < steps - 1) { + gmem_softmax_d.load(reinterpret_cast(dp_sum)); + gmem_softmax_d.move(); + } + // Load from shared memory. smem_dq.template load(dq_out); diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index fcc3ecd..ccf9de1 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -120,10 +120,25 @@ void run_fmha_fp16_sm80(Launch_params &l run_fmha_fp16_sm80_loop_(launch_params, configure); } // if (launch_params.params.d == 64) { - // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; - // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>; - // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>; - // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); + // // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; + // // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>; + // // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>; + // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } + // if (launch_params.params.d == 64) { + // if( launch_params.params.s == 128 ) { + // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else if( launch_params.params.s >= 256 ) { + // auto dprops = at::cuda::getCurrentDeviceProperties(); + // if (dprops->major == 8 && dprops->minor >= 0) { + // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else if (dprops->major == 7 && dprops->minor == 5) { + // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } + // } // } } \ No newline at end of file