diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index ab10a83..2ca50b8 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -74,362 +74,6 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void dot_do_o(Tensor const &do_, Tensor const &o, - Tensor &dP_sum, const int gdP_col_stride, const float scale) { - static_assert(Layout0::rank == 3, "Only support 3D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(do_.layout() == o.layout()); - // Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64) - // The last coordinate is the "page". - Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()), - make_layout(get<0>(do_.layout()), - get<2>(do_.layout())))); - Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout()); - Tensor do_fp32 = flash::convert_type(do_reshaped); - Tensor o_fp32 = flash::convert_type(o_reshaped); - #pragma unroll - for (int mi = 0; mi < size<0>(do_reshaped); ++mi) { - float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); - #pragma unroll - for (int ni = 1; ni < size<1>(do_reshaped); ni++) { - dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); - } - flash::SumOp sum_op; - dP_sum_cur = flash::Allreduce::run(dP_sum_cur, sum_op) * scale; - if (threadIdx.x % THREADS_PER_ROW == 0) { - dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur; - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel. -// This is used in the case where we want to parallelize the backward across seqlen_k. -template -inline __device__ void compute_dot_do_o(const Params ¶ms) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - const int m_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q) return; - - const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) - + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) - + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; - const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM; - - Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), - Shape, Int>{}, - make_stride(params.do_row_stride, _1{})); - Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, - make_stride(params.o_row_stride, _1{})); - Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), - Shape, Int>{}, - make_stride(params.h * params.d_rounded, _1{})); - Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), - Shape>{}, Stride<_1>{}); - - typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO; - auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); - // TODO: careful, we're zeroing out dQaccum with type float4, but when - // we do atomicAdds, we use type float. The layouts are different. Check this. - typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum; - auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); - - Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); - Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); - Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); - - Tensor cdO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO); - - // Allocate predicate tensors for k - Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOgdO))); - // Set predicates for k bounds - #pragma unroll - for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;} - - Tensor tdOrdO = make_fragment_like(tdOgdO); - Tensor tdOrO = make_fragment_like(tdOgO); - flash::copy( - gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM - ); - flash::copy( - gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM - ); - // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final - // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here, - // so that (dP - dP_sum) is on the same scale. - dot_do_o(tdOrdO, tdOrO, dP_sum, - Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); - if (Clear_dQaccum) { - // We're actually not zero'ing out all of dQaccum, but only the part that we're going to - // do atomicAdds on. - Tensor zero = make_fragment_like(tdQgdQaccum); - clear(zero); - cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void clear_dKVaccum(const Params ¶ms) { - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - const int n_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - const BlockInfo binfo(params, bidb); - if (n_block * kBlockN >= binfo.actual_seqlen_k) return; - - const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; - - Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, Stride, _1>{}); - Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, Stride, _1>{}); - - typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; - auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); - Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum); - Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum); - Tensor zero = make_fragment_like(tdKgdKaccum); - clear(zero); - cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum); - cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert dQ from dQaccum (in float) to fp16/bf16. -// This is used in the case where we want to parallelize the backward across seqlen_k. -template -inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - const int m_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q) return; - - const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) - + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; - const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; - - Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), - Shape, Int>{}, - make_stride(params.dq_row_stride, _1{})); - Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), - Shape, Int>{}, - make_stride(params.h * params.d_rounded, _1{})); - - Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutdQ{}); - - typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; - auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum; - auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); - - typename Kernel_traits::TiledMmadQ tiled_mma_dq; - auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); - auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); - Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); - Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum); - - Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K - CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); - - Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum); - clear(acc_dq); - for (int s = 0; s < nsplits; ++s) { - cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum); - #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); } - tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride; - } - #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } - // Convert acc_dq from fp32 to fp16 - Tensor rdQ = flash::convert_type(acc_dq); - Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); - __syncthreads(); - Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); - cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); - - Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); - Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); - #pragma unroll - for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM - ); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16. -// This is used in the case where we want to parallelize the backward across seqlen_q. -template -inline __device__ void convert_dKV(const Params ¶ms) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - const int n_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - const BlockInfo binfo(params, bidb); - if (n_block * kBlockN >= binfo.actual_seqlen_k) return; - - const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) - + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; - const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) - + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; - const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded - + n_block * kBlockN) * params.d_rounded; - - Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), - Shape, Int>{}, - make_stride(params.dk_row_stride, _1{})); - Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), - Shape, Int>{}, - make_stride(params.dv_row_stride, _1{})); - Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, - Stride, _1>{}); - Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, - Stride, _1>{}); - - Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutdKV{}); - Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) - - typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV; - auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum; - auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); - - typename Kernel_traits::TiledMmadKV tiled_mma_dkv; - auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); - auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx); - Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) - Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); - Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); - Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum); - Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum); - - Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K - Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K - CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum)); - CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum)); - - Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum); - Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum); - cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum); - cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum); - #pragma unroll - for (int i = 0; i < size(acc_dk); ++i) { - acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout; - } - #pragma unroll - for (int i = 0; i < size(acc_dv); ++i) { - acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout; - } - // Convert acc_dk from fp32 to fp16 - Tensor rdK = flash::convert_type(acc_dk); - Tensor rdV = flash::convert_type(acc_dv); - Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) - Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); - cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); - __syncthreads(); - Tensor tdKrdK = make_tensor(shape(tdKgdK)); - Tensor tdVrdV = make_tensor(shape(tdVgdV)); - cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); - cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); - - Tensor cdKV = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); - #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - flash::copy( - gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN - ); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 51cbdd7..a28a108 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -8,6 +8,7 @@ #include "static_switch.h" #include "flash.h" +#include "flash_bwd_preprocess_kernel.h" #include "flash_bwd_kernel.h" template diff --git a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h new file mode 100644 index 0000000..6582d81 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h @@ -0,0 +1,377 @@ +/*************************************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include +#include +#include + +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void dot_do_o(Tensor const &do_, Tensor const &o, + Tensor &dP_sum, const int gdP_col_stride, const float scale) { + static_assert(Layout0::rank == 3, "Only support 3D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(do_.layout() == o.layout()); + // Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64) + // The last coordinate is the "page". + Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()), + make_layout(get<0>(do_.layout()), + get<2>(do_.layout())))); + Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout()); + Tensor do_fp32 = flash::convert_type(do_reshaped); + Tensor o_fp32 = flash::convert_type(o_reshaped); + #pragma unroll + for (int mi = 0; mi < size<0>(do_reshaped); ++mi) { + float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); + #pragma unroll + for (int ni = 1; ni < size<1>(do_reshaped); ni++) { + dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); + } + flash::SumOp sum_op; + dP_sum_cur = flash::Allreduce::run(dP_sum_cur, sum_op) * scale; + if (threadIdx.x % THREADS_PER_ROW == 0) { + dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel. +// This is used in the case where we want to parallelize the backward across seqlen_k. +template +inline __device__ void compute_dot_do_o(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM; + + Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), + Shape, Int>{}, + make_stride(params.do_row_stride, _1{})); + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + Shape, Int>{}, + make_stride(params.h * params.d_rounded, _1{})); + Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO; + auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); + // TODO: careful, we're zeroing out dQaccum with type float4, but when + // we do atomicAdds, we use type float. The layouts are different. Check this. + typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum; + auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); + + Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); + Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); + Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); + + Tensor cdO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO); + + // Allocate predicate tensors for k + Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOgdO))); + // Set predicates for k bounds + #pragma unroll + for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;} + + Tensor tdOrdO = make_fragment_like(tdOgdO); + Tensor tdOrO = make_fragment_like(tdOgO); + flash::copy( + gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM + ); + flash::copy( + gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM + ); + // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final + // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here, + // so that (dP - dP_sum) is on the same scale. + dot_do_o(tdOrdO, tdOrO, dP_sum, + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); + if (Clear_dQaccum) { + // We're actually not zero'ing out all of dQaccum, but only the part that we're going to + // do atomicAdds on. + Tensor zero = make_fragment_like(tdQgdQaccum); + clear(zero); + cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void clear_dKVaccum(const Params ¶ms) { + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int n_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (n_block * kBlockN >= binfo.actual_seqlen_k) return; + + const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; + + Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, Stride, _1>{}); + Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, Stride, _1>{}); + + typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; + auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); + Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum); + Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum); + Tensor zero = make_fragment_like(tdKgdKaccum); + clear(zero); + cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum); + cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert dQ from dQaccum (in float) to fp16/bf16. +// This is used in the case where we want to parallelize the backward across seqlen_k. +template +inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + + Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), + Shape, Int>{}, + make_stride(params.dq_row_stride, _1{})); + Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + Shape, Int>{}, + make_stride(params.h * params.d_rounded, _1{})); + + Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutdQ{}); + + typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; + auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum; + auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); + + typename Kernel_traits::TiledMmadQ tiled_mma_dq; + auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); + auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); + Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); + Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum); + + Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K + CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); + + Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum); + clear(acc_dq); + for (int s = 0; s < nsplits; ++s) { + cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum); + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); } + tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride; + } + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } + // Convert acc_dq from fp32 to fp16 + Tensor rdQ = flash::convert_type(acc_dq); + Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); + __syncthreads(); + Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); + cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); + + Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); + Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); + #pragma unroll + for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16. +// This is used in the case where we want to parallelize the backward across seqlen_q. +template +inline __device__ void convert_dKV(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + const int n_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (n_block * kBlockN >= binfo.actual_seqlen_k) return; + + const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; + const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; + const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + + n_block * kBlockN) * params.d_rounded; + + Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), + Shape, Int>{}, + make_stride(params.dk_row_stride, _1{})); + Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), + Shape, Int>{}, + make_stride(params.dv_row_stride, _1{})); + Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, + Stride, _1>{}); + Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, + Stride, _1>{}); + + Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutdKV{}); + Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) + + typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV; + auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum; + auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); + + typename Kernel_traits::TiledMmadKV tiled_mma_dkv; + auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); + auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx); + Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); + Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum); + Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum); + + Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum)); + CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum)); + + Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum); + Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum); + cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum); + cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum); + #pragma unroll + for (int i = 0; i < size(acc_dk); ++i) { + acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout; + } + #pragma unroll + for (int i = 0; i < size(acc_dv); ++i) { + acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout; + } + // Convert acc_dk from fp32 to fp16 + Tensor rdK = flash::convert_type(acc_dk); + Tensor rdV = flash::convert_type(acc_dv); + Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) + Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); + cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); + __syncthreads(); + Tensor tdKrdK = make_tensor(shape(tdKgdK)); + Tensor tdVrdV = make_tensor(shape(tdVgdV)); + cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); + cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); + + Tensor cdKV = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); + #pragma unroll + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + flash::copy( + gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); +} + +} // namespace flash