diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index fd23f46..fc5724c 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -73,11 +73,9 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void dot_do_o(Tensor const &do_, Tensor const &o, - Tensor &dP_sum, Tensor &sdPsum, - const int gdP_col_stride, const float scale) { + Tensor &dP_sum, const int gdP_col_stride, const float scale) { static_assert(Layout0::rank == 3, "Only support 3D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(do_.layout() == o.layout()); @@ -100,7 +98,6 @@ inline __device__ void dot_do_o(Tensor const &do_, Tensor::run(dP_sum_cur, sum_op) * scale; if (threadIdx.x % THREADS_PER_ROW == 0) { dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur; - // recast(sdPsum)(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum; } } } @@ -178,7 +175,7 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here, // so that (dP - dP_sum) is on the same scale. - dot_do_o(tdOrdO, tdOrO, dP_sum, dP_sum, + dot_do_o(tdOrdO, tdOrO, dP_sum, Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); if (Clear_dQaccum) { Tensor zero = make_fragment_like(tdQgdQaccum); @@ -517,8 +514,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); // sP and sdQ share the same memory so be careful Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{}); - Tensor sdPsum = make_tensor(make_smem_ptr(reinterpret_cast((sP.data() + cute::max(size(sP), size(sdQ))).get())), - Shape>{}); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); @@ -733,7 +728,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); } if (Is_first) { cute::copy(tdOrdO, tdOsdO); - dot_do_o(tdOrdO, tdOrO, gdPsum, sdPsum, + dot_do_o(tdOrdO, tdOrO, gdPsum, Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); } @@ -930,11 +925,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); } gdPsum.data() = gdPsum.data() + (-int(kBlockM)); - // if (!Is_first && tidx < kBlockM / 2) { - // sdPsum(tidx) = recast(gdPsum)(tidx); - // if (!Is_first && tidx < kBlockM) { - // recast(sdPsum)(tidx) = gdPsum(tidx); - // } } if (!Is_last) { @@ -976,7 +966,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if (Is_first && m_block > m_block_min) { cute::copy(tdOrdO, tdOsdO); - dot_do_o(tdOrdO, tdOrO, gdPsum, sdPsum, + dot_do_o(tdOrdO, tdOrO, gdPsum, Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); } @@ -1317,7 +1307,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in Tensor dP_sum = make_fragment_like(lse); cute::copy(tdOrdO, tdOsdO); dot_do_o( - tdOrdO, tdOrO, sdPsum, sdPsum, + tdOrdO, tdOrO, sdPsum, Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout ); __syncthreads(); diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index c7f2e4b..e7d605f 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -321,6 +321,10 @@ struct Flash_bwd_kernel_traits : public Base { static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); + static constexpr int kSmemSize = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + kSmemPSize