445 lines
24 KiB
C++
445 lines
24 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
|
******************************************************************************/
|
|
|
|
#pragma once
|
|
|
|
#include <cute/tensor.hpp>
|
|
|
|
#include <cutlass/cutlass.h>
|
|
#include <cutlass/array.h>
|
|
#include <cutlass/numeric_types.h>
|
|
#include "cutlass/pipeline/pipeline.hpp"
|
|
|
|
#include "flash.h"
|
|
#include "block_info.h"
|
|
#include "kernel_traits.h"
|
|
#include "utils.h"
|
|
|
|
namespace flash {
|
|
|
|
using namespace cute;
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int THREADS_PER_ROW, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
|
inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
|
|
Tensor<Engine1, Layout1> &dP_sum, const int gdP_col_stride, const float scale) {
|
|
static_assert(Layout0::rank == 3, "Only support 3D Tensor");
|
|
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
|
CUTE_STATIC_ASSERT_V(do_.layout() == o.layout());
|
|
// Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64)
|
|
// The last coordinate is the "page".
|
|
Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()),
|
|
make_layout(get<0>(do_.layout()),
|
|
get<2>(do_.layout()))));
|
|
Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout());
|
|
Tensor do_fp32 = flash::convert_type<float>(do_reshaped);
|
|
Tensor o_fp32 = flash::convert_type<float>(o_reshaped);
|
|
#pragma unroll
|
|
for (int mi = 0; mi < size<0>(do_reshaped); ++mi) {
|
|
float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
|
|
#pragma unroll
|
|
for (int ni = 1; ni < size<1>(do_reshaped); ni++) {
|
|
dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
|
|
}
|
|
flash::SumOp<float> sum_op;
|
|
dP_sum_cur = flash::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
|
|
if (threadIdx.x % THREADS_PER_ROW == 0) {
|
|
dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;
|
|
}
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
|
|
// This is used in the case where we want to parallelize the backward across seqlen_k.
|
|
template<bool Clear_dQaccum=true, typename Kernel_traits, typename Params>
|
|
inline __device__ void compute_dot_do_o(const Params ¶ms) {
|
|
using Element = typename Kernel_traits::Element;
|
|
using ElementAccum = typename Kernel_traits::ElementAccum;
|
|
using index_t = typename Kernel_traits::index_t;
|
|
|
|
const int m_block = blockIdx.x;
|
|
// The block index for the batch.
|
|
const int bidb = blockIdx.y;
|
|
// The block index for the head.
|
|
const int bidh = blockIdx.z;
|
|
// The thread index.
|
|
const int tidx = threadIdx.x;
|
|
|
|
constexpr int kBlockM = Kernel_traits::kBlockM;
|
|
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
|
|
|
const BlockInfo binfo(params, bidb);
|
|
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
|
|
|
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
|
|
+ m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
|
|
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
|
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
|
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
|
|
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
|
|
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
|
|
|
|
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
|
|
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
make_stride(params.do_row_stride, _1{}));
|
|
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
|
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
make_stride(params.o_row_stride, _1{}));
|
|
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
|
|
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
make_stride(params.h * params.d_rounded, _1{}));
|
|
Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
|
|
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
|
|
|
typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO;
|
|
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
|
|
// TODO: careful, we're zeroing out dQaccum with type float4, but when
|
|
// we do atomicAdds, we use type float. The layouts are different. Check this.
|
|
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
|
|
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
|
|
|
|
Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
|
|
Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
|
|
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
|
|
|
|
Tensor cdO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
|
Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO);
|
|
|
|
// Allocate predicate tensors for k
|
|
Tensor tdOpdO = make_tensor<bool>(make_shape(size<2>(tdOgdO)));
|
|
// Set predicates for k bounds
|
|
#pragma unroll
|
|
for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;}
|
|
|
|
Tensor tdOrdO = make_fragment_like(tdOgdO);
|
|
Tensor tdOrO = make_fragment_like(tdOgO);
|
|
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
|
|
gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
|
|
);
|
|
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
|
|
gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
|
|
);
|
|
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
|
|
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
|
|
// so that (dP - dP_sum) is on the same scale.
|
|
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,
|
|
// Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
|
Kernel_traits::kNThreadsNonWS / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
|
if (Clear_dQaccum) {
|
|
// We're actually not zero'ing out all of dQaccum, but only the part that we're going to
|
|
// do atomicAdds on.
|
|
Tensor zero = make_fragment_like(tdQgdQaccum);
|
|
clear(zero);
|
|
cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum);
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template<typename Kernel_traits, typename Params>
|
|
inline __device__ void clear_dKVaccum(const Params ¶ms) {
|
|
using ElementAccum = typename Kernel_traits::ElementAccum;
|
|
using index_t = typename Kernel_traits::index_t;
|
|
|
|
const int n_block = blockIdx.x;
|
|
// The block index for the batch.
|
|
const int bidb = blockIdx.y;
|
|
// The block index for the head.
|
|
const int bidh = blockIdx.z;
|
|
// The thread index.
|
|
const int tidx = threadIdx.x;
|
|
|
|
constexpr int kBlockN = Kernel_traits::kBlockN;
|
|
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
|
|
|
const BlockInfo binfo(params, bidb);
|
|
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
|
|
|
|
const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded;
|
|
|
|
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
|
|
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
|
|
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
|
|
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
|
|
|
|
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
|
|
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
|
|
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum);
|
|
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum);
|
|
Tensor zero = make_fragment_like(tdKgdKaccum);
|
|
clear(zero);
|
|
cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum);
|
|
cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Convert dQ from dQaccum (in float) to fp16/bf16.
|
|
// This is used in the case where we want to parallelize the backward across seqlen_k.
|
|
// template<typename Kernel_traits, typename Params, typename TiledCopydQaccum>
|
|
template<typename Kernel_traits, typename TiledCopydQaccum>
|
|
// inline __device__ void convert_dQ(const Params ¶ms,
|
|
__global__ void convert_dQ(CUTE_GRID_CONSTANT Flash_bwd_params const params,
|
|
CUTE_GRID_CONSTANT TiledCopydQaccum const tma_load_dQaccum) {
|
|
using Element = typename Kernel_traits::Element;
|
|
using ElementAccum = typename Kernel_traits::ElementAccum;
|
|
using index_t = typename Kernel_traits::index_t;
|
|
|
|
// Shared memory.
|
|
extern __shared__ char smem_[];
|
|
|
|
int lane_predicate = cute::elect_one_sync();
|
|
int warp_idx = cutlass::canonical_warp_idx_sync();
|
|
// Issue Tma Descriptor Prefetch from a single thread
|
|
if (warp_idx == 0 && lane_predicate) {
|
|
cute::prefetch_tma_descriptor(tma_load_dQaccum.get_tma_descriptor());
|
|
}
|
|
|
|
const int m_block = blockIdx.x;
|
|
// The block index for the batch.
|
|
const int bidb = blockIdx.y;
|
|
// The block index for the head.
|
|
const int bidh = blockIdx.z;
|
|
// The thread index.
|
|
const int tidx = threadIdx.x;
|
|
|
|
constexpr int kBlockM = Kernel_traits::kBlockM;
|
|
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
|
|
|
static constexpr bool dQ_swapAB = Kernel_traits::dQ_swapAB;
|
|
|
|
Tensor mdQaccum = tma_load_dQaccum.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
|
|
Tensor gdQaccum = local_tile(mdQaccum(_, _, bidh, bidb), Shape<Int<kBlockM>, Int<kHeadDim>>{}, make_coord(m_block, _0{})); // (M, K)
|
|
|
|
const BlockInfo binfo(params, bidb);
|
|
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
|
|
|
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
|
|
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
|
|
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
|
|
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
|
|
|
|
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
|
|
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
make_stride(params.dq_row_stride, _1{}));
|
|
// Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
|
|
// Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
// make_stride(params.h * params.d_rounded, _1{}));
|
|
|
|
Tensor sdQTMA = make_tensor(make_smem_ptr(reinterpret_cast<ElementAccum *>(smem_)),
|
|
typename Kernel_traits::SmemLayoutdQaccTMA{});
|
|
Tensor sdQaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementAccum *>(smem_)),
|
|
typename Kernel_traits::SmemLayoutdQacc{});
|
|
Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
|
typename Kernel_traits::SmemLayoutdQ{});
|
|
Tensor sdQt = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
|
typename Kernel_traits::SmemLayoutdQt{});
|
|
|
|
auto &barrier_dQaccum = *reinterpret_cast<cutlass::arch::ClusterTransactionBarrier*>(smem_ + sizeof(ElementAccum) * size(sdQTMA));
|
|
|
|
auto block_tma_dQ = tma_load_dQaccum.get_slice(_0{});
|
|
Tensor tdQgdQaccumTMA = block_tma_dQ.partition_S(gdQaccum); // (TMA, TMA_M, TMA_K)
|
|
Tensor tdQsdQaccumTMA = block_tma_dQ.partition_D(sdQTMA); // (TMA, TMA_M, TMA_K)
|
|
|
|
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
|
|
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
|
|
// typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum;
|
|
// typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
|
|
// auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
|
|
|
|
typename Kernel_traits::TiledMmadQ tiled_mma_dq;
|
|
auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
|
|
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
|
|
|
|
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
|
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
|
|
// Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum);
|
|
|
|
constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size<0>(sdQTMA) * size<1>(sdQTMA) * cutlass::sizeof_bits_v<ElementAccum> / 8);
|
|
if (warp_idx == 0 && lane_predicate) {
|
|
barrier_dQaccum.init(1 /*numThreads*/);
|
|
}
|
|
__syncthreads();
|
|
if (warp_idx == 0 && lane_predicate) {
|
|
barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum);
|
|
copy(tma_load_dQaccum.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(barrier_dQaccum), 0 /*mcast_mask*/), tdQgdQaccumTMA, tdQsdQaccumTMA);
|
|
}
|
|
barrier_dQaccum.wait(0);
|
|
// if (cute::thread0()) { print_tensor(sdQTMA); printf("\n"); }
|
|
|
|
typename Kernel_traits::RmemTiledCopydQacc rmem_tiled_copy_dQaccum;
|
|
auto rmem_thr_copy_dQaccum = rmem_tiled_copy_dQaccum.get_thread_slice(threadIdx.x);
|
|
Tensor tdQsdQaccum = rmem_thr_copy_dQaccum.partition_S(sdQaccum);
|
|
|
|
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<!dQ_swapAB ? kBlockM : kHeadDim>, Int<!dQ_swapAB ? kHeadDim : kBlockM>>{}); // MMA, MMA_N, MMA_K
|
|
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQsdQaccum));
|
|
|
|
Tensor tdQrdQaccum = rmem_thr_copy_dQaccum.retile_D(acc_dq);
|
|
cute::copy(rmem_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum);
|
|
// Tensor dQ_tmp = make_tensor(acc_dq.data(), flash::convert_layout_acc_rowcol(acc_dq.layout()));
|
|
// if (blockIdx.x == 0 && threadIdx.x == 0) { print_tensor(dQ_tmp); printf("\n"); }
|
|
#pragma unroll
|
|
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
|
|
// Convert acc_dq from fp32 to fp16
|
|
Tensor rdQ = flash::convert_type<Element>(acc_dq);
|
|
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
|
|
// dQacc and dQ uses the same shared memory, need to wait for all threads to finish reading smem first
|
|
__syncthreads();
|
|
if constexpr (!dQ_swapAB) {
|
|
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
|
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
|
|
} else {
|
|
Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
|
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt);
|
|
}
|
|
__syncthreads();
|
|
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
|
|
cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
|
|
|
|
Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
|
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
|
|
Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
|
|
#pragma unroll
|
|
for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; }
|
|
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
|
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
|
gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM
|
|
);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16.
|
|
// This is used in the case where we want to parallelize the backward across seqlen_q.
|
|
template<typename Kernel_traits, typename Params>
|
|
inline __device__ void convert_dKV(const Params ¶ms) {
|
|
using Element = typename Kernel_traits::Element;
|
|
using ElementAccum = typename Kernel_traits::ElementAccum;
|
|
using index_t = typename Kernel_traits::index_t;
|
|
|
|
// Shared memory.
|
|
extern __shared__ char smem_[];
|
|
|
|
const int n_block = blockIdx.x;
|
|
// The block index for the batch.
|
|
const int bidb = blockIdx.y;
|
|
// The block index for the head.
|
|
const int bidh = blockIdx.z;
|
|
// The thread index.
|
|
const int tidx = threadIdx.x;
|
|
|
|
constexpr int kBlockN = Kernel_traits::kBlockN;
|
|
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
|
|
|
static constexpr bool dKV_swapAB = Kernel_traits::dKV_swapAB;
|
|
|
|
const BlockInfo binfo(params, bidb);
|
|
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
|
|
|
|
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
|
|
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
|
|
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
|
|
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
|
|
const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded
|
|
+ n_block * kBlockN) * params.d_rounded;
|
|
|
|
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
|
|
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
make_stride(params.dk_row_stride, _1{}));
|
|
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
|
|
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
make_stride(params.dv_row_stride, _1{}));
|
|
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
|
|
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
Stride<Int<kHeadDim>, _1>{});
|
|
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
|
|
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
Stride<Int<kHeadDim>, _1>{});
|
|
|
|
Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
|
typename Kernel_traits::SmemLayoutdKV{});
|
|
Tensor sdKt = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
|
typename Kernel_traits::SmemLayoutdKVt{});
|
|
Tensor sdV = make_tensor(sdK.data() + size(sdK),
|
|
typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
|
|
Tensor sdVt = make_tensor(make_smem_ptr(sdK.data() + size(sdK)),
|
|
typename Kernel_traits::SmemLayoutdKVt{});
|
|
|
|
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV;
|
|
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
|
|
// typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum;
|
|
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
|
|
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
|
|
|
|
typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
|
|
auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
|
|
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
|
|
|
|
Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
|
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
|
|
Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
|
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
|
|
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum);
|
|
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum);
|
|
|
|
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<!dKV_swapAB ? kBlockN : kHeadDim>, Int<!dKV_swapAB ? kHeadDim : kBlockN>>{}); // MMA, MMA_N, MMA_K
|
|
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<!dKV_swapAB ? kBlockN : kHeadDim>, Int<!dKV_swapAB ? kHeadDim : kBlockN>>{}); // MMA, MMA_N, MMA_K
|
|
CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum));
|
|
CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum));
|
|
|
|
Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum);
|
|
Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum);
|
|
cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum);
|
|
cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum);
|
|
#pragma unroll
|
|
for (int i = 0; i < size(acc_dk); ++i) {
|
|
acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout;
|
|
}
|
|
#pragma unroll
|
|
for (int i = 0; i < size(acc_dv); ++i) {
|
|
acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout;
|
|
}
|
|
// Convert acc_dk from fp32 to fp16
|
|
Tensor rdK = flash::convert_type<Element>(acc_dk);
|
|
Tensor rdV = flash::convert_type<Element>(acc_dv);
|
|
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
|
|
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
|
|
if constexpr (!dKV_swapAB) {
|
|
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
|
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
|
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
|
|
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
|
|
} else {
|
|
Tensor taccdKsdKt = smem_thr_copy_dKV.partition_D(sdKt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
|
Tensor taccdVsdVt = smem_thr_copy_dKV.partition_D(sdVt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
|
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdKt);
|
|
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdVt);
|
|
}
|
|
__syncthreads();
|
|
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
|
|
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
|
|
cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
|
|
cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
|
|
// if (cute::thread0()) { print_tensor(tdKrdK); printf("\n"); }
|
|
// if (cute::thread0()) { print_tensor(tdVrdV); printf("\n"); }
|
|
|
|
Tensor cdKV = make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
|
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
|
|
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
|
|
#pragma unroll
|
|
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
|
|
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
|
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
|
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
|
);
|
|
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
|
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
|
);
|
|
}
|
|
|
|
} // namespace flash
|