From a5559a0e752ec9f8ce7157e1d4db9f08593638c3 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 3 Jul 2022 17:52:05 -0700 Subject: [PATCH] Do P * dP (pointwise) in the bwd in fp32 instead of fp16 --- csrc/flash_attn/src/fmha.h | 2 +- .../src/fmha_dgrad_kernel_1xN_loop.h | 59 ++++++------------- flash_attn/flash_attn_interface.py | 4 +- 3 files changed, 22 insertions(+), 43 deletions(-) diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h index d60df45..4d85b21 100644 --- a/csrc/flash_attn/src/fmha.h +++ b/csrc/flash_attn/src/fmha.h @@ -96,7 +96,7 @@ struct FMHA_fprop_params : public Qkv_params { void * __restrict__ softmax_lse_ptr; // The dimensions. - int b, seqlen_q, seqlen_k, d, seqlen_q_rounded; + int b, seqlen_q, seqlen_k, d; // The scaling factors for the kernel. float scale_bmm1f; diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index e9c0312..c5755d9 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -389,6 +389,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng } } + auto pointwise_mult = [](float p, float dp, float d) { + return p * ((!Is_dropout) || p >= 0.f ? dp : d); + }; + #pragma unroll + for (int mi = 0; mi < Mma_tile_p::MMAS_M; mi++) { + #pragma unroll + for (int ni = 0; ni < Mma_tile_p::MMAS_N; ni++) { + softmax.elt_[2 * mi + 0][4 * ni + 0] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 0], acc_dp[mi][ni].elt(0), dp_sum[2 * mi + 0]); + softmax.elt_[2 * mi + 0][4 * ni + 1] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 1], acc_dp[mi][ni].elt(1), dp_sum[2 * mi + 0]); + softmax.elt_[2 * mi + 0][4 * ni + 2] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 2], acc_dp[mi][ni].elt(4), dp_sum[2 * mi + 0]); + softmax.elt_[2 * mi + 0][4 * ni + 3] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 3], acc_dp[mi][ni].elt(5), dp_sum[2 * mi + 0]); + softmax.elt_[2 * mi + 1][4 * ni + 0] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 0], acc_dp[mi][ni].elt(2), dp_sum[2 * mi + 1]); + softmax.elt_[2 * mi + 1][4 * ni + 1] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 1], acc_dp[mi][ni].elt(3), dp_sum[2 * mi + 1]); + softmax.elt_[2 * mi + 1][4 * ni + 2] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 2], acc_dp[mi][ni].elt(6), dp_sum[2 * mi + 1]); + softmax.elt_[2 * mi + 1][4 * ni + 3] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 3], acc_dp[mi][ni].elt(7), dp_sum[2 * mi + 1]); + } + } + // Load the fragments for K^T. typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N]; smem_kt.load(frag_kt[0], 0); @@ -404,46 +422,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng } } - softmax.unpack_noscale(acc_dp); - // // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax - // // will be zero. - // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum[mi] *= params.p_dropout; } - - Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; - softmax.pack(frag_dp); - - if (!Is_dropout) { - #pragma unroll - for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { - frag_p[mi][ni].hmul(frag_dp[mi][ni]); - } - } - } else { - __half2 dp_sum_half[Mma_tile_p::MMAS_M * 2]; - for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { - dp_sum_half[mi] = __float2half2_rn(dp_sum[mi]); - } - const __half zero_h = __half(0.f); - #pragma unroll - for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { - #pragma unroll - for (int ii = 0; ii < 4; ++ii) { - const __half2 p = frag_p[mi][ni].template elt_as<__half2>(ii); - const __half2 pdp = __hmul2(p, frag_dp[mi][ni].template elt_as<__half2>(ii)); - // If this element is dropped, then frag_p stores -p instead of p. - // So pd holds -p * dp_sum in that case. - const __half2 pd = __hmul2(p, dp_sum_half[mi * 2 + (ii % 2)]); - const __half low = __low2half(p) >= zero_h ? __low2half(pdp) : __low2half(pd); - const __half high = __high2half(p) >= zero_h ? __high2half(pdp) : __high2half(pd); - frag_p[mi][ni].template elt_as<__half2>(ii) = __halves2half2(low, high); - } - } - } - } + softmax.pack(frag_p); // Store dp to smem for transpose smem_dp.store(frag_p); diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index d111c7a..1076648 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -215,8 +215,8 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, """dropout_p should be set to 0.0 during evaluation Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - k: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch. - v: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch. + k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths