flash-attention/hopper/flash_bwd_kernel.h
2024-07-11 09:53:36 -07:00

2043 lines
126 KiB
C++

/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/tensor.hpp"
#include <cutlass/cutlass.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/array.h>
#include <cutlass/barrier.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "cutlass/pipeline/pipeline.hpp"
#include "flash.h"
#include "utils.h"
#include "softmax.h"
namespace flash {
using namespace cute;
template <typename Ktraits, bool Is_causal, typename TiledCopyQ, typename TiledCopydO,
typename TiledCopyK, typename TiledCopyV, typename TiledCopydK, typename TiledCopydV>
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
compute_dqkv(CUTE_GRID_CONSTANT Flash_bwd_params const params,
CUTE_GRID_CONSTANT TiledCopyQ const tma_load_Q,
CUTE_GRID_CONSTANT TiledCopydO const tma_load_dO,
CUTE_GRID_CONSTANT TiledCopyK const tma_load_K,
CUTE_GRID_CONSTANT TiledCopyV const tma_load_V,
CUTE_GRID_CONSTANT TiledCopydK const tma_store_dK,
CUTE_GRID_CONSTANT TiledCopydV const tma_store_dV) {
using Element = typename Ktraits::Element;
using ElementAccum = typename Ktraits::ElementAccum;
using SoftType = ElementAccum;
using index_t = typename Ktraits::index_t;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int kNThreads = Ktraits::kNThreads;
// static constexpr int NumMmaThreads = size(typename Ktraits::TiledMmaSdP{});
static constexpr int NumMmaThreads = Ktraits::kNThreads;
static constexpr int kBlockM = Ktraits::kBlockM;
// static constexpr int kBlockN = Ktraits::kBlockN;
// constexpr int kHeadDim = Ktraits::kHeadDim;
static constexpr int kStages = Ktraits::kStages;
static constexpr bool SdP_swapAB = Ktraits::SdP_swapAB;
static constexpr bool dKV_swapAB = Ktraits::dKV_swapAB;
static constexpr bool dQ_swapAB = Ktraits::dQ_swapAB;
static constexpr bool Mma_dQ_is_RS = Ktraits::Mma_dQ_is_RS;
if constexpr (dQ_swapAB) { static_assert(!Mma_dQ_is_RS); }
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
int const n_block = blockIdx.x;
int const bidb = blockIdx.z; // The block index for the batch.
int const bidh = blockIdx.y; // The block index for the head.
int lane_predicate = cute::elect_one_sync();
int warp_idx = cutlass::canonical_warp_idx_sync();
// Issue Tma Descriptor Prefetch from a single thread
if (warp_idx == 0 && lane_predicate) {
cute::prefetch_tma_descriptor(tma_load_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_dO.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_K.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_V.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_store_dK.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_store_dV.get_tma_descriptor());
}
Tensor mQ = tma_load_Q.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor mdO = tma_load_dO.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor mK = tma_load_K.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor mV = tma_load_V.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
Tensor mdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dsoftmax_sum)),
make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q_rounded, params.seqlen_q_rounded, _1{}));
Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dq_accum_ptr)),
make_shape(params.seqlen_q, params.d, params.h, params.b),
make_stride(params.d * params.h, _1{}, params.d, params.d * params.h * params.seqlen_q_rounded));
Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
Tensor gdO = local_tile(mdO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
Tensor gdQaccum = local_tile(mdQaccum(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
// if (cute::thread0()) { print(tma_load_K); printf("\n"); }
// if (cute::thread0()) { print(mK); printf("\n"); print(gK); printf("\n"); }
typename Ktraits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(threadIdx.x);
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
// Construct SMEM tensors.
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQ{});
Tensor sdO = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdO{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Ktraits::SmemLayoutV{});
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutP{});
Tensor sdS = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdS{});
Tensor sQt = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQt{});
Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdOt{});
Tensor sKt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutKt{});
Tensor sPt = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutPt{});
Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdSt{});
// Prepare the TMA loads
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
auto block_tma_Q = tma_load_Q.get_slice(cluster_local_block_id.y);
auto block_tma_dO = tma_load_dO.get_slice(cluster_local_block_id.y);
auto block_tma_K = tma_load_K.get_slice(_0{});
auto block_tma_V = tma_load_V.get_slice(_0{});
Tensor tQgQ = block_tma_Q.partition_S(gQ); // (TMA, TMA_M, TMA_K, k)
Tensor tQsQ = block_tma_Q.partition_D(sQ); // (TMA, TMA_M, TMA_K, PIPE)
Tensor tdOgdO = block_tma_dO.partition_S(gdO); // (TMA, TMA_M, TMA_K, k)
Tensor tdOsdO = block_tma_dO.partition_D(sdO); // (TMA, TMA_M, TMA_K, PIPE)
Tensor tKgK = block_tma_K.partition_S(gK); // (TMA, TMA_N, TMA_K)
Tensor tKsK = block_tma_K.partition_D(sK); // (TMA, TMA_N, TMA_K)
Tensor tVgV = block_tma_V.partition_S(gV); // (TMA, TMA_N, TMA_K)
Tensor tVsV = block_tma_V.partition_D(sV); // (TMA, TMA_N, TMA_K)
// if (cute::thread0()) { print(tQgQ); printf("\n"); print(tQsQ); printf("\n"); }
// if (cute::thread0()) { print(tKgK); printf("\n"); print(tKsK); printf("\n"); }
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size<0>(sQ) * size<1>(sQ) * cutlass::sizeof_bits_v<Element> / 8);
constexpr uint32_t TmaTransactionBytesdO = static_cast<uint32_t>(size<0>(sdO) * size<1>(sdO) * cutlass::sizeof_bits_v<Element> / 8);
static_assert(TmaTransactionBytesQ == TmaTransactionBytesdO);
constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size<0>(sK) * size<1>(sK) * cutlass::sizeof_bits_v<Element> / 8);
constexpr uint32_t TmaTransactionBytesV = static_cast<uint32_t>(size<0>(sV) * size<1>(sV) * cutlass::sizeof_bits_v<Element> / 8);
static_assert(TmaTransactionBytesK == TmaTransactionBytesV);
// Obtain warp index
int thread_idx = int(threadIdx.x);
int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup;
// int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = TmaTransactionBytesQ;
pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_K.init(1 /*numThreads*/);
shared_storage.barrier_V.init(1 /*numThreads*/);
}
// cutlass::arch::fence_barrier_init();
// We're counting on pipeline_q to call fence_barrier_init();
MainloopPipeline pipeline_q(shared_storage.pipeline_q, pipeline_params, ClusterShape{});
MainloopPipeline pipeline_do(shared_storage.pipeline_do, pipeline_params, ClusterShape{});
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
} else {
__syncthreads();
}
// State variables used for iterating the circular buffer
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
// smem_pipe_write is used by the producer of SMEM data - i.e TMA
PipelineState smem_pipe_read_q, smem_pipe_read_do;
PipelineState smem_pipe_release_q, smem_pipe_release_do;
PipelineState smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_write_do = cutlass::make_producer_start_state<MainloopPipeline>();
// Copy K tile and V tile from GMEM to SMEM.
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_K.arrive_and_expect_tx(TmaTransactionBytesK);
copy(tma_load_K.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_K), 0 /*mcast_mask*/), tKgK, tKsK);
shared_storage.barrier_V.arrive_and_expect_tx(TmaTransactionBytesV);
copy(tma_load_V.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_V), 0 /*mcast_mask*/), tVgV, tVsV);
}
// if (cute::thread0()) { print_tensor(sQ); printf("\n"); } __syncthreads();
int m_block = cute::ceil_div(params.seqlen_q, kBlockM) - 1;
uint16_t mcast_mask_qdo = 0;
if constexpr (cute::is_same_v<typename Ktraits::GmemTiledCopyQdO, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_qdo |= (uint16_t(1) << block_layout(n, cluster_local_block_id.x, _0{}));
}
}
// Issue TmaLoads (Prologue fetches)
if (warp_idx == 0 && lane_predicate) {
// Issue the prologue loads
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < kStages && stage <= m_block; ++stage) {
pipeline_q.producer_acquire(smem_pipe_write_q);
copy(tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo), tQgQ(_, _, _, m_block - stage), tQsQ(_, _, _, stage));
++smem_pipe_write_q;
pipeline_do.producer_acquire(smem_pipe_write_do);
copy(tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do), mcast_mask_qdo), tdOgdO(_, _, _, m_block - stage), tdOsdO(_, _, _, stage));
++smem_pipe_write_do;
}
}
Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
Tensor gdPsum = local_tile(mdPsum(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
// Initialize matmul objects.
typename Ktraits::TiledMmaSdP tiledMmaSdP;
auto threadMmaSdP = tiledMmaSdP.get_thread_slice(threadIdx.x);
typename Ktraits::TiledMmadKV tiledMmadKV;
auto threadMmadKV = tiledMmadKV.get_thread_slice(threadIdx.x);
typename Ktraits::TiledMmadQ tiledMmadQ;
auto threadMmadQ = tiledMmadQ.get_thread_slice(threadIdx.x);
// Allocate accumulator
Tensor tdKrdK = partition_fragment_C(tiledMmadKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
Tensor tdVrdV = partition_fragment_C(tiledMmadKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Ktraits::SmemCopyAtomPdS{}, tiledMmaSdP);
auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(threadIdx.x);
if constexpr (!SdP_swapAB) {
Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// Allocate "fragments/descriptors"
Tensor tSrQ = threadMmaSdP.partition_fragment_A(sQ);
Tensor tSrK = threadMmaSdP.partition_fragment_B(sK);
Tensor tdPrdO = threadMmaSdP.partition_fragment_A(sdO);
Tensor tdPrV = threadMmaSdP.partition_fragment_B(sV);
Tensor caccS = make_identity_tensor(select<0, 1>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N)
static_assert(decltype(size<0, 0>(taccScS))::value == 2);
static_assert(decltype(size<0, 1>(taccScS))::value == 2);
// taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices.
Tensor taccScS_row = taccScS(make_coord(_0{}, _, _0{}), _, _0{});
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccScS_row(mi));
lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0;
}
// if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); }
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
clear(tdKrdK);
clear(tdVrdV);
shared_storage.barrier_K.wait(0);
shared_storage.barrier_V.wait(0);
__syncthreads();
// #pragma unroll 2
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block >= 0; --m_block) {
Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{}));
pipeline_q.consumer_wait(smem_pipe_read_q);
__syncwarp();
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tSrQ(_, _, _, smem_pipe_read_q.index()), tSrK, tSrS);
Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{}));
pipeline_do.consumer_wait(smem_pipe_read_do);
__syncwarp();
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tdPrdO(_, _, _, smem_pipe_read_do.index()), tdPrV, tdPrdP);
warpgroup_wait<1>();
// Reshape tSrS from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout()));
flash::scale_apply_exp2</*Scale=*/true, /*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
// if (cute::thread0()) { print_tensor(scores); printf("\n"); }
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(tSrS);
Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tPaP, tPsP);
int const warp_group_idx = cutlass::canonical_warp_group_idx();
cutlass::arch::NamedBarrier::arrive(kNThreads, warp_group_idx /*id*/);
warpgroup_wait<0>();
// Reshape tdPrdP from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
// if (cute::thread0()) { print_tensor(dS); printf("\n"); }
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); }
}
Tensor rdS = flash::convert_type<Element>(tdPrdP);
Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
// cutlass::arch::NamedBarrier::arrive(kNThreads, 1 /*id*/);
cutlass::arch::NamedBarrier::arrive(kNThreads, 2 + warp_group_idx /*id*/);
// if (cute::thread0()) { print_tensor(dS); printf("\n"); }
if (m_block > 0) {
gLSE.data() = gLSE.data() + (-int(kBlockM));
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccScS_row(mi));
lse(mi) = gLSE(row);
dP_sum(mi) = gdPsum(row);
}
}
Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
if constexpr (Mma_dQ_is_RS) {
static_assert(!dQ_swapAB);
Tensor tdQrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadQ>(tdPrdP.layout()));
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ);
// if (cute::thread0()) { print(tdQrdS); printf("\n"); print(tdQrK); printf("\n"); print(tdQrdQ); printf("\n"); }
}
// warpgroup_wait<0>();
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); }
// if (cute::thread0()) { print_tensor(sK); printf("\n"); }
// if (cute::thread0()) { print_tensor(sKt); printf("\n"); } __syncthreads();
// __syncthreads(); // Without this I'm getting race condition, I thought the barrier would be enough
// SMEM fence to make sure sP is written before it's read by WGMMA
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier::sync(kNThreads, 1 - warp_group_idx /*id*/);
if constexpr (!dKV_swapAB) {
Tensor tdVrP = threadMmadKV.partition_fragment_A(sPt);
Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrdV);
} else {
Tensor tdVrP = threadMmadKV.partition_fragment_B(sPt);
Tensor tdVrdO = threadMmadKV.partition_fragment_A(sdOt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrP, tdVrdV);
}
++smem_pipe_read_do;
// warpgroup_wait<0>();
// Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout()));
// if (cute::thread0()) { print_tensor(dV_tmp); printf("\n"); }
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier::sync(kNThreads, 2 + 1 - warp_group_idx /*id*/);
if constexpr (!Mma_dQ_is_RS) {
if constexpr (!dQ_swapAB) {
Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ);
} else {
Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ);
}
}
// warpgroup_wait<0>();
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); }
// if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dQ_tmp); printf("\n"); }
if constexpr (!dKV_swapAB) {
Tensor tdKrdS = threadMmadKV.partition_fragment_A(sdSt);
Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdK);
} else {
Tensor tdKrdS = threadMmadKV.partition_fragment_B(sdSt);
Tensor tdKrQ = threadMmadKV.partition_fragment_A(sQt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdS, tdKrdK);
}
++smem_pipe_read_q;
// Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout()));
// if (cute::thread0()) { print_tensor(dK_tmp); printf("\n"); }
warpgroup_wait<Mma_dQ_is_RS ? 2 : 1>();
// if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); }
Tensor tdQrdQ_atomic = recast<float4>(tdQrdQ);
Tensor tdQgdQaccum_atomic = recast<float4>(tdQgdQaccum(_, _, _, m_block));
#pragma unroll
for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); }
// for (int i = 0; i < size(tdQrdQ_atomic); ++i) { tdQgdQaccum_atomic(i) = tdQrdQ_atomic(i); }
warpgroup_wait<0>();
pipeline_do.consumer_release(smem_pipe_release_do); // release V
++smem_pipe_release_do;
pipeline_q.consumer_release(smem_pipe_release_q); // release V
++smem_pipe_release_q;
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate && m_block >= kStages) {
pipeline_q.producer_acquire(smem_pipe_write_q);
copy(tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo), tQgQ(_, _, _, m_block - kStages), tQsQ(_, _, _, smem_pipe_write_q.index()));
++smem_pipe_write_q;
pipeline_do.producer_acquire(smem_pipe_write_do);
copy(tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do), mcast_mask_qdo), tdOgdO(_, _, _, m_block - kStages), tdOsdO(_, _, _, smem_pipe_write_do.index()));
++smem_pipe_write_do;
}
}
} else { // SdP_swapAB
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdSt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// Allocate "fragments/descriptors"
Tensor tSrQ = threadMmaSdP.partition_fragment_B(sQ);
Tensor tSrK = threadMmaSdP.partition_fragment_A(sK);
Tensor tdPrdO = threadMmaSdP.partition_fragment_B(sdO);
Tensor tdPrV = threadMmaSdP.partition_fragment_A(sV);
Tensor caccS = make_identity_tensor(select<1, 0>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N)
static_assert(decltype(size<0, 0>(taccScS))::value == 2);
static_assert(decltype(size<0, 1>(taccScS))::value == 2);
// taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices.
Tensor taccScS_row = taccScS(make_coord(_, _0{}, _), _0{}, _);
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<1>(taccScS_row(mi));
lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0;
}
// cute::fill(lse, 1);
// cute::fill(dP_sum, 1);
// if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); }
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
clear(tdKrdK);
clear(tdVrdV);
shared_storage.barrier_K.wait(0);
shared_storage.barrier_V.wait(0);
// #pragma unroll 2
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block >= 0; --m_block) {
Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{}));
pipeline_q.consumer_wait(smem_pipe_read_q);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tSrK, tSrQ(_, _, _, smem_pipe_read_q.index()), tSrS);
Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{}));
pipeline_do.consumer_wait(smem_pipe_read_do);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tdPrV, tdPrdO(_, _, _, smem_pipe_read_do.index()), tdPrdP);
warpgroup_wait<1>();
// Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout()));
flash::scale_apply_exp2</*Scale=*/true, /*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
// if (cute::thread0()) { print_tensor(scores); printf("\n"); }
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(tSrS);
static_assert(!dKV_swapAB);
Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadKV>(tSrS.layout()));
Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrdV);
++smem_pipe_read_do;
// warpgroup_wait<0>();
// Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout()));
// if (cute::thread0()) { print_tensor(dV_tmp); printf("\n"); }
warpgroup_wait<1>();
// Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); }
}
// if (cute::thread0()) { print_tensor(dS); printf("\n"); }
Tensor rdS = flash::convert_type<Element>(tdPrdP);
Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
if (m_block > 0) {
gLSE.data() = gLSE.data() + (-int(kBlockM));
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<1>(taccScS_row(mi));
lse(mi) = gLSE(row);
dP_sum(mi) = gdPsum(row);
}
}
Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadKV>(tdPrdP.layout()));
Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdK);
++smem_pipe_read_q;
// warpgroup_wait<0>();
// Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout()));
// if (cute::thread0()) { print_tensor(dK_tmp); printf("\n"); }
// SMEM fence to make sure sP is written before it's read by WGMMA
cutlass::arch::fence_view_async_shared();
// cutlass::arch::NamedBarrier::sync(kNThreads, 0 /*id*/);
__syncthreads();
static_assert(!Mma_dQ_is_RS);
Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
if constexpr (!dQ_swapAB) {
Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ);
} else {
Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ);
}
// warpgroup_wait<0>();
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); }
warpgroup_wait<0>();
// if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); }
Tensor tdQrdQ_atomic = recast<float4>(tdQrdQ);
Tensor tdQgdQaccum_atomic = recast<float4>(tdQgdQaccum(_, _, _, m_block));
#pragma unroll
for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); }
// for (int i = 0; i < size(tdQrdQ_atomic); ++i) { tdQgdQaccum_atomic(i) = tdQrdQ_atomic(i); }
pipeline_do.consumer_release(smem_pipe_release_do); // release V
++smem_pipe_release_do;
pipeline_q.consumer_release(smem_pipe_release_q); // release V
++smem_pipe_release_q;
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate && m_block >= kStages) {
pipeline_q.producer_acquire(smem_pipe_write_q);
copy(tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo), tQgQ(_, _, _, m_block - kStages), tQsQ(_, _, _, smem_pipe_write_q.index()));
++smem_pipe_write_q;
pipeline_do.producer_acquire(smem_pipe_write_do);
copy(tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do), mcast_mask_qdo), tdOgdO(_, _, _, m_block - kStages), tdOsdO(_, _, _, smem_pipe_write_do.index()));
++smem_pipe_write_do;
}
}
}
// Epilogue
#pragma unroll
for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.scale_softmax; }
Tensor tdKrdK_out = convert_type<Element>(tdKrdK);
Tensor tdVrdV_out = convert_type<Element>(tdVrdV);
Tensor sdK = make_tensor(make_smem_ptr(shared_storage.smem_dk.data()), typename Ktraits::SmemLayoutdK{});
Tensor sdV = make_tensor(make_smem_ptr(shared_storage.smem_dv.data()), typename Ktraits::SmemLayoutdV{});
Tensor sdKt = make_tensor(make_smem_ptr(shared_storage.smem_dk.data()), typename Ktraits::SmemLayoutdKt{});
Tensor sdVt = make_tensor(make_smem_ptr(shared_storage.smem_dv.data()), typename Ktraits::SmemLayoutdVt{});
auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Ktraits::SmemCopyAtomdKV{}, tiledMmadKV);
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(threadIdx.x);
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N)
__syncthreads();
if constexpr (!dKV_swapAB) {
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
} else {
Tensor taccdKsdKt = smem_thr_copy_dKV.partition_D(sdKt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor taccdVsdVt = smem_thr_copy_dKV.partition_D(sdVt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdKt);
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdVt);
}
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
Tensor mdK = tma_store_dK.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor mdV = tma_store_dV.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
auto block_tma_dK = tma_store_dK.get_slice(_0{});
auto block_tma_dV = tma_store_dV.get_slice(_0{});
Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K)
Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K)
Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K)
Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
__syncthreads(); // ensure all threads have issued their async fence
lane_predicate = cute::elect_one_sync();
warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate) {
cute::copy(tma_store_dV, tdVsdV, tdVgdV);
cute::copy(tma_store_dK, tdKsdK, tdKgdK);
tma_store_arrive();
}
tma_store_wait<0>();
// To make sure remote SMEM doesn't get destroyed
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive();
cute::cluster_wait();
}
}
template <typename Ktraits, bool Is_causal, typename TiledCopyQ, typename TiledCopydO,
typename TiledCopyK, typename TiledCopyV, typename TiledCopydQ, typename TiledCopydK, typename TiledCopydV>
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
compute_dqkv_seqqpar(CUTE_GRID_CONSTANT Flash_bwd_params const params,
CUTE_GRID_CONSTANT TiledCopyQ const tma_load_Q,
CUTE_GRID_CONSTANT TiledCopydO const tma_load_dO,
CUTE_GRID_CONSTANT TiledCopyK const tma_load_K,
CUTE_GRID_CONSTANT TiledCopyV const tma_load_V,
CUTE_GRID_CONSTANT TiledCopydQ const tma_store_dQ,
CUTE_GRID_CONSTANT TiledCopydK const tma_store_dK,
CUTE_GRID_CONSTANT TiledCopydV const tma_store_dV) {
using Element = typename Ktraits::Element;
using ElementAccum = typename Ktraits::ElementAccum;
using SoftType = ElementAccum;
using index_t = typename Ktraits::index_t;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int kNThreads = Ktraits::kNThreads;
static constexpr int NumMmaThreads = Ktraits::kNThreads;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
// constexpr int kHeadDim = Ktraits::kHeadDim;
static constexpr int kStages = Ktraits::kStages;
static constexpr bool SdP_swapAB = Ktraits::SdP_swapAB;
static constexpr bool dKV_swapAB = Ktraits::dKV_swapAB;
static constexpr bool dQ_swapAB = Ktraits::dQ_swapAB;
static constexpr bool Mma_dQ_is_RS = Ktraits::Mma_dQ_is_RS;
if constexpr (dQ_swapAB) { static_assert(!Mma_dQ_is_RS); }
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
int const m_block = blockIdx.x;
int const bidb = blockIdx.z; // The block index for the batch.
int const bidh = blockIdx.y; // The block index for the head.
int lane_predicate = cute::elect_one_sync();
int warp_idx = cutlass::canonical_warp_idx_sync();
// Issue Tma Descriptor Prefetch from a single thread
if (warp_idx == 0 && lane_predicate) {
cute::prefetch_tma_descriptor(tma_load_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_dO.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_K.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_V.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_store_dQ.get_tma_descriptor());
}
Tensor mQ = tma_load_Q.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor mdO = tma_load_dO.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor mK = tma_load_K.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor mV = tma_load_V.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
Tensor mdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dsoftmax_sum)),
make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q_rounded, params.seqlen_q_rounded, _1{}));
Tensor mdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dk_accum_ptr)),
make_shape(params.seqlen_k, params.d, params.h, params.b),
make_stride(params.d * params.h, _1{}, params.d, params.d * params.h * params.seqlen_k));
Tensor mdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dv_accum_ptr)),
make_shape(params.seqlen_k, params.d, params.h, params.b),
make_stride(params.d * params.h, _1{}, params.d, params.d * params.h * params.seqlen_k));
Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
Tensor gdO = local_tile(mdO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
Tensor gdKaccum = local_tile(mdKaccum(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
Tensor gdVaccum = local_tile(mdVaccum(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
typename Ktraits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(threadIdx.x);
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum);
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum);
// Construct SMEM tensors.
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQ{});
Tensor sdO = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdO{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Ktraits::SmemLayoutV{});
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutP{});
Tensor sdS = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdS{});
Tensor sQt = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQt{});
Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdOt{});
Tensor sKt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutKt{});
Tensor sPt = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutPt{});
Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdSt{});
// Prepare the TMA loads
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
auto block_tma_Q = tma_load_Q.get_slice(_0{});
auto block_tma_dO = tma_load_dO.get_slice(_0{});
auto block_tma_K = tma_load_K.get_slice(cluster_local_block_id.x);
auto block_tma_V = tma_load_V.get_slice(cluster_local_block_id.x);
Tensor tQgQ = block_tma_Q.partition_S(gQ); // (TMA, TMA_M, TMA_K)
Tensor tQsQ = block_tma_Q.partition_D(sQ); // (TMA, TMA_M, TMA_K)
Tensor tdOgdO = block_tma_dO.partition_S(gdO); // (TMA, TMA_M, TMA_K)
Tensor tdOsdO = block_tma_dO.partition_D(sdO); // (TMA, TMA_M, TMA_K)
Tensor tKgK = block_tma_K.partition_S(gK); // (TMA, TMA_N, TMA_K, k)
Tensor tKsK = block_tma_K.partition_D(sK); // (TMA, TMA_N, TMA_K, PIPE)
Tensor tVgV = block_tma_V.partition_S(gV); // (TMA, TMA_N, TMA_K, k)
Tensor tVsV = block_tma_V.partition_D(sV); // (TMA, TMA_N, TMA_K, PIPE)
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size<0>(sQ) * size<1>(sQ) * cutlass::sizeof_bits_v<Element> / 8);
constexpr uint32_t TmaTransactionBytesdO = static_cast<uint32_t>(size<0>(sdO) * size<1>(sdO) * cutlass::sizeof_bits_v<Element> / 8);
static_assert(TmaTransactionBytesQ == TmaTransactionBytesdO);
constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size<0>(sK) * size<1>(sK) * cutlass::sizeof_bits_v<Element> / 8);
constexpr uint32_t TmaTransactionBytesV = static_cast<uint32_t>(size<0>(sV) * size<1>(sV) * cutlass::sizeof_bits_v<Element> / 8);
static_assert(TmaTransactionBytesK == TmaTransactionBytesV);
// Obtain warp index
int thread_idx = int(threadIdx.x);
int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup;
// int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = TmaTransactionBytesK;
pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.init(1 /*numThreads*/);
shared_storage.barrier_dO.init(1 /*numThreads*/);
}
// cutlass::arch::fence_barrier_init();
// We're counting on pipeline_k to call fence_barrier_init();
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
} else {
__syncthreads();
}
// State variables used for iterating the circular buffer
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
// smem_pipe_write is used by the producer of SMEM data - i.e TMA
PipelineState smem_pipe_read_k, smem_pipe_read_v;
PipelineState smem_pipe_release_k, smem_pipe_release_v;
PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
// Copy K tile and V tile from GMEM to SMEM.
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
copy(tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
shared_storage.barrier_dO.arrive_and_expect_tx(TmaTransactionBytesdO);
copy(tma_load_dO.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_dO), 0 /*mcast_mask*/), tdOgdO, tdOsdO);
}
// if (cute::thread0()) { print_tensor(sQ); printf("\n"); } __syncthreads();
int n_block = cute::ceil_div(params.seqlen_k, kBlockN) - 1;
uint16_t mcast_mask_kv = 0;
if constexpr (cute::is_same_v<typename Ktraits::GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
}
}
// Issue TmaLoads (Prologue fetches)
if (warp_idx == 0 && lane_predicate) {
// Issue the prologue loads
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < kStages && stage <= n_block; ++stage) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, _, _, n_block - stage), tKsK(_, _, _, stage));
++smem_pipe_write_k;
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, _, _, n_block - stage), tVsV(_, _, _, stage));
++smem_pipe_write_v;
}
}
Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
Tensor gdPsum = local_tile(mdPsum(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
// Initialize matmul objects.
typename Ktraits::TiledMmaSdP tiledMmaSdP;
auto threadMmaSdP = tiledMmaSdP.get_thread_slice(threadIdx.x);
typename Ktraits::TiledMmadKV tiledMmadKV;
auto threadMmadKV = tiledMmadKV.get_thread_slice(threadIdx.x);
typename Ktraits::TiledMmadQ tiledMmadQ;
auto threadMmadQ = tiledMmadQ.get_thread_slice(threadIdx.x);
// Allocate accumulator
Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
clear(tdQrdQ);
auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Ktraits::SmemCopyAtomPdS{}, tiledMmaSdP);
auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(threadIdx.x);
if constexpr (!SdP_swapAB) {
Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// Allocate "fragments/descriptors"
Tensor tSrQ = threadMmaSdP.partition_fragment_A(sQ);
Tensor tSrK = threadMmaSdP.partition_fragment_B(sK);
Tensor tdPrdO = threadMmaSdP.partition_fragment_A(sdO);
Tensor tdPrV = threadMmaSdP.partition_fragment_B(sV);
Tensor caccS = make_identity_tensor(select<0, 1>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N)
static_assert(decltype(size<0, 0>(taccScS))::value == 2);
static_assert(decltype(size<0, 1>(taccScS))::value == 2);
// taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices.
Tensor taccScS_row = taccScS(make_coord(_0{}, _, _0{}), _, _0{});
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccScS_row(mi));
lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0;
}
// if (cute::thread0()) { print_tensor(lse); printf("\n"); }
// if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); }
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
shared_storage.barrier_Q.wait(0);
shared_storage.barrier_dO.wait(0);
// #pragma unroll 2
CUTLASS_PRAGMA_NO_UNROLL
for (; n_block >= 0; --n_block) {
// Otherwise we might have WG0 still wating on NamedBarrier but WG1 already
// started the next iteration and start flipping the same NamedBarrier.
__syncthreads();
Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{}));
pipeline_k.consumer_wait(smem_pipe_read_k);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{}));
pipeline_v.consumer_wait(smem_pipe_read_v);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tdPrdO, tdPrV(_, _, _, smem_pipe_read_v.index()), tdPrdP);
++smem_pipe_read_v;
warpgroup_wait<1>();
// Reshape tSrS from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout()));
flash::scale_apply_exp2</*Scale=*/true, /*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
// if (cute::thread0()) { print_tensor(scores); printf("\n"); }
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(tSrS);
Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tPaP, tPsP);
int const warp_group_idx = cutlass::canonical_warp_group_idx();
cutlass::arch::NamedBarrier::arrive(kNThreads, warp_group_idx /*id*/);
warpgroup_wait<0>();
// Reshape tdPrdP from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
// if (cute::thread0()) { print_tensor(dS); printf("\n"); }
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); }
}
Tensor rdS = flash::convert_type<Element>(tdPrdP);
Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
cutlass::arch::NamedBarrier::arrive(kNThreads, 2 + warp_group_idx /*id*/);
// if (cute::thread0()) { print_tensor(dS); printf("\n"); }
if constexpr (Mma_dQ_is_RS) {
static_assert(!dQ_swapAB);
Tensor tdQrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadQ>(tdPrdP.layout()));
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadQ, tdQrdS, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdQ);
// if (cute::thread0()) { print(tdQrdS); printf("\n"); print(tdQrK); printf("\n"); print(tdQrdQ); printf("\n"); }
}
// warpgroup_wait<0>();
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); }
// if (cute::thread0()) { print_tensor(sK); printf("\n"); }
// if (cute::thread0()) { print_tensor(sKt); printf("\n"); } __syncthreads();
// if (cute::thread0()) { printf("before barrier sync 0\n"); }
// SMEM fence to make sure sP is written before it's read by WGMMA
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier::sync(kNThreads, 1 - warp_group_idx /*id*/);
// if (cute::thread0()) { printf("after barrier sync 0\n"); }
Tensor tdVrdV = partition_fragment_C(tiledMmadKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
if constexpr (!dKV_swapAB) {
Tensor tdVrP = threadMmadKV.partition_fragment_A(sPt);
Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadKV, tdVrP, tdVrdO, tdVrdV);
} else {
Tensor tdVrP = threadMmadKV.partition_fragment_B(sPt);
Tensor tdVrdO = threadMmadKV.partition_fragment_A(sdOt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadKV, tdVrdO, tdVrP, tdVrdV);
}
// warpgroup_wait<0>();
// Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout()));
// if (cute::thread0()) { print_tensor(dV_tmp); printf("\n"); }
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier::sync(kNThreads, 2 + 1 - warp_group_idx /*id*/);
if constexpr (!Mma_dQ_is_RS) {
if constexpr (!dQ_swapAB) {
Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadQ, tdQrdS, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdQ);
} else {
Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadQ, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdS, tdQrdQ);
}
}
++smem_pipe_read_k;
// warpgroup_wait<0>();
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); }
// if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dQ_tmp); printf("\n"); }
Tensor tdKrdK = partition_fragment_C(tiledMmadKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
if constexpr (!dKV_swapAB) {
Tensor tdKrdS = threadMmadKV.partition_fragment_A(sdSt);
Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadKV, tdKrdS, tdKrQ, tdKrdK);
} else {
Tensor tdKrdS = threadMmadKV.partition_fragment_B(sdSt);
Tensor tdKrQ = threadMmadKV.partition_fragment_A(sQt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadKV, tdKrQ, tdKrdS, tdKrdK);
}
// warpgroup_wait<0>();
// Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout()));
// if (cute::thread0()) { print_tensor(dK_tmp); printf("\n"); }
warpgroup_wait<Mma_dQ_is_RS ? 1 : 2>();
// if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); }
Tensor tdVrdV_atomic = recast<float4>(tdVrdV);
Tensor tdVgdVaccum_atomic = recast<float4>(tdVgdVaccum(_, _, _, n_block));
#pragma unroll
for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdVaccum_atomic(i), tdVrdV_atomic(i)); }
// for (int i = 0; i < size(tdVrdV_atomic); ++i) { tdVgdVaccum_atomic(i) = tdVrdV_atomic(i); }
warpgroup_wait<0>();
Tensor tdKrdK_atomic = recast<float4>(tdKrdK);
Tensor tdKgdKaccum_atomic = recast<float4>(tdKgdKaccum(_, _, _, n_block));
#pragma unroll
for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdKaccum_atomic(i), tdKrdK_atomic(i)); }
pipeline_v.consumer_release(smem_pipe_release_v); // release V
++smem_pipe_release_v;
pipeline_k.consumer_release(smem_pipe_release_k); // release V
++smem_pipe_release_k;
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate && n_block >= kStages) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, _, _, n_block - kStages), tKsK(_, _, _, smem_pipe_write_k.index()));
++smem_pipe_write_k;
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, _, _, n_block - kStages), tVsV(_, _, _, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
}
} else { // SdP_swapAB
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdSt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// Allocate "fragments/descriptors"
Tensor tSrQ = threadMmaSdP.partition_fragment_B(sQ);
Tensor tSrK = threadMmaSdP.partition_fragment_A(sK);
Tensor tdPrdO = threadMmaSdP.partition_fragment_B(sdO);
Tensor tdPrV = threadMmaSdP.partition_fragment_A(sV);
Tensor caccS = make_identity_tensor(select<1, 0>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N)
static_assert(decltype(size<0, 0>(taccScS))::value == 2);
static_assert(decltype(size<0, 1>(taccScS))::value == 2);
// taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices.
Tensor taccScS_row = taccScS(make_coord(_, _0{}, _), _0{}, _);
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<1>(taccScS_row(mi));
lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0;
}
// if (cute::thread0()) { print_tensor(taccScS_row); printf("\n"); }
// cute::fill(lse, 1);
// cute::fill(dP_sum, 1);
// if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); }
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
clear(tdQrdQ);
shared_storage.barrier_Q.wait(0);
shared_storage.barrier_dO.wait(0);
// #pragma unroll 2
CUTLASS_PRAGMA_NO_UNROLL
for (; n_block >= 0; --n_block) {
Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{}));
pipeline_k.consumer_wait(smem_pipe_read_k);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tSrK(_, _, _, smem_pipe_read_k.index()), tSrQ, tSrS);
Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{}));
pipeline_v.consumer_wait(smem_pipe_read_v);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tdPrV(_, _, _, smem_pipe_read_v.index()), tdPrdO, tdPrdP);
++smem_pipe_read_v;
warpgroup_wait<1>();
// Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout()));
// if (cute::thread0()) { print_tensor(lse); printf("\n"); }
flash::scale_apply_exp2</*Scale=*/true, /*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
// if (cute::thread0()) { print_tensor(scores); printf("\n"); }
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(tSrS);
static_assert(!dKV_swapAB);
Tensor tdVrdV = partition_fragment_C(tiledMmadKV, select<1, 2>(TileShape_MNK{}));
Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadKV>(tSrS.layout()));
Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadKV, tdVrP, tdVrdO, tdVrdV);
// warpgroup_wait<0>();
// Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout()));
// if (cute::thread0()) { print_tensor(dV_tmp); printf("\n"); }
warpgroup_wait<1>();
// Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); }
}
// if (cute::thread0()) { print_tensor(dS); printf("\n"); }
Tensor rdS = flash::convert_type<Element>(tdPrdP);
Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
Tensor tdKrdK = partition_fragment_C(tiledMmadKV, select<1, 2>(TileShape_MNK{}));
Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadKV>(tdPrdP.layout()));
Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadKV, tdKrdS, tdKrQ, tdKrdK);
// warpgroup_wait<0>();
// Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout()));
// if (cute::thread0()) { print_tensor(dK_tmp); printf("\n"); }
warpgroup_wait<1>();
// if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); }
Tensor tdVrdV_atomic = recast<float4>(tdVrdV);
Tensor tdVgdVaccum_atomic = recast<float4>(tdVgdVaccum(_, _, _, n_block));
#pragma unroll
for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdVaccum_atomic(i), tdVrdV_atomic(i)); }
// for (int i = 0; i < size(tdVrdV_atomic); ++i) { tdVgdVaccum_atomic(i) = tdVrdV_atomic(i); }
// SMEM fence to make sure sP is written before it's read by WGMMA
cutlass::arch::fence_view_async_shared();
// cutlass::arch::NamedBarrier::sync(kNThreads, 0 /*id*/);
__syncthreads();
static_assert(!Mma_dQ_is_RS);
if constexpr (!dQ_swapAB) {
Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadQ, tdQrdS, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdQ);
} else {
Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadQ, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdS, tdQrdQ);
}
++smem_pipe_read_k;
// warpgroup_wait<0>();
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); }
warpgroup_wait<1>();
// if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); }
Tensor tdKrdK_atomic = recast<float4>(tdKrdK);
Tensor tdKgdKaccum_atomic = recast<float4>(tdKgdKaccum(_, _, _, n_block));
#pragma unroll
for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdKaccum_atomic(i), tdKrdK_atomic(i)); }
// for (int i = 0; i < size(tdVrdV_atomic); ++i) { tdVgdVaccum_atomic(i) = tdVrdV_atomic(i); }
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_release_v); // release V
++smem_pipe_release_v;
pipeline_k.consumer_release(smem_pipe_release_k); // release V
++smem_pipe_release_k;
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate && n_block >= kStages) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, _, _, n_block - kStages), tKsK(_, _, _, smem_pipe_write_k.index()));
++smem_pipe_write_k;
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, _, _, n_block - kStages), tVsV(_, _, _, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
}
}
// Epilogue
#pragma unroll
for (int i = 0; i < size(tdQrdQ); ++i) { tdQrdQ(i) *= params.scale_softmax; }
// if (cute::thread0()) { print_tensor(tdQrdQ); }
Tensor tdQrdQ_out = convert_type<Element>(tdQrdQ);
Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), typename Ktraits::SmemLayoutdQ{});
Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), typename Ktraits::SmemLayoutdQt{});
auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Ktraits::SmemCopyAtomdQ{}, tiledMmadQ);
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(threadIdx.x);
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(tdQrdQ_out); // ((Atom,AtomNum), MMA_M, MMA_N)
__syncthreads();
if constexpr (!dQ_swapAB) {
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
} else {
Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt);
}
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
Tensor mdQ = tma_store_dQ.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor gdQ = local_tile(mdQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
auto block_tma_dQ = tma_store_dQ.get_slice(_0{});
Tensor tdQgdQ = block_tma_dQ.partition_D(gdQ); // (TMA, TMA_M, TMA_K)
Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K)
__syncthreads(); // ensure all threads have issued their async fence
// if (cute::thread0()) { print_tensor(sdQ); }
lane_predicate = cute::elect_one_sync();
warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate) {
cute::copy(tma_store_dQ, tdQsdQ, tdQgdQ);
tma_store_arrive();
}
tma_store_wait<0>();
// To make sure remote SMEM doesn't get destroyed
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive();
cute::cluster_wait();
}
}
template <typename Ktraits, bool Is_causal, typename TiledCopyQ, typename TiledCopydO,
typename TiledCopyK, typename TiledCopyV, typename TiledCopydK, typename TiledCopydV, typename TiledCopydQ, typename TiledCopyAdddQ>
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
compute_dqkv_ws(CUTE_GRID_CONSTANT Flash_bwd_params const params,
CUTE_GRID_CONSTANT TiledCopyQ const tma_load_Q,
CUTE_GRID_CONSTANT TiledCopydO const tma_load_dO,
CUTE_GRID_CONSTANT TiledCopyK const tma_load_K,
CUTE_GRID_CONSTANT TiledCopyV const tma_load_V,
CUTE_GRID_CONSTANT TiledCopydK const tma_store_dK,
CUTE_GRID_CONSTANT TiledCopydV const tma_store_dV,
CUTE_GRID_CONSTANT TiledCopydQ const tma_store_dQ,
CUTE_GRID_CONSTANT TiledCopyAdddQ const tma_reduce_add_dQ) {
using Element = typename Ktraits::Element;
using ElementAccum = typename Ktraits::ElementAccum;
using SoftType = ElementAccum;
using index_t = typename Ktraits::index_t;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static_assert(Ktraits::Is_WS);
// static constexpr int kNThreads = Ktraits::kNThreads;
// static constexpr int NumMmaThreads = size(typename Ktraits::TiledMmaSdP{});
static constexpr int NumMmaThreads = 256;
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kNThreadsdQ = Ktraits::kNThreadsdQ;
// static constexpr int kBlockN = Ktraits::kBlockN;
// constexpr int kHeadDim = Ktraits::kHeadDim;
// static constexpr int kStages = Ktraits::kStages;
static constexpr bool SdP_swapAB = Ktraits::SdP_swapAB;
static constexpr bool dKV_swapAB = Ktraits::dKV_swapAB;
static constexpr bool dQ_swapAB = Ktraits::dQ_swapAB;
if constexpr (SdP_swapAB) { static_assert(!dKV_swapAB); }
static constexpr bool Mma_dQ_is_RS = Ktraits::Mma_dQ_is_RS;
if constexpr (dQ_swapAB) { static_assert(!Mma_dQ_is_RS); }
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
int lane_predicate = cute::elect_one_sync();
int warp_idx = cutlass::canonical_warp_idx_sync();
// Issue Tma Descriptor Prefetch from a single thread
if (warp_idx == 0 && lane_predicate) {
cute::prefetch_tma_descriptor(tma_load_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_dO.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_K.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_load_V.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_store_dK.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_store_dV.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_store_dQ.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_reduce_add_dQ.get_tma_descriptor());
}
// Construct SMEM tensors.
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQ{});
Tensor sdO = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdO{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Ktraits::SmemLayoutV{});
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutP{});
Tensor sdS = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdS{});
Tensor sQt = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQt{});
Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdOt{});
Tensor sKt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutKt{});
Tensor sPt = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutPt{});
Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdSt{});
Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), typename Ktraits::SmemLayoutdQacc{});
Tensor sdQ2 = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), typename Ktraits::SmemLayoutdQacc2{});
Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), typename Ktraits::SmemLayoutdQacct{});
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size<0>(sQ) * size<1>(sQ) * cutlass::sizeof_bits_v<Element> / 8);
constexpr uint32_t TmaTransactionBytesdO = static_cast<uint32_t>(size<0>(sdO) * size<1>(sdO) * cutlass::sizeof_bits_v<Element> / 8);
static_assert(TmaTransactionBytesQ == TmaTransactionBytesdO);
constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size<0>(sK) * size<1>(sK) * cutlass::sizeof_bits_v<Element> / 8);
constexpr uint32_t TmaTransactionBytesV = static_cast<uint32_t>(size<0>(sV) * size<1>(sV) * cutlass::sizeof_bits_v<Element> / 8);
static_assert(TmaTransactionBytesK == TmaTransactionBytesV);
// Obtain warp index
int thread_idx = int(threadIdx.x);
int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup;
// int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = TmaTransactionBytesQ;
int warp_group_idx = cutlass::canonical_warp_group_idx();
if (warp_group_idx == 0) {
pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
} else {
pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
}
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_K.init(1 /*numThreads*/);
shared_storage.barrier_V.init(1 /*numThreads*/);
}
// cutlass::arch::fence_barrier_init();
// We're counting on pipeline_q to call fence_barrier_init();
MainloopPipeline pipeline_q(shared_storage.pipeline_q, pipeline_params, ClusterShape{});
MainloopPipeline pipeline_do(shared_storage.pipeline_do, pipeline_params, ClusterShape{});
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
} else {
__syncthreads();
}
if (warp_group_idx == 0) { // Producer
// method in cutlass/arch/reg_reconfig.h
// calls setmaxnreg.dec.sync.aligned.u32
cutlass::arch::warpgroup_reg_dealloc<24>();
int const n_block = blockIdx.x;
int const bidb = blockIdx.z; // The block index for the batch.
int const bidh = blockIdx.y; // The block index for the head.
int m_block = cute::ceil_div(params.seqlen_q, kBlockM) - 1;
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
int lane_predicate = cute::elect_one_sync();
// if (warp_idx_in_warpgroup == 0 && lane_predicate) {
if (warp_idx_in_warpgroup == 0) { // Load K, and do TMA on Q and dO
Tensor mQ = tma_load_Q.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor mdO = tma_load_dO.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor mK = tma_load_K.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
Tensor gdO = local_tile(mdO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
// Prepare the TMA loads
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
auto block_tma_Q = tma_load_Q.get_slice(cluster_local_block_id.y);
auto block_tma_dO = tma_load_dO.get_slice(cluster_local_block_id.y);
auto block_tma_K = tma_load_K.get_slice(_0{});
Tensor tQgQ = block_tma_Q.partition_S(gQ); // (TMA, TMA_M, TMA_K, k)
Tensor tQsQ = block_tma_Q.partition_D(sQ); // (TMA, TMA_M, TMA_K, PIPE)
Tensor tdOgdO = block_tma_dO.partition_S(gdO); // (TMA, TMA_M, TMA_K, k)
Tensor tdOsdO = block_tma_dO.partition_D(sdO); // (TMA, TMA_M, TMA_K, PIPE)
Tensor tKgK = block_tma_K.partition_S(gK); // (TMA, TMA_N, TMA_K)
Tensor tKsK = block_tma_K.partition_D(sK); // (TMA, TMA_N, TMA_K)
PipelineState smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_write_do = cutlass::make_producer_start_state<MainloopPipeline>();
uint16_t mcast_mask_qdo = 0;
if constexpr (cute::is_same_v<typename Ktraits::GmemTiledCopyQdO, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_qdo |= (uint16_t(1) << block_layout(n, cluster_local_block_id.x, _0{}));
}
}
if (lane_predicate) {
// Copy K tile and V tile from GMEM to SMEM.
shared_storage.barrier_K.arrive_and_expect_tx(TmaTransactionBytesK);
copy(tma_load_K.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_K), 0 /*mcast_mask*/), tKgK, tKsK);
#pragma unroll 2
for (; m_block >= 0; --m_block) {
pipeline_q.producer_acquire(smem_pipe_write_q);
copy(tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo), tQgQ(_, _, _, m_block), tQsQ(_, _, _, smem_pipe_write_q.index()));
++smem_pipe_write_q;
pipeline_do.producer_acquire(smem_pipe_write_do);
copy(tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do), mcast_mask_qdo), tdOgdO(_, _, _, m_block), tdOsdO(_, _, _, smem_pipe_write_do.index()));
++smem_pipe_write_do;
}
// Tail loop
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline_q.producer_tail(smem_pipe_write_q);
pipeline_do.producer_tail(smem_pipe_write_do);
}
} else if (warp_idx_in_warpgroup == 1) { // Load V, and do TMA_REDUCE_ADD on dQ
Tensor mV = tma_load_V.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
auto block_tma_V = tma_load_V.get_slice(_0{});
Tensor tVgV = block_tma_V.partition_S(gV); // (TMA, TMA_N, TMA_K)
Tensor tVsV = block_tma_V.partition_D(sV); // (TMA, TMA_N, TMA_K)
if (lane_predicate) {
shared_storage.barrier_V.arrive_and_expect_tx(TmaTransactionBytesV);
copy(tma_load_V.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_V), 0 /*mcast_mask*/), tVgV, tVsV);
}
Tensor mdQaccum = tma_store_dQ.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b));
Tensor gdQaccum = local_tile(mdQaccum(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
auto block_tma_dQ = tma_store_dQ.get_slice(_0{});
Tensor tdQgdQ = block_tma_dQ.partition_D(gdQaccum); // (TMA, TMA_M, TMA_K)
Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K)
Tensor tdQsdQ2 = block_tma_dQ.partition_S(sdQ2); // (TMA, TMA_M, TMA_K, 2)
int *lock_ptr = params.dq_semaphore + bidb * params.h + bidh;
using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
// cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 1 /*id*/); // sdQ empty, ready to be written to
cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 /*id*/); // sdQ empty, ready to be written to
// cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 + (m_block + 1) % 2 /*id*/); // sdQ empty, ready to be written to
// cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 + m_block % 2 /*id*/); // sdQ empty, ready to be written to
// if (n_block == 0) { // Use TMA_STORE
if (false) { // Use TMA_STORE
#pragma unroll 2
for (; m_block >= 0; --m_block) {
cutlass::arch::NamedBarrier::sync(kNThreadsdQ + 32, 2 /*id*/); // sdQ full, to be written to gmem
// cutlass::arch::NamedBarrier::sync(kNThreadsdQ + 32, 2 + m_block % 2 /*id*/); // sdQ full, to be written to gmem
if (lane_predicate) {
cute::copy(tma_store_dQ, tdQsdQ, tdQgdQ(_, _, _, m_block));
// cute::copy(tma_store_dQ, tdQsdQ2(_, _, _, m_block % 2), tdQgdQ(_, _, _, m_block));
tma_store_arrive();
}
tma_store_wait<0>();
Barrier::arrive_inc(lock_ptr, threadIdx.x % 32, m_block * params.b * params.h);
cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 /*id*/); // sdQ empty, ready to be written to
// cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 + m_block % 2 /*id*/); // sdQ empty, ready to be written to
}
} else { // Use TMA_REDUCE_ADD
#pragma unroll 2
for (; m_block >= 0; --m_block) {
// Barrier::wait_eq(lock_ptr, threadIdx.x % 32, m_block * params.b * params.h, n_block);
// Barrier::wait_lt(lock_ptr, threadIdx.x % 32, m_block * params.b * params.h, 1);
cutlass::arch::NamedBarrier::sync(kNThreadsdQ + 32, 2 /*id*/); // sdQ full, to be written to gmem
// cutlass::arch::NamedBarrier::sync(kNThreadsdQ + 32, 2 + m_block % 2 /*id*/); // sdQ full, to be written to gmem
if (lane_predicate) {
cute::copy(tma_reduce_add_dQ, tdQsdQ, tdQgdQ(_, _, _, m_block));
// cute::copy(tma_reduce_add_dQ, tdQsdQ2(_, _, _, m_block % 2), tdQgdQ(_, _, _, m_block));
tma_store_arrive();
}
tma_store_wait<0>();
// Barrier::arrive_inc(lock_ptr, threadIdx.x % 32, m_block * params.b * params.h);
// cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 + m_block % 2 /*id*/); // sdQ empty, ready to be written to
cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 /*id*/); // sdQ empty, ready to be written to
}
}
// } else if (warp_idx_in_warpgroup == 2) { // Load LSE and dPSum
// Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
// make_shape(params.b, params.h, params.seqlen_q),
// make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
// Tensor mdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dsoftmax_sum)),
// make_shape(params.b, params.h, params.seqlen_q),
// make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
// Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(_)); // (M, _)
// Tensor gdPsum = local_tile(mdPsum(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(_)); // (M, _)
// Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape<Int<kBlockM>>{});
// Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.smem_dpsum.data()), Shape<Int<kBlockM>>{});
// #pragma unroll 2
// for (; m_block >= 0; --m_block) {
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 3 /*id*/); // sLSE and sdPsum are empty
// #pragma unroll
// for (int i = 0; i < cute::ceil_div(kBlockM, 32); ++i) {
// int idx = threadIdx.x % 32 + i * 32;
// sLSE(idx) = idx < params.seqlen_q - m_block * kBlockM ? gLSE(idx, m_block) : INFINITY;
// sdPsum(idx) = idx < params.seqlen_q - m_block * kBlockM ? gdPsum(idx, m_block) : 0;
// }
// // sLSE and sdPsum are ready for WG 1
// cutlass::arch::NamedBarrier::arrive(128 + 32, 3 + 1 /*id*/);
// // sLSE and sdPsum are ready for WG 2
// cutlass::arch::NamedBarrier::arrive(128 + 32, 3 + 2 /*id*/);
// }
}
} else { // Consumers
// method in cutlass/arch/reg_reconfig.h
// calls setmaxnreg.inc.sync.aligned.u32
cutlass::arch::warpgroup_reg_alloc<240>();
// State variables used for iterating the circular buffer
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
// smem_pipe_write is used by the producer of SMEM data - i.e TMA
PipelineState smem_pipe_read_q, smem_pipe_read_do;
PipelineState smem_pipe_release_q, smem_pipe_release_do;
int m_block = cute::ceil_div(params.seqlen_q, kBlockM) - 1;
const int m_block_max = m_block;
int bidb = blockIdx.z; // The block index for the batch.
int bidh = blockIdx.y; // The block index for the head.
Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
Tensor mdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.dsoftmax_sum)),
make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q_rounded, params.seqlen_q_rounded, _1{}));
Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
Tensor gdPsum = local_tile(mdPsum(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape<Int<kBlockM>>{});
Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.smem_dpsum.data()), Shape<Int<kBlockM>>{});
typename Ktraits::RmemTiledCopydQacc rmem_tiled_copy_dQaccum;
// auto rmem_thr_copy_dQaccum = rmem_tiled_copy_dQaccum.get_thread_slice((threadIdx.x - NumCopyThreads) % kNThreadsdQ);
auto rmem_thr_copy_dQaccum = rmem_tiled_copy_dQaccum.get_thread_slice(threadIdx.x - NumCopyThreads);
Tensor tdQsdQaccum = rmem_thr_copy_dQaccum.partition_D(sdQ);
Tensor tdQsdQaccum2 = rmem_thr_copy_dQaccum.partition_D(sdQ2);
// Initialize matmul objects.
typename Ktraits::TiledMmaSdP tiledMmaSdP;
auto threadMmaSdP = tiledMmaSdP.get_thread_slice(threadIdx.x - NumCopyThreads);
typename Ktraits::TiledMmadKV tiledMmadKV;
auto threadMmadKV = tiledMmadKV.get_thread_slice(threadIdx.x - NumCopyThreads);
typename Ktraits::TiledMmadQ tiledMmadQ;
// auto threadMmadQ = tiledMmadQ.get_thread_slice((threadIdx.x - NumCopyThreads) % kNThreadsdQ);
auto threadMmadQ = tiledMmadQ.get_thread_slice(threadIdx.x - NumCopyThreads);
// Allocate accumulator
Tensor tdKrdK = partition_fragment_C(tiledMmadKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
Tensor tdVrdV = partition_fragment_C(tiledMmadKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Ktraits::SmemCopyAtomPdS{}, tiledMmaSdP);
auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(threadIdx.x - NumCopyThreads);
// auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Ktraits::SmemCopyAtomdQ{}, tiledMmadQ);
// auto smem_tiled_copy_dQ = make_tiled_copy_C(Copy_Atom<cute::SM90_U32x4_STSM_N, ElementAccum>{}, tiledMmadQ);
// auto smem_tiled_copy_dQ = make_tiled_copy_C(Copy_Atom<DefaultCopy, ElementAccum>{}, tiledMmadQ);
// auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(threadIdx.x - NumCopyThreads);
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdSt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
if constexpr (SdP_swapAB) {
// Allocate "fragments/descriptors"
Tensor tSrQ = threadMmaSdP.partition_fragment_B(sQ);
Tensor tSrK = threadMmaSdP.partition_fragment_A(sK);
Tensor tdPrdO = threadMmaSdP.partition_fragment_B(sdO);
Tensor tdPrV = threadMmaSdP.partition_fragment_A(sV);
Tensor caccS = make_identity_tensor(select<1, 0>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N)
static_assert(decltype(size<0, 0>(taccScS))::value == 2);
static_assert(decltype(size<0, 1>(taccScS))::value == 2);
// taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices.
Tensor taccScS_row = taccScS(make_coord(_, _0{}, _), _0{}, _);
static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(taccScS_row))::value, 8);
static constexpr bool kStatsDivisibleBy8 = decltype(size(taccScS_row))::value % 8 == 0;
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
// Tensor lse = make_tensor<ElementAccum>(Shape<Int<kStatsPerThread>>{});
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<1>(taccScS_row(mi));
lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0;
}
// #pragma unroll
// for (int mi = 0; mi < size(lse); ++mi) {
// const int row_idx = mi * 8 + (threadIdx.x % 32) / 4;
// const int row = kStatsDivisibleBy8 || row_idx < size(taccScS_row) ? get<1>(taccScS_row(row_idx)) : 0;
// lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
// dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0;
// }
// if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dP_sum); printf("\n"); }
// Trying to spread LSE and dPSum across threads in a warp but it's slower
// const int row_idx = mi * 8 + (threadIdx.x % 32) / 4;
// const int row = get<1>(taccScS_row(row_idx)); // TODO: what if row_idx is outside the range?
// cute::fill(lse, 1);
// cute::fill(dP_sum, 1);
// if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); }
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 3 /*id*/); // sLSE and sdPsum are empty
clear(tdKrdK);
clear(tdVrdV);
shared_storage.barrier_K.wait(0);
shared_storage.barrier_V.wait(0);
// #pragma unroll 2
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block >= 0; --m_block) {
// Putting this dQ block at the beginning of the loop gives an extra 10 TFLOPs
// It does make the code uglier, idk if it's worth it.
if (m_block < m_block_max) {
// SMEM fence to make sure sP is written before it's read by WGMMA
cutlass::arch::fence_view_async_shared();
// dS is already written to smem, and the smem for dQ is empty (from warp 1 doing TMA_REDUCE_ADD)
// int warp_group_idx = cutlass::canonical_warp_group_idx();
// if (warp_group_idx == 1 + (m_block + 1) % 2) {
// // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/);
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 4);
// } else {
// // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/);
// cutlass::arch::NamedBarrier::sync(NumMmaThreads, 4);
// static_assert(!Mma_dQ_is_RS);
// Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
// if constexpr (!dQ_swapAB) {
// Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS);
// Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
// flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ);
// } else {
// Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS);
// Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt);
// flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ);
// }
// Tensor taccdQrdQ = rmem_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N)
// cutlass::arch::NamedBarrier::sync(NumMmaThreads / 2 + 32, 0 + (m_block + 1) % 2 /*id*/);
// cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum2(_, _, _, (m_block + 1) % 2));
// cutlass::arch::fence_view_async_shared();
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads / 2 + 32, 2 + (m_block + 1) % 2 /*id*/); // sdQ ready to be written to gmem
// }
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/);
cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 /*id*/);
static_assert(!Mma_dQ_is_RS);
Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
// Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N)
if constexpr (!dQ_swapAB) {
Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ);
} else {
Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ);
}
// Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt);
// Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
Tensor taccdQrdQ = rmem_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N)
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 1 /*id*/); // sdQ empty, ready to be written to
cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum);
// cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum2(_, _, _, (m_block + 1) % 2));
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 2 /*id*/); // sdQ ready to be written to gmem
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 2 + (m_block + 1) % 2 /*id*/); // sdQ ready to be written to gmem
}
Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{}));
pipeline_q.consumer_wait(smem_pipe_read_q);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tSrK, tSrQ(_, _, _, smem_pipe_read_q.index()), tSrS);
Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{}));
pipeline_do.consumer_wait(smem_pipe_read_do);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tdPrV, tdPrdO(_, _, _, smem_pipe_read_do.index()), tdPrdP);
// sLSE and sdPsum are done loading for WG 1 or 2
// cutlass::arch::NamedBarrier::sync(128 + 32, 3 + cutlass::canonical_warp_group_idx() /*id*/);
// Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
// #pragma unroll
// for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = sLSE(get<1>(taccScS_row(mi))); }
warpgroup_wait<1>();
// Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout()));
flash::scale_apply_exp2</*Scale=*/true, /*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
// #pragma unroll
// for (int mi = 0; mi < size<0>(lse); ++mi) { lse(mi) *= float(M_LOG2E); }
// #pragma unroll
// for (int mi = 0; mi < size<0>(scores); ++mi) {
// // const float lse_scaled = lse(mi) * float(M_LOG2E);
// const float lse_scaled = __shfl_sync(0xffffffff, lse(mi / 8), (mi % 8) * 4 + (threadIdx.x % 4));
// // const float lse_scaled = __shfl_xor_sync(0xffffffff, lse(mi / 8), 1 << (mi % 4)) * float(M_LOG2E);
// // const float lse_scaled = lse(mi);
// #pragma unroll
// for (int ni = 0; ni < size<1>(scores); ++ni) {
// scores(mi, ni) = exp2f(scores(mi, ni) * params.scale_softmax_log2 - lse_scaled);
// }
// }
// if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(scores); printf("\n"); }
// Tensor dP_sum = make_fragment_like(lse);
// sLSE and sdPsum are done loading for WG 1 or 2
// cutlass::arch::NamedBarrier::sync(128 + 32, 3 + cutlass::canonical_warp_group_idx() /*id*/);
// #pragma unroll
// for (int mi = 0; mi < size(dP_sum); ++mi) { dP_sum(mi) = sdPsum(get<1>(taccScS_row(mi))); }
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(tSrS);
warpgroup_wait<0>();
// Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
// if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dS); printf("\n"); }
// if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dP_sum); printf("\n"); }
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
// const float dP_sum_cur = __shfl_sync(0xffffffff, dP_sum(mi / 8), (mi % 8) * 4 + (threadIdx.x % 4));
// const float dP_sum_cur = __shfl_xor_sync(0xffffffff, dP_sum(mi / 8), 1 << (mi % 4));
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); }
// for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur); }
}
// sLSE and sdPsum are done processing, can load for the next iteration
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 3 /*id*/);
// if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dS); printf("\n"); }
Tensor rdS = flash::convert_type<Element>(tdPrdP);
Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
if (m_block > 0) {
gLSE.data() = gLSE.data() + (-int(kBlockM));
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
}
// #pragma unroll
// for (int mi = 0; mi < size(lse); ++mi) {
// // const int row = get<1>(taccScS_row(mi));
// const int row_idx = mi * 8 + (threadIdx.x % 32) / 4;
// const int row = kStatsDivisibleBy8 || row_idx < size(taccScS_row) ? get<1>(taccScS_row(row_idx)) : 0;
// lse(mi) = gLSE(row);
// dP_sum(mi) = gdPsum(row);
// }
Tensor lse_float2 = recast<float2>(lse);
Tensor dP_sum_float2 = recast<float2>(dP_sum);
#pragma unroll
for (int mi = 0; mi < size(lse) / 2; ++mi) {
const int row = get<1>(taccScS_row(mi * 2));
lse_float2(mi) = *reinterpret_cast<float2*>(&(gLSE(row)));
dP_sum_float2(mi) = *reinterpret_cast<float2*>(&(gdPsum(row)));
}
static_assert(!dKV_swapAB);
Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadKV>(tSrS.layout()));
Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrdV);
++smem_pipe_read_do;
// warpgroup_wait<0>();
// Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout()));
// if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dV_tmp); printf("\n"); }
Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadKV>(tdPrdP.layout()));
Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiledMmadKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdK);
++smem_pipe_read_q;
// warpgroup_wait<0>();
// Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout()));
// if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dK_tmp); printf("\n"); }
pipeline_do.consumer_release(smem_pipe_release_do); // release V
++smem_pipe_release_do;
pipeline_q.consumer_release(smem_pipe_release_q); // release V
++smem_pipe_release_q;
// warpgroup_wait<0>();
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); }
}
{
// SMEM fence to make sure sP is written before it's read by WGMMA
cutlass::arch::fence_view_async_shared();
// dS is already written to smem, and the smem for dQ is empty (from warp 1 doing TMA_REDUCE_ADD)
// int warp_group_idx = cutlass::canonical_warp_group_idx();
// if (warp_group_idx == 1 + (m_block + 1) % 2) {
// // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/);
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 4);
// } else {
// // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/);
// cutlass::arch::NamedBarrier::sync(NumMmaThreads, 4);
// static_assert(!Mma_dQ_is_RS);
// Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
// if constexpr (!dQ_swapAB) {
// Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS);
// Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
// flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ);
// } else {
// Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS);
// Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt);
// flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ);
// }
// Tensor taccdQrdQ = rmem_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N)
// cutlass::arch::NamedBarrier::sync(NumMmaThreads / 2 + 32, 0 + (m_block + 1) % 2 /*id*/);
// cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum2(_, _, _, (m_block + 1) % 2));
// cutlass::arch::fence_view_async_shared();
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads / 2 + 32, 2 + (m_block + 1) % 2 /*id*/); // sdQ ready to be written to gmem
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dQ_tmp); printf("\n"); }
// // if (blockIdx.x == 0 && threadIdx.x == 128) { print(taccdQrdQ); printf("\n"); print(tdQsdQaccum2); printf("\n"); }
// }
cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 /*id*/);
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 + 0 % 2 /*id*/);
// cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0 /*id*/);
static_assert(!Mma_dQ_is_RS);
Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
if constexpr (!dQ_swapAB) {
Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ);
} else {
Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS);
Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ);
}
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dQ_tmp); printf("\n"); }
Tensor taccdQrdQ = rmem_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N)
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 1 /*id*/); // sdQ empty, ready to be written to
cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum);
// cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum2(_, _, _, 0 % 2));
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 2 /*id*/); // sdQ ready to be written to gmem
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 2 + 0 % 2 /*id*/); // sdQ ready to be written to gmem
// if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(sdQ); printf("\n"); }
}
} else { // !SdP_swapAB
Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// Allocate "fragments/descriptors"
Tensor tSrQ = threadMmaSdP.partition_fragment_A(sQ);
Tensor tSrK = threadMmaSdP.partition_fragment_B(sK);
Tensor tdPrdO = threadMmaSdP.partition_fragment_A(sdO);
Tensor tdPrV = threadMmaSdP.partition_fragment_B(sV);
Tensor caccS = make_identity_tensor(select<0, 1>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N)
static_assert(decltype(size<0, 0>(taccScS))::value == 2);
static_assert(decltype(size<0, 1>(taccScS))::value == 2);
// taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices.
Tensor taccScS_row = taccScS(make_coord(_0{}, _, _0{}), _, _0{});
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccScS_row(mi));
lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0;
}
// if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); }
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
clear(tdKrdK);
clear(tdVrdV);
shared_storage.barrier_K.wait(0);
shared_storage.barrier_V.wait(0);
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block >= 0; --m_block) {
Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{}));
pipeline_q.consumer_wait(smem_pipe_read_q);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tSrQ(_, _, _, smem_pipe_read_q.index()), tSrK, tSrS);
Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{}));
pipeline_do.consumer_wait(smem_pipe_read_do);
// if (blockIdx.x == 0 && blockIdx.z == 0 && threadIdx.x == 128) { printf("After dO wait\n"); }
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmaSdP, tdPrdO(_, _, _, smem_pipe_read_do.index()), tdPrV, tdPrdP);
warpgroup_wait<1>();
// Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout()));
flash::scale_apply_exp2</*Scale=*/true, /*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
// if (blockIdx.x == 0 && blockIdx.z == 0 && threadIdx.x == 128) { print_tensor(scores); printf("\n"); }
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(tSrS);
Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N)
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 8 /*id*/);
cute::copy(smem_tiled_copy_PdS, tPaP, tPsP);
int const warp_group_idx = cutlass::canonical_warp_group_idx() - 1;
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 4 + warp_group_idx /*id*/);
// if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("After barrier arrive 4, tidx = %d\n", threadIdx.x); }
warpgroup_wait<0>();
// Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
// if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dS); printf("\n"); }
// if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dP_sum); printf("\n"); }
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); }
}
// if (blockIdx.x == 0 && blockIdx.z == 0 && threadIdx.x == 128) { print_tensor(dS); printf("\n"); }
Tensor rdS = flash::convert_type<Element>(tdPrdP);
Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 6 + warp_group_idx /*id*/);
// if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("After barrier arrive 6, tidx = %d\n", threadIdx.x); }
if (m_block > 0) {
gLSE.data() = gLSE.data() + (-int(kBlockM));
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
}
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<1>(taccScS_row(mi));
lse(mi) = gLSE(row);
dP_sum(mi) = gdPsum(row);
}
Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
if constexpr (Mma_dQ_is_RS) {
static_assert(!dQ_swapAB);
Tensor tdQrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<typename Ktraits::TiledMmadQ>(tdPrdP.layout()));
Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ);
// if (cute::thread0()) { print(tdQrdS); printf("\n"); print(tdQrK); printf("\n"); print(tdQrdQ); printf("\n"); }
}
cutlass::arch::fence_view_async_shared();
// if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("Before barrier sync 4, tidx = %d\n", threadIdx.x); }
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 4 + 1 - warp_group_idx /*id*/);
// if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("After barrier sync 4, tidx = %d\n", threadIdx.x); }
// if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128)) { print_tensor(sPt); printf("\n"); }
// if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128)) { print_tensor(sdOt); printf("\n"); }
if constexpr (!dKV_swapAB) {
Tensor tdVrP = threadMmadKV.partition_fragment_A(sPt);
Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrdV);
} else {
Tensor tdVrP = threadMmadKV.partition_fragment_B(sPt);
Tensor tdVrdO = threadMmadKV.partition_fragment_A(sdOt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrP, tdVrdV);
}
++smem_pipe_read_do;
// warpgroup_wait<0>();
// Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout()));
// if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dV_tmp); printf("\n"); }
cutlass::arch::fence_view_async_shared();
// if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("Before barrier sync 6, tidx = %d\n", threadIdx.x); }
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 6 + 1 - warp_group_idx /*id*/);
// if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("After barrier sync 6, tidx = %d\n", threadIdx.x); }
if constexpr (!dKV_swapAB) {
Tensor tdKrdS = threadMmadKV.partition_fragment_A(sdSt);
Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdK);
} else {
Tensor tdKrdS = threadMmadKV.partition_fragment_B(sdSt);
Tensor tdKrQ = threadMmadKV.partition_fragment_A(sQt);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiledMmadKV, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdS, tdKrdK);
}
++smem_pipe_read_q;
warpgroup_wait<0>();
// Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout()));
// if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dK_tmp); printf("\n"); }
pipeline_do.consumer_release(smem_pipe_release_do); // release V
++smem_pipe_release_do;
pipeline_q.consumer_release(smem_pipe_release_q); // release V
++smem_pipe_release_q;
// warpgroup_wait<0>();
// Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout()));
// if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); }
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 8 /*id*/);
}
}
// Epilogue
Tensor sdK = make_tensor(make_smem_ptr(shared_storage.smem_dk.data()), typename Ktraits::SmemLayoutdK{});
Tensor sdV = make_tensor(make_smem_ptr(shared_storage.smem_dv.data()), typename Ktraits::SmemLayoutdV{});
Tensor sdKt = make_tensor(make_smem_ptr(shared_storage.smem_dk.data()), typename Ktraits::SmemLayoutdKt{});
Tensor sdVt = make_tensor(make_smem_ptr(shared_storage.smem_dv.data()), typename Ktraits::SmemLayoutdVt{});
auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Ktraits::SmemCopyAtomdKV{}, tiledMmadKV);
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(threadIdx.x - NumCopyThreads);
int n_block = blockIdx.x;
bidb = blockIdx.z; // The block index for the batch.
bidh = blockIdx.y; // The block index for the head.
Tensor mdK = tma_store_dK.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor mdV = tma_store_dV.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b));
Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
auto block_tma_dK = tma_store_dK.get_slice(_0{});
auto block_tma_dV = tma_store_dV.get_slice(_0{});
Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K)
Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K)
Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K)
Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
// Very slightly faster to do the smem write and TMA write for dV first, then do the same for dK,
// Instead of doing both at the same time.
Tensor tdVrdV_out = convert_type<Element>(tdVrdV);
#pragma unroll
for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.scale_softmax; }
Tensor tdKrdK_out = convert_type<Element>(tdKrdK);
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N)
// Can't use __syncthreads() in WS code
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(NumMmaThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
synchronize();
if constexpr (!dKV_swapAB) {
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
} else {
Tensor taccdVsdVt = smem_thr_copy_dKV.partition_D(sdVt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdVt);
}
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
synchronize();
lane_predicate = cute::elect_one_sync();
warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == NumCopyThreads / cutlass::NumThreadsPerWarp && lane_predicate) {
cute::copy(tma_store_dV, tdVsdV, tdVgdV);
tma_store_arrive();
}
if constexpr (!dKV_swapAB) {
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
} else {
Tensor taccdKsdKt = smem_thr_copy_dKV.partition_D(sdKt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdKt);
}
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
synchronize();
if (warp_idx == NumCopyThreads / cutlass::NumThreadsPerWarp && lane_predicate) {
cute::copy(tma_store_dK, tdKsdK, tdKgdK);
tma_store_arrive();
}
tma_store_wait<0>();
}
}
} // namespace flash