From 146d314057c5f193a70c2b36896e739c8c60aef4 Mon Sep 17 00:00:00 2001 From: dan_the_3rd <43445237+danthe3rd@users.noreply.github.com> Date: Thu, 13 Jul 2023 04:30:46 +0200 Subject: [PATCH] Update fMHA kernels (#992) * Update fMHA kernels Upstream recent changes to fMHA that we did in xFormers. Previous version in CUTLASS: facebookresearch/xformers@b6be33a Updating to: facebookresearch/xformers@55a4798 * minor changes * make var work --------- Co-authored-by: danthe3rd Co-authored-by: Haicheng Wu --- .../default_fmha_grouped.h | 37 +- .../fmha_backward_test.py | 1 + .../fmha_grouped.h | 194 ++- .../fused_multi_head_attention_backward.cu | 7 +- .../fused_multihead_attention_fixed_seqlen.cu | 13 +- ...sed_multihead_attention_variable_seqlen.cu | 17 +- .../gemm/custom_mma_multistage.h | 8 - .../gemm/custom_mma_pipelined.h | 3 +- .../gemm/mma_from_smem.h | 292 ++-- .../gemm_kernel_utils.h | 13 +- .../default_warp_iterator_from_smem.h | 143 ++ .../iterators/transpose_warp_iterator.h | 8 +- .../iterators/warp_iterator_from_smem.h | 80 +- .../kernel_backward.h | 1185 ++++++++++------- .../kernel_forward.h | 336 +++-- .../transform/tile_smem_loader.h | 2 + 16 files changed, 1419 insertions(+), 920 deletions(-) create mode 100644 examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h diff --git a/examples/41_fused_multi_head_attention/default_fmha_grouped.h b/examples/41_fused_multi_head_attention/default_fmha_grouped.h index b0acc943..481a321c 100644 --- a/examples/41_fused_multi_head_attention/default_fmha_grouped.h +++ b/examples/41_fused_multi_head_attention/default_fmha_grouped.h @@ -30,10 +30,10 @@ **************************************************************************************************/ /*! \file - \brief + \brief Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropriate threadblock-scoped epilogue. - + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are accommodated by exchanging A and B operands and assuming transposed layouts. Partial specializations here choose 'device::GemmTransposed' to implement this functionality. @@ -50,6 +50,7 @@ #include "fmha_grouped.h" #include "gemm_kernel_utils.h" +#include "gemm/custom_mma.h" #include "gemm/find_default_mma.h" #include "gemm/mma_from_smem.h" @@ -70,7 +71,7 @@ template < bool isAligned_, int kQueriesPerBlock, int kKeysPerBlock, - bool kSingleValueIteration, + int kMaxK = (int)cutlass::platform::numeric_limits::max(), GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly > struct DefaultFMHAGrouped { @@ -85,6 +86,8 @@ struct DefaultFMHAGrouped { using ArchTag = ArchTag_; static bool const kIsAligned = isAligned_; + static bool const kSingleValueIteration = kMaxK <= kKeysPerBlock; + static constexpr bool kIsHalf = cutlass::sizeof_bits::value == 16; static int const kWarpSize = 32; static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize); @@ -145,14 +148,20 @@ struct DefaultFMHAGrouped { ThreadblockShape, WarpShape, InstructionShape, - kStages, + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, Operator >::DefaultMma; using MmaCore = typename DefaultMma::MmaCore; using IteratorA = typename DefaultMma::IteratorA; using IteratorB = typename DefaultMma::IteratorB; - using Mma = typename DefaultMma::ThreadblockMma; + using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma; + using Mma = typename cutlass::platform::conditional< + kSingleValueIteration, + typename MakeCustomMma::Mma, + DefaultThreadblockMma>::type; using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< typename Mma::Operator::IteratorC, ElementAccumulator, @@ -232,14 +241,24 @@ struct DefaultFMHAGrouped { InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, - kStages, + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, kSplitKSerial, Operator>; + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape + typename DefaultGemm::Mma::Policy::Operator::InstructionShape, + typename DefaultGemm::Mma::Policy::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, - typename MM0::AccumulatorSharedStorage, + MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK + WarpIteratorA, false>; // kScaleOperandA using Mma = typename DefaultMmaFromSmem::Mma; @@ -256,10 +275,6 @@ struct DefaultFMHAGrouped { typename cutlass::epilogue::threadblock::PredicatedTileIterator< typename DefaultEpilogue::OutputTileIterator::ThreadMap, output_accum_t>; - - struct SharedStorageMM1 { - typename Mma::SharedStorage mm; - }; }; /// Define the kernel in terms of the default kernel diff --git a/examples/41_fused_multi_head_attention/fmha_backward_test.py b/examples/41_fused_multi_head_attention/fmha_backward_test.py index cafd028b..ee0b7934 100644 --- a/examples/41_fused_multi_head_attention/fmha_backward_test.py +++ b/examples/41_fused_multi_head_attention/fmha_backward_test.py @@ -142,6 +142,7 @@ with PipedSubprocess(fmha_bw_binary) as bw_kernel: "custom_mask_type", (1 if causal else 0), "num_batches", B, "repeat_count", repeat_count, + "num_splits_key", (Mkv // 128), ) bw_kernel.writeTensor(query, "query", ["q_strideB", "q_strideM", "q_strideH"]) bw_kernel.writeTensor(key, "key", ["k_strideB", "k_strideM", "k_strideH"]) diff --git a/examples/41_fused_multi_head_attention/fmha_grouped.h b/examples/41_fused_multi_head_attention/fmha_grouped.h index f71ca22b..6365ad66 100644 --- a/examples/41_fused_multi_head_attention/fmha_grouped.h +++ b/examples/41_fused_multi_head_attention/fmha_grouped.h @@ -147,6 +147,9 @@ public: static int const kThreadsPerWarp = 32; static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount; + static constexpr int kNumWarpsPerBlock = + kQueriesPerBlock * kKeysPerBlock / (kThreadsPerWarp * kThreadsPerWarp); + using ProblemVisitor = FMHAGroupedProblemVisitor< ThreadblockShape, kGroupScheduleMode, @@ -369,13 +372,16 @@ public: cutlass::Array m_prime; cutlass::Array s_prime; cutlass::Array mi; + cutlass::Array out_rescale; + cutlass::Array + addition_storage; }; struct SharedStorageEpilogueAtEnd : ScalingCoefs { struct SharedStorageAfterMM0 { // Everything here might be overwritten during MM0 typename MM0::AccumulatorSharedStorage si; - typename MM1::SharedStorageMM1 mm1; + typename MM1::Mma::SharedStorage mm1; }; union { @@ -397,7 +403,7 @@ public: struct SharedStorageAfterMM0 { // Everything here might be overwritten during MM0 typename MM0::AccumulatorSharedStorage si; - typename MM1::SharedStorageMM1 mm1; + typename MM1::Mma::SharedStorage mm1; typename MM1::DefaultEpilogue::SharedStorage epilogue; }; @@ -490,6 +496,7 @@ public: auto& s_prime = shared_storage.s_prime; [[maybe_unused]] auto& si = shared_storage.after_mm0.si; auto& mi = shared_storage.mi; + auto& out_rescale = shared_storage.out_rescale; ProblemVisitor problem_visitor( params.problem_visitor, @@ -512,6 +519,7 @@ public: if (thread_id() < kQueriesPerBlock) { s_prime[thread_id()] = ElementAccumulator(0); + out_rescale[thread_id()] = accum_t(1.0); m_prime[thread_id()] = -cutlass::platform::numeric_limits::infinity(); mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); @@ -568,7 +576,7 @@ public: cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); MM1::Mma::prologue( - shared_storage.after_mm0.mm1.mm, + shared_storage.after_mm0.mm1, iterator_V, thread_id(), problem_size_1_k); @@ -623,6 +631,8 @@ public: if (kPreloadV) { prologueV(0); + } else { + MM1::Mma::drain_cp_asyncs(); } typename MM0::Mma::Operator::IteratorC::TensorCoord @@ -649,30 +659,48 @@ public: }, [&](int accum_m) {}); } - DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { - DISPATCH_BOOL( - num_keys - iter_key_start >= kKeysPerBlock, - kFullColumns, - ([&] { - // Update `mi` from accum stored in registers - // Also does accum[i] <- exp(accum[i] - mi) - iterative_softmax< - typename MM0::Mma::Operator::IteratorC, - kFullColumns, - kIsFirst>( - accum_o, - accum, - mi, - m_prime, - s_prime, - lane_id(), - thread_id(), - warp_id(), - num_keys - iter_key_start, - iteratorC_tile_offset, - kSupportsBias ? 1.0f : params.scale); - })); - })); + // DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { + // DISPATCH_BOOL( + // num_keys - iter_key_start >= kKeysPerBlock, + // kFullColumns, + // ([&] { + // // Update `mi` from accum stored in registers + // // Also does accum[i] <- exp(accum[i] - mi) + // iterative_softmax< + // typename MM0::Mma::Operator::IteratorC, + // kFullColumns, + // kIsFirst>( + // accum_o, + // accum, + // mi, + // m_prime, + // s_prime, + // lane_id(), + // thread_id(), + // warp_id(), + // num_keys - iter_key_start, + // iteratorC_tile_offset, + // kSupportsBias ? 1.0f : params.scale); + // })); + // })); + + // Update `mi` from accum stored in registers + // Also does accum[i] <- exp(accum[i] - mi) + iterative_softmax( + accum_o, + accum, + mi, + m_prime, + s_prime, + out_rescale, + shared_storage.addition_storage, + lane_id(), + thread_id(), + warp_id(), + num_keys - iter_key_start, + iter_key_start == 0, + iteratorC_tile_offset, + kSupportsBias ? 1.0f : params.scale); // Output results to shared-memory int warp_idx_mn_0 = warp_id() % @@ -717,12 +745,14 @@ public: cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); typename MM1::Mma mma_pv( - shared_storage.after_mm0.mm1.mm, - shared_storage.after_mm0.si, + // operand A: Pij_dropped in shared memory + shared_storage.after_mm0.si.accum_ref(), + // operand B: shared memory staging area for Vj, which is loaded + // from global memory + shared_storage.after_mm0.mm1.operand_B_ref(), (int)thread_id(), (int)warp_id(), - (int)lane_id(), - (int)problem_size_1_k); + (int)lane_id()); mma_pv.set_prologue_done(kPreloadV); if (!kKeepOutputInRF) { @@ -737,6 +767,7 @@ public: } if (!kKeepOutputInRF) { + MM1::Mma::drain_cp_asyncs(); DISPATCH_BOOL( iter_key_start == 0, kIsFirst, ([&] { DISPATCH_BOOL( @@ -787,7 +818,7 @@ public: decltype(createOutputIter), decltype(createOutputAccumIter)>:: apply(createOutputIter, createOutputAccumIter, col); - EpilogueOutputOp rescale(s_prime, m_prime); + EpilogueOutputOp rescale(s_prime, out_rescale); Epilogue epilogue( shared_storage.epilogue_shared_storage(), thread_id(), @@ -836,34 +867,37 @@ public: typename MM1::OutputTileIteratorAccum // source tile >; auto dest_iter = createOutputIter(0); - EpilogueOutputOp rescale(s_prime, m_prime); + EpilogueOutputOp rescale(s_prime, out_rescale); Epilogue epilogue( shared_storage.epilogue_shared_storage(), thread_id(), warp_id(), lane_id()); + MM1::Mma::drain_cp_asyncs(); epilogue(rescale, dest_iter, accum_o); } // Next tile problem_visitor.advance(gridDim.x); + __syncthreads(); // Don't start the next iteration until all threads are done using shared memory. } } - template < - typename WarpIteratorC, - bool kFullColumns, - bool kIsFirst> + template CUTLASS_DEVICE static void iterative_softmax( typename WarpIteratorC::Fragment& frag_o, // output so far typename WarpIteratorC::Fragment& frag, cutlass::Array& mi, cutlass::Array& m_prime, cutlass::Array& s_prime, + cutlass::Array& out_rescale, + cutlass::Array& + addition_storage, int8_t lane_id, int8_t thread_id, int8_t warp_id, - int16_t max_col, + int max_col, + bool is_first, typename WarpIteratorC::TensorCoord const& tile_offset, float scaling) { /* Iterates on the accumulator and corresponding position on result matrix @@ -884,12 +918,11 @@ public: kThreadsPerWarp>::Iterator; // Convert to `accum_t` (rather than double) constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E - if (!kIsFirst) { - if (thread_id < kQueriesPerBlock) { - m_prime[thread_id] = mi[thread_id]; - } - __syncthreads(); - } + + static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, ""); + static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock; + + frag = cutlass::multiplies()(scaling * kLog2e, frag); auto lane_offset = LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); @@ -903,46 +936,64 @@ public: max = -cutlass::platform::numeric_limits::infinity(); }, [&](int accum_m, int accum_n, int idx) { - if (kFullColumns || accum_n < max_col) { + if (accum_n < max_col) { max = cutlass::fast_max(max, frag[idx]); } }, [&](int accum_m) { // Having 4x atomicMax seems faster than reduce within warp // first... - atomicMaxFloat(&mi[accum_m], max * scaling); + atomicMaxFloat(&mi[accum_m], max); }); } - frag = cutlass::multiplies()(scaling * kLog2e, frag); // Make sure we all share the update values for `mi` __syncthreads(); - if (thread_id < kQueriesPerBlock) { - auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); - m_prime[thread_id] = m_prime_exp; - s_prime[thread_id] *= m_prime_exp; + // Doing this `exp` is quite expensive. Let's + // split it across the warps + bool restore_mi_to_minus_inf = false; + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + auto m_prime_id = m_prime[id]; + auto mi_id = mi[id]; + bool changed = m_prime_id < mi_id; // `false` if both are -inf + if (changed) { + auto m_prime_exp = exp2f(m_prime_id - mi_id); + out_rescale[id] = m_prime_exp; + s_prime[id] *= m_prime_exp; + } else { + // Only when bias is enabled, it's possible that all the first values + // of attention are masked to `-inf`. In that case we want to avoid + // `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0 + if (kSupportsBias && + mi_id == -cutlass::platform::numeric_limits::infinity()) { + restore_mi_to_minus_inf = true; + mi[id] = 0.0f; + } + out_rescale[id] = 1.0f; + } } __syncthreads(); // Update output fragments - if (kKeepOutputInRF && !kIsFirst) { - accum_t mp; + if (kKeepOutputInRF && !is_first) { + accum_t line_rescale; LambdaIterator::iterateRows( lane_offset, - [&](int accum_m) { mp = m_prime[accum_m]; }, - [&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, + [&](int accum_m) { line_rescale = out_rescale[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag_o[idx] = frag_o[idx] * line_rescale; + }, [&](int accum_m) {}); - __syncthreads(); } // Update accum_m, accum_n, ... { accum_t mi_row, total_row; LambdaIterator::iterateRows( lane_offset, - [&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, + [&](int accum_m) { mi_row = mi[accum_m]; }, [&](int accum_m, int accum_n, int idx) { - frag[idx] = (kFullColumns || accum_n < max_col) - ? exp2f(frag[idx] - mi_row) - : accum_t(0.0); + frag[idx] = + (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0); }, [&](int accum_m) {}); LambdaIterator::iterateRows( @@ -954,10 +1005,31 @@ public: lane_id, total_row, [](accum_t a, accum_t b) { return a + b; })) { - atomicAdd(&s_prime[accum_m], total_row); + // NOTE: we could atomically add `total_row` to `s_prime`, but + // it's faster (and deterministic) to avoid atomics here + addition_storage + [accum_m + kQueriesPerBlock * tile_offset.column()] = + total_row; } }); } + + __syncthreads(); + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + accum_t total_row = s_prime[id]; + if (restore_mi_to_minus_inf) { + // Restore `mi`, see above when we set `restore_mi_to_minus_inf=true` + mi[id] = -cutlass::platform::numeric_limits::infinity(); + } else { + m_prime[id] = mi[id]; + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { + total_row += addition_storage[id + kQueriesPerBlock * i]; + } + s_prime[id] = total_row; + } } }; diff --git a/examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu b/examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu index 09dd1330..84662828 100644 --- a/examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu +++ b/examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu @@ -65,10 +65,12 @@ struct DefaultKernel { Element, true, // kIsAligned_ false, // kApplyDropout_ - kPreload,// kPreload_ + kPreload, // kPreload_ kBlockSizeI, // kBlockSizeI_, kBlockSizeJ, // kBlockSizeJ_, - kMaxK // kMaxK + kMaxK, // kMaxK + false, // kKeysQueriesAlignedToBlockSize + true // kEnableSplitKeys >; }; @@ -181,6 +183,7 @@ int runKernel() { READ_I64(custom_mask_type); READ_I64(num_batches); int64_t repeat_count = readInt64("repeat_count"); + READ_I64(num_splits_key); READ_TENSOR_AND_STRIDES_BMH(Element, query, q); READ_TENSOR_AND_STRIDES_BMH(Element, key, k); diff --git a/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu b/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu index a0604018..c4bb109d 100644 --- a/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu +++ b/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu @@ -999,7 +999,7 @@ public: template < int kQueriesPerBlock, int kKeysPerBlock, - bool kSingleValueIteration + int kMaxK > int run_attention(Options& options) { using Attention = AttentionKernel< @@ -1008,7 +1008,7 @@ int run_attention(Options& options) { true, // Memory is aligned kQueriesPerBlock, kKeysPerBlock, - kSingleValueIteration, + kMaxK, false, // Supports dropout false // Supports bias >; @@ -1094,15 +1094,16 @@ int main(int argc, char const **args) { if (options.head_size_v > 64) { static int const kQueriesPerBlock = 32; static int const kKeysPerBlock = 128; - if (options.head_size_v <= kKeysPerBlock) { - return run_attention(options); + if (options.head_size_v <= 128) { + return run_attention(options); } else { - return run_attention(options); + return run_attention(options); } } else { + static constexpr int kMaxK = 64; // <- Decrease to 32/16 if your problem is smaller static int const kQueriesPerBlock = 64; static int const kKeysPerBlock = 64; - return run_attention(options); + return run_attention(options); } } diff --git a/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu b/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu index f2568e3a..db7e6846 100644 --- a/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu +++ b/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu @@ -1061,7 +1061,7 @@ public: template < int kQueriesPerBlock, int kKeysPerBlock, - bool kSingleValueIteration, + int kMaxK, cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ > int run_grouped(Options& options) { @@ -1071,7 +1071,7 @@ int run_grouped(Options& options) { true, // Memory is aligned kQueriesPerBlock, kKeysPerBlock, - kSingleValueIteration, + kMaxK, GroupScheduleMode_ >::FMHAKernel; @@ -1098,18 +1098,18 @@ int run_grouped(Options& options) { template < int kQueriesPerBlock, int kKeysPerBlock, - bool kSingleValueIteration + int kMaxK > int run_attention(Options& options) { if (options.scheduler_mode == cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) { return run_grouped(options); } else { return run_grouped(options); } } @@ -1180,14 +1180,15 @@ int main(int argc, char const **args) { static int const kQueriesPerBlock = 32; static int const kKeysPerBlock = 128; if (options.head_size_v <= kKeysPerBlock) { - return run_attention(options); + return run_attention(options); } else { - return run_attention(options); + return run_attention(options); } } else { + static constexpr int kMaxK = 64; // <- Decrease to 32/16 if your problem is smaller static int const kQueriesPerBlock = 64; static int const kKeysPerBlock = 64; - return run_attention(options); + return run_attention(options); } } diff --git a/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h b/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h index e5cdc88f..5441a0a0 100644 --- a/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h +++ b/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h @@ -747,14 +747,6 @@ class CustomMmaMultistage : public CustomMmaBase { arch::OpMultiplyAddComplexFastF32>::value) { accum = plus_accum(accum, tmp_accum); } - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - // commit and drain all pending and predicated cp.async pnz from the GEMM - // mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - } } }; diff --git a/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h b/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h index f074fdbd..65743645 100644 --- a/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h +++ b/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h @@ -310,7 +310,8 @@ class CustomMmaPipelined : public CustomMmaBase { iterator_B.clear_mask(gemm_k_iterations <= 1); // Issue loads during the first warp-level matrix multiply-add *AFTER* - // issuing shared memory loads (which have the tightest latency requirement). + // issuing shared memory loads (which have the tightest latency + // requirement). // // Mainloop diff --git a/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h b/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h index 5f38782b..df510d6a 100644 --- a/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h +++ b/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h @@ -30,7 +30,8 @@ * **************************************************************************************************/ /*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. + \brief Tools and utils to store a GEMM output in shmem, and to use that + output as operandA for another GEMM back-to-back */ #pragma once @@ -55,6 +56,7 @@ #include "../epilogue/epilogue_thread_apply_logsumexp.h" #include "../gemm/mma_accum_lambda_iterator.h" #include "../gemm_kernel_utils.h" +#include "../iterators/default_warp_iterator_from_smem.h" #include "../iterators/make_residual_last.h" #include "../iterators/transpose_warp_iterator.h" #include "../iterators/warp_iterator_from_smem.h" @@ -128,18 +130,22 @@ class AccumulatorSharedStorage { template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, - // Maximum value for K - int kMaxK, + // Maximum K dimension - also the dimension of the shared-memory + // holding `OperandA` + int kMaxK_, /// Policy describing tuning details (concept: MmaPolicy) typename Policy_, /// Number of stages, int Stages, + /// Layout in shared-memory of operand A + typename SmemLayoutA, /// Used for partial specialization typename Enable = bool> class MmaBaseFromSharedMemory { public: ///< Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; + static constexpr int kMaxK = kMaxK_; ///< Policy describing tuning details using Policy = Policy_; @@ -175,8 +181,7 @@ class MmaBaseFromSharedMemory { static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages; /// Tensor reference to the A operand - using TensorRefA = - TensorRef; + using TensorRefA = TensorRef; /// Tensor reference to the B operand using TensorRefB = @@ -240,14 +245,14 @@ class MmaBaseFromSharedMemory { CUTLASS_DEVICE MmaBaseFromSharedMemory( ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage& shared_storage, + TensorRefB& b_tile, ///< ID within the threadblock int thread_idx, ///< ID of warp int warp_idx, ///< ID of each thread within a warp int lane_idx) - : warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} + : warp_tile_iterator_B_(b_tile, lane_idx) {} }; namespace { @@ -333,14 +338,13 @@ template < typename Shape_, // BEGIN smem /// Iterates over the intermediate accumulator tile in shared memory - typename WarpIteratorA, + typename WarpIteratorA_, /// whether or not to perform elementwise multiplication of A // by another matrix (A_scale) that is also kept in shared memory prior // to matmul A @ B bool ScaleOperandA_, - // Accumulator type - typename AccumulatorSharedStorage, - // END smem + /// Max GEMM problem size in K dimension + int MaxK, /// Iterates over tiles of B operand in global memory // (concept: ReadableTileIterator | ForwardTileIterator | // MaskedTileIterator) @@ -363,21 +367,24 @@ template < typename Enable = bool> class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< Shape_, - AccumulatorSharedStorage::Shape::kN, + MaxK, Policy_, - 2> { + 2, + typename WarpIteratorA_::Layout> { public: ///< Base class using Base = MmaBaseFromSharedMemory< Shape_, - AccumulatorSharedStorage::Shape::kN, + MaxK, Policy_, - 2>; + 2, + typename WarpIteratorA_::Layout>; using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> static constexpr bool ScaleOperandA = ScaleOperandA_; + using WarpIteratorA = WarpIteratorA_; ///< loads fragments of A_scale from shared memory if operand A scaling is ///< enabled. otherwise no-op. using WarpIteratorAScale = typename cutlass::platform::conditional< @@ -454,19 +461,17 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< /// constructor for MMA with operand A scaling enabled. CUTLASS_DEVICE MmaPipelinedFromSharedMemory( - // shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, - // warp iterator over A tile held in shared memory - WarpIteratorA warp_iter_a, - // warp iterator over A_scale tile held in shared memory - WarpIteratorAScale warp_iter_a_scale, + typename Base::TensorRefA a, // Operand A in shared memory + typename Base::TensorRefA a_scale, // Operand A_scale in shared memory + typename Base::TensorRefB + b_staging, // staging memory for loading tiles of B int thread_idx, int warp_idx, int lane_idx) - : Base(shared_storage, thread_idx, warp_idx, lane_idx), - warp_tile_iterator_A_(warp_iter_a), - warp_tile_iterator_A_scale_(warp_iter_a_scale), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + : Base(b_staging, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(a, lane_idx), + warp_tile_iterator_A_scale_(a_scale, lane_idx), + smem_iterator_B_(b_staging, thread_idx) { // Compute warp location within threadblock tile by mapping the warp_id to // three coordinates: // _m: the warp's position within the threadblock along the M dimension @@ -489,17 +494,14 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< /// Construct from tensor references CUTLASS_DEVICE MmaPipelinedFromSharedMemory( - typename Base::SharedStorage& - shared_storage, ///< Shared storage needed for internal use by - ///< threadblock-scoped GEMM - AccumulatorSharedStorage& accumulator_shared_storage, + typename Base::TensorRefA a, ///< Operand A in shared memory + typename Base::TensorRefB b_staging, ///< staging memory for loading B int thread_idx, ///< ID within the threadblock int warp_idx, ///< ID of warp - int lane_idx, ///< ID of each thread within a warp - int problem_size_0_n) - : Base(shared_storage, thread_idx, warp_idx, lane_idx), - warp_tile_iterator_A_(accumulator_shared_storage.accum_ref(), lane_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + int lane_idx) ///< ID of each thread within a warp + : Base(b_staging, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(a, lane_idx), + smem_iterator_B_(b_staging, thread_idx) { // Compute warp location within threadblock tile by mapping the warp_id to // three coordinates: // _m: the warp's position within the threadblock along the M dimension @@ -531,6 +533,9 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< int thread_idx, int problem_size_0_n) {} + CUTLASS_DEVICE + static void drain_cp_asyncs() {} + /// Perform a threadblock-scoped matrix multiply-accumulate CUTLASS_DEVICE void operator()( @@ -599,7 +604,8 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< iterator_B.clear_mask(gemm_k_iterations <= 1); // Issue loads during the first warp-level matrix multiply-add *AFTER* - // issuing shared memory loads (which have the tightest latency requirement). + // issuing shared memory loads (which have the tightest latency + // requirement). // // Mainloop @@ -620,8 +626,10 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< bool hasNext = true; if (warp_mma_k == Base::kWarpGemmIterations - 1) { - // Write fragments to shared memory - this->smem_iterator_B_.store(transform_B(tb_frag_B)); + if (gemm_k_iterations > 1) { + // Write fragments to shared memory + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + } __syncthreads(); @@ -695,8 +703,6 @@ template < // by another matrix (A_scale) that is also kept in shared memory prior // to matmul A @ B bool ScaleOperandA_, - // Accumulator type - typename AccumulatorSharedStorage, /// Iterates over tiles of B operand in global memory // (concept: ReadableTileIterator | ForwardTileIterator | // MaskedTileIterator) @@ -717,11 +723,20 @@ template < int kMaxK_, /// Used for partial specialization typename Enable = bool> -class MmaMultistageFromSharedMemory - : public MmaBaseFromSharedMemory { +class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory< + Shape1_, + kMaxK_, + Policy1_, + Stages_, + typename WarpIteratorA1_::Layout> { public: ///< Base class - using Base = MmaBaseFromSharedMemory; + using Base = MmaBaseFromSharedMemory< + Shape1_, + kMaxK_, + Policy1_, + Stages_, + typename WarpIteratorA1_::Layout>; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using Shape1 = Shape1_; @@ -825,20 +840,16 @@ class MmaMultistageFromSharedMemory /// constructor for MMA with operand A scaling enabled. CUTLASS_DEVICE MmaMultistageFromSharedMemory( - // shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, - // warp level iterator over operand A tile kept in shared memory - WarpIteratorA1 warp_tile_iterator_A1, - // warp level iterator over operand A elementwise scale tile kept in - // shared memory. - WarpIteratorAScale warp_tile_iterator_A1_scale, + typename Base::TensorRefA a, + typename Base::TensorRefA a_scale, + typename Base::TensorRefB b_tile, int thread_idx, int warp_idx, int lane_idx) - : Base(shared_storage, thread_idx, warp_idx, lane_idx), - warp_tile_iterator_A1_(warp_tile_iterator_A1), - warp_tile_iterator_A1_scale_(warp_tile_iterator_A1_scale), - smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx), + : Base(b_tile, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(a, lane_idx), + warp_tile_iterator_A1_scale_(a_scale, lane_idx), + smem_iterator_B1_(b_tile, thread_idx), prologue_done_(false) { // Compute warp location within threadblock tile by mapping the warp_id to // three coordinates: @@ -863,23 +874,17 @@ class MmaMultistageFromSharedMemory /// Construct from tensor references CUTLASS_DEVICE MmaMultistageFromSharedMemory( - typename Base::SharedStorage& - shared_storage, ///< Shared storage needed for internal use by - ///< threadblock-scoped GEMM - AccumulatorSharedStorage& accumulator_shared_storage, + typename Base::TensorRefA a, + typename Base::TensorRefB b_tile, ///< ID within the threadblock int thread_idx, ///< ID of warp int warp_idx, ///< ID of each thread within a warp - int lane_idx, - ///< GEMM0 N is used for accumulator extent - int problem_size_0_n) - : Base(shared_storage, thread_idx, warp_idx, lane_idx), - warp_tile_iterator_A1_( - accumulator_shared_storage.accum_ref(), - lane_idx), - smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx), + int lane_idx) + : Base(b_tile, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(a, lane_idx), + smem_iterator_B1_(b_tile, thread_idx), prologue_done_(false) { // Compute warp location within threadblock tile by mapping the warp_id to // three coordinates: @@ -919,6 +924,15 @@ class MmaMultistageFromSharedMemory smem_iterator_B1); } + CUTLASS_DEVICE + static void drain_cp_asyncs() { + // commit and drain all pending and predicated cp.async pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + CUTLASS_DEVICE void copy_tiles_and_advance_1( IteratorB1& iterator_B1, @@ -1253,100 +1267,11 @@ class MmaMultistageFromSharedMemory } }; -template < - typename WarpShape, - typename InstructionShape, - typename RegularWarpIterator, - typename Policy, - typename Enable = void> -struct DefaultWarpIteratorAFromSharedMemory {}; - -// TensorOp - Ampere half -template -struct DefaultWarpIteratorAFromSharedMemory< - cutlass::gemm::GemmShape<32, 32, 32>, - cutlass::gemm::GemmShape<16, 8, 8>, - RegularWarpIterator, - Policy, - typename platform::enable_if<( - sizeof_bits::value == 16 && - Policy::Operator::Policy::OpDelta::kRow == 1)>::type> { - static constexpr auto kWarpSize = 32; - using OpDelta = typename Policy::Operator::Policy::OpDelta; - using WarpShape = cutlass::MatrixShape<32, 32>; - - using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem< - cutlass::gemm::Operand::kA, - typename RegularWarpIterator::Element>; -}; - -// TensorOp - Ampere f32 -template -struct DefaultWarpIteratorAFromSharedMemory< - WarpShape, - cutlass::gemm::GemmShape<16, 8, 8>, - RegularWarpIterator, - Policy, - typename platform::enable_if<( - sizeof_bits::value != 16 || - Policy::Operator::Policy::OpDelta::kRow != 1)>::type> { - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; - static constexpr auto kWarpSize = 32; - using OpDelta = typename Policy::Operator::Policy::OpDelta; - - using WarpIterator = - cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< - cutlass::MatrixShape, - cutlass::gemm::Operand::kA, - typename RegularWarpIterator::Element, - cutlass::layout::RowMajor, - cutlass::MatrixShape, - OpDelta::kRow, - kWarpSize>; -}; - -// TensorOp - Volta -template -struct DefaultWarpIteratorAFromSharedMemory< - WarpShape, - cutlass::gemm::GemmShape<16, 16, 4>, - RegularWarpIterator, - Policy> { - using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>; - static constexpr auto kWarpSize = 32; - using OpDelta = typename Policy::Operator::Policy::OpDelta; - - using WarpIterator = - cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator< - cutlass::MatrixShape<32, 32>, // MatrixShape, - cutlass::gemm::Operand::kA, - typename RegularWarpIterator::Element, - cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, - cutlass::MatrixShape<16, 4>, - OpDelta::kRow, - kWarpSize>; -}; - -// Simt -template -struct DefaultWarpIteratorAFromSharedMemory< - WarpShape, - cutlass::gemm::GemmShape<1, 1, 1>, - RegularWarpIterator, - Policy> { - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - static constexpr auto kWarpSize = 32; - - // We just use the same iterator, as we reproduced the same shared-memory - // schema. Just modify it to handle non-complete tiles. - using WarpIterator = RegularWarpIterator; -}; - // Converts a "regular" Mma into their counterpart from shared memory template < typename Mma_, - typename AccumulatorSharedStorage, + int kMaxK, + typename WarpIteratorA_, /// whether or not to apply elementwise multiplication of operand A by /// another matrix in shared memory before usage in A @ B bool kScaleOperandA, @@ -1364,6 +1289,7 @@ template < /// Iterates over tiles of A operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) typename SmemIteratorA_, + typename WarpIteratorA_, /// Iterates over tiles of B operand in global memory // (concept: ReadableTileIterator | ForwardTileIterator | // MaskedTileIterator) @@ -1381,7 +1307,8 @@ template < typename TransformA_, /// Transformation applied to B operand typename TransformB_, - typename AccumulatorSharedStorage_, + // Max MMA problem size K + int kMaxK, /// whether or not to apply elementwise multiplication of operand A by /// another matrix in shared memory before usage in A @ B bool kScaleOperandA, @@ -1398,12 +1325,10 @@ struct DefaultMmaFromSharedMemory< Policy_, TransformA_, TransformB_>, - AccumulatorSharedStorage_, + kMaxK, + WarpIteratorA_, kScaleOperandA, kTransposeA> { - static constexpr int kWarpSize = 32; - using SmemAccumulatorLayout = cutlass::layout::RowMajor; - using RegularMma = MmaPipelined< Shape_, IteratorA_, @@ -1421,11 +1346,7 @@ struct DefaultMmaFromSharedMemory< using ArchMmaOperator = typename Policy_::Operator; static constexpr bool kIsTransposedA = false; - using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory< - WarpShape, - InstructionShape, - typename RegularMma::Operator::IteratorA, - Policy_>::WarpIterator; + using WarpIteratorA = WarpIteratorA_; using IteratorB = typename cutlass::transform::threadblock::MakeIteratorResidualLast< IteratorB_>::Iterator; @@ -1434,7 +1355,7 @@ struct DefaultMmaFromSharedMemory< Shape_, WarpIteratorA, kScaleOperandA, - AccumulatorSharedStorage_, + kMaxK, IteratorB, SmemIteratorB_, ElementC_, @@ -1452,6 +1373,7 @@ template < /// Iterates over tiles of A operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) typename SmemIteratorA_, + typename WarpIteratorA_, /// Cache operation for operand A cutlass::arch::CacheOperation::Kind CacheOpA, /// Iterates over tiles of B operand in global memory @@ -1473,7 +1395,7 @@ template < int Stages, /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear, - typename AccumulatorSharedStorage_, + int kMaxK, /// whether or not to apply elementwise multiplication of operand A by /// another matrix in shared memory before usage in A @ B bool kScaleOperandA, @@ -1492,11 +1414,10 @@ struct DefaultMmaFromSharedMemory< Policy_, Stages, SharedMemoryClear>, - AccumulatorSharedStorage_, + kMaxK, + WarpIteratorA_, kScaleOperandA, kTransposeA> { - static constexpr int kWarpSize = 32; - using RegularMma = MmaMultistage< Shape_, IteratorA_, @@ -1513,11 +1434,6 @@ struct DefaultMmaFromSharedMemory< using WarpShape = typename Policy_::Operator::Shape; using InstructionShape = typename Policy_::Operator::InstructionShape; - using WarpIteratorA_ = typename DefaultWarpIteratorAFromSharedMemory< - WarpShape, - InstructionShape, - typename RegularMma::Operator::IteratorA, - Policy_>::WarpIterator; using WarpIteratorTranspose = TransposeWarpIterator; static constexpr bool kIsTransposedA = WarpIteratorTranspose::kSupportsTranspose && kTransposeA; @@ -1526,9 +1442,6 @@ struct DefaultMmaFromSharedMemory< typename WarpIteratorTranspose::Iterator, WarpIteratorA_>::type; - static int constexpr kMaxK = kIsTransposedA - ? AccumulatorSharedStorage_::Shape::kM - : AccumulatorSharedStorage_::Shape::kN; // Reduce the number of stages if we don't need that many static int constexpr kStagesMax = (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); @@ -1542,7 +1455,6 @@ struct DefaultMmaFromSharedMemory< Shape_, WarpIteratorA, kScaleOperandA, - AccumulatorSharedStorage_, IteratorB, SmemIteratorB_, RegularMma::kCacheOpB, @@ -1750,27 +1662,17 @@ struct B2bGemm< using FragmentC = IteratorC::Fragment; using lse_scalar_t = float; - using SmemAccumulatorLayout = cutlass::layout::RowMajor; - using SmemIteratorD0 = cutlass::epilogue::warp::TileIteratorVoltaTensorOp< - WarpShape, - cutlass::gemm::GemmShape<32, 32, 4>, - scalar_t, - SmemAccumulatorLayout>; - - // // Storage in shared-memory for Q.Kt + // Storage in shared-memory for Q.Kt + using SmemAccumulatorLayout = + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>; using AccumulatorSharedStorage = cutlass::gemm::threadblock::AccumulatorSharedStorage< ThreadblockShape, scalar_t, - cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< - 16, - 32>, // typename SmemIteratorD0::TensorLayout, + SmemAccumulatorLayout, cutlass::MatrixShape<0, 0> // Padding >; - - using OutputLayout = - cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>; - using TensorRef = cutlass::TensorRef; + using TensorRef = cutlass::TensorRef; using Policy = typename IteratorC::Policy; using Element = accum_t; // Those are MmaVoltaTensorOpAccumulatorTileIterator private fields diff --git a/examples/41_fused_multi_head_attention/gemm_kernel_utils.h b/examples/41_fused_multi_head_attention/gemm_kernel_utils.h index 84f91667..5740cab0 100644 --- a/examples/41_fused_multi_head_attention/gemm_kernel_utils.h +++ b/examples/41_fused_multi_head_attention/gemm_kernel_utils.h @@ -228,8 +228,17 @@ struct call_conditional { // The cheapest way to do it is just to broadcast it from lane 0 //////////////////////////////////////////////////////////////////////////////// -CUTLASS_DEVICE int32_t warp_uniform(int32_t value) { - return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0); +template +CUTLASS_DEVICE T warp_uniform(T value) { + struct { + union { + T value; + uint32_t asInt; + }; + } p; + p.value = value; + p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0); + return p.value; } template diff --git a/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h b/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h new file mode 100644 index 00000000..9a0885b6 --- /dev/null +++ b/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h @@ -0,0 +1,143 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Instanciates the right WarpIterator to read from shared memory + The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading + data dumped with `B2bGemm::accumToSmem`. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" +#include "cutlass/platform/platform.h" + +#include "warp_iterator_from_smem.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + typename WarpShape, + typename InstructionShape, + typename RegularWarpIterator, + typename Policy, + typename Enable = void> +struct DefaultWarpIteratorAFromSharedMemory {}; + +// TensorOp - Ampere half +template +struct DefaultWarpIteratorAFromSharedMemory< + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, kInstrK>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value == 16 && + Policy::Operator::Policy::OpDelta::kRow == 1)>::type> { + using OpDelta = typename Policy::Operator::Policy::OpDelta; + using WarpShape = cutlass::MatrixShape<32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, kInstrK>; + + using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem< + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::MatrixShape>; +}; + +// TensorOp - Ampere f32 +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value != 16 || + Policy::Operator::Policy::OpDelta::kRow != 1)>::type> { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + cutlass::MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajor, + cutlass::MatrixShape, + OpDelta::kRow, + kWarpSize>; +}; + +// TensorOp - Volta +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 16, 4>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator< + cutlass::MatrixShape<32, 32>, // MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, + cutlass::MatrixShape<16, 4>, + OpDelta::kRow, + kWarpSize>; +}; + +// Simt +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<1, 1, 1>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr auto kWarpSize = 32; + + // We just use the same iterator, as we reproduced the same shared-memory + // schema. Just modify it to handle non-complete tiles. + using WarpIterator = RegularWarpIterator; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h b/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h index 37c42ea2..1784bd2e 100644 --- a/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h +++ b/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h @@ -44,10 +44,12 @@ template < cutlass::gemm::Operand Operand, /// Data type of A elements typename Element, + typename InstructionShape, bool kTranspose> struct TransposeWarpIterator< - cutlass::gemm::warp::WarpIteratorFromSmem> { - using Iterator = - cutlass::gemm::warp::WarpIteratorFromSmem; + cutlass::gemm::warp:: + WarpIteratorFromSmem> { + using Iterator = cutlass::gemm::warp:: + WarpIteratorFromSmem; static bool constexpr kSupportsTranspose = true; }; diff --git a/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h b/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h index 37f41699..7e0dc6c7 100644 --- a/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h +++ b/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h @@ -56,6 +56,7 @@ template < Operand Operand_, /// Data type of A elements typename Element_, + typename InstructionShape_, bool kTranspose = false> class WarpIteratorFromSmem { public: @@ -64,6 +65,9 @@ class WarpIteratorFromSmem { /// Operand tag static Operand const kOperand = Operand_; + static_assert( + kOperand == Operand::kA, + "No support for OperandB at the moment"); /// Basic check static_assert( @@ -78,7 +82,11 @@ class WarpIteratorFromSmem { using Layout = cutlass::layout::RowMajor; /// Shape of one matrix product operation (concept: MatrixShape) - using InstructionShape = cutlass::MatrixShape<16, 8>; + using InstructionShape = InstructionShape_; + static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16"); + static_assert( + InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16, + "Only supports 16x8x8 / 16x8x16"); /// Delta between *MMA operations (in units of *MMA operations, concept: /// MatrixShape) @@ -133,7 +141,9 @@ class WarpIteratorFromSmem { : InstructionShape::kRow); static int constexpr kAccessesInner = (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; + // Number of 32bits tiles to load per `ldmatrix` static int const kTilesPerInstruction = InstructionShape::kRow / 8; + static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8"); private: /// Underlying tensor reference @@ -153,38 +163,28 @@ class WarpIteratorFromSmem { CUTLASS_HOST_DEVICE WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id) : ref_(ref), iterations_(0) { + // See also: + // https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688 + // 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4) + // 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4) int ldsm_vec_num = (lane_id >> 3); if (kOperand == Operand::kA) { origin_ = MatrixCoord(lane_id % 8, 0); static_assert( - InstructionCount::kRow * kAccessesInner * kTilesPerInstruction == 4, - ""); - CUTLASS_PRAGMA_UNROLL - for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; - ++inst_m_idx) { - CUTLASS_PRAGMA_UNROLL - for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { - CUTLASS_PRAGMA_UNROLL - for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; - ++access_m_idx) { - int access_idx = access_m_idx + - kTilesPerInstruction * - (inner_idx + kAccessesInner * inst_m_idx); - - MatrixCoord offset( - access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, - inner_idx * 4 * kElementsPerAccess); - - if (access_idx == ldsm_vec_num) { - if (kTranspose) { - offset = MatrixCoord(offset.column(), offset.row()); - } - origin_ += offset; - } - } - } + InstructionCount::kRow * kTilesPerInstruction == 4, + "can't use ldmatrix.x4"); + int access_m_idx = ldsm_vec_num % kTilesPerInstruction; + int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner; + int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner); + MatrixCoord offset( + access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, + inner_idx * 4 * kElementsPerAccess); + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); } + origin_ += offset; } else { + // Note: This is not tested or used origin_ = MatrixCoord(0, lane_id % 8); static_assert(InstructionCount::kColumn * kAccessesInner == 4, ""); CUTLASS_PRAGMA_UNROLL @@ -256,17 +256,23 @@ class WarpIteratorFromSmem { using LoadLayout = typename platform:: conditional::type; - MatrixCoord offset; - if (kOperand == Operand::kA) { - offset = MatrixCoord(0, iterations_ * InstructionShape::kColumn); - } else { - offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0); + CUTLASS_PRAGMA_UNROLL + for (int access_m_idx = 0; access_m_idx < + (InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4; + ++access_m_idx) { + MatrixCoord offset; + if (kOperand == Operand::kA) { + offset = MatrixCoord( + access_m_idx * 16, iterations_ * InstructionShape::kColumn); + } else { + offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0); + } + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + cutlass::arch::ldsm( + access_ptr[access_m_idx], ref_.data() + ref_.offset(offset)); } - if (kTranspose) { - offset = MatrixCoord(offset.column(), offset.row()); - } - cutlass::arch::ldsm( - access_ptr[0], ref_.data() + ref_.offset(offset)); } }; diff --git a/examples/41_fused_multi_head_attention/kernel_backward.h b/examples/41_fused_multi_head_attention/kernel_backward.h index b5e036ad..b2f4ed40 100644 --- a/examples/41_fused_multi_head_attention/kernel_backward.h +++ b/examples/41_fused_multi_head_attention/kernel_backward.h @@ -71,6 +71,7 @@ #include "cutlass/gemm/threadblock/default_mma_core_sm70.h" #include "cutlass/gemm/threadblock/default_mma_core_sm75.h" #include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/integer_subbyte.h" #include "cutlass/matrix_shape.h" #include "cutlass/platform/platform.h" #include "cutlass/transform/threadblock/predicated_tile_iterator.h" @@ -171,10 +172,53 @@ struct GmemTile { sub_fragment, gmem_ptr, true); } } + + CUTLASS_DEVICE void storeAtomicAdd( + FragmentType const& fragment, + int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + float* gmem_ptr = ptr + thread_id * AccessType::kElements + i * kStride; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + float val = fragment[i * AccessType::kElements + j]; + float* ptr = gmem_ptr + j; + atomicAdd(ptr, val); + } + } + } +}; + +struct AtomicLock { + CUTLASS_DEVICE static void acquire( + int32_t* lock, + int set_val, + int thread_id) { + if (thread_id == 0) { + while (atomicCAS(lock, 0 /*cmp*/, set_val /*setval*/) != set_val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + __nanosleep(40); +#endif + } + } + __syncthreads(); + } + CUTLASS_DEVICE static void release(int32_t* lock, int thread_id) { + if (thread_id == 0) { + int status = 0; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("st.global.release.gpu.b32 [%0], %1;\n" + : + : "l"(lock), "r"(status)); +#else + asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); +#endif + } + } }; template -constexpr int getWarpsPerSm() { +constexpr int getWarpsPerSmBw() { bool is_half = !cutlass::platform::is_same::value; if (Arch::kMinComputeCapability >= 80) { return is_half ? 12 : 8; @@ -198,7 +242,13 @@ template < int kBlockSizeI_, int kBlockSizeJ_, // upperbound on `max(value.shape[-1], query.shape[-1])` - int kMaxK_ = (int)cutlass::platform::numeric_limits::max()> + int kMaxK_ = (int)cutlass::platform::numeric_limits::max(), + // assumes that `cu_seqlen` is None, and + // (1) `num_queries % kBlockSizeI == 0` + // (2) `num_keys % kBlockSizeJ == 0` + bool kKeysQueriesAlignedToBlockSize_ = false, + // Allows to parallelize across keys + bool kEnableSplitKeys_ = true> struct AttentionBackwardKernel { enum CustomMaskType { NoCustomMask = 0, @@ -218,253 +268,8 @@ struct AttentionBackwardKernel { static constexpr int kBlockSizeI = kBlockSizeI_; static constexpr int kBlockSizeJ = kBlockSizeJ_; static constexpr int kMaxK = kMaxK_; - - struct Params { - // Input tensors - scalar_t* query_ptr; // [Mq, nH, K] - scalar_t* key_ptr; // [Mk, nH, K] - scalar_t* value_ptr; // [Mk, nH, Kv] - scalar_t* bias_ptr = nullptr; - lse_scalar_t* logsumexp_ptr; // [nH, Mq] - scalar_t* output_ptr; // [Mq, nH, Kv] - scalar_t* grad_output_ptr; // [Mq, nH, Kv] - accum_t* delta_ptr; // [nH, Mq] - int32_t* cu_seqlens_q_ptr = nullptr; - int32_t* cu_seqlens_k_ptr = nullptr; - - // Output tensors - output_t* grad_query_ptr; // [Mq, nH, K] - output_t* grad_key_ptr; // [Mk, nH, K] - output_t* grad_value_ptr; // [Mk, nH, Kv] - output_t* grad_bias_ptr = nullptr; - - // Accumulators - union { - output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv] - output_accum_t* workspace_gk; - }; - output_accum_t* workspace_gv; // (will be calculated by the kernel) - output_accum_t* workspace_gq; // (will be calculated by the kernel) - - // Scale - accum_t scale; - - // Dimensions/strides - int32_t head_dim = -1; - int32_t head_dim_value = -1; - int32_t num_queries = -1; - int32_t num_keys = -1; - int32_t num_heads = -1; - uint8_t custom_mask_type = NoCustomMask; - - int32_t q_strideM; - int32_t k_strideM; - int32_t v_strideM; - int32_t bias_strideM = 0; - int32_t gO_strideM; - int32_t gB_strideM; - int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise - -#ifdef HAS_PYTORCH - // dropout - at::PhiloxCudaState rng_engine_inputs; -#endif - // RNG sequence offset based on batch_id and head_id - unsigned long long dropout_batch_head_rng_offset; - float dropout_prob = 0.0f; - - CUTLASS_HOST_DEVICE int32_t o_strideM() const { - return head_dim_value * num_heads; - } - CUTLASS_HOST_DEVICE int32_t gQ_strideM() const { - return gQKV_strideM_multiplier * num_heads * head_dim; - } - CUTLASS_HOST_DEVICE int32_t gK_strideM() const { - return gQKV_strideM_multiplier * num_heads * head_dim; - } - CUTLASS_HOST_DEVICE int32_t gV_strideM() const { - return gQKV_strideM_multiplier * num_heads * head_dim_value; - } - - // Everything below is only used in `advance_to_block` - // and shouldn't use registers - int64_t o_strideH; - int32_t q_strideH; - int32_t k_strideH; - int32_t v_strideH; - int32_t bias_strideH = 0; - int64_t o_strideB; - int64_t q_strideB; - int64_t k_strideB; - int64_t v_strideB; - int64_t bias_strideB = 0; - int64_t lse_strideB; - int64_t lse_strideH; - int64_t delta_strideB; - int64_t delta_strideH; - int32_t num_batches; - - int64_t gO_strideB = 0; - int64_t gQ_strideB = 0; - int64_t gK_strideB = 0; - int64_t gV_strideB = 0; - int64_t gB_strideB = 0; - int64_t gO_strideH = 0; - int64_t gQ_strideH = 0; - int64_t gK_strideH = 0; - int64_t gV_strideH = 0; - int64_t gB_strideH = 0; - - CUTLASS_DEVICE bool advance_to_block() { - int64_t batch_id = blockIdx.z; - int32_t head_id = blockIdx.y; - - if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) { - assert(workspace_size() == 0 || workspace != nullptr); - - workspace += (batch_id * num_heads + head_id) * workspace_strideBH(); - workspace = warp_uniform(workspace); - workspace_gv = workspace + workspace_elements_gk(); - workspace_gq = workspace_gv + workspace_elements_gv(); - } else { - workspace = nullptr; - } - - // Advance pointers that depend on the total concatenated - // number of queries, as `num_queries` is modified in the block - // below - dropout_batch_head_rng_offset = - batch_id * (num_heads * num_queries * num_keys) + - head_id * (num_queries * num_keys); - logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH; - - if (cu_seqlens_q_ptr != nullptr) { - assert(cu_seqlens_k_ptr != nullptr); - cu_seqlens_q_ptr += batch_id; - cu_seqlens_k_ptr += batch_id; - int32_t q_start = cu_seqlens_q_ptr[0]; - int32_t k_start = cu_seqlens_k_ptr[0]; - int64_t q_next_start = cu_seqlens_q_ptr[1]; - int64_t k_next_start = cu_seqlens_k_ptr[1]; - assert(q_next_start - q_start <= num_queries); - assert(k_next_start - k_start <= num_keys); - num_queries = q_next_start - q_start; - num_keys = k_next_start - k_start; - - // Jump manually - batch_id = 0; - - query_ptr += q_start * q_strideM; - key_ptr += k_start * k_strideM; - value_ptr += k_start * v_strideM; - assert(bias_ptr == nullptr); - assert(grad_bias_ptr == nullptr); - output_ptr += q_start * o_strideM(); - grad_output_ptr += q_start * gO_strideM; - delta_ptr += q_start; - - grad_query_ptr += q_start * gQ_strideM(); - grad_key_ptr += k_start * gK_strideM(); - grad_value_ptr += k_start * gV_strideM(); - } - - query_ptr += batch_id * q_strideB + head_id * q_strideH; - key_ptr += batch_id * k_strideB + head_id * k_strideH; - value_ptr += batch_id * v_strideB + head_id * v_strideH; - if (bias_ptr != nullptr) { - bias_ptr += batch_id * bias_strideB + head_id * bias_strideH; - } - output_ptr += batch_id * o_strideB + head_id * o_strideH; - grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; - delta_ptr += batch_id * delta_strideB + head_id * delta_strideH; - - grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; - grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; - grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; - if (grad_bias_ptr != nullptr) { - grad_bias_ptr += batch_id * gB_strideB + head_id * gB_strideH; - } - - head_dim = warp_uniform(head_dim); - head_dim_value = warp_uniform(head_dim_value); - num_queries = warp_uniform(num_queries); - num_keys = warp_uniform(num_keys); - num_heads = warp_uniform(num_heads); - - gO_strideM = warp_uniform(gO_strideM); - gQKV_strideM_multiplier = warp_uniform(gQKV_strideM_multiplier); - q_strideM = warp_uniform(q_strideM); - k_strideM = warp_uniform(k_strideM); - v_strideM = warp_uniform(v_strideM); - - query_ptr = warp_uniform(query_ptr); - key_ptr = warp_uniform(key_ptr); - value_ptr = warp_uniform(value_ptr); - bias_ptr = warp_uniform(bias_ptr); - logsumexp_ptr = warp_uniform(logsumexp_ptr); - output_ptr = warp_uniform(output_ptr); - grad_output_ptr = warp_uniform(grad_output_ptr); - delta_ptr = warp_uniform(delta_ptr); - - grad_query_ptr = warp_uniform(grad_query_ptr); - grad_key_ptr = warp_uniform(grad_key_ptr); - grad_value_ptr = warp_uniform(grad_value_ptr); - grad_bias_ptr = warp_uniform(grad_bias_ptr); - custom_mask_type = warp_uniform(custom_mask_type); - -#if 0 - PRINT_T0("[b:%d h:%d] dp[0]:%f Q:%f K:%f V:%f LSE:%f", - int(blockIdx.z), int(blockIdx.y), - float(delta_ptr[0]), - float(query_ptr[0]), float(key_ptr[0]), float(value_ptr[0]), - float(logsumexp_ptr[0]) - ) -#endif - return true; - } - - __host__ dim3 getBlocksGrid() const { - return dim3(1, num_heads, num_batches); - } - __host__ dim3 getThreadsGrid() const { - return dim3(kWarpSize, kNumWarpsPerBlock, 1); - } - CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const { - if (!kNeedsAccumGradK) { - return 0; - } - return align_up(num_keys, (int32_t)kBlockSizeJ) * - align_up(head_dim, (int32_t)kBlockSizeI); - } - CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const { - if (!kNeedsAccumGradV) { - return 0; - } - return align_up(num_keys, (int32_t)kBlockSizeJ) * - align_up(head_dim_value, (int32_t)kBlockSizeI); - } - CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const { - if (!kNeedsAccumGradQ) { - return 0; - } - if (num_keys <= kBlockSizeJ) { - return 0; - } - return align_up(num_queries, (int32_t)kBlockSizeI) * - align_up(head_dim, (int32_t)kBlockSizeJ); - } - CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const { - // Aligned on 128bits - return align_up( - workspace_elements_gk() + workspace_elements_gv() + - workspace_elements_gq(), - int64_t(4)); - } - CUTLASS_HOST_DEVICE int64_t workspace_size() const { - // Returns size of buffer we need to run this kernel - return num_batches * num_heads * workspace_strideBH() * sizeof(float); - } - }; + static constexpr bool kKeysQueriesAlignedToBlockSize = + kKeysQueriesAlignedToBlockSize_; static constexpr int64_t kWarpSize = 32; @@ -495,17 +300,10 @@ struct AttentionBackwardKernel { static constexpr bool kKernelComputesDelta = kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70); - static constexpr bool kNeedsAccumGradQ = - !cutlass::platform::is_same::value; - static constexpr bool kNeedsAccumGradK = !kOutputInRF && - !cutlass::platform::is_same::value; - static constexpr bool kNeedsAccumGradV = !kOutputInRF && - !cutlass::platform::is_same::value; - // Launch bounds static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock; static constexpr int64_t kMinBlocksPerSm = - getWarpsPerSm() / kNumWarpsPerBlock; + getWarpsPerSmBw() / kNumWarpsPerBlock; using GemmType = DefaultGemmType; using DefaultConfig = @@ -625,14 +423,23 @@ struct AttentionBackwardKernel { // same time. // if no dropout: // for computing dVj += Pij.T @ dOi + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, // WarpShape + typename DefaultGemm::Mma::Operator:: + InstructionShape, // InstructionShape + typename DefaultGemm::Mma::Operator:: + IteratorA, // RegularWarpIterator + typename DefaultGemm::Mma::Policy // Policy + >::WarpIterator; using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, - typename MatmulQK::AccumulatorSharedStorage, + MatmulQK::AccumulatorSharedStorage::Shape::kN, + WarpIteratorA, kApplyDropout>; // kScaleOperandA using Mma = typename DefaultMmaFromSmem::Mma; - using WarpIteratorA = typename DefaultMmaFromSmem::WarpIteratorA; using IteratorB = typename Mma::IteratorB; using WarpCount = typename Mma::WarpCount; @@ -693,6 +500,10 @@ struct AttentionBackwardKernel { typename GemmType::Operator, cutlass::gemm::SharedMemoryClearOption::kNone>; using Mma = typename MakeCustomMma::Mma; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + ElementAccum, + kWarpSize>::Iterator; // epilogue used to write bias gradient, which is just the output of this // matmul with some operations applied to the fragment @@ -701,8 +512,8 @@ struct AttentionBackwardKernel { // Epilogue to store to shared-memory in a format that we can use later for // the second matmul using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< - typename Mma::Operator::IteratorC, - typename Mma::Operator, + typename DefaultGemm::Mma::Operator::IteratorC, + typename DefaultGemm::Mma::Operator, scalar_t, WarpShape, ThreadblockShape>; @@ -737,10 +548,17 @@ struct AttentionBackwardKernel { false, // SplitKSerial typename GemmType::Operator>; + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, + typename DefaultGemm::Mma::Operator::InstructionShape, + typename DefaultGemm::Mma::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, - typename MatmulDOIVJ::AccumulatorSharedStorage, + MatmulDOIVJ::AccumulatorSharedStorage::Shape::kN, + WarpIteratorA, false>; // kScaleOperandA using Mma = typename DefaultMmaFromSmem::Mma; using IteratorB = typename Mma::IteratorB; @@ -782,15 +600,23 @@ struct AttentionBackwardKernel { false, // SplitKSerial typename GemmType::Operator>; + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, + typename DefaultGemm::Mma::Operator::InstructionShape, + typename DefaultGemm::Mma::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; using DefaultMmaFromSmemN = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, - typename MatmulQK::AccumulatorSharedStorage, + MatmulQK::AccumulatorSharedStorage::Shape::kN, // kMaxK + WarpIteratorA, false>; // kScaleOperandA using DefaultMmaFromSmemT = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, - typename MatmulDOIVJ::AccumulatorSharedStorage, + MatmulDOIVJ::AccumulatorSharedStorage::Shape::kM, // kMaxK + WarpIteratorA, false, // kScaleOperandA kPreload>; // kTransposeA using DefaultMmaFromSmem = typename cutlass::platform::conditional< @@ -810,6 +636,279 @@ struct AttentionBackwardKernel { using AccumTileGmem = GmemTile; }; + static constexpr bool kEnableSplitKeys = kEnableSplitKeys_; + + static constexpr bool kNeedsAccumGradQ = kEnableSplitKeys || + !cutlass::platform::is_same::value; + static constexpr bool kNeedsAccumGradK = !kOutputInRF && + !cutlass::platform::is_same::value; + static constexpr bool kNeedsAccumGradV = !kOutputInRF && + !cutlass::platform::is_same::value; + + struct GradQTempStorage { + int32_t lock; + int32_t counter; + int32_t pad[2]; // pad to 128bits + output_accum_t buffer[MatmulGradQ::AccumTileGmem::kElementsStored]; + }; + + struct Params { + // Input tensors + scalar_t* query_ptr = nullptr; // [Mq, nH, K] + scalar_t* key_ptr = nullptr; // [Mk, nH, K] + scalar_t* value_ptr = nullptr; // [Mk, nH, Kv] + scalar_t* bias_ptr = nullptr; + lse_scalar_t* logsumexp_ptr = nullptr; // [nH, Mq] + scalar_t* output_ptr = nullptr; // [Mq, nH, Kv] + scalar_t* grad_output_ptr = nullptr; // [Mq, nH, Kv] + accum_t* delta_ptr = nullptr; // [nH, Mq] + int32_t* cu_seqlens_q_ptr = nullptr; + int32_t* cu_seqlens_k_ptr = nullptr; + + // Output tensors + output_t* grad_query_ptr = nullptr; // [Mq, nH, K] + output_t* grad_key_ptr = nullptr; // [Mk, nH, K] + output_t* grad_value_ptr = nullptr; // [Mk, nH, Kv] + output_t* grad_bias_ptr = nullptr; + + // Accumulators + output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv] + output_accum_t* workspace_gv = + nullptr; // (will be calculated by the kernel) + GradQTempStorage* workspace_gq = + nullptr; // (will be calculated by the kernel) + + // Scale + accum_t scale = 1.0f; + + // Dimensions/strides + int32_t head_dim = -1; + int32_t head_dim_value = -1; + int32_t num_queries = -1; + int32_t num_keys = -1; + int32_t num_heads = -1; + uint8_t custom_mask_type = NoCustomMask; + + int32_t q_strideM = -1; + int32_t k_strideM = -1; + int32_t v_strideM = -1; + int32_t bias_strideM = 0; + int32_t gO_strideM = -1; + int32_t gB_strideM = -1; + int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise + +#ifdef HAS_PYTORCH + // dropout + at::PhiloxCudaState rng_engine_inputs = {0, 0}; +#endif + // RNG sequence offset based on batch_id and head_id + unsigned long long dropout_batch_head_rng_offset = 0; + float dropout_prob = 0.0f; + + CUTLASS_HOST_DEVICE int32_t o_strideM() const { + return head_dim_value * num_heads; + } + CUTLASS_HOST_DEVICE int32_t gQ_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gK_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gV_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim_value; + } + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int64_t o_strideH = -1; + int32_t q_strideH = -1; + int32_t k_strideH = -1; + int32_t v_strideH = -1; + int64_t bias_strideH = 0; + int64_t o_strideB = -1; + int64_t q_strideB = -1; + int64_t k_strideB = -1; + int64_t v_strideB = -1; + int64_t bias_strideB = 0; + int64_t lse_strideB = -1; + int64_t lse_strideH = -1; + int64_t delta_strideB = -1; + int64_t delta_strideH = -1; + int32_t num_batches = -1; + int16_t num_splits_key = 1; // We use `gridDim.x` inside kernel + + int64_t gO_strideB = 0; + int64_t gQ_strideB = 0; + int64_t gK_strideB = 0; + int64_t gV_strideB = 0; + int64_t gB_strideB = 0; + int64_t gO_strideH = 0; + int64_t gQ_strideH = 0; + int64_t gK_strideH = 0; + int64_t gV_strideH = 0; + int64_t gB_strideH = 0; + + CUTLASS_DEVICE int16_t num_splits_key_device() const { + return kEnableSplitKeys ? gridDim.x : 1; + } + CUTLASS_DEVICE int16_t split_key_device() const { + return kEnableSplitKeys ? blockIdx.x : 0; + } + + CUTLASS_DEVICE bool advance_to_block() { + int64_t batch_id = blockIdx.z; + int32_t head_id = blockIdx.y; + + if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) { + assert(workspace_size() == 0 || workspace != nullptr); + + workspace += (batch_id * num_heads + head_id) * workspace_strideBH(); + workspace = warp_uniform(workspace); + workspace_gv = workspace + workspace_elements_gk(); + workspace_gq = + (GradQTempStorage*)(workspace_gv + workspace_elements_gv()); + if (kEnableSplitKeys) { + workspace_gv += workspace_elements_gv() * split_key_device() / + num_splits_key_device(); + workspace += workspace_elements_gk() * split_key_device() / + num_splits_key_device(); + } + } else { + workspace = nullptr; + } + + // Advance pointers that depend on the total concatenated + // number of queries, as `num_queries` is modified in the block + // below + dropout_batch_head_rng_offset = + batch_id * (num_heads * num_queries * num_keys) + + head_id * (num_queries * num_keys); + logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH; + + if (cu_seqlens_q_ptr != nullptr) { + assert(cu_seqlens_k_ptr != nullptr); + cu_seqlens_q_ptr += batch_id; + cu_seqlens_k_ptr += batch_id; + int32_t q_start = cu_seqlens_q_ptr[0]; + int32_t k_start = cu_seqlens_k_ptr[0]; + int64_t q_next_start = cu_seqlens_q_ptr[1]; + int64_t k_next_start = cu_seqlens_k_ptr[1]; + assert(q_next_start - q_start <= num_queries); + assert(k_next_start - k_start <= num_keys); + num_queries = q_next_start - q_start; + num_keys = k_next_start - k_start; + + // Jump manually + batch_id = 0; + + query_ptr += q_start * q_strideM; + key_ptr += k_start * k_strideM; + value_ptr += k_start * v_strideM; + assert(bias_ptr == nullptr); + assert(grad_bias_ptr == nullptr); + output_ptr += q_start * o_strideM(); + grad_output_ptr += q_start * gO_strideM; + delta_ptr += q_start; + + grad_query_ptr += q_start * gQ_strideM(); + grad_key_ptr += k_start * gK_strideM(); + grad_value_ptr += k_start * gV_strideM(); + } + + query_ptr += batch_id * q_strideB + head_id * q_strideH; + key_ptr += batch_id * k_strideB + head_id * k_strideH; + value_ptr += batch_id * v_strideB + head_id * v_strideH; + if (bias_ptr != nullptr) { + bias_ptr += batch_id * bias_strideB + head_id * bias_strideH; + } + output_ptr += batch_id * o_strideB + head_id * o_strideH; + grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; + delta_ptr += batch_id * delta_strideB + head_id * delta_strideH; + + grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; + grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; + grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; + if (grad_bias_ptr != nullptr) { + grad_bias_ptr += batch_id * gB_strideB + head_id * gB_strideH; + } + + // Some values are modified above + // Signal to the compiler that they are the same in all threads + // and can be stored in warp-uniform registers (Sm75+) + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + custom_mask_type = warp_uniform(custom_mask_type); + + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + bias_ptr = warp_uniform(bias_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + output_ptr = warp_uniform(output_ptr); + grad_output_ptr = warp_uniform(grad_output_ptr); + delta_ptr = warp_uniform(delta_ptr); + + grad_query_ptr = warp_uniform(grad_query_ptr); + grad_key_ptr = warp_uniform(grad_key_ptr); + grad_value_ptr = warp_uniform(grad_value_ptr); + grad_bias_ptr = warp_uniform(grad_bias_ptr); + +#if 0 + PRINT_T0("[b:%d h:%d] dp[0]:%f Q:%f K:%f V:%f LSE:%f", + int(blockIdx.z), int(blockIdx.y), + float(delta_ptr[0]), + float(query_ptr[0]), float(key_ptr[0]), float(value_ptr[0]), + float(logsumexp_ptr[0]) + ) +#endif + return true; + } + + __host__ dim3 getBlocksGrid() const { + return dim3(num_splits_key, num_heads, num_batches); + } + __host__ dim3 getThreadsGrid() const { + return dim3(kWarpSize * kNumWarpsPerBlock, 1, 1); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const { + if (!kNeedsAccumGradK) { + return 0; + } + return num_splits_key * align_up(num_keys, (int32_t)kBlockSizeJ) * + align_up(head_dim, (int32_t)kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const { + if (!kNeedsAccumGradV) { + return 0; + } + return num_splits_key * align_up(num_keys, (int32_t)kBlockSizeJ) * + align_up(head_dim_value, (int32_t)kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const { + if (!kNeedsAccumGradQ) { + return 0; + } + int num_blocks = ceil_div(num_queries, kBlockSizeI); + int num_cols = ceil_div(head_dim, MatmulGradQ::ThreadblockShape::kN); + return num_blocks * num_cols * sizeof(GradQTempStorage) / + sizeof(output_accum_t); + } + CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const { + // Aligned on 128bits + return align_up( + workspace_elements_gk() + workspace_elements_gv() + + workspace_elements_gq(), + int64_t(4)); + } + CUTLASS_HOST_DEVICE int64_t workspace_size() const { + // Returns size of buffer we need to run this kernel + return num_batches * num_heads * workspace_strideBH() * sizeof(float); + } + CUTLASS_HOST_DEVICE bool should_zero_workspace() const { + return num_splits_key > 1; + } + }; + // shared storage for keeping Zij matrix. not needed if we aren't using // dropout, in which case we use an empty array to save shared memory using ZijSharedStorage = typename cutlass::platform::conditional< @@ -848,12 +947,7 @@ struct AttentionBackwardKernel { // 10. write to fragment typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; }; - // 5. store Zij. it is needed: - // - to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij - // are loaded for the computation of dVj. - // - to compute dPij = (dOi @ Vj.T) * Zij - // 6. used in dVj += (Pij.T * Zij) @ dOi - // 9. used in dPij = dPij_dropped * Zij + // 5. store Zij. it is needed in dVj += (Pij.T * Zij) @ dOi ZijSharedStorage zij; union { @@ -936,8 +1030,10 @@ struct AttentionBackwardKernel { printf(" tmpT_shared_storage: %db\n", FSZ(part3.tmpT_shared_storage)); printf(" part4: %db\n", FSZ(part4)); printf(" mm_qk_q: %db\n", FSZ(part4.mm_qk_q)); - printf(" gradK_epilogue_final: %db\n", FSZ(part4.gradK_epilogue_final)); - printf(" gradV_epilogue_final: %db\n", FSZ(part4.gradV_epilogue_final)); + printf( + " gradK_epilogue_final: %db\n", FSZ(part4.gradK_epilogue_final)); + printf( + " gradV_epilogue_final: %db\n", FSZ(part4.gradV_epilogue_final)); } // =========================================== #define FIELD(INSIDE_STRUCT, FIELDNAME) \ @@ -987,11 +1083,9 @@ struct AttentionBackwardKernel { // - in next step where it is used in dSij = Pij * (dPij - Di) typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; }; - // 3. store Zij. it is needed: - // - in this step, where it is used to compute Pij_dropped = Pij * Zij - // on the - // fly as fragments of Pij are loaded for the computation of dVj. - // - later to compute dPij = (dOi @ Vj.T) * Zij + // 3. store Zij. it is needed in this step, where it is used + // to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij are + // loaded for the computation of dVj. ZijSharedStorage zij; union { @@ -1008,8 +1102,6 @@ struct AttentionBackwardKernel { struct { // (from part2) - Pij for computing dSij = Pij * (dPij - Di) typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; - // (from part2) - Zij for computing dPij = dPij_dropped * Zij - ZijSharedStorage zij; // matmul to compute dOiVj typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; }; @@ -1182,15 +1274,46 @@ struct AttentionBackwardKernel { XFORMERS_CHECK(p.head_dim <= kMaxK, "kMaxK: Expected `head_dim < kMaxK`"); XFORMERS_CHECK( p.head_dim_value <= kMaxK, "kMaxK: Expected `head_dim_value < kMaxK`"); + if (kKeysQueriesAlignedToBlockSize) { + XFORMERS_CHECK( + p.cu_seqlens_k_ptr == nullptr, + "This kernel does not support cu_seqlen"); + XFORMERS_CHECK( + p.cu_seqlens_q_ptr == nullptr, + "This kernel does not support cu_seqlen"); + XFORMERS_CHECK( + p.num_queries % kBlockSizeI == 0, + "kKeysQueriesAlignedToBlockSize condition not respected"); + XFORMERS_CHECK( + p.num_keys % kBlockSizeJ == 0, + "kKeysQueriesAlignedToBlockSize condition not respected"); + } + XFORMERS_CHECK( + kEnableSplitKeys || p.num_splits_key == 1, "SplitKeys is disabled"); + XFORMERS_CHECK( + p.num_splits_key > 0, "Invalid `num_splits_key` (expected >0)"); + XFORMERS_CHECK( + p.num_splits_key <= cutlass::ceil_div(p.num_keys, kBlockSizeJ), + "Invalid `num_splits_key` (too large)"); return true; } - static CUTLASS_DEVICE void attention_kernel(Params const& p) { + static CUTLASS_DEVICE void attention_kernel(Params p) { extern __shared__ char smem_buffer[]; SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + uint16_t thread_id = threadIdx.x; + uint8_t warp_id = warp_uniform(thread_id / 32); + uint8_t lane_id = thread_id % 32; + + int32_t key_start = p.split_key_device() * kBlockSizeJ; + if (key_start >= p.num_keys) { + return; + } if (kPrologueQK) { - prologueQkNextIteration(shared_storage, p, 0, 0); + int32_t query_start = getQueryStart(p, key_start); + prologueQkNextIteration( + shared_storage, p, query_start, key_start, warp_id, lane_id); } // Computes (dO*out).sum(-1) and writes it to `p.delta_ptr` @@ -1200,12 +1323,12 @@ struct AttentionBackwardKernel { if (p.head_dim_value % kOptimalElements == 0) { for (int query_start = 0; query_start < p.num_queries; query_start += kBlockSizeI) { - computeDelta(p, query_start); + computeDelta(p, query_start, warp_id, lane_id); } } else { for (int query_start = 0; query_start < p.num_queries; query_start += kBlockSizeI) { - computeDelta<1>(p, query_start); + computeDelta<1>(p, query_start, warp_id, lane_id); } } __syncthreads(); @@ -1232,77 +1355,57 @@ struct AttentionBackwardKernel { &rng_state_init); } #endif - - int32_t key_start = 0; - int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ; - for (; key_start < key_end; key_start += kBlockSizeJ) { + CUTLASS_PRAGMA_UNROLL + for (; key_start < p.num_keys; + key_start += p.num_splits_key_device() * kBlockSizeJ) { output_frags.clear(); - int32_t query_start = getQueryStart(p, key_start); - int32_t query_end = query_start + - (p.num_queries - query_start) / kBlockSizeI * kBlockSizeI; - for (; query_start < query_end; query_start += kBlockSizeI) { - processBlockIJ( + + CUTLASS_PRAGMA_UNROLL + for (int32_t query_start_shifted = getQueryStart(p, key_start); + query_start_shifted < getQueryStartShift(p) + getQueryEnd(p); + query_start_shifted += kBlockSizeI) { + // This line here + // vvvvvvvvvvvvvv + warp_id = warp_uniform(warp_id); + // ^^^^^^^^^^^^^^ + // ... makes everything use less RF and be 10% faster. Why? + // I don't know. My theory is that it forces `nvcc` to + // re-compute indices, offsets etc... and not keep them + // from the previous iteration, which prevents MASSIVE + // register spilling. + + int32_t query_start = query_start_shifted; + if (query_start >= p.num_queries) { + query_start = query_start % getQueryEnd(p); + } + + processBlockIJ( shared_storage, output_frags, p, query_start, key_start, - rng_state_init); - } - // last (partial) query - if (query_start < p.num_queries) { - processBlockIJ( - shared_storage, - output_frags, - p, - query_start, - key_start, - rng_state_init); + rng_state_init, + warp_id, + lane_id); } if (kOutputInRF) { - writeFragsToGmem(shared_storage, output_frags, p, key_start); + writeFragsToGmem( + shared_storage, output_frags, p, key_start, warp_id, lane_id); } else if (getQueryStart(p, key_start) >= p.num_queries) { - zfillGradKV(p, key_start); + zfillGradKV( + p, key_start, warp_id, lane_id); } __syncthreads(); } - // Last (partial) key - if (key_start != p.num_keys) { - output_frags.clear(); - int32_t query_start = getQueryStart(p, key_start); - for (; query_start < p.num_queries; query_start += kBlockSizeI) { - processBlockIJ( - shared_storage, - output_frags, - p, - query_start, - key_start, - rng_state_init); - } - if (kOutputInRF) { - writeFragsToGmem(shared_storage, output_frags, p, key_start); - } else if (getQueryStart(p, key_start) >= p.num_queries) { - zfillGradKV(p, key_start); - } - } - } - - static CUTLASS_DEVICE void loadDi( - cutlass::Array& di, - Params const& p, - int32_t query_start) { - int32_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; - if (thread_id < kBlockSizeI) { - accum_t di_rf = accum_t(0); - if (query_start + thread_id < p.num_queries) { - di_rf = p.delta_ptr[query_start + thread_id]; - } - di[thread_id] = di_rf; - } } template - static CUTLASS_DEVICE void zfillGradKV(Params const& p, int32_t key_start) { + static CUTLASS_DEVICE void zfillGradKV( + Params const& p, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { constexpr int kThreadsPerKey = 8; constexpr int kParallelKeys = kNumThreads / kThreadsPerKey; static_assert(kBlockSizeJ % kParallelKeys == 0, ""); @@ -1310,8 +1413,7 @@ struct AttentionBackwardKernel { // It's only used when some keys are "useless" and don't attend to // any query, due to causal masking - int lane_id = get_lane_id(); - int thread_id = get_thread_id(); + int thread_id = 32 * warp_id + lane_id; int k_shift = lane_id % kThreadsPerKey; CUTLASS_PRAGMA_UNROLL @@ -1336,31 +1438,51 @@ struct AttentionBackwardKernel { static CUTLASS_DEVICE void processBlockIJ( SharedStorage& shared_storage, OutputFragments& output_frags, - Params const& p, + Params& p, int32_t query_start, int32_t key_start, - const curandStatePhilox4_32_10_t& curand_state_init) { + const curandStatePhilox4_32_10_t& curand_state_init, + uint8_t warp_id, + uint8_t lane_id) { + cutlass::Array + dropout_keep_mask_doivj; + dropout_keep_mask_doivj.fill(1); + const float dropout_scale = + kApplyDropout ? 1.0 / (1.0 - p.dropout_prob) : 1.0f; + cutlass::MatrixCoord no_offset{0, 0}; accum_t scale = p.scale; - int16_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; - int8_t warp_id = warp_uniform(threadIdx.y); - int8_t lane_id = threadIdx.x; + int16_t thread_id = 32 * warp_id + lane_id; + + auto rematerializeThreadIds = [&]() { + // Prevents `nvcc` from keeping values deduced from + // `thread_id`, `warp_id`, ... in RF - to reduce register pressure + warp_id = warp_uniform(thread_id / 32); + lane_id = thread_id % 32; + thread_id = 32 * warp_id + lane_id; + }; bool isFirstQuery = (query_start == getQueryStart(p, key_start)); int32_t next_query, next_key; incrIteration(p, query_start, key_start, next_query, next_key); bool isLastQuery = next_key != key_start; - __syncthreads(); - loadDi(shared_storage.di(), p, query_start); + + accum_t di_rf = accum_t(0); + if (thread_id < kBlockSizeI) { + if (query_start + thread_id < p.num_queries) { + di_rf = p.delta_ptr[query_start + thread_id]; + } + shared_storage.di()[thread_id] = di_rf; + } int32_t num_queries_in_block = skipBoundsChecks ? MatmulQK::Mma::Shape::kN - : cutlass::fast_min( - (int32_t)MatmulQK::Mma::Shape::kN, p.num_queries - query_start); + : warp_uniform(cutlass::fast_min( + (int32_t)MatmulQK::Mma::Shape::kN, p.num_queries - query_start)); int32_t num_keys_in_block = skipBoundsChecks ? MatmulQK::Mma::Shape::kM - : cutlass::fast_min( - (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start); + : warp_uniform(cutlass::fast_min( + (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start)); auto prologueGradV = [&](int col) { typename MatmulGradV::Mma::IteratorB iterator_dO( @@ -1499,11 +1621,7 @@ struct AttentionBackwardKernel { [&](int accum_n) {}, [&](int accum_m, int accum_n, int idx) { // remember we are transposed - if (skipBoundsChecks || - (accum_n < num_queries_in_block && - accum_m < num_keys_in_block)) { - accum[idx] += bias_tensor_ref.at({accum_n, accum_m}); - } + accum[idx] += bias_tensor_ref.at({accum_n, accum_m}); }, [&](int accum_n) {}); } @@ -1549,6 +1667,10 @@ struct AttentionBackwardKernel { warp_id, lane_id, output_tile_coords); +#if 0 + auto accum_ref_attnT = shared_storage.attn_shared_storage().accum_ref(); + PRINT_TENSOR4x4_T0_L0("attn_T", accum_ref_attnT); +#endif // if we are using dropout, compute Zij, writing it to shared memory. // each element of Zij is: @@ -1564,9 +1686,11 @@ struct AttentionBackwardKernel { // number sequence. for Z, the end of a row and the beginning of the // next have adjacent offsets, but for Zij (tile of global matrix), this // is not necessarily the case. - const int num_threads = blockDim.x * blockDim.y * blockDim.z; + // We must fill the entire `zij` shmem with values (even out of bounds + // on the K-dimension) otherwise we can get NaNs during the GEMM + const int kQueriesPerBlock = kBlockSizeI; const int threads_per_row = cutlass::fast_min( - num_threads / num_queries_in_block, num_keys_in_block); + int32_t(kNumThreads / kQueriesPerBlock), num_keys_in_block); const int elts_per_thread = cutlass::round_nearest( cutlass::ceil_div(num_keys_in_block, threads_per_row), 4); @@ -1574,19 +1698,17 @@ struct AttentionBackwardKernel { const int thread_start_j = (thread_id % threads_per_row) * elts_per_thread; - if (thread_i < num_queries_in_block && - thread_start_j < num_keys_in_block) { + if (thread_i < kQueriesPerBlock && thread_start_j < num_keys_in_block) { curandStatePhilox4_32_10_t curand_state = curand_state_init; skipahead( (query_start + thread_i) * p.num_keys + (key_start + thread_start_j), &curand_state); - const float dropout_scale = 1.0 / (1.0 - p.dropout_prob); // generate elements of Zij, 4 elements at a time for (int zij_start_col_idx = thread_start_j; zij_start_col_idx < - cutlass::fast_min(thread_start_j + elts_per_thread, - num_keys_in_block); + cutlass::fast_min(thread_start_j + elts_per_thread, + num_keys_in_block); zij_start_col_idx += 4) { const float4 rand_uniform_quad = curand_uniform4(&curand_state); @@ -1594,23 +1716,51 @@ struct AttentionBackwardKernel { for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { // we'll write Zij transposed since attention is also transposed // during the matmul to compute dV. - zij.at({zij_start_col_idx + quad_idx, thread_i}) = - static_cast( - dropout_scale * - ((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob)); + zij.at({zij_start_col_idx + quad_idx /*k*/, thread_i /*q*/}) = + (&rand_uniform_quad.x)[quad_idx] > p.dropout_prob + ? scalar_t(dropout_scale) + : scalar_t(0); } } } + __syncthreads(); +#if 0 + PRINT_TENSOR4x4_T0_L0("zij", zij); + PRINT_TENSOR4x4_T0_L0_START("zij", zij, kBlockSizeJ - 4, kBlockSizeI - 4); +#endif + + // Save mask for later DOIVJ matmul + + int warp_idx_mn_0 = warp_id % + (MatmulDOIVJ::Mma::Base::WarpCount::kM * + MatmulDOIVJ::Mma::Base::WarpCount::kN); + auto output_tile_coords_doivj = cutlass::MatrixCoord{ + warp_idx_mn_0 % MatmulDOIVJ::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MatmulDOIVJ::Mma::Base::WarpCount::kM}; + auto lane_offset = MatmulDOIVJ::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords_doivj); + MatmulDOIVJ::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m /*q*/, int accum_n /*k*/, int idx) { + if (zij.at({accum_n, accum_m}) == scalar_t(0)) { + dropout_keep_mask_doivj[idx] = cutlass::uint1b_t(0); + } + }, + [&](int accum_m) {}); } __syncthreads(); } + rematerializeThreadIds(); ///////////////////////////////////////////////////////////////////////////////////////////////// // GradV matmul // // grad_v[j_start:j_end] += attn_T @ do_i ///////////////////////////////////////////////////////////////////////////////////////////////// - for (int col = 0; col < (kOutputInRF ? 1 : p.head_dim_value); + constexpr bool kSingleIterationGradV = + kMaxK <= MatmulGradV::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value); col += MatmulGradV::ThreadblockShape::kN) { using Mma = typename MatmulGradV::Mma; using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; @@ -1634,14 +1784,15 @@ struct AttentionBackwardKernel { // if dropout: dVj += (Pij.T * Zij) @ dOi // otherwise: dVj += Pij.T @ dOi Mma mma( - shared_storage.mm_gradV(), - // operand A: Pij - typename MatmulGradV::WarpIteratorA( - shared_storage.attn_shared_storage().accum_ref(), lane_id), - // if we're using dropout, operand A is Pij_dropped = Pij * Zij - // which is computed on the fly as fragments of Pij are loaded in - typename Mma::WarpIteratorAScale( - shared_storage.zij().accum_ref(), lane_id), + // operand A: Pij.T + shared_storage.attn_shared_storage().accum_ref(), + // operand A_scale Zij.T: + // if we're using dropout, operand A is Pij_dropped.T = Pij.T * Zij.T + // which is computed on the fly as fragments of Pij.T are loaded in + shared_storage.zij().accum_ref(), + // operand B: dOi - which was loaded into shared memory previously + // when we computed dVj + shared_storage.mm_gradV().operand_B_ref(), thread_id, warp_id, lane_id); @@ -1669,7 +1820,7 @@ struct AttentionBackwardKernel { iterator_B, output_frags.gradV); __syncthreads(); - if (kPrologueGV && + if (kPrologueGV && !kSingleIterationGradV && col + MatmulGradV::ThreadblockShape::kN < p.head_dim_value) { prologueGradV(col + MatmulGradV::ThreadblockShape::kN); } @@ -1682,11 +1833,14 @@ struct AttentionBackwardKernel { shared_storage.gradV_epilogue(), output_frags.gradV, createEpilogueIter(), - isFirstQuery || kNeedsAccumGradV); + isFirstQuery || kNeedsAccumGradV, + warp_id, + lane_id); } } } __syncthreads(); + ///////////////////////////////////////////////////////////////////////////////////////////////// // MatmulDOIVJ ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -1739,38 +1893,32 @@ struct AttentionBackwardKernel { // attn_shared_storage [smem] <- tmp.T // tmp_shared_storage [smem] <- tmp { - using LambdaIterator = typename DefaultMmaAccumLambdaIterator< - typename Mma::Operator::IteratorC, - typename MatmulDOIVJ::ElementAccum, - kWarpSize>::Iterator; + using LambdaIterator = typename MatmulDOIVJ::AccumLambdaIterator; auto lane_offset = LambdaIterator::get_lane_offset( lane_id, warp_id, output_tile_coords); - // if dropout was used, compute dPij = dPij_dropped * Zij - // Zij was written to shared memory earlier, and the elementwise - // multiplication occurs on a fragment of dPij_dropped if (kApplyDropout) { - const auto zij = shared_storage.zij().accum_ref(); - LambdaIterator::iterateRows( lane_offset, [&](int accum_m) {}, [&](int accum_m, int accum_n, int idx) { - const int global_query_idx = query_start + accum_m; - const int global_key_idx = key_start + accum_n; - - if (skipBoundsChecks || - (global_query_idx < p.num_queries && - global_key_idx < p.num_keys)) { - accum[idx] *= zij.at({accum_n, accum_m}); + if (dropout_keep_mask_doivj[idx].get()) { + accum[idx] *= dropout_scale; + } else { + accum[idx] = 0; } }, [&](int accum_m) {}); } auto attn_T = shared_storage.attn_shared_storage().accum_ref(); +#if 0 + PRINT_B0_T0("doivj_dropped"); + print_warp_accum(accum, lane_offset, 4, 4); + PRINT_TENSOR4x4_T0_L0("attn_T", attn_T) +#endif accum_t current_di; - typename Mma::FragmentC fragment_attn, fragment_di; + // dSij = (dPij - Di) * Pij LambdaIterator::iterateRows( lane_offset, [&](int accum_m) { current_di = shared_storage.di()[accum_m]; }, @@ -1780,17 +1928,15 @@ struct AttentionBackwardKernel { if (skipBoundsChecks || (accum_m < num_queries_in_block && accum_n < num_keys_in_block)) { - fragment_attn[idx] = attn_T.at({accum_n, accum_m}); + accum_t attn = attn_T.at({accum_n, accum_m}); + accum[idx] = (accum[idx] - current_di) * attn; } else { - fragment_attn[idx] = 0; + accum[idx] = 0; } - fragment_di[idx] = current_di; }, [&](int accum_m) { }); - // dSij = (dPij - Di) * Pij - accum = (accum - fragment_di) * fragment_attn; // store bias gradient tile dBij to global memory, // where dBij = dSij = Pij * (dPij - Di) @@ -1818,6 +1964,11 @@ struct AttentionBackwardKernel { accum = accum * scale; +#if 0 + PRINT_B0_T0("(doivj - di) * attn * scale"); + print_warp_accum(accum, lane_offset, 4, 4); +#endif + __syncthreads(); if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) { auto tmpT = shared_storage.tmpT_shared_storage().accum_ref(); @@ -1839,12 +1990,22 @@ struct AttentionBackwardKernel { output_tile_coords); __syncthreads(); } + // Force `nvcc` to recompute values that depend on the variables just below + // to use less RF and prevent some spilling + p.head_dim = warp_uniform(p.head_dim); + p.k_strideM = warp_uniform(p.k_strideM); + rematerializeThreadIds(); + ///////////////////////////////////////////////////////////////////////////////////////////////// // GradQ matmul // // grad_q[i_start:i_end] += tmp @ k_j ///////////////////////////////////////////////////////////////////////////////////////////////// - for (int col = 0; col < p.head_dim; + // Skip the loop & associated branches if we know at compile time the number + // of iterations + constexpr bool kSingleIterationGradQ = + kMaxK <= MatmulGradQ::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradQ ? 1 : p.head_dim); col += MatmulGradQ::ThreadblockShape::kN) { using Mma = typename MatmulGradQ::Mma; using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; @@ -1864,24 +2025,36 @@ struct AttentionBackwardKernel { auto a = shared_storage.tmp_shared_storage().accum_ref(); Mma mma( - shared_storage.mm_gradQ(), - shared_storage.tmp_shared_storage(), + // operand A: dSij + shared_storage.tmp_shared_storage().accum_ref(), + // operand B: Kj + shared_storage.mm_gradQ().operand_B_ref(), thread_id, warp_id, - lane_id, - problem_size.k()); + lane_id); typename Mma::FragmentC accum; - bool isFirst = key_start == 0; int col_id = col / MatmulGradQ::ThreadblockShape::kN; - int storage_id = - (col_id + - query_start / kBlockSizeI * - ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN)); - AccumTileGmem gmem_tile{ - p.workspace_gq + storage_id * AccumTileGmem::kElementsStored}; - if (isFirst || !kNeedsAccumGradQ) { + int num_cols = kSingleIterationGradQ + ? 1 + : ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN); + int storage_id = (col_id + query_start / kBlockSizeI * num_cols); + + if (p.num_splits_key_device() > 1) { + AtomicLock::acquire( + &p.workspace_gq[storage_id].lock, + p.split_key_device() + 1, + thread_id); + // Make sure we can see other block's output + __threadfence(); + } + + AccumTileGmem gmem_tile{&p.workspace_gq[storage_id].buffer[0]}; + if (!kNeedsAccumGradQ || + (p.num_splits_key_device() == 1 && key_start == 0)) { + // if we know we are the first to access it, we know it's only zeros. + // Avoids a load from gmem (and gmem init as well) accum.clear(); } else { gmem_tile.load(accum, thread_id); @@ -1895,29 +2068,59 @@ struct AttentionBackwardKernel { mma.set_prologue_done(kPrologueGQ); mma(gemm_k_iterations, accum, iterator_B, accum); __syncthreads(); - bool isLastColumn = col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim; + bool isLastColumn = kSingleIterationGradQ || + (col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim); if (kPrologueGQ && !isLastColumn) { prologueGradQ(col + MatmulGradQ::ThreadblockShape::kN); } + bool isLast = [&]() { + int32_t next_key = key_start + p.num_splits_key_device() * kBlockSizeJ; + if (p.num_keys <= next_key) { + return true; + } + if (query_start < getSmallestQueryForKey(p, next_key)) { + return true; + } + return false; + }(); // Output results - int32_t next_query, next_key; - incrIteration(p, p.num_queries, key_start, next_query, next_key); - bool isLast = next_query > query_start || next_key >= p.num_keys; + if (p.num_splits_key_device() > 1) { + int32_t numAddsSoFar = -1; + if (isLast && thread_id == 0) { + numAddsSoFar = atomicAdd(&p.workspace_gq[storage_id].counter, 1) + + 1; // `atomicAdd` returns the old value + } + isLast = __syncthreads_or( + numAddsSoFar == getNumParallelBlocksForQuery(p, query_start)); + assert(numAddsSoFar <= getNumParallelBlocksForQuery(p, query_start)); + } if (kNeedsAccumGradQ && !isLast) { gmem_tile.store(accum, thread_id); + if (p.num_splits_key_device() > 1) { + // Make sure everyone wrote before we release the lock + __threadfence(); + __syncthreads(); + AtomicLock::release(&p.workspace_gq[storage_id].lock, thread_id); + } } else { + // NOTE: We're not releasing the lock because no one is expected + // to come after us (we're the last one to write) typename MatmulGradQ::OutputTileIterator output_it( typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, p.grad_query_ptr + query_start * p.gQ_strideM() + col, {problem_size.m(), problem_size.n()}, thread_id); + bool storage_contains_zeros = kNeedsAccumGradQ || key_start == 0 || + (p.num_splits_key_device() > 1); accumulateInGmem( isLastColumn ? shared_storage.gradQ_epilogue_lastIter() : shared_storage.gradQ_epilogue(), accum, output_it, - isFirst || kNeedsAccumGradQ); + storage_contains_zeros, + warp_id, + lane_id); } } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -1925,7 +2128,11 @@ struct AttentionBackwardKernel { // // grad_k[i_start:i_end] += tmp.transpose(-2, -1) @ q_i ///////////////////////////////////////////////////////////////////////////////////////////////// - for (int col = 0; col < (kOutputInRF ? 1 : p.head_dim); + rematerializeThreadIds(); + + constexpr bool kSingleIterationGradK = + kMaxK <= MatmulGradK::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradK ? 1 : p.head_dim); col += MatmulGradK::ThreadblockShape::kN) { using Mma = typename MatmulGradK::Mma; using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; @@ -1962,16 +2169,17 @@ struct AttentionBackwardKernel { decltype(getTmp), decltype(getTmpT)>::apply(getTmp, getTmpT, 0); Mma mma( - shared_storage.mm_gradK(), - opA, + // operand A: dSij.T + opA.accum_ref(), + // operand B: Qi + shared_storage.mm_gradK().operand_B_ref(), thread_id, warp_id, - lane_id, - problem_size.k()); + lane_id); int storage_id = col / MatmulGradK::ThreadblockShape::kN; AccumTileGmem gmem_tile{ - p.workspace_gk + storage_id * AccumTileGmem::kElementsStored}; + p.workspace + storage_id * AccumTileGmem::kElementsStored}; if (!kOutputInRF) { if (isFirstQuery || !kNeedsAccumGradK) { output_frags.gradK.clear(); @@ -1992,7 +2200,8 @@ struct AttentionBackwardKernel { iterator_B, output_frags.gradK); __syncthreads(); - bool isLastColumn = col + MatmulGradK::ThreadblockShape::kN >= p.head_dim; + bool isLastColumn = kSingleIterationGradK || + col + MatmulGradK::ThreadblockShape::kN >= p.head_dim; if (kPrologueGK && !isLastColumn) { prologueGradK(col + MatmulGradK::ThreadblockShape::kN); } @@ -2000,10 +2209,11 @@ struct AttentionBackwardKernel { if (kPrologueQK && isLastColumn) { int32_t next_query, next_key; incrIteration(p, query_start, key_start, next_query, next_key); - DISPATCH_BOOL(next_key != key_start, kForceReloadK, ([&]() { - prologueQkNextIteration( - shared_storage, p, next_query, next_key); - })); + DISPATCH_BOOL( + next_key != key_start, kForceReloadK, ([&]() { + prologueQkNextIteration( + shared_storage, p, next_query, next_key, warp_id, lane_id); + })); } // Output results @@ -2016,24 +2226,63 @@ struct AttentionBackwardKernel { : shared_storage.gradK_epilogue(), output_frags.gradK, createEpilogueIter(), - isFirstQuery || kNeedsAccumGradK); + isFirstQuery || kNeedsAccumGradK, + warp_id, + lane_id); + __syncthreads(); } } } } + static CUTLASS_DEVICE int32_t getQueryStartShift(Params const& p) { + if (p.custom_mask_type == NoCustomMask && p.num_splits_key_device() > 1) { + return (p.split_key_device() * kBlockSizeI) % getQueryEnd(p); + } + return 0; + } + + // Iteration order logic static CUTLASS_DEVICE int32_t getQueryStart(Params const& p, int32_t key_start) { + return getSmallestQueryForKey(p, key_start) + getQueryStartShift(p); + }; + static CUTLASS_DEVICE int32_t getQueryEnd(Params const& p) { + return align_up(p.num_queries, kBlockSizeI); + }; + + static CUTLASS_DEVICE int32_t + getSmallestQueryForKey(Params const& p, int32_t key_start) { if (p.custom_mask_type == CausalFromTopLeft) { return (key_start / kBlockSizeI) * kBlockSizeI; } else if (p.custom_mask_type == CausalFromBottomRight) { int first_query = - cutlass::fast_max(0, key_start - p.num_keys - p.num_queries); + cutlass::fast_max(0, key_start - p.num_keys + p.num_queries); return (first_query / kBlockSizeI) * kBlockSizeI; } return 0; }; + // Returns how many kernel blocks will write to a given block in `grad_query` + // This is usually equal to the number of key splits, but can be different + // for instance in the causal case, or varying seqlen + static CUTLASS_DEVICE int32_t + getNumParallelBlocksForQuery(Params const& p, int32_t query_start) { + int16_t num_key_blocks = ceil_div(p.num_keys, kBlockSizeJ); + if (p.custom_mask_type == CausalFromTopLeft) { + int32_t last_key_for_block = query_start + kBlockSizeI - 1; + last_key_for_block = cutlass::fast_min(last_key_for_block, p.num_keys); + num_key_blocks = ceil_div(last_key_for_block, kBlockSizeJ); + } else if (p.custom_mask_type == CausalFromBottomRight) { + int32_t last_key_for_block = + query_start + (kBlockSizeI - 1) + (1 + p.num_keys - p.num_queries); + last_key_for_block = cutlass::fast_min(last_key_for_block, p.num_keys); + num_key_blocks = ceil_div(last_key_for_block, kBlockSizeJ); + } + return cutlass::fast_min(p.num_splits_key_device(), num_key_blocks); + }; + + // Returns the next block to process static CUTLASS_DEVICE void incrIteration( Params const& p, int32_t query_start, @@ -2042,10 +2291,26 @@ struct AttentionBackwardKernel { int32_t& next_key) { next_query = query_start + kBlockSizeI; next_key = key_start; - if (next_query >= p.num_queries) { - next_key = key_start + kBlockSizeJ; - next_query = getQueryStart(p, next_key); + auto query_shift = getQueryStartShift(p); + // Wrap around + if (query_shift) { + if (next_query >= p.num_queries) { + next_query = getSmallestQueryForKey(p, key_start); + return; + } else if (query_start < query_shift && query_shift <= next_query) { + // jump to next key + } else { + return; + } + } else { + if (next_query < p.num_queries) { + return; + } + // jump to next key } + // Next key + next_key = key_start + p.num_splits_key_device() * kBlockSizeJ; + next_query = getQueryStart(p, next_key); } template @@ -2053,14 +2318,16 @@ struct AttentionBackwardKernel { SharedStorage& shared_storage, Params const& p, int32_t query_start, - int32_t key_start) { + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { if (query_start >= p.num_queries || key_start >= p.num_keys) { return; } static constexpr bool kReloadK = kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat; - auto thread_id = get_thread_id(); + int thread_id = 32 * warp_id + lane_id; typename MatmulQK::Mma::IteratorA iterator_A( {int32_t(p.k_strideM)}, p.key_ptr + key_start * p.k_strideM, @@ -2089,7 +2356,10 @@ struct AttentionBackwardKernel { SharedStorage& shared_storage, OutputFragments& output_frags, Params const& p, - int32_t key_start) { + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { + uint16_t thread_id = 32 * warp_id + lane_id; int32_t num_keys_in_block = skipBoundsChecks ? MatmulQK::Mma::Shape::kM : cutlass::fast_min( @@ -2098,24 +2368,28 @@ struct AttentionBackwardKernel { typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, p.grad_value_ptr + key_start * p.gV_strideM(), {num_keys_in_block, p.head_dim_value}, - get_thread_id()); + thread_id); accumulateInGmem( shared_storage.gradV_epilogue_final(), output_frags.gradV, outputV_it, - true); + true, + warp_id, + lane_id); typename MatmulGradK::OutputTileIterator outputK_it( typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, p.grad_key_ptr + key_start * p.gK_strideM(), {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim}, - get_thread_id()); + thread_id); accumulateInGmem( shared_storage.gradK_epilogue_final(), output_frags.gradK, outputK_it, - true); + true, + warp_id, + lane_id); } template @@ -2123,10 +2397,13 @@ struct AttentionBackwardKernel { typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem, typename MatmulT::Mma::FragmentC const& accum, typename MatmulT::OutputTileIterator output_it, - bool first) { + bool first, + uint8_t warp_id, + uint8_t lane_id) { using DefaultEpilogue = typename MatmulT::DefaultEpilogue; using DefaultOutputOp = typename MatmulT::DefaultOutputOp; using Mma = typename MatmulT::Mma; + int thread_id = 32 * warp_id + lane_id; DISPATCH_BOOL( first, kIsFirst, ([&]() { static constexpr auto ScaleType = kIsFirst @@ -2154,8 +2431,7 @@ struct AttentionBackwardKernel { true // IterationsUnroll >; EpilogueOutputOp rescale({1, 1}); - Epilogue epilogue( - epilogue_smem, get_thread_id(), get_warp_id(), get_lane_id()); + Epilogue epilogue(epilogue_smem, thread_id, warp_id, lane_id); epilogue(rescale, output_it, accum, output_it); })); } @@ -2163,17 +2439,18 @@ struct AttentionBackwardKernel { template static CUTLASS_DEVICE void computeDelta( Params const& p, - int32_t query_start) { + int32_t query_start, + uint8_t warp_id, + uint8_t lane_id) { // Each thread computes one value for Delta // Depending on warp configuration, we might have multiple // threads of the same warp working on the same row using AccessType = cutlass::Array; static_assert(kNumThreads >= kBlockSizeI, ""); static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI; - int16_t thread_id = get_thread_id(); + int16_t thread_id = 32 * warp_id + lane_id; - int16_t laneFirstCol = - kElementsPerAccess * (get_lane_id() % kNumThreadsPerLine); + int16_t laneFirstCol = kElementsPerAccess * (lane_id % kNumThreadsPerLine); int16_t laneRow = thread_id / kNumThreadsPerLine; bool rowPred = (query_start + laneRow) < p.num_queries; bool pred = rowPred; @@ -2260,16 +2537,6 @@ struct AttentionBackwardKernel { p.delta_ptr[query_start + laneRow] = delta_value; } } - - static CUTLASS_DEVICE int8_t get_lane_id() { - return threadIdx.x; - } - static CUTLASS_DEVICE int8_t get_warp_id() { - return threadIdx.y; - } - static CUTLASS_DEVICE int16_t get_thread_id() { - return threadIdx.x + threadIdx.y * blockDim.x; - } }; template diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h index f0306241..28e6cd0f 100644 --- a/examples/41_fused_multi_head_attention/kernel_forward.h +++ b/examples/41_fused_multi_head_attention/kernel_forward.h @@ -66,6 +66,7 @@ #include "debug_utils.h" #include "epilogue/epilogue_pipelined.h" #include "epilogue/epilogue_rescale_output.h" +#include "gemm/custom_mma.h" #include "gemm/find_default_mma.h" #include "gemm/mma_from_smem.h" #include "gemm_kernel_utils.h" @@ -77,7 +78,7 @@ using namespace gemm_kernel_utils; namespace { template -constexpr int getWarpsPerSm() { +constexpr int getWarpsPerSmFw() { return ( Arch::kMinComputeCapability >= 80 && !cutlass::platform::is_same::value @@ -92,6 +93,24 @@ static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { } } // namespace +// If ToBatchHookType_ is supplied other than this default (which is +// never the case in the xformers library) then the user is +// defining the logic which each block uses to find its data to work on, +// with the advance_to_batch function with the following signature. +// It should return false if there is no work to do for this block. +// In general this will not work with saving for backward due to fixed layout +// for logsumexp and incompatible rngs for dropout, so is likely only useful for +// custom inference. +struct DefaultToBatchHook { + template + CUTLASS_DEVICE static bool advance_to_batch( + Params&, + int64_t& /* q_start */, + int64_t& /* k_start */) { + return true; + } +}; + template < // The datatype of Q/K/V typename scalar_t_, @@ -99,13 +118,15 @@ template < typename ArchTag, // If Q/K/V are correctly aligned in memory and we can run a fast kernel bool isAligned_, - int kQueriesPerBlock, + int kQueriesPerBlock_, int kKeysPerBlock_, - bool kSingleValueIteration_, // = `value.shape[-1] <= kKeysPerBlock` + // upperbound on `max(value.shape[-1], query.shape[-1])` + int kMaxK_ = (int)cutlass::platform::numeric_limits::max(), // This is quite slower on V100 for some reason // Set to false if you know at compile-time you will never need dropout bool kSupportsDropout_ = true, - bool kSupportsBias_ = true> + bool kSupportsBias_ = true, + typename ToBatchHookType_ = DefaultToBatchHook> struct AttentionKernel { enum CustomMaskType { NoCustomMask = 0, @@ -125,11 +146,14 @@ struct AttentionKernel { static constexpr bool kSupportsDropout = kSupportsDropout_; static constexpr bool kSupportsBias = kSupportsBias_; static constexpr int kKeysPerBlock = kKeysPerBlock_; + static constexpr int kQueriesPerBlock = kQueriesPerBlock_; + static constexpr int kMaxK = kMaxK_; static constexpr bool kIsAligned = isAligned_; - static constexpr bool kSingleValueIteration = kSingleValueIteration_; + static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock; static constexpr int32_t kAlignLSE = 32; // block size of backward - static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 && - cutlass::sizeof_bits::value == 16; + static constexpr bool kIsHalf = cutlass::sizeof_bits::value == 16; + static constexpr bool kPreloadV = + ArchTag::kMinComputeCapability >= 80 && kIsHalf; static constexpr bool kKeepOutputInRF = kSingleValueIteration; static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && !cutlass::platform::is_same::value; @@ -143,66 +167,67 @@ struct AttentionKernel { // Launch bounds static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock; static constexpr int kMinBlocksPerSm = - getWarpsPerSm() / kNumWarpsPerBlock; + getWarpsPerSmFw() / kNumWarpsPerBlock; struct Params { // Input tensors - scalar_t* query_ptr; // [num_queries, num_heads, head_dim] - scalar_t* key_ptr; // [num_keys, num_heads, head_dim] - scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value] + scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim] + scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim] + scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value] scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] int32_t* seqstart_q_ptr = nullptr; int32_t* seqstart_k_ptr = nullptr; - int32_t* causal_diagonal_ptr = nullptr; int32_t* seqlen_k_ptr = nullptr; uint32_t causal_diagonal_offset = 0; // Output tensors - output_t* output_ptr; // [num_queries, num_heads, head_dim_value] - output_accum_t* - output_accum_ptr; // [num_queries, num_heads, head_dim_value] - lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null + output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value] + // [num_queries, num_heads, head_dim_value] + output_accum_t* output_accum_ptr = nullptr; + // [num_heads, num_queries] - can be null + lse_scalar_t* logsumexp_ptr = nullptr; // Scale - accum_t scale; + accum_t scale = 0.0; // Dimensions/strides - int32_t head_dim; - int32_t head_dim_value; - int32_t num_queries; - int32_t num_keys; + int32_t head_dim = 0; + int32_t head_dim_value = 0; + int32_t num_queries = 0; + int32_t num_keys = 0; + int32_t num_keys_absolute = 0; uint8_t custom_mask_type = NoCustomMask; - int32_t q_strideM; - int32_t k_strideM; - int32_t v_strideM; + int32_t q_strideM = 0; + int32_t k_strideM = 0; + int32_t v_strideM = 0; int32_t bias_strideM = 0; int32_t o_strideM = 0; // Everything below is only used in `advance_to_block` // and shouldn't use registers - int32_t q_strideH; - int32_t k_strideH; - int32_t v_strideH; - int32_t bias_strideH = 0; + int32_t q_strideH = 0; + int32_t k_strideH = 0; + int32_t v_strideH = 0; + int64_t bias_strideH = 0; - int64_t q_strideB; - int64_t k_strideB; - int64_t v_strideB; - int32_t bias_strideB = 0; + int64_t q_strideB = 0; + int64_t k_strideB = 0; + int64_t v_strideB = 0; + int64_t bias_strideB = 0; - int32_t num_batches; - int32_t num_heads; + int32_t num_batches = 0; + int32_t num_heads = 0; // dropout - bool use_dropout; - unsigned long long dropout_batch_head_rng_offset; - float dropout_prob; + bool use_dropout = false; + unsigned long long dropout_batch_head_rng_offset = 0; + float dropout_prob = 0.0f; #ifdef HAS_PYTORCH - at::PhiloxCudaState rng_engine_inputs; + at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0); #endif // Moves pointers to what we should process @@ -220,9 +245,17 @@ struct AttentionKernel { head_id * num_queries * num_keys; } - int64_t q_start, k_start; + int64_t q_start = 0, k_start = 0; // Advance to current batch - in case of different sequence lengths - if (seqstart_q_ptr != nullptr) { + constexpr bool kToBatchHook = + !cutlass::platform::is_same:: + value; + if (kToBatchHook) { + // Call out to a custom implementation. + if (!ToBatchHookType_::advance_to_batch(*this, q_start, k_start)) { + return false; + } + } else if (seqstart_q_ptr != nullptr) { assert(seqstart_k_ptr != nullptr); seqstart_q_ptr += batch_id; @@ -285,12 +318,12 @@ struct AttentionKernel { } // Custom masking - if (causal_diagonal_ptr) { - causal_diagonal_offset = causal_diagonal_ptr[batch_id]; - } if (custom_mask_type == CausalFromBottomRight) { - causal_diagonal_offset += num_keys - num_queries; + causal_diagonal_offset = num_keys - num_queries; } + // We use num_keys_absolute to index into the rng_state + // We need this index to match between forward and backwards + num_keys_absolute = num_keys; if (custom_mask_type == CausalFromTopLeft || custom_mask_type == CausalFromBottomRight) { // the bottom row of the current block is query_start + kQueriesPerBlock @@ -323,6 +356,7 @@ struct AttentionKernel { // Make sure the compiler knows these variables are the same on all // the threads of the warp. + // Only worth doing if they could have been modified above. query_ptr = warp_uniform(query_ptr); key_ptr = warp_uniform(key_ptr); value_ptr = warp_uniform(value_ptr); @@ -335,8 +369,6 @@ struct AttentionKernel { num_queries = warp_uniform(num_queries); num_keys = warp_uniform(num_keys); num_heads = warp_uniform(num_heads); - head_dim = warp_uniform(head_dim); - head_dim_value = warp_uniform(head_dim_value); o_strideM = warp_uniform(o_strideM); custom_mask_type = warp_uniform(custom_mask_type); return true; @@ -395,14 +427,19 @@ struct AttentionKernel { ThreadblockShape, // ThreadblockShape WarpShape, // WarpShape typename GemmType::InstructionShape, // InstructionShape - DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that - // uses too much smem + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, typename GemmType::Operator // Operator >::DefaultMma; using MmaCore = typename DefaultMma::MmaCore; using IteratorA = typename DefaultMma::IteratorA; using IteratorB = typename DefaultMma::IteratorB; - using Mma = typename DefaultMma::ThreadblockMma; + using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma; + using Mma = typename cutlass::platform::conditional< + kSingleValueIteration, + typename MakeCustomMma::Mma, + DefaultThreadblockMma>::type; using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< typename Mma::Operator::IteratorC, accum_t, @@ -475,14 +512,23 @@ struct AttentionKernel { typename GemmType::InstructionShape, typename DefaultConfig::EpilogueOutputOp, void, // ThreadblockSwizzle - not used - DefaultConfig::kStages, + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, false, // SplitKSerial typename GemmType::Operator>; + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape + typename DefaultGemm::Mma::Policy::Operator::InstructionShape, + typename DefaultGemm::Mma::Policy::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, - typename MM0::AccumulatorSharedStorage, + MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK + WarpIteratorA, false>; // kScaleOperandA using Mma = typename DefaultMmaFromSmem::Mma; using IteratorB = typename Mma::IteratorB; @@ -500,10 +546,6 @@ struct AttentionKernel { typename cutlass::epilogue::threadblock::PredicatedTileIterator< typename DefaultEpilogue::OutputTileIterator::ThreadMap, output_accum_t>; - - struct SharedStorageMM1 { - typename Mma::SharedStorage mm; - }; }; static constexpr int64_t kAlignmentQ = MM0::kAlignmentA; @@ -515,6 +557,9 @@ struct AttentionKernel { cutlass::Array m_prime; cutlass::Array s_prime; cutlass::Array mi; + cutlass::Array out_rescale; + cutlass::Array + addition_storage; }; struct SharedStorageEpilogueAtEnd : ScalingCoefs { @@ -524,7 +569,7 @@ struct AttentionKernel { typename MM0::BiasLoader::SmemTile bias; typename MM0::AccumulatorSharedStorage si; }; - typename MM1::SharedStorageMM1 mm1; + typename MM1::Mma::SharedStorage mm1; }; union { @@ -546,7 +591,7 @@ struct AttentionKernel { typename MM0::BiasLoader::SmemTile bias; typename MM0::AccumulatorSharedStorage si; }; - typename MM1::SharedStorageMM1 mm1; + typename MM1::Mma::SharedStorage mm1; typename MM1::DefaultEpilogue::SharedStorage epilogue; }; @@ -600,9 +645,6 @@ struct AttentionKernel { XFORMERS_CHECK( p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0, "value is not correctly aligned (strideH)"); - XFORMERS_CHECK( - p.causal_diagonal_ptr == nullptr || p.custom_mask_type != NoCustomMask, - "`causal_diagonal_ptr` is only useful when `custom_mask_type` is causal"); XFORMERS_CHECK( p.custom_mask_type < NumCustomMaskTypes, "invalid value for `custom_mask_type`"); @@ -619,11 +661,13 @@ struct AttentionKernel { auto& m_prime = shared_storage.m_prime; auto& s_prime = shared_storage.s_prime; auto& mi = shared_storage.mi; + auto& out_rescale = shared_storage.out_rescale; const uint32_t query_start = blockIdx.x * kQueriesPerBlock; static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); if (thread_id() < kQueriesPerBlock) { s_prime[thread_id()] = accum_t(0); + out_rescale[thread_id()] = accum_t(1.0); m_prime[thread_id()] = -cutlass::platform::numeric_limits::infinity(); mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); @@ -695,7 +739,7 @@ struct AttentionKernel { thread_id(), cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); MM1::Mma::prologue( - shared_storage.after_mm0.mm1.mm, + shared_storage.after_mm0.mm1, iterator_V, thread_id(), problem_size_1_k); @@ -739,7 +783,7 @@ struct AttentionKernel { thread_id(), tb_offset_B); - auto my_warp_id = warp_id(); + auto my_warp_id = warp_uniform(warp_id()); auto my_lane_id = lane_id(); // Construct thread-scoped matrix multiply @@ -759,6 +803,8 @@ struct AttentionKernel { if (kPreloadV) { prologueV(0); + } else { + MM1::Mma::drain_cp_asyncs(); } typename MM0::Mma::Operator::IteratorC::TensorCoord @@ -793,7 +839,7 @@ struct AttentionKernel { // Pij += Bij, Pij is in register fragment and Bij is in shared memory auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( - lane_id(), warp_id(), iteratorC_tile_offset); + my_lane_id, my_warp_id, iteratorC_tile_offset); MM0::AccumLambdaIterator::iterateRows( lane_offset, [&](int accum_m) {}, @@ -817,7 +863,7 @@ struct AttentionKernel { (query_start + p.causal_diagonal_offset)) { auto query_start = blockIdx.x * kQueriesPerBlock; auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( - lane_id(), warp_id(), iteratorC_tile_offset); + my_lane_id, my_warp_id, iteratorC_tile_offset); int32_t last_col; MM0::AccumLambdaIterator::iterateRows( lane_offset, @@ -836,30 +882,23 @@ struct AttentionKernel { }, [&](int accum_m) {}); } - DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { - DISPATCH_BOOL( - p.num_keys - iter_key_start >= kKeysPerBlock, - kFullColumns, - ([&] { - // Update `mi` from accum stored in registers - // Also does accum[i] <- exp(accum[i] - mi) - iterative_softmax< - typename MM0::Mma::Operator::IteratorC, - kFullColumns, - kIsFirst>( - accum_o, - accum, - mi, - m_prime, - s_prime, - lane_id(), - thread_id(), - warp_id(), - p.num_keys - iter_key_start, - iteratorC_tile_offset, - kSupportsBias ? 1.0f : p.scale); - })); - })); + // Update `mi` from accum stored in registers + // Also does accum[i] <- exp(accum[i] - mi) + iterative_softmax( + accum_o, + accum, + mi, + m_prime, + s_prime, + out_rescale, + shared_storage.addition_storage, + my_lane_id, + thread_id(), + my_warp_id, + p.num_keys - iter_key_start, + iter_key_start == 0, + iteratorC_tile_offset, + kSupportsBias ? 1.0f : p.scale); // Output results to shared-memory int warp_idx_mn_0 = my_warp_id % @@ -910,7 +949,7 @@ struct AttentionKernel { curandStatePhilox4_32_10_t curand_state = curand_state_init; skipahead( static_cast( - (query_start + thread_i) * p.num_keys + + (query_start + thread_i) * p.num_keys_absolute + (iter_key_start + thread_start_j)), &curand_state); const float dropout_scale = 1.0 / (1.0 - p.dropout_prob); @@ -964,12 +1003,14 @@ struct AttentionKernel { thread_id(), cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); typename MM1::Mma mma_pv( - shared_storage.after_mm0.mm1.mm, - shared_storage.after_mm0.si, + // operand A: Pij_dropped in shared memory + shared_storage.after_mm0.si.accum_ref(), + // operand B: shared memory staging area for Vj, which is loaded + // from global memory + shared_storage.after_mm0.mm1.operand_B_ref(), (int)thread_id(), - (int)warp_id(), - (int)lane_id(), - (int)problem_size_1_k); + (int)my_warp_id, + (int)my_lane_id); mma_pv.set_prologue_done(kPreloadV); if (!kKeepOutputInRF) { accum_o.clear(); @@ -982,6 +1023,7 @@ struct AttentionKernel { } if (!kKeepOutputInRF) { + MM1::Mma::drain_cp_asyncs(); DISPATCH_BOOL( iter_key_start == 0, kIsFirst, ([&] { DISPATCH_BOOL( @@ -1033,12 +1075,12 @@ struct AttentionKernel { decltype(createOutputIter), decltype(createOutputAccumIter)>:: apply(createOutputIter, createOutputAccumIter, col); - EpilogueOutputOp rescale(s_prime, m_prime); + EpilogueOutputOp rescale(s_prime, out_rescale); Epilogue epilogue( shared_storage.epilogue_shared_storage(), thread_id(), - warp_id(), - lane_id()); + my_warp_id, + my_lane_id); epilogue(rescale, dest_iter, accum_o, source_iter); })); })); @@ -1082,12 +1124,13 @@ struct AttentionKernel { typename MM1::OutputTileIteratorAccum // source tile >; auto dest_iter = createOutputIter(0); - EpilogueOutputOp rescale(s_prime, m_prime); + EpilogueOutputOp rescale(s_prime, out_rescale); Epilogue epilogue( shared_storage.epilogue_shared_storage(), thread_id(), warp_id(), lane_id()); + MM1::Mma::drain_cp_asyncs(); epilogue(rescale, dest_iter, accum_o); } @@ -1097,8 +1140,9 @@ struct AttentionKernel { static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E if (thread_id() < p.num_queries) { - p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) + + p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) + cutlass::fast_log(accum_t(s_prime[thread_id()])); } else if (thread_id() < lse_dim) { p.logsumexp_ptr[thread_id()] = @@ -1107,20 +1151,21 @@ struct AttentionKernel { } } - template < - typename WarpIteratorC, - bool kFullColumns, - bool kIsFirst> + template CUTLASS_DEVICE static void iterative_softmax( typename WarpIteratorC::Fragment& frag_o, // output so far typename WarpIteratorC::Fragment& frag, cutlass::Array& mi, cutlass::Array& m_prime, cutlass::Array& s_prime, + cutlass::Array& out_rescale, + cutlass::Array& + addition_storage, int8_t lane_id, int8_t thread_id, int8_t warp_id, - int16_t max_col, + int max_col, + bool is_first, typename WarpIteratorC::TensorCoord const& tile_offset, float scaling) { /* Iterates on the accumulator and corresponding position on result matrix @@ -1141,12 +1186,11 @@ struct AttentionKernel { kWarpSize>::Iterator; // Convert to `accum_t` (rather than double) constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E - if (!kIsFirst) { - if (thread_id < kQueriesPerBlock) { - m_prime[thread_id] = mi[thread_id]; - } - __syncthreads(); - } + + static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, ""); + static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock; + + frag = cutlass::multiplies()(scaling * kLog2e, frag); auto lane_offset = LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); @@ -1160,46 +1204,64 @@ struct AttentionKernel { max = -cutlass::platform::numeric_limits::infinity(); }, [&](int accum_m, int accum_n, int idx) { - if (kFullColumns || accum_n < max_col) { + if (accum_n < max_col) { max = cutlass::fast_max(max, frag[idx]); } }, [&](int accum_m) { // Having 4x atomicMax seems faster than reduce within warp // first... - atomicMaxFloat(&mi[accum_m], max * scaling); + atomicMaxFloat(&mi[accum_m], max); }); } - frag = cutlass::multiplies()(scaling * kLog2e, frag); // Make sure we all share the update values for `mi` __syncthreads(); - if (thread_id < kQueriesPerBlock) { - auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); - m_prime[thread_id] = m_prime_exp; - s_prime[thread_id] *= m_prime_exp; + // Doing this `exp` is quite expensive. Let's + // split it across the warps + bool restore_mi_to_minus_inf = false; + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + auto m_prime_id = m_prime[id]; + auto mi_id = mi[id]; + bool changed = m_prime_id < mi_id; // `false` if both are -inf + if (changed) { + auto m_prime_exp = exp2f(m_prime_id - mi_id); + out_rescale[id] = m_prime_exp; + s_prime[id] *= m_prime_exp; + } else { + // Only when bias is enabled, it's possible that all the first values + // of attention are masked to `-inf`. In that case we want to avoid + // `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0 + if (kSupportsBias && + mi_id == -cutlass::platform::numeric_limits::infinity()) { + restore_mi_to_minus_inf = true; + mi[id] = 0.0f; + } + out_rescale[id] = 1.0f; + } } __syncthreads(); // Update output fragments - if (kKeepOutputInRF && !kIsFirst) { - accum_t mp; + if (kKeepOutputInRF && !is_first) { + accum_t line_rescale; LambdaIterator::iterateRows( lane_offset, - [&](int accum_m) { mp = m_prime[accum_m]; }, - [&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, + [&](int accum_m) { line_rescale = out_rescale[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag_o[idx] = frag_o[idx] * line_rescale; + }, [&](int accum_m) {}); - __syncthreads(); } // Update accum_m, accum_n, ... { accum_t mi_row, total_row; LambdaIterator::iterateRows( lane_offset, - [&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, + [&](int accum_m) { mi_row = mi[accum_m]; }, [&](int accum_m, int accum_n, int idx) { - frag[idx] = (kFullColumns || accum_n < max_col) - ? exp2f(frag[idx] - mi_row) - : accum_t(0.0); + frag[idx] = + (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0); }, [&](int accum_m) {}); LambdaIterator::iterateRows( @@ -1211,10 +1273,30 @@ struct AttentionKernel { lane_id, total_row, [](accum_t a, accum_t b) { return a + b; })) { - atomicAdd(&s_prime[accum_m], total_row); + // NOTE: we could atomically add `total_row` to `s_prime`, but + // it's faster (and deterministic) to avoid atomics here + addition_storage + [accum_m + kQueriesPerBlock * tile_offset.column()] = + total_row; } }); } + __syncthreads(); + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + accum_t total_row = s_prime[id]; + if (restore_mi_to_minus_inf) { + // Restore `mi`, see above when we set `restore_mi_to_minus_inf=true` + mi[id] = -cutlass::platform::numeric_limits::infinity(); + } else { + m_prime[id] = mi[id]; + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { + total_row += addition_storage[id + kQueriesPerBlock * i]; + } + s_prime[id] = total_row; + } } static CUTLASS_DEVICE int8_t lane_id() { diff --git a/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h b/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h index 345bc5bb..6c2d1764 100644 --- a/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h +++ b/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h @@ -29,6 +29,8 @@ * **************************************************************************************************/ +#pragma once + #include #include "cutlass/aligned_buffer.h" #include "cutlass/array.h"