/****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include #include #include #include #include #include #include "cutlass/pipeline/pipeline.hpp" #include "flash.h" #include "utils.h" #include "softmax.h" namespace flash { using namespace cute; template __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) compute_dqkv(CUTE_GRID_CONSTANT Flash_bwd_params const params, CUTE_GRID_CONSTANT TiledCopyQ const tma_load_Q, CUTE_GRID_CONSTANT TiledCopydO const tma_load_dO, CUTE_GRID_CONSTANT TiledCopyK const tma_load_K, CUTE_GRID_CONSTANT TiledCopyV const tma_load_V, CUTE_GRID_CONSTANT TiledCopydK const tma_store_dK, CUTE_GRID_CONSTANT TiledCopydV const tma_store_dV) { using Element = typename Ktraits::Element; using ElementAccum = typename Ktraits::ElementAccum; using SoftType = ElementAccum; using index_t = typename Ktraits::index_t; using TileShape_MNK = typename Ktraits::TileShape_MNK; using ClusterShape = typename Ktraits::ClusterShape_MNK; static constexpr int kNThreads = Ktraits::kNThreads; // static constexpr int NumMmaThreads = size(typename Ktraits::TiledMmaSdP{}); static constexpr int NumMmaThreads = Ktraits::kNThreads; static constexpr int kBlockM = Ktraits::kBlockM; // static constexpr int kBlockN = Ktraits::kBlockN; // constexpr int kHeadDim = Ktraits::kHeadDim; static constexpr int kStages = Ktraits::kStages; static constexpr bool SdP_swapAB = Ktraits::SdP_swapAB; static constexpr bool dKV_swapAB = Ktraits::dKV_swapAB; static constexpr bool dQ_swapAB = Ktraits::dQ_swapAB; static constexpr bool Mma_dQ_is_RS = Ktraits::Mma_dQ_is_RS; if constexpr (dQ_swapAB) { static_assert(!Mma_dQ_is_RS); } using MainloopPipeline = typename Ktraits::MainloopPipeline; using PipelineParams = typename MainloopPipeline::Params; using PipelineState = typename MainloopPipeline::PipelineState; extern __shared__ char shared_memory[]; auto &shared_storage = *reinterpret_cast(shared_memory); int const n_block = blockIdx.x; int const bidb = blockIdx.z; // The block index for the batch. int const bidh = blockIdx.y; // The block index for the head. int lane_predicate = cute::elect_one_sync(); int warp_idx = cutlass::canonical_warp_idx_sync(); // Issue Tma Descriptor Prefetch from a single thread if (warp_idx == 0 && lane_predicate) { cute::prefetch_tma_descriptor(tma_load_Q.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_load_dO.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_load_K.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_load_V.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_store_dK.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_store_dV.get_tma_descriptor()); } Tensor mQ = tma_load_Q.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); Tensor mdO = tma_load_dO.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); Tensor mK = tma_load_K.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); Tensor mV = tma_load_V.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), make_shape(params.b, params.h, params.seqlen_q), make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); Tensor mdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum)), make_shape(params.b, params.h, params.seqlen_q), make_stride(params.h * params.seqlen_q_rounded, params.seqlen_q_rounded, _1{})); Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr)), make_shape(params.seqlen_q, params.d, params.h, params.b), make_stride(params.d * params.h, _1{}, params.d, params.d * params.h * params.seqlen_q_rounded)); Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) Tensor gdO = local_tile(mdO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) Tensor gdQaccum = local_tile(mdQaccum(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) // if (cute::thread0()) { print(tma_load_K); printf("\n"); } // if (cute::thread0()) { print(mK); printf("\n"); print(gK); printf("\n"); } typename Ktraits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum; auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(threadIdx.x); Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); // Construct SMEM tensors. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQ{}); Tensor sdO = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdO{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Ktraits::SmemLayoutV{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutP{}); Tensor sdS = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdS{}); Tensor sQt = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQt{}); Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdOt{}); Tensor sKt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutKt{}); Tensor sPt = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutPt{}); Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdSt{}); // Prepare the TMA loads uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; auto block_tma_Q = tma_load_Q.get_slice(cluster_local_block_id.y); auto block_tma_dO = tma_load_dO.get_slice(cluster_local_block_id.y); auto block_tma_K = tma_load_K.get_slice(_0{}); auto block_tma_V = tma_load_V.get_slice(_0{}); Tensor tQgQ = block_tma_Q.partition_S(gQ); // (TMA, TMA_M, TMA_K, k) Tensor tQsQ = block_tma_Q.partition_D(sQ); // (TMA, TMA_M, TMA_K, PIPE) Tensor tdOgdO = block_tma_dO.partition_S(gdO); // (TMA, TMA_M, TMA_K, k) Tensor tdOsdO = block_tma_dO.partition_D(sdO); // (TMA, TMA_M, TMA_K, PIPE) Tensor tKgK = block_tma_K.partition_S(gK); // (TMA, TMA_N, TMA_K) Tensor tKsK = block_tma_K.partition_D(sK); // (TMA, TMA_N, TMA_K) Tensor tVgV = block_tma_V.partition_S(gV); // (TMA, TMA_N, TMA_K) Tensor tVsV = block_tma_V.partition_D(sV); // (TMA, TMA_N, TMA_K) // if (cute::thread0()) { print(tQgQ); printf("\n"); print(tQsQ); printf("\n"); } // if (cute::thread0()) { print(tKgK); printf("\n"); print(tKsK); printf("\n"); } // Set the bytes transferred in this TMA transaction (may involve multiple issues) constexpr uint32_t TmaTransactionBytesQ = static_cast(size<0>(sQ) * size<1>(sQ) * cutlass::sizeof_bits_v / 8); constexpr uint32_t TmaTransactionBytesdO = static_cast(size<0>(sdO) * size<1>(sdO) * cutlass::sizeof_bits_v / 8); static_assert(TmaTransactionBytesQ == TmaTransactionBytesdO); constexpr uint32_t TmaTransactionBytesK = static_cast(size<0>(sK) * size<1>(sK) * cutlass::sizeof_bits_v / 8); constexpr uint32_t TmaTransactionBytesV = static_cast(size<0>(sV) * size<1>(sV) * cutlass::sizeof_bits_v / 8); static_assert(TmaTransactionBytesK == TmaTransactionBytesV); // Obtain warp index int thread_idx = int(threadIdx.x); int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup; // int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; PipelineParams pipeline_params; pipeline_params.transaction_bytes = TmaTransactionBytesQ; pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer; pipeline_params.is_leader = warp_group_thread_idx == 0; pipeline_params.num_consumers = NumMmaThreads; if (warp_idx == 0 && lane_predicate) { shared_storage.barrier_K.init(1 /*numThreads*/); shared_storage.barrier_V.init(1 /*numThreads*/); } // cutlass::arch::fence_barrier_init(); // We're counting on pipeline_q to call fence_barrier_init(); MainloopPipeline pipeline_q(shared_storage.pipeline_q, pipeline_params, ClusterShape{}); MainloopPipeline pipeline_do(shared_storage.pipeline_do, pipeline_params, ClusterShape{}); // We need this to guarantee that the Pipeline init is visible // To all producers and consumer blocks in the Cluster if constexpr (size(ClusterShape{}) > 1) { cute::cluster_arrive_relaxed(); cute::cluster_wait(); } else { __syncthreads(); } // State variables used for iterating the circular buffer // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA // smem_pipe_write is used by the producer of SMEM data - i.e TMA PipelineState smem_pipe_read_q, smem_pipe_read_do; PipelineState smem_pipe_release_q, smem_pipe_release_do; PipelineState smem_pipe_write_q = cutlass::make_producer_start_state(); PipelineState smem_pipe_write_do = cutlass::make_producer_start_state(); // Copy K tile and V tile from GMEM to SMEM. if (warp_idx == 0 && lane_predicate) { shared_storage.barrier_K.arrive_and_expect_tx(TmaTransactionBytesK); copy(tma_load_K.with(reinterpret_cast(shared_storage.barrier_K), 0 /*mcast_mask*/), tKgK, tKsK); shared_storage.barrier_V.arrive_and_expect_tx(TmaTransactionBytesV); copy(tma_load_V.with(reinterpret_cast(shared_storage.barrier_V), 0 /*mcast_mask*/), tVgV, tVsV); } // if (cute::thread0()) { print_tensor(sQ); printf("\n"); } __syncthreads(); int m_block = cute::ceil_div(params.seqlen_q, kBlockM) - 1; uint16_t mcast_mask_qdo = 0; if constexpr (cute::is_same_v) { auto block_layout = Layout{}; // (m,n) -> block_id for (int n = 0; n < size<1>(block_layout); ++n) { mcast_mask_qdo |= (uint16_t(1) << block_layout(n, cluster_local_block_id.x, _0{})); } } // Issue TmaLoads (Prologue fetches) if (warp_idx == 0 && lane_predicate) { // Issue the prologue loads CUTLASS_PRAGMA_UNROLL for (int stage = 0; stage < kStages && stage <= m_block; ++stage) { pipeline_q.producer_acquire(smem_pipe_write_q); copy(tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo), tQgQ(_, _, _, m_block - stage), tQsQ(_, _, _, stage)); ++smem_pipe_write_q; pipeline_do.producer_acquire(smem_pipe_write_do); copy(tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do), mcast_mask_qdo), tdOgdO(_, _, _, m_block - stage), tdOsdO(_, _, _, stage)); ++smem_pipe_write_do; } } Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); Tensor gdPsum = local_tile(mdPsum(bidb, bidh, _), Shape>{}, make_coord(m_block)); // Initialize matmul objects. typename Ktraits::TiledMmaSdP tiledMmaSdP; auto threadMmaSdP = tiledMmaSdP.get_thread_slice(threadIdx.x); typename Ktraits::TiledMmadKV tiledMmadKV; auto threadMmadKV = tiledMmadKV.get_thread_slice(threadIdx.x); typename Ktraits::TiledMmadQ tiledMmadQ; auto threadMmadQ = tiledMmadQ.get_thread_slice(threadIdx.x); // Allocate accumulator Tensor tdKrdK = partition_fragment_C(tiledMmadKV, select(TileShape_MNK{})); Tensor tdVrdV = partition_fragment_C(tiledMmadKV, select(TileShape_MNK{})); auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Ktraits::SmemCopyAtomPdS{}, tiledMmaSdP); auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(threadIdx.x); if constexpr (!SdP_swapAB) { Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N) // Allocate "fragments/descriptors" Tensor tSrQ = threadMmaSdP.partition_fragment_A(sQ); Tensor tSrK = threadMmaSdP.partition_fragment_B(sK); Tensor tdPrdO = threadMmaSdP.partition_fragment_A(sdO); Tensor tdPrV = threadMmaSdP.partition_fragment_B(sV); Tensor caccS = make_identity_tensor(select<0, 1>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n) Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N) static_assert(decltype(size<0, 0>(taccScS))::value == 2); static_assert(decltype(size<0, 1>(taccScS))::value == 2); // taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices. Tensor taccScS_row = taccScS(make_coord(_0{}, _, _0{}), _, _0{}); Tensor lse = make_tensor(Shape>{}); Tensor dP_sum = make_fragment_like(lse); #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccScS_row(mi)); lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0; } // if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); } // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply // with V (which would be zero), we're fine. However, with ALiBi, we might modify these // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. clear(tdKrdK); clear(tdVrdV); shared_storage.barrier_K.wait(0); shared_storage.barrier_V.wait(0); __syncthreads(); // #pragma unroll 2 CUTLASS_PRAGMA_NO_UNROLL for (; m_block >= 0; --m_block) { Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{})); pipeline_q.consumer_wait(smem_pipe_read_q); __syncwarp(); flash::gemm(tiledMmaSdP, tSrQ(_, _, _, smem_pipe_read_q.index()), tSrK, tSrS); Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{})); pipeline_do.consumer_wait(smem_pipe_read_do); __syncwarp(); flash::gemm(tiledMmaSdP, tdPrdO(_, _, _, smem_pipe_read_do.index()), tdPrV, tdPrdP); warpgroup_wait<1>(); // Reshape tSrS from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); // if (cute::thread0()) { print_tensor(scores); printf("\n"); } // Convert scores from fp32 to fp16/bf16 Tensor rP = flash::convert_type(tSrS); Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); int const warp_group_idx = cutlass::canonical_warp_group_idx(); cutlass::arch::NamedBarrier::arrive(kNThreads, warp_group_idx /*id*/); warpgroup_wait<0>(); // Reshape tdPrdP from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); // if (cute::thread0()) { print_tensor(dS); printf("\n"); } #pragma unroll for (int mi = 0; mi < size<0>(dS); ++mi) { #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); } } Tensor rdS = flash::convert_type(tdPrdP); Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); // cutlass::arch::NamedBarrier::arrive(kNThreads, 1 /*id*/); cutlass::arch::NamedBarrier::arrive(kNThreads, 2 + warp_group_idx /*id*/); // if (cute::thread0()) { print_tensor(dS); printf("\n"); } if (m_block > 0) { gLSE.data() = gLSE.data() + (-int(kBlockM)); gdPsum.data() = gdPsum.data() + (-int(kBlockM)); #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccScS_row(mi)); lse(mi) = gLSE(row); dP_sum(mi) = gdPsum(row); } } Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select(TileShape_MNK{})); if constexpr (Mma_dQ_is_RS) { static_assert(!dQ_swapAB); Tensor tdQrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); flash::gemm(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ); // if (cute::thread0()) { print(tdQrdS); printf("\n"); print(tdQrK); printf("\n"); print(tdQrdQ); printf("\n"); } } // warpgroup_wait<0>(); // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); // if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); } // if (cute::thread0()) { print_tensor(sK); printf("\n"); } // if (cute::thread0()) { print_tensor(sKt); printf("\n"); } __syncthreads(); // __syncthreads(); // Without this I'm getting race condition, I thought the barrier would be enough // SMEM fence to make sure sP is written before it's read by WGMMA cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(kNThreads, 1 - warp_group_idx /*id*/); if constexpr (!dKV_swapAB) { Tensor tdVrP = threadMmadKV.partition_fragment_A(sPt); Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt); flash::gemm(tiledMmadKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrdV); } else { Tensor tdVrP = threadMmadKV.partition_fragment_B(sPt); Tensor tdVrdO = threadMmadKV.partition_fragment_A(sdOt); flash::gemm(tiledMmadKV, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrP, tdVrdV); } ++smem_pipe_read_do; // warpgroup_wait<0>(); // Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout())); // if (cute::thread0()) { print_tensor(dV_tmp); printf("\n"); } cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(kNThreads, 2 + 1 - warp_group_idx /*id*/); if constexpr (!Mma_dQ_is_RS) { if constexpr (!dQ_swapAB) { Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS); Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); flash::gemm(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ); } else { Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS); Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt); flash::gemm(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ); } } // warpgroup_wait<0>(); // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); // if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); } // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dQ_tmp); printf("\n"); } if constexpr (!dKV_swapAB) { Tensor tdKrdS = threadMmadKV.partition_fragment_A(sdSt); Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt); flash::gemm(tiledMmadKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdK); } else { Tensor tdKrdS = threadMmadKV.partition_fragment_B(sdSt); Tensor tdKrQ = threadMmadKV.partition_fragment_A(sQt); flash::gemm(tiledMmadKV, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdS, tdKrdK); } ++smem_pipe_read_q; // Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout())); // if (cute::thread0()) { print_tensor(dK_tmp); printf("\n"); } warpgroup_wait(); // if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); } Tensor tdQrdQ_atomic = recast(tdQrdQ); Tensor tdQgdQaccum_atomic = recast(tdQgdQaccum(_, _, _, m_block)); #pragma unroll for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } // for (int i = 0; i < size(tdQrdQ_atomic); ++i) { tdQgdQaccum_atomic(i) = tdQrdQ_atomic(i); } warpgroup_wait<0>(); pipeline_do.consumer_release(smem_pipe_release_do); // release V ++smem_pipe_release_do; pipeline_q.consumer_release(smem_pipe_release_q); // release V ++smem_pipe_release_q; int const lane_predicate = cute::elect_one_sync(); int const warp_idx = cutlass::canonical_warp_idx_sync(); if (warp_idx == 0 && lane_predicate && m_block >= kStages) { pipeline_q.producer_acquire(smem_pipe_write_q); copy(tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo), tQgQ(_, _, _, m_block - kStages), tQsQ(_, _, _, smem_pipe_write_q.index())); ++smem_pipe_write_q; pipeline_do.producer_acquire(smem_pipe_write_do); copy(tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do), mcast_mask_qdo), tdOgdO(_, _, _, m_block - kStages), tdOsdO(_, _, _, smem_pipe_write_do.index())); ++smem_pipe_write_do; } } } else { // SdP_swapAB Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdSt); // ((Atom,AtomNum),PIPE_M,PIPE_N) // Allocate "fragments/descriptors" Tensor tSrQ = threadMmaSdP.partition_fragment_B(sQ); Tensor tSrK = threadMmaSdP.partition_fragment_A(sK); Tensor tdPrdO = threadMmaSdP.partition_fragment_B(sdO); Tensor tdPrV = threadMmaSdP.partition_fragment_A(sV); Tensor caccS = make_identity_tensor(select<1, 0>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n) Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N) static_assert(decltype(size<0, 0>(taccScS))::value == 2); static_assert(decltype(size<0, 1>(taccScS))::value == 2); // taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices. Tensor taccScS_row = taccScS(make_coord(_, _0{}, _), _0{}, _); Tensor lse = make_tensor(Shape>{}); Tensor dP_sum = make_fragment_like(lse); #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<1>(taccScS_row(mi)); lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0; } // cute::fill(lse, 1); // cute::fill(dP_sum, 1); // if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); } // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply // with V (which would be zero), we're fine. However, with ALiBi, we might modify these // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. clear(tdKrdK); clear(tdVrdV); shared_storage.barrier_K.wait(0); shared_storage.barrier_V.wait(0); // #pragma unroll 2 CUTLASS_PRAGMA_NO_UNROLL for (; m_block >= 0; --m_block) { Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{})); pipeline_q.consumer_wait(smem_pipe_read_q); flash::gemm(tiledMmaSdP, tSrK, tSrQ(_, _, _, smem_pipe_read_q.index()), tSrS); Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{})); pipeline_do.consumer_wait(smem_pipe_read_do); flash::gemm(tiledMmaSdP, tdPrV, tdPrdO(_, _, _, smem_pipe_read_do.index()), tdPrdP); warpgroup_wait<1>(); // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout())); flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); // if (cute::thread0()) { print_tensor(scores); printf("\n"); } // Convert scores from fp32 to fp16/bf16 Tensor rP = flash::convert_type(tSrS); static_assert(!dKV_swapAB); Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt); flash::gemm(tiledMmadKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrdV); ++smem_pipe_read_do; // warpgroup_wait<0>(); // Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout())); // if (cute::thread0()) { print_tensor(dV_tmp); printf("\n"); } warpgroup_wait<1>(); // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); #pragma unroll for (int mi = 0; mi < size<0>(dS); ++mi) { #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); } } // if (cute::thread0()) { print_tensor(dS); printf("\n"); } Tensor rdS = flash::convert_type(tdPrdP); Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); if (m_block > 0) { gLSE.data() = gLSE.data() + (-int(kBlockM)); gdPsum.data() = gdPsum.data() + (-int(kBlockM)); #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<1>(taccScS_row(mi)); lse(mi) = gLSE(row); dP_sum(mi) = gdPsum(row); } } Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt); flash::gemm(tiledMmadKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdK); ++smem_pipe_read_q; // warpgroup_wait<0>(); // Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout())); // if (cute::thread0()) { print_tensor(dK_tmp); printf("\n"); } // SMEM fence to make sure sP is written before it's read by WGMMA cutlass::arch::fence_view_async_shared(); // cutlass::arch::NamedBarrier::sync(kNThreads, 0 /*id*/); __syncthreads(); static_assert(!Mma_dQ_is_RS); Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select(TileShape_MNK{})); if constexpr (!dQ_swapAB) { Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS); Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); flash::gemm(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ); } else { Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS); Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt); flash::gemm(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ); } // warpgroup_wait<0>(); // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); // if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); } warpgroup_wait<0>(); // if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); } Tensor tdQrdQ_atomic = recast(tdQrdQ); Tensor tdQgdQaccum_atomic = recast(tdQgdQaccum(_, _, _, m_block)); #pragma unroll for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } // for (int i = 0; i < size(tdQrdQ_atomic); ++i) { tdQgdQaccum_atomic(i) = tdQrdQ_atomic(i); } pipeline_do.consumer_release(smem_pipe_release_do); // release V ++smem_pipe_release_do; pipeline_q.consumer_release(smem_pipe_release_q); // release V ++smem_pipe_release_q; int const lane_predicate = cute::elect_one_sync(); int const warp_idx = cutlass::canonical_warp_idx_sync(); if (warp_idx == 0 && lane_predicate && m_block >= kStages) { pipeline_q.producer_acquire(smem_pipe_write_q); copy(tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo), tQgQ(_, _, _, m_block - kStages), tQsQ(_, _, _, smem_pipe_write_q.index())); ++smem_pipe_write_q; pipeline_do.producer_acquire(smem_pipe_write_do); copy(tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do), mcast_mask_qdo), tdOgdO(_, _, _, m_block - kStages), tdOsdO(_, _, _, smem_pipe_write_do.index())); ++smem_pipe_write_do; } } } // Epilogue #pragma unroll for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.scale_softmax; } Tensor tdKrdK_out = convert_type(tdKrdK); Tensor tdVrdV_out = convert_type(tdVrdV); Tensor sdK = make_tensor(make_smem_ptr(shared_storage.smem_dk.data()), typename Ktraits::SmemLayoutdK{}); Tensor sdV = make_tensor(make_smem_ptr(shared_storage.smem_dv.data()), typename Ktraits::SmemLayoutdV{}); Tensor sdKt = make_tensor(make_smem_ptr(shared_storage.smem_dk.data()), typename Ktraits::SmemLayoutdKt{}); Tensor sdVt = make_tensor(make_smem_ptr(shared_storage.smem_dv.data()), typename Ktraits::SmemLayoutdVt{}); auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Ktraits::SmemCopyAtomdKV{}, tiledMmadKV); auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(threadIdx.x); Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N) __syncthreads(); if constexpr (!dKV_swapAB) { Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); } else { Tensor taccdKsdKt = smem_thr_copy_dKV.partition_D(sdKt); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccdVsdVt = smem_thr_copy_dKV.partition_D(sdVt); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdKt); cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdVt); } cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA Tensor mdK = tma_store_dK.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); Tensor mdV = tma_store_dV.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) auto block_tma_dK = tma_store_dK.get_slice(_0{}); auto block_tma_dV = tma_store_dV.get_slice(_0{}); Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K) Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K) Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K) Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K) __syncthreads(); // ensure all threads have issued their async fence lane_predicate = cute::elect_one_sync(); warp_idx = cutlass::canonical_warp_idx_sync(); if (warp_idx == 0 && lane_predicate) { cute::copy(tma_store_dV, tdVsdV, tdVgdV); cute::copy(tma_store_dK, tdKsdK, tdKgdK); tma_store_arrive(); } tma_store_wait<0>(); // To make sure remote SMEM doesn't get destroyed if constexpr (size(ClusterShape{}) > 1) { cute::cluster_arrive(); cute::cluster_wait(); } } template __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) compute_dqkv_seqqpar(CUTE_GRID_CONSTANT Flash_bwd_params const params, CUTE_GRID_CONSTANT TiledCopyQ const tma_load_Q, CUTE_GRID_CONSTANT TiledCopydO const tma_load_dO, CUTE_GRID_CONSTANT TiledCopyK const tma_load_K, CUTE_GRID_CONSTANT TiledCopyV const tma_load_V, CUTE_GRID_CONSTANT TiledCopydQ const tma_store_dQ, CUTE_GRID_CONSTANT TiledCopydK const tma_store_dK, CUTE_GRID_CONSTANT TiledCopydV const tma_store_dV) { using Element = typename Ktraits::Element; using ElementAccum = typename Ktraits::ElementAccum; using SoftType = ElementAccum; using index_t = typename Ktraits::index_t; using TileShape_MNK = typename Ktraits::TileShape_MNK; using ClusterShape = typename Ktraits::ClusterShape_MNK; static constexpr int kNThreads = Ktraits::kNThreads; static constexpr int NumMmaThreads = Ktraits::kNThreads; static constexpr int kBlockM = Ktraits::kBlockM; static constexpr int kBlockN = Ktraits::kBlockN; // constexpr int kHeadDim = Ktraits::kHeadDim; static constexpr int kStages = Ktraits::kStages; static constexpr bool SdP_swapAB = Ktraits::SdP_swapAB; static constexpr bool dKV_swapAB = Ktraits::dKV_swapAB; static constexpr bool dQ_swapAB = Ktraits::dQ_swapAB; static constexpr bool Mma_dQ_is_RS = Ktraits::Mma_dQ_is_RS; if constexpr (dQ_swapAB) { static_assert(!Mma_dQ_is_RS); } using MainloopPipeline = typename Ktraits::MainloopPipeline; using PipelineParams = typename MainloopPipeline::Params; using PipelineState = typename MainloopPipeline::PipelineState; extern __shared__ char shared_memory[]; auto &shared_storage = *reinterpret_cast(shared_memory); int const m_block = blockIdx.x; int const bidb = blockIdx.z; // The block index for the batch. int const bidh = blockIdx.y; // The block index for the head. int lane_predicate = cute::elect_one_sync(); int warp_idx = cutlass::canonical_warp_idx_sync(); // Issue Tma Descriptor Prefetch from a single thread if (warp_idx == 0 && lane_predicate) { cute::prefetch_tma_descriptor(tma_load_Q.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_load_dO.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_load_K.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_load_V.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_store_dQ.get_tma_descriptor()); } Tensor mQ = tma_load_Q.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); Tensor mdO = tma_load_dO.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); Tensor mK = tma_load_K.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); Tensor mV = tma_load_V.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), make_shape(params.b, params.h, params.seqlen_q), make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); Tensor mdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum)), make_shape(params.b, params.h, params.seqlen_q), make_stride(params.h * params.seqlen_q_rounded, params.seqlen_q_rounded, _1{})); Tensor mdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr)), make_shape(params.seqlen_k, params.d, params.h, params.b), make_stride(params.d * params.h, _1{}, params.d, params.d * params.h * params.seqlen_k)); Tensor mdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr)), make_shape(params.seqlen_k, params.d, params.h, params.b), make_stride(params.d * params.h, _1{}, params.d, params.d * params.h * params.seqlen_k)); Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) Tensor gdO = local_tile(mdO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) Tensor gdKaccum = local_tile(mdKaccum(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) Tensor gdVaccum = local_tile(mdVaccum(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) typename Ktraits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(threadIdx.x); Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum); Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum); // Construct SMEM tensors. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQ{}); Tensor sdO = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdO{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Ktraits::SmemLayoutV{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutP{}); Tensor sdS = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdS{}); Tensor sQt = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQt{}); Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdOt{}); Tensor sKt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutKt{}); Tensor sPt = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutPt{}); Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdSt{}); // Prepare the TMA loads uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; auto block_tma_Q = tma_load_Q.get_slice(_0{}); auto block_tma_dO = tma_load_dO.get_slice(_0{}); auto block_tma_K = tma_load_K.get_slice(cluster_local_block_id.x); auto block_tma_V = tma_load_V.get_slice(cluster_local_block_id.x); Tensor tQgQ = block_tma_Q.partition_S(gQ); // (TMA, TMA_M, TMA_K) Tensor tQsQ = block_tma_Q.partition_D(sQ); // (TMA, TMA_M, TMA_K) Tensor tdOgdO = block_tma_dO.partition_S(gdO); // (TMA, TMA_M, TMA_K) Tensor tdOsdO = block_tma_dO.partition_D(sdO); // (TMA, TMA_M, TMA_K) Tensor tKgK = block_tma_K.partition_S(gK); // (TMA, TMA_N, TMA_K, k) Tensor tKsK = block_tma_K.partition_D(sK); // (TMA, TMA_N, TMA_K, PIPE) Tensor tVgV = block_tma_V.partition_S(gV); // (TMA, TMA_N, TMA_K, k) Tensor tVsV = block_tma_V.partition_D(sV); // (TMA, TMA_N, TMA_K, PIPE) // Set the bytes transferred in this TMA transaction (may involve multiple issues) constexpr uint32_t TmaTransactionBytesQ = static_cast(size<0>(sQ) * size<1>(sQ) * cutlass::sizeof_bits_v / 8); constexpr uint32_t TmaTransactionBytesdO = static_cast(size<0>(sdO) * size<1>(sdO) * cutlass::sizeof_bits_v / 8); static_assert(TmaTransactionBytesQ == TmaTransactionBytesdO); constexpr uint32_t TmaTransactionBytesK = static_cast(size<0>(sK) * size<1>(sK) * cutlass::sizeof_bits_v / 8); constexpr uint32_t TmaTransactionBytesV = static_cast(size<0>(sV) * size<1>(sV) * cutlass::sizeof_bits_v / 8); static_assert(TmaTransactionBytesK == TmaTransactionBytesV); // Obtain warp index int thread_idx = int(threadIdx.x); int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup; // int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; PipelineParams pipeline_params; pipeline_params.transaction_bytes = TmaTransactionBytesK; pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer; pipeline_params.is_leader = warp_group_thread_idx == 0; pipeline_params.num_consumers = NumMmaThreads; if (warp_idx == 0 && lane_predicate) { shared_storage.barrier_Q.init(1 /*numThreads*/); shared_storage.barrier_dO.init(1 /*numThreads*/); } // cutlass::arch::fence_barrier_init(); // We're counting on pipeline_k to call fence_barrier_init(); MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{}); MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{}); // We need this to guarantee that the Pipeline init is visible // To all producers and consumer blocks in the Cluster if constexpr (size(ClusterShape{}) > 1) { cute::cluster_arrive_relaxed(); cute::cluster_wait(); } else { __syncthreads(); } // State variables used for iterating the circular buffer // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA // smem_pipe_write is used by the producer of SMEM data - i.e TMA PipelineState smem_pipe_read_k, smem_pipe_read_v; PipelineState smem_pipe_release_k, smem_pipe_release_v; PipelineState smem_pipe_write_k = cutlass::make_producer_start_state(); PipelineState smem_pipe_write_v = cutlass::make_producer_start_state(); // Copy K tile and V tile from GMEM to SMEM. if (warp_idx == 0 && lane_predicate) { shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); copy(tma_load_Q.with(reinterpret_cast(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ); shared_storage.barrier_dO.arrive_and_expect_tx(TmaTransactionBytesdO); copy(tma_load_dO.with(reinterpret_cast(shared_storage.barrier_dO), 0 /*mcast_mask*/), tdOgdO, tdOsdO); } // if (cute::thread0()) { print_tensor(sQ); printf("\n"); } __syncthreads(); int n_block = cute::ceil_div(params.seqlen_k, kBlockN) - 1; uint16_t mcast_mask_kv = 0; if constexpr (cute::is_same_v) { auto block_layout = Layout{}; // (m,n) -> block_id for (int m = 0; m < size<0>(block_layout); ++m) { mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); } } // Issue TmaLoads (Prologue fetches) if (warp_idx == 0 && lane_predicate) { // Issue the prologue loads CUTLASS_PRAGMA_UNROLL for (int stage = 0; stage < kStages && stage <= n_block; ++stage) { pipeline_k.producer_acquire(smem_pipe_write_k); copy(tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, _, _, n_block - stage), tKsK(_, _, _, stage)); ++smem_pipe_write_k; pipeline_v.producer_acquire(smem_pipe_write_v); copy(tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, _, _, n_block - stage), tVsV(_, _, _, stage)); ++smem_pipe_write_v; } } Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); Tensor gdPsum = local_tile(mdPsum(bidb, bidh, _), Shape>{}, make_coord(m_block)); // Initialize matmul objects. typename Ktraits::TiledMmaSdP tiledMmaSdP; auto threadMmaSdP = tiledMmaSdP.get_thread_slice(threadIdx.x); typename Ktraits::TiledMmadKV tiledMmadKV; auto threadMmadKV = tiledMmadKV.get_thread_slice(threadIdx.x); typename Ktraits::TiledMmadQ tiledMmadQ; auto threadMmadQ = tiledMmadQ.get_thread_slice(threadIdx.x); // Allocate accumulator Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select(TileShape_MNK{})); clear(tdQrdQ); auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Ktraits::SmemCopyAtomPdS{}, tiledMmaSdP); auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(threadIdx.x); if constexpr (!SdP_swapAB) { Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N) // Allocate "fragments/descriptors" Tensor tSrQ = threadMmaSdP.partition_fragment_A(sQ); Tensor tSrK = threadMmaSdP.partition_fragment_B(sK); Tensor tdPrdO = threadMmaSdP.partition_fragment_A(sdO); Tensor tdPrV = threadMmaSdP.partition_fragment_B(sV); Tensor caccS = make_identity_tensor(select<0, 1>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n) Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N) static_assert(decltype(size<0, 0>(taccScS))::value == 2); static_assert(decltype(size<0, 1>(taccScS))::value == 2); // taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices. Tensor taccScS_row = taccScS(make_coord(_0{}, _, _0{}), _, _0{}); Tensor lse = make_tensor(Shape>{}); Tensor dP_sum = make_fragment_like(lse); #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccScS_row(mi)); lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0; } // if (cute::thread0()) { print_tensor(lse); printf("\n"); } // if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); } // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply // with V (which would be zero), we're fine. However, with ALiBi, we might modify these // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. shared_storage.barrier_Q.wait(0); shared_storage.barrier_dO.wait(0); // #pragma unroll 2 CUTLASS_PRAGMA_NO_UNROLL for (; n_block >= 0; --n_block) { // Otherwise we might have WG0 still wating on NamedBarrier but WG1 already // started the next iteration and start flipping the same NamedBarrier. __syncthreads(); Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{})); pipeline_k.consumer_wait(smem_pipe_read_k); flash::gemm(tiledMmaSdP, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{})); pipeline_v.consumer_wait(smem_pipe_read_v); flash::gemm(tiledMmaSdP, tdPrdO, tdPrV(_, _, _, smem_pipe_read_v.index()), tdPrdP); ++smem_pipe_read_v; warpgroup_wait<1>(); // Reshape tSrS from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); // if (cute::thread0()) { print_tensor(scores); printf("\n"); } // Convert scores from fp32 to fp16/bf16 Tensor rP = flash::convert_type(tSrS); Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); int const warp_group_idx = cutlass::canonical_warp_group_idx(); cutlass::arch::NamedBarrier::arrive(kNThreads, warp_group_idx /*id*/); warpgroup_wait<0>(); // Reshape tdPrdP from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); // if (cute::thread0()) { print_tensor(dS); printf("\n"); } #pragma unroll for (int mi = 0; mi < size<0>(dS); ++mi) { #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); } } Tensor rdS = flash::convert_type(tdPrdP); Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); cutlass::arch::NamedBarrier::arrive(kNThreads, 2 + warp_group_idx /*id*/); // if (cute::thread0()) { print_tensor(dS); printf("\n"); } if constexpr (Mma_dQ_is_RS) { static_assert(!dQ_swapAB); Tensor tdQrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); flash::gemm(tiledMmadQ, tdQrdS, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdQ); // if (cute::thread0()) { print(tdQrdS); printf("\n"); print(tdQrK); printf("\n"); print(tdQrdQ); printf("\n"); } } // warpgroup_wait<0>(); // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); // if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); } // if (cute::thread0()) { print_tensor(sK); printf("\n"); } // if (cute::thread0()) { print_tensor(sKt); printf("\n"); } __syncthreads(); // if (cute::thread0()) { printf("before barrier sync 0\n"); } // SMEM fence to make sure sP is written before it's read by WGMMA cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(kNThreads, 1 - warp_group_idx /*id*/); // if (cute::thread0()) { printf("after barrier sync 0\n"); } Tensor tdVrdV = partition_fragment_C(tiledMmadKV, select(TileShape_MNK{})); if constexpr (!dKV_swapAB) { Tensor tdVrP = threadMmadKV.partition_fragment_A(sPt); Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt); flash::gemm(tiledMmadKV, tdVrP, tdVrdO, tdVrdV); } else { Tensor tdVrP = threadMmadKV.partition_fragment_B(sPt); Tensor tdVrdO = threadMmadKV.partition_fragment_A(sdOt); flash::gemm(tiledMmadKV, tdVrdO, tdVrP, tdVrdV); } // warpgroup_wait<0>(); // Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout())); // if (cute::thread0()) { print_tensor(dV_tmp); printf("\n"); } cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(kNThreads, 2 + 1 - warp_group_idx /*id*/); if constexpr (!Mma_dQ_is_RS) { if constexpr (!dQ_swapAB) { Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS); Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); flash::gemm(tiledMmadQ, tdQrdS, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdQ); } else { Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS); Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt); flash::gemm(tiledMmadQ, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdS, tdQrdQ); } } ++smem_pipe_read_k; // warpgroup_wait<0>(); // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); // if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); } // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dQ_tmp); printf("\n"); } Tensor tdKrdK = partition_fragment_C(tiledMmadKV, select(TileShape_MNK{})); if constexpr (!dKV_swapAB) { Tensor tdKrdS = threadMmadKV.partition_fragment_A(sdSt); Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt); flash::gemm(tiledMmadKV, tdKrdS, tdKrQ, tdKrdK); } else { Tensor tdKrdS = threadMmadKV.partition_fragment_B(sdSt); Tensor tdKrQ = threadMmadKV.partition_fragment_A(sQt); flash::gemm(tiledMmadKV, tdKrQ, tdKrdS, tdKrdK); } // warpgroup_wait<0>(); // Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout())); // if (cute::thread0()) { print_tensor(dK_tmp); printf("\n"); } warpgroup_wait(); // if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); } Tensor tdVrdV_atomic = recast(tdVrdV); Tensor tdVgdVaccum_atomic = recast(tdVgdVaccum(_, _, _, n_block)); #pragma unroll for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdVaccum_atomic(i), tdVrdV_atomic(i)); } // for (int i = 0; i < size(tdVrdV_atomic); ++i) { tdVgdVaccum_atomic(i) = tdVrdV_atomic(i); } warpgroup_wait<0>(); Tensor tdKrdK_atomic = recast(tdKrdK); Tensor tdKgdKaccum_atomic = recast(tdKgdKaccum(_, _, _, n_block)); #pragma unroll for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdKaccum_atomic(i), tdKrdK_atomic(i)); } pipeline_v.consumer_release(smem_pipe_release_v); // release V ++smem_pipe_release_v; pipeline_k.consumer_release(smem_pipe_release_k); // release V ++smem_pipe_release_k; int const lane_predicate = cute::elect_one_sync(); int const warp_idx = cutlass::canonical_warp_idx_sync(); if (warp_idx == 0 && lane_predicate && n_block >= kStages) { pipeline_k.producer_acquire(smem_pipe_write_k); copy(tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, _, _, n_block - kStages), tKsK(_, _, _, smem_pipe_write_k.index())); ++smem_pipe_write_k; pipeline_v.producer_acquire(smem_pipe_write_v); copy(tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, _, _, n_block - kStages), tVsV(_, _, _, smem_pipe_write_v.index())); ++smem_pipe_write_v; } } } else { // SdP_swapAB Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdSt); // ((Atom,AtomNum),PIPE_M,PIPE_N) // Allocate "fragments/descriptors" Tensor tSrQ = threadMmaSdP.partition_fragment_B(sQ); Tensor tSrK = threadMmaSdP.partition_fragment_A(sK); Tensor tdPrdO = threadMmaSdP.partition_fragment_B(sdO); Tensor tdPrV = threadMmaSdP.partition_fragment_A(sV); Tensor caccS = make_identity_tensor(select<1, 0>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n) Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N) static_assert(decltype(size<0, 0>(taccScS))::value == 2); static_assert(decltype(size<0, 1>(taccScS))::value == 2); // taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices. Tensor taccScS_row = taccScS(make_coord(_, _0{}, _), _0{}, _); Tensor lse = make_tensor(Shape>{}); Tensor dP_sum = make_fragment_like(lse); #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<1>(taccScS_row(mi)); lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0; } // if (cute::thread0()) { print_tensor(taccScS_row); printf("\n"); } // cute::fill(lse, 1); // cute::fill(dP_sum, 1); // if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); } // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply // with V (which would be zero), we're fine. However, with ALiBi, we might modify these // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. clear(tdQrdQ); shared_storage.barrier_Q.wait(0); shared_storage.barrier_dO.wait(0); // #pragma unroll 2 CUTLASS_PRAGMA_NO_UNROLL for (; n_block >= 0; --n_block) { Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{})); pipeline_k.consumer_wait(smem_pipe_read_k); flash::gemm(tiledMmaSdP, tSrK(_, _, _, smem_pipe_read_k.index()), tSrQ, tSrS); Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{})); pipeline_v.consumer_wait(smem_pipe_read_v); flash::gemm(tiledMmaSdP, tdPrV(_, _, _, smem_pipe_read_v.index()), tdPrdO, tdPrdP); ++smem_pipe_read_v; warpgroup_wait<1>(); // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout())); // if (cute::thread0()) { print_tensor(lse); printf("\n"); } flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); // if (cute::thread0()) { print_tensor(scores); printf("\n"); } // Convert scores from fp32 to fp16/bf16 Tensor rP = flash::convert_type(tSrS); static_assert(!dKV_swapAB); Tensor tdVrdV = partition_fragment_C(tiledMmadKV, select<1, 2>(TileShape_MNK{})); Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt); flash::gemm(tiledMmadKV, tdVrP, tdVrdO, tdVrdV); // warpgroup_wait<0>(); // Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout())); // if (cute::thread0()) { print_tensor(dV_tmp); printf("\n"); } warpgroup_wait<1>(); // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); #pragma unroll for (int mi = 0; mi < size<0>(dS); ++mi) { #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); } } // if (cute::thread0()) { print_tensor(dS); printf("\n"); } Tensor rdS = flash::convert_type(tdPrdP); Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); Tensor tdKrdK = partition_fragment_C(tiledMmadKV, select<1, 2>(TileShape_MNK{})); Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt); flash::gemm(tiledMmadKV, tdKrdS, tdKrQ, tdKrdK); // warpgroup_wait<0>(); // Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout())); // if (cute::thread0()) { print_tensor(dK_tmp); printf("\n"); } warpgroup_wait<1>(); // if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); } Tensor tdVrdV_atomic = recast(tdVrdV); Tensor tdVgdVaccum_atomic = recast(tdVgdVaccum(_, _, _, n_block)); #pragma unroll for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdVaccum_atomic(i), tdVrdV_atomic(i)); } // for (int i = 0; i < size(tdVrdV_atomic); ++i) { tdVgdVaccum_atomic(i) = tdVrdV_atomic(i); } // SMEM fence to make sure sP is written before it's read by WGMMA cutlass::arch::fence_view_async_shared(); // cutlass::arch::NamedBarrier::sync(kNThreads, 0 /*id*/); __syncthreads(); static_assert(!Mma_dQ_is_RS); if constexpr (!dQ_swapAB) { Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS); Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); flash::gemm(tiledMmadQ, tdQrdS, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdQ); } else { Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS); Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt); flash::gemm(tiledMmadQ, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdS, tdQrdQ); } ++smem_pipe_read_k; // warpgroup_wait<0>(); // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); // if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); } warpgroup_wait<1>(); // if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); } Tensor tdKrdK_atomic = recast(tdKrdK); Tensor tdKgdKaccum_atomic = recast(tdKgdKaccum(_, _, _, n_block)); #pragma unroll for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdKaccum_atomic(i), tdKrdK_atomic(i)); } // for (int i = 0; i < size(tdVrdV_atomic); ++i) { tdVgdVaccum_atomic(i) = tdVrdV_atomic(i); } warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_release_v); // release V ++smem_pipe_release_v; pipeline_k.consumer_release(smem_pipe_release_k); // release V ++smem_pipe_release_k; int const lane_predicate = cute::elect_one_sync(); int const warp_idx = cutlass::canonical_warp_idx_sync(); if (warp_idx == 0 && lane_predicate && n_block >= kStages) { pipeline_k.producer_acquire(smem_pipe_write_k); copy(tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, _, _, n_block - kStages), tKsK(_, _, _, smem_pipe_write_k.index())); ++smem_pipe_write_k; pipeline_v.producer_acquire(smem_pipe_write_v); copy(tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, _, _, n_block - kStages), tVsV(_, _, _, smem_pipe_write_v.index())); ++smem_pipe_write_v; } } } // Epilogue #pragma unroll for (int i = 0; i < size(tdQrdQ); ++i) { tdQrdQ(i) *= params.scale_softmax; } // if (cute::thread0()) { print_tensor(tdQrdQ); } Tensor tdQrdQ_out = convert_type(tdQrdQ); Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), typename Ktraits::SmemLayoutdQ{}); Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), typename Ktraits::SmemLayoutdQt{}); auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Ktraits::SmemCopyAtomdQ{}, tiledMmadQ); auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(threadIdx.x); Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(tdQrdQ_out); // ((Atom,AtomNum), MMA_M, MMA_N) __syncthreads(); if constexpr (!dQ_swapAB) { Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); } else { Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt); } cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA Tensor mdQ = tma_store_dQ.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); Tensor gdQ = local_tile(mdQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) auto block_tma_dQ = tma_store_dQ.get_slice(_0{}); Tensor tdQgdQ = block_tma_dQ.partition_D(gdQ); // (TMA, TMA_M, TMA_K) Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K) __syncthreads(); // ensure all threads have issued their async fence // if (cute::thread0()) { print_tensor(sdQ); } lane_predicate = cute::elect_one_sync(); warp_idx = cutlass::canonical_warp_idx_sync(); if (warp_idx == 0 && lane_predicate) { cute::copy(tma_store_dQ, tdQsdQ, tdQgdQ); tma_store_arrive(); } tma_store_wait<0>(); // To make sure remote SMEM doesn't get destroyed if constexpr (size(ClusterShape{}) > 1) { cute::cluster_arrive(); cute::cluster_wait(); } } template __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) compute_dqkv_ws(CUTE_GRID_CONSTANT Flash_bwd_params const params, CUTE_GRID_CONSTANT TiledCopyQ const tma_load_Q, CUTE_GRID_CONSTANT TiledCopydO const tma_load_dO, CUTE_GRID_CONSTANT TiledCopyK const tma_load_K, CUTE_GRID_CONSTANT TiledCopyV const tma_load_V, CUTE_GRID_CONSTANT TiledCopydK const tma_store_dK, CUTE_GRID_CONSTANT TiledCopydV const tma_store_dV, CUTE_GRID_CONSTANT TiledCopydQ const tma_store_dQ, CUTE_GRID_CONSTANT TiledCopyAdddQ const tma_reduce_add_dQ) { using Element = typename Ktraits::Element; using ElementAccum = typename Ktraits::ElementAccum; using SoftType = ElementAccum; using index_t = typename Ktraits::index_t; using TileShape_MNK = typename Ktraits::TileShape_MNK; using ClusterShape = typename Ktraits::ClusterShape_MNK; static_assert(Ktraits::Is_WS); // static constexpr int kNThreads = Ktraits::kNThreads; // static constexpr int NumMmaThreads = size(typename Ktraits::TiledMmaSdP{}); static constexpr int NumMmaThreads = 256; static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; static constexpr int kBlockM = Ktraits::kBlockM; static constexpr int kNThreadsdQ = Ktraits::kNThreadsdQ; // static constexpr int kBlockN = Ktraits::kBlockN; // constexpr int kHeadDim = Ktraits::kHeadDim; // static constexpr int kStages = Ktraits::kStages; static constexpr bool SdP_swapAB = Ktraits::SdP_swapAB; static constexpr bool dKV_swapAB = Ktraits::dKV_swapAB; static constexpr bool dQ_swapAB = Ktraits::dQ_swapAB; if constexpr (SdP_swapAB) { static_assert(!dKV_swapAB); } static constexpr bool Mma_dQ_is_RS = Ktraits::Mma_dQ_is_RS; if constexpr (dQ_swapAB) { static_assert(!Mma_dQ_is_RS); } using MainloopPipeline = typename Ktraits::MainloopPipeline; using PipelineParams = typename MainloopPipeline::Params; using PipelineState = typename MainloopPipeline::PipelineState; extern __shared__ char shared_memory[]; auto &shared_storage = *reinterpret_cast(shared_memory); int lane_predicate = cute::elect_one_sync(); int warp_idx = cutlass::canonical_warp_idx_sync(); // Issue Tma Descriptor Prefetch from a single thread if (warp_idx == 0 && lane_predicate) { cute::prefetch_tma_descriptor(tma_load_Q.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_load_dO.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_load_K.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_load_V.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_store_dK.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_store_dV.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_store_dQ.get_tma_descriptor()); cute::prefetch_tma_descriptor(tma_reduce_add_dQ.get_tma_descriptor()); } // Construct SMEM tensors. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQ{}); Tensor sdO = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdO{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Ktraits::SmemLayoutV{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutP{}); Tensor sdS = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdS{}); Tensor sQt = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQt{}); Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdOt{}); Tensor sKt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutKt{}); Tensor sPt = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutPt{}); Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdSt{}); Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), typename Ktraits::SmemLayoutdQacc{}); Tensor sdQ2 = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), typename Ktraits::SmemLayoutdQacc2{}); Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), typename Ktraits::SmemLayoutdQacct{}); // Set the bytes transferred in this TMA transaction (may involve multiple issues) constexpr uint32_t TmaTransactionBytesQ = static_cast(size<0>(sQ) * size<1>(sQ) * cutlass::sizeof_bits_v / 8); constexpr uint32_t TmaTransactionBytesdO = static_cast(size<0>(sdO) * size<1>(sdO) * cutlass::sizeof_bits_v / 8); static_assert(TmaTransactionBytesQ == TmaTransactionBytesdO); constexpr uint32_t TmaTransactionBytesK = static_cast(size<0>(sK) * size<1>(sK) * cutlass::sizeof_bits_v / 8); constexpr uint32_t TmaTransactionBytesV = static_cast(size<0>(sV) * size<1>(sV) * cutlass::sizeof_bits_v / 8); static_assert(TmaTransactionBytesK == TmaTransactionBytesV); // Obtain warp index int thread_idx = int(threadIdx.x); int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup; // int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; PipelineParams pipeline_params; pipeline_params.transaction_bytes = TmaTransactionBytesQ; int warp_group_idx = cutlass::canonical_warp_group_idx(); if (warp_group_idx == 0) { pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; } else { pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; } pipeline_params.is_leader = warp_group_thread_idx == 0; pipeline_params.num_consumers = NumMmaThreads; if (warp_idx == 0 && lane_predicate) { shared_storage.barrier_K.init(1 /*numThreads*/); shared_storage.barrier_V.init(1 /*numThreads*/); } // cutlass::arch::fence_barrier_init(); // We're counting on pipeline_q to call fence_barrier_init(); MainloopPipeline pipeline_q(shared_storage.pipeline_q, pipeline_params, ClusterShape{}); MainloopPipeline pipeline_do(shared_storage.pipeline_do, pipeline_params, ClusterShape{}); // We need this to guarantee that the Pipeline init is visible // To all producers and consumer blocks in the Cluster if constexpr (size(ClusterShape{}) > 1) { cute::cluster_arrive_relaxed(); cute::cluster_wait(); } else { __syncthreads(); } if (warp_group_idx == 0) { // Producer // method in cutlass/arch/reg_reconfig.h // calls setmaxnreg.dec.sync.aligned.u32 cutlass::arch::warpgroup_reg_dealloc<24>(); int const n_block = blockIdx.x; int const bidb = blockIdx.z; // The block index for the batch. int const bidh = blockIdx.y; // The block index for the head. int m_block = cute::ceil_div(params.seqlen_q, kBlockM) - 1; int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); int lane_predicate = cute::elect_one_sync(); // if (warp_idx_in_warpgroup == 0 && lane_predicate) { if (warp_idx_in_warpgroup == 0) { // Load K, and do TMA on Q and dO Tensor mQ = tma_load_Q.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); Tensor mdO = tma_load_dO.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); Tensor mK = tma_load_K.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) Tensor gdO = local_tile(mdO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) // Prepare the TMA loads uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; auto block_tma_Q = tma_load_Q.get_slice(cluster_local_block_id.y); auto block_tma_dO = tma_load_dO.get_slice(cluster_local_block_id.y); auto block_tma_K = tma_load_K.get_slice(_0{}); Tensor tQgQ = block_tma_Q.partition_S(gQ); // (TMA, TMA_M, TMA_K, k) Tensor tQsQ = block_tma_Q.partition_D(sQ); // (TMA, TMA_M, TMA_K, PIPE) Tensor tdOgdO = block_tma_dO.partition_S(gdO); // (TMA, TMA_M, TMA_K, k) Tensor tdOsdO = block_tma_dO.partition_D(sdO); // (TMA, TMA_M, TMA_K, PIPE) Tensor tKgK = block_tma_K.partition_S(gK); // (TMA, TMA_N, TMA_K) Tensor tKsK = block_tma_K.partition_D(sK); // (TMA, TMA_N, TMA_K) PipelineState smem_pipe_write_q = cutlass::make_producer_start_state(); PipelineState smem_pipe_write_do = cutlass::make_producer_start_state(); uint16_t mcast_mask_qdo = 0; if constexpr (cute::is_same_v) { auto block_layout = Layout{}; // (m,n) -> block_id for (int n = 0; n < size<1>(block_layout); ++n) { mcast_mask_qdo |= (uint16_t(1) << block_layout(n, cluster_local_block_id.x, _0{})); } } if (lane_predicate) { // Copy K tile and V tile from GMEM to SMEM. shared_storage.barrier_K.arrive_and_expect_tx(TmaTransactionBytesK); copy(tma_load_K.with(reinterpret_cast(shared_storage.barrier_K), 0 /*mcast_mask*/), tKgK, tKsK); #pragma unroll 2 for (; m_block >= 0; --m_block) { pipeline_q.producer_acquire(smem_pipe_write_q); copy(tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo), tQgQ(_, _, _, m_block), tQsQ(_, _, _, smem_pipe_write_q.index())); ++smem_pipe_write_q; pipeline_do.producer_acquire(smem_pipe_write_do); copy(tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do), mcast_mask_qdo), tdOgdO(_, _, _, m_block), tdOsdO(_, _, _, smem_pipe_write_do.index())); ++smem_pipe_write_do; } // Tail loop /* This helps avoid early exit of blocks in Cluster * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used * then would just be acquired since the phase was * still inverted from make_producer_start_state */ pipeline_q.producer_tail(smem_pipe_write_q); pipeline_do.producer_tail(smem_pipe_write_do); } } else if (warp_idx_in_warpgroup == 1) { // Load V, and do TMA_REDUCE_ADD on dQ Tensor mV = tma_load_V.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) auto block_tma_V = tma_load_V.get_slice(_0{}); Tensor tVgV = block_tma_V.partition_S(gV); // (TMA, TMA_N, TMA_K) Tensor tVsV = block_tma_V.partition_D(sV); // (TMA, TMA_N, TMA_K) if (lane_predicate) { shared_storage.barrier_V.arrive_and_expect_tx(TmaTransactionBytesV); copy(tma_load_V.with(reinterpret_cast(shared_storage.barrier_V), 0 /*mcast_mask*/), tVgV, tVsV); } Tensor mdQaccum = tma_store_dQ.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); Tensor gdQaccum = local_tile(mdQaccum(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) auto block_tma_dQ = tma_store_dQ.get_slice(_0{}); Tensor tdQgdQ = block_tma_dQ.partition_D(gdQaccum); // (TMA, TMA_M, TMA_K) Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K) Tensor tdQsdQ2 = block_tma_dQ.partition_S(sdQ2); // (TMA, TMA_M, TMA_K, 2) int *lock_ptr = params.dq_semaphore + bidb * params.h + bidh; using Barrier = cutlass::GenericBarrier; // cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 1 /*id*/); // sdQ empty, ready to be written to cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 /*id*/); // sdQ empty, ready to be written to // cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 + (m_block + 1) % 2 /*id*/); // sdQ empty, ready to be written to // cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 + m_block % 2 /*id*/); // sdQ empty, ready to be written to // if (n_block == 0) { // Use TMA_STORE if (false) { // Use TMA_STORE #pragma unroll 2 for (; m_block >= 0; --m_block) { cutlass::arch::NamedBarrier::sync(kNThreadsdQ + 32, 2 /*id*/); // sdQ full, to be written to gmem // cutlass::arch::NamedBarrier::sync(kNThreadsdQ + 32, 2 + m_block % 2 /*id*/); // sdQ full, to be written to gmem if (lane_predicate) { cute::copy(tma_store_dQ, tdQsdQ, tdQgdQ(_, _, _, m_block)); // cute::copy(tma_store_dQ, tdQsdQ2(_, _, _, m_block % 2), tdQgdQ(_, _, _, m_block)); tma_store_arrive(); } tma_store_wait<0>(); Barrier::arrive_inc(lock_ptr, threadIdx.x % 32, m_block * params.b * params.h); cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 /*id*/); // sdQ empty, ready to be written to // cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 + m_block % 2 /*id*/); // sdQ empty, ready to be written to } } else { // Use TMA_REDUCE_ADD #pragma unroll 2 for (; m_block >= 0; --m_block) { // Barrier::wait_eq(lock_ptr, threadIdx.x % 32, m_block * params.b * params.h, n_block); // Barrier::wait_lt(lock_ptr, threadIdx.x % 32, m_block * params.b * params.h, 1); cutlass::arch::NamedBarrier::sync(kNThreadsdQ + 32, 2 /*id*/); // sdQ full, to be written to gmem // cutlass::arch::NamedBarrier::sync(kNThreadsdQ + 32, 2 + m_block % 2 /*id*/); // sdQ full, to be written to gmem if (lane_predicate) { cute::copy(tma_reduce_add_dQ, tdQsdQ, tdQgdQ(_, _, _, m_block)); // cute::copy(tma_reduce_add_dQ, tdQsdQ2(_, _, _, m_block % 2), tdQgdQ(_, _, _, m_block)); tma_store_arrive(); } tma_store_wait<0>(); // Barrier::arrive_inc(lock_ptr, threadIdx.x % 32, m_block * params.b * params.h); // cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 + m_block % 2 /*id*/); // sdQ empty, ready to be written to cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 /*id*/); // sdQ empty, ready to be written to } } // } else if (warp_idx_in_warpgroup == 2) { // Load LSE and dPSum // Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), // make_shape(params.b, params.h, params.seqlen_q), // make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); // Tensor mdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum)), // make_shape(params.b, params.h, params.seqlen_q), // make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); // Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(_)); // (M, _) // Tensor gdPsum = local_tile(mdPsum(bidb, bidh, _), Shape>{}, make_coord(_)); // (M, _) // Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape>{}); // Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.smem_dpsum.data()), Shape>{}); // #pragma unroll 2 // for (; m_block >= 0; --m_block) { // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 3 /*id*/); // sLSE and sdPsum are empty // #pragma unroll // for (int i = 0; i < cute::ceil_div(kBlockM, 32); ++i) { // int idx = threadIdx.x % 32 + i * 32; // sLSE(idx) = idx < params.seqlen_q - m_block * kBlockM ? gLSE(idx, m_block) : INFINITY; // sdPsum(idx) = idx < params.seqlen_q - m_block * kBlockM ? gdPsum(idx, m_block) : 0; // } // // sLSE and sdPsum are ready for WG 1 // cutlass::arch::NamedBarrier::arrive(128 + 32, 3 + 1 /*id*/); // // sLSE and sdPsum are ready for WG 2 // cutlass::arch::NamedBarrier::arrive(128 + 32, 3 + 2 /*id*/); // } } } else { // Consumers // method in cutlass/arch/reg_reconfig.h // calls setmaxnreg.inc.sync.aligned.u32 cutlass::arch::warpgroup_reg_alloc<240>(); // State variables used for iterating the circular buffer // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA // smem_pipe_write is used by the producer of SMEM data - i.e TMA PipelineState smem_pipe_read_q, smem_pipe_read_do; PipelineState smem_pipe_release_q, smem_pipe_release_do; int m_block = cute::ceil_div(params.seqlen_q, kBlockM) - 1; const int m_block_max = m_block; int bidb = blockIdx.z; // The block index for the batch. int bidh = blockIdx.y; // The block index for the head. Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), make_shape(params.b, params.h, params.seqlen_q), make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); Tensor mdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum)), make_shape(params.b, params.h, params.seqlen_q), make_stride(params.h * params.seqlen_q_rounded, params.seqlen_q_rounded, _1{})); Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); Tensor gdPsum = local_tile(mdPsum(bidb, bidh, _), Shape>{}, make_coord(m_block)); Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape>{}); Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.smem_dpsum.data()), Shape>{}); typename Ktraits::RmemTiledCopydQacc rmem_tiled_copy_dQaccum; // auto rmem_thr_copy_dQaccum = rmem_tiled_copy_dQaccum.get_thread_slice((threadIdx.x - NumCopyThreads) % kNThreadsdQ); auto rmem_thr_copy_dQaccum = rmem_tiled_copy_dQaccum.get_thread_slice(threadIdx.x - NumCopyThreads); Tensor tdQsdQaccum = rmem_thr_copy_dQaccum.partition_D(sdQ); Tensor tdQsdQaccum2 = rmem_thr_copy_dQaccum.partition_D(sdQ2); // Initialize matmul objects. typename Ktraits::TiledMmaSdP tiledMmaSdP; auto threadMmaSdP = tiledMmaSdP.get_thread_slice(threadIdx.x - NumCopyThreads); typename Ktraits::TiledMmadKV tiledMmadKV; auto threadMmadKV = tiledMmadKV.get_thread_slice(threadIdx.x - NumCopyThreads); typename Ktraits::TiledMmadQ tiledMmadQ; // auto threadMmadQ = tiledMmadQ.get_thread_slice((threadIdx.x - NumCopyThreads) % kNThreadsdQ); auto threadMmadQ = tiledMmadQ.get_thread_slice(threadIdx.x - NumCopyThreads); // Allocate accumulator Tensor tdKrdK = partition_fragment_C(tiledMmadKV, select(TileShape_MNK{})); Tensor tdVrdV = partition_fragment_C(tiledMmadKV, select(TileShape_MNK{})); auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Ktraits::SmemCopyAtomPdS{}, tiledMmaSdP); auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(threadIdx.x - NumCopyThreads); // auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Ktraits::SmemCopyAtomdQ{}, tiledMmadQ); // auto smem_tiled_copy_dQ = make_tiled_copy_C(Copy_Atom{}, tiledMmadQ); // auto smem_tiled_copy_dQ = make_tiled_copy_C(Copy_Atom{}, tiledMmadQ); // auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(threadIdx.x - NumCopyThreads); Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdSt); // ((Atom,AtomNum),PIPE_M,PIPE_N) if constexpr (SdP_swapAB) { // Allocate "fragments/descriptors" Tensor tSrQ = threadMmaSdP.partition_fragment_B(sQ); Tensor tSrK = threadMmaSdP.partition_fragment_A(sK); Tensor tdPrdO = threadMmaSdP.partition_fragment_B(sdO); Tensor tdPrV = threadMmaSdP.partition_fragment_A(sV); Tensor caccS = make_identity_tensor(select<1, 0>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n) Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N) static_assert(decltype(size<0, 0>(taccScS))::value == 2); static_assert(decltype(size<0, 1>(taccScS))::value == 2); // taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices. Tensor taccScS_row = taccScS(make_coord(_, _0{}, _), _0{}, _); static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(taccScS_row))::value, 8); static constexpr bool kStatsDivisibleBy8 = decltype(size(taccScS_row))::value % 8 == 0; Tensor lse = make_tensor(Shape>{}); // Tensor lse = make_tensor(Shape>{}); Tensor dP_sum = make_fragment_like(lse); #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<1>(taccScS_row(mi)); lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0; } // #pragma unroll // for (int mi = 0; mi < size(lse); ++mi) { // const int row_idx = mi * 8 + (threadIdx.x % 32) / 4; // const int row = kStatsDivisibleBy8 || row_idx < size(taccScS_row) ? get<1>(taccScS_row(row_idx)) : 0; // lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; // dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0; // } // if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dP_sum); printf("\n"); } // Trying to spread LSE and dPSum across threads in a warp but it's slower // const int row_idx = mi * 8 + (threadIdx.x % 32) / 4; // const int row = get<1>(taccScS_row(row_idx)); // TODO: what if row_idx is outside the range? // cute::fill(lse, 1); // cute::fill(dP_sum, 1); // if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); } // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply // with V (which would be zero), we're fine. However, with ALiBi, we might modify these // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 3 /*id*/); // sLSE and sdPsum are empty clear(tdKrdK); clear(tdVrdV); shared_storage.barrier_K.wait(0); shared_storage.barrier_V.wait(0); // #pragma unroll 2 CUTLASS_PRAGMA_NO_UNROLL for (; m_block >= 0; --m_block) { // Putting this dQ block at the beginning of the loop gives an extra 10 TFLOPs // It does make the code uglier, idk if it's worth it. if (m_block < m_block_max) { // SMEM fence to make sure sP is written before it's read by WGMMA cutlass::arch::fence_view_async_shared(); // dS is already written to smem, and the smem for dQ is empty (from warp 1 doing TMA_REDUCE_ADD) // int warp_group_idx = cutlass::canonical_warp_group_idx(); // if (warp_group_idx == 1 + (m_block + 1) % 2) { // // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/); // cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 4); // } else { // // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/); // cutlass::arch::NamedBarrier::sync(NumMmaThreads, 4); // static_assert(!Mma_dQ_is_RS); // Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select(TileShape_MNK{})); // if constexpr (!dQ_swapAB) { // Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS); // Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); // flash::gemm(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ); // } else { // Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS); // Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt); // flash::gemm(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ); // } // Tensor taccdQrdQ = rmem_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N) // cutlass::arch::NamedBarrier::sync(NumMmaThreads / 2 + 32, 0 + (m_block + 1) % 2 /*id*/); // cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum2(_, _, _, (m_block + 1) % 2)); // cutlass::arch::fence_view_async_shared(); // cutlass::arch::NamedBarrier::arrive(NumMmaThreads / 2 + 32, 2 + (m_block + 1) % 2 /*id*/); // sdQ ready to be written to gmem // } // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/); cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 /*id*/); static_assert(!Mma_dQ_is_RS); Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select(TileShape_MNK{})); // Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N) if constexpr (!dQ_swapAB) { Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS); Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); flash::gemm(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ); } else { Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS); Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt); flash::gemm(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ); } // Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N) // cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt); // Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) // cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); Tensor taccdQrdQ = rmem_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N) // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 1 /*id*/); // sdQ empty, ready to be written to cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum); // cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum2(_, _, _, (m_block + 1) % 2)); cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 2 /*id*/); // sdQ ready to be written to gmem // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 2 + (m_block + 1) % 2 /*id*/); // sdQ ready to be written to gmem } Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{})); pipeline_q.consumer_wait(smem_pipe_read_q); flash::gemm(tiledMmaSdP, tSrK, tSrQ(_, _, _, smem_pipe_read_q.index()), tSrS); Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{})); pipeline_do.consumer_wait(smem_pipe_read_do); flash::gemm(tiledMmaSdP, tdPrV, tdPrdO(_, _, _, smem_pipe_read_do.index()), tdPrdP); // sLSE and sdPsum are done loading for WG 1 or 2 // cutlass::arch::NamedBarrier::sync(128 + 32, 3 + cutlass::canonical_warp_group_idx() /*id*/); // Tensor lse = make_tensor(Shape>{}); // #pragma unroll // for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = sLSE(get<1>(taccScS_row(mi))); } warpgroup_wait<1>(); // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout())); flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); // #pragma unroll // for (int mi = 0; mi < size<0>(lse); ++mi) { lse(mi) *= float(M_LOG2E); } // #pragma unroll // for (int mi = 0; mi < size<0>(scores); ++mi) { // // const float lse_scaled = lse(mi) * float(M_LOG2E); // const float lse_scaled = __shfl_sync(0xffffffff, lse(mi / 8), (mi % 8) * 4 + (threadIdx.x % 4)); // // const float lse_scaled = __shfl_xor_sync(0xffffffff, lse(mi / 8), 1 << (mi % 4)) * float(M_LOG2E); // // const float lse_scaled = lse(mi); // #pragma unroll // for (int ni = 0; ni < size<1>(scores); ++ni) { // scores(mi, ni) = exp2f(scores(mi, ni) * params.scale_softmax_log2 - lse_scaled); // } // } // if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(scores); printf("\n"); } // Tensor dP_sum = make_fragment_like(lse); // sLSE and sdPsum are done loading for WG 1 or 2 // cutlass::arch::NamedBarrier::sync(128 + 32, 3 + cutlass::canonical_warp_group_idx() /*id*/); // #pragma unroll // for (int mi = 0; mi < size(dP_sum); ++mi) { dP_sum(mi) = sdPsum(get<1>(taccScS_row(mi))); } // Convert scores from fp32 to fp16/bf16 Tensor rP = flash::convert_type(tSrS); warpgroup_wait<0>(); // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); // if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dS); printf("\n"); } // if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dP_sum); printf("\n"); } #pragma unroll for (int mi = 0; mi < size<0>(dS); ++mi) { // const float dP_sum_cur = __shfl_sync(0xffffffff, dP_sum(mi / 8), (mi % 8) * 4 + (threadIdx.x % 4)); // const float dP_sum_cur = __shfl_xor_sync(0xffffffff, dP_sum(mi / 8), 1 << (mi % 4)); #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); } // for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur); } } // sLSE and sdPsum are done processing, can load for the next iteration // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 3 /*id*/); // if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dS); printf("\n"); } Tensor rdS = flash::convert_type(tdPrdP); Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); if (m_block > 0) { gLSE.data() = gLSE.data() + (-int(kBlockM)); gdPsum.data() = gdPsum.data() + (-int(kBlockM)); } // #pragma unroll // for (int mi = 0; mi < size(lse); ++mi) { // // const int row = get<1>(taccScS_row(mi)); // const int row_idx = mi * 8 + (threadIdx.x % 32) / 4; // const int row = kStatsDivisibleBy8 || row_idx < size(taccScS_row) ? get<1>(taccScS_row(row_idx)) : 0; // lse(mi) = gLSE(row); // dP_sum(mi) = gdPsum(row); // } Tensor lse_float2 = recast(lse); Tensor dP_sum_float2 = recast(dP_sum); #pragma unroll for (int mi = 0; mi < size(lse) / 2; ++mi) { const int row = get<1>(taccScS_row(mi * 2)); lse_float2(mi) = *reinterpret_cast(&(gLSE(row))); dP_sum_float2(mi) = *reinterpret_cast(&(gdPsum(row))); } static_assert(!dKV_swapAB); Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt); flash::gemm(tiledMmadKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrdV); ++smem_pipe_read_do; // warpgroup_wait<0>(); // Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout())); // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dV_tmp); printf("\n"); } Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt); flash::gemm(tiledMmadKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdK); ++smem_pipe_read_q; // warpgroup_wait<0>(); // Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout())); // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dK_tmp); printf("\n"); } pipeline_do.consumer_release(smem_pipe_release_do); // release V ++smem_pipe_release_do; pipeline_q.consumer_release(smem_pipe_release_q); // release V ++smem_pipe_release_q; // warpgroup_wait<0>(); // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); // if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); } } { // SMEM fence to make sure sP is written before it's read by WGMMA cutlass::arch::fence_view_async_shared(); // dS is already written to smem, and the smem for dQ is empty (from warp 1 doing TMA_REDUCE_ADD) // int warp_group_idx = cutlass::canonical_warp_group_idx(); // if (warp_group_idx == 1 + (m_block + 1) % 2) { // // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/); // cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 4); // } else { // // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/); // cutlass::arch::NamedBarrier::sync(NumMmaThreads, 4); // static_assert(!Mma_dQ_is_RS); // Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select(TileShape_MNK{})); // if constexpr (!dQ_swapAB) { // Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS); // Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); // flash::gemm(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ); // } else { // Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS); // Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt); // flash::gemm(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ); // } // Tensor taccdQrdQ = rmem_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N) // cutlass::arch::NamedBarrier::sync(NumMmaThreads / 2 + 32, 0 + (m_block + 1) % 2 /*id*/); // cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum2(_, _, _, (m_block + 1) % 2)); // cutlass::arch::fence_view_async_shared(); // cutlass::arch::NamedBarrier::arrive(NumMmaThreads / 2 + 32, 2 + (m_block + 1) % 2 /*id*/); // sdQ ready to be written to gmem // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); // // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dQ_tmp); printf("\n"); } // // if (blockIdx.x == 0 && threadIdx.x == 128) { print(taccdQrdQ); printf("\n"); print(tdQsdQaccum2); printf("\n"); } // } cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 /*id*/); // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 + 0 % 2 /*id*/); // cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0 /*id*/); static_assert(!Mma_dQ_is_RS); Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select(TileShape_MNK{})); if constexpr (!dQ_swapAB) { Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS); Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); flash::gemm(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ); } else { Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS); Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt); flash::gemm(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ); } // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dQ_tmp); printf("\n"); } Tensor taccdQrdQ = rmem_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N) // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 1 /*id*/); // sdQ empty, ready to be written to cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum); // cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum2(_, _, _, 0 % 2)); cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 2 /*id*/); // sdQ ready to be written to gmem // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 2 + 0 % 2 /*id*/); // sdQ ready to be written to gmem // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(sdQ); printf("\n"); } } } else { // !SdP_swapAB Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N) // Allocate "fragments/descriptors" Tensor tSrQ = threadMmaSdP.partition_fragment_A(sQ); Tensor tSrK = threadMmaSdP.partition_fragment_B(sK); Tensor tdPrdO = threadMmaSdP.partition_fragment_A(sdO); Tensor tdPrV = threadMmaSdP.partition_fragment_B(sV); Tensor caccS = make_identity_tensor(select<0, 1>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n) Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N) static_assert(decltype(size<0, 0>(taccScS))::value == 2); static_assert(decltype(size<0, 1>(taccScS))::value == 2); // taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices. Tensor taccScS_row = taccScS(make_coord(_0{}, _, _0{}), _, _0{}); Tensor lse = make_tensor(Shape>{}); Tensor dP_sum = make_fragment_like(lse); #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccScS_row(mi)); lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0; } // if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); } // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply // with V (which would be zero), we're fine. However, with ALiBi, we might modify these // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. clear(tdKrdK); clear(tdVrdV); shared_storage.barrier_K.wait(0); shared_storage.barrier_V.wait(0); CUTLASS_PRAGMA_NO_UNROLL for (; m_block >= 0; --m_block) { Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{})); pipeline_q.consumer_wait(smem_pipe_read_q); flash::gemm(tiledMmaSdP, tSrQ(_, _, _, smem_pipe_read_q.index()), tSrK, tSrS); Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{})); pipeline_do.consumer_wait(smem_pipe_read_do); // if (blockIdx.x == 0 && blockIdx.z == 0 && threadIdx.x == 128) { printf("After dO wait\n"); } flash::gemm(tiledMmaSdP, tdPrdO(_, _, _, smem_pipe_read_do.index()), tdPrV, tdPrdP); warpgroup_wait<1>(); // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); // if (blockIdx.x == 0 && blockIdx.z == 0 && threadIdx.x == 128) { print_tensor(scores); printf("\n"); } // Convert scores from fp32 to fp16/bf16 Tensor rP = flash::convert_type(tSrS); Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N) cutlass::arch::NamedBarrier::sync(NumMmaThreads, 8 /*id*/); cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); int const warp_group_idx = cutlass::canonical_warp_group_idx() - 1; cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 4 + warp_group_idx /*id*/); // if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("After barrier arrive 4, tidx = %d\n", threadIdx.x); } warpgroup_wait<0>(); // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); // if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dS); printf("\n"); } // if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dP_sum); printf("\n"); } #pragma unroll for (int mi = 0; mi < size<0>(dS); ++mi) { #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); } } // if (blockIdx.x == 0 && blockIdx.z == 0 && threadIdx.x == 128) { print_tensor(dS); printf("\n"); } Tensor rdS = flash::convert_type(tdPrdP); Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 6 + warp_group_idx /*id*/); // if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("After barrier arrive 6, tidx = %d\n", threadIdx.x); } if (m_block > 0) { gLSE.data() = gLSE.data() + (-int(kBlockM)); gdPsum.data() = gdPsum.data() + (-int(kBlockM)); } #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<1>(taccScS_row(mi)); lse(mi) = gLSE(row); dP_sum(mi) = gdPsum(row); } Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select(TileShape_MNK{})); if constexpr (Mma_dQ_is_RS) { static_assert(!dQ_swapAB); Tensor tdQrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); flash::gemm(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ); // if (cute::thread0()) { print(tdQrdS); printf("\n"); print(tdQrK); printf("\n"); print(tdQrdQ); printf("\n"); } } cutlass::arch::fence_view_async_shared(); // if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("Before barrier sync 4, tidx = %d\n", threadIdx.x); } cutlass::arch::NamedBarrier::sync(NumMmaThreads, 4 + 1 - warp_group_idx /*id*/); // if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("After barrier sync 4, tidx = %d\n", threadIdx.x); } // if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128)) { print_tensor(sPt); printf("\n"); } // if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128)) { print_tensor(sdOt); printf("\n"); } if constexpr (!dKV_swapAB) { Tensor tdVrP = threadMmadKV.partition_fragment_A(sPt); Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt); flash::gemm(tiledMmadKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrdV); } else { Tensor tdVrP = threadMmadKV.partition_fragment_B(sPt); Tensor tdVrdO = threadMmadKV.partition_fragment_A(sdOt); flash::gemm(tiledMmadKV, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrP, tdVrdV); } ++smem_pipe_read_do; // warpgroup_wait<0>(); // Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout())); // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dV_tmp); printf("\n"); } cutlass::arch::fence_view_async_shared(); // if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("Before barrier sync 6, tidx = %d\n", threadIdx.x); } cutlass::arch::NamedBarrier::sync(NumMmaThreads, 6 + 1 - warp_group_idx /*id*/); // if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("After barrier sync 6, tidx = %d\n", threadIdx.x); } if constexpr (!dKV_swapAB) { Tensor tdKrdS = threadMmadKV.partition_fragment_A(sdSt); Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt); flash::gemm(tiledMmadKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdK); } else { Tensor tdKrdS = threadMmadKV.partition_fragment_B(sdSt); Tensor tdKrQ = threadMmadKV.partition_fragment_A(sQt); flash::gemm(tiledMmadKV, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdS, tdKrdK); } ++smem_pipe_read_q; warpgroup_wait<0>(); // Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout())); // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dK_tmp); printf("\n"); } pipeline_do.consumer_release(smem_pipe_release_do); // release V ++smem_pipe_release_do; pipeline_q.consumer_release(smem_pipe_release_q); // release V ++smem_pipe_release_q; // warpgroup_wait<0>(); // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); // if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); } cutlass::arch::NamedBarrier::sync(NumMmaThreads, 8 /*id*/); } } // Epilogue Tensor sdK = make_tensor(make_smem_ptr(shared_storage.smem_dk.data()), typename Ktraits::SmemLayoutdK{}); Tensor sdV = make_tensor(make_smem_ptr(shared_storage.smem_dv.data()), typename Ktraits::SmemLayoutdV{}); Tensor sdKt = make_tensor(make_smem_ptr(shared_storage.smem_dk.data()), typename Ktraits::SmemLayoutdKt{}); Tensor sdVt = make_tensor(make_smem_ptr(shared_storage.smem_dv.data()), typename Ktraits::SmemLayoutdVt{}); auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Ktraits::SmemCopyAtomdKV{}, tiledMmadKV); auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(threadIdx.x - NumCopyThreads); int n_block = blockIdx.x; bidb = blockIdx.z; // The block index for the batch. bidh = blockIdx.y; // The block index for the head. Tensor mdK = tma_store_dK.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); Tensor mdV = tma_store_dV.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) auto block_tma_dK = tma_store_dK.get_slice(_0{}); auto block_tma_dV = tma_store_dV.get_slice(_0{}); Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K) Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K) Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K) Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K) // Very slightly faster to do the smem write and TMA write for dV first, then do the same for dK, // Instead of doing both at the same time. Tensor tdVrdV_out = convert_type(tdVrdV); #pragma unroll for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.scale_softmax; } Tensor tdKrdK_out = convert_type(tdKrdK); Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N) // Can't use __syncthreads() in WS code auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(NumMmaThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; synchronize(); if constexpr (!dKV_swapAB) { Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); } else { Tensor taccdVsdVt = smem_thr_copy_dKV.partition_D(sdVt); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdVt); } cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA synchronize(); lane_predicate = cute::elect_one_sync(); warp_idx = cutlass::canonical_warp_idx_sync(); if (warp_idx == NumCopyThreads / cutlass::NumThreadsPerWarp && lane_predicate) { cute::copy(tma_store_dV, tdVsdV, tdVgdV); tma_store_arrive(); } if constexpr (!dKV_swapAB) { Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); } else { Tensor taccdKsdKt = smem_thr_copy_dKV.partition_D(sdKt); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdKt); } cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA synchronize(); if (warp_idx == NumCopyThreads / cutlass::NumThreadsPerWarp && lane_predicate) { cute::copy(tma_store_dK, tdKsdK, tdKgdK); tma_store_arrive(); } tma_store_wait<0>(); } } } // namespace flash