 146d314057
			
		
	
	
		146d314057
		
			
		
	
	
	
	
		
			
			* Update fMHA kernels Upstream recent changes to fMHA that we did in xFormers. Previous version in CUTLASS: facebookresearch/xformers@b6be33a Updating to: facebookresearch/xformers@55a4798 * minor changes * make var work --------- Co-authored-by: danthe3rd <danthe3rd> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
		
			
				
	
	
		
			1043 lines
		
	
	
		
			37 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			1043 lines
		
	
	
		
			37 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /***************************************************************************************************
 | |
|  * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 | |
|  * SPDX-License-Identifier: BSD-3-Clause
 | |
|  *
 | |
|  * Redistribution and use in source and binary forms, with or without
 | |
|  * modification, are permitted provided that the following conditions are met:
 | |
|  *
 | |
|  * 1. Redistributions of source code must retain the above copyright notice, this
 | |
|  * list of conditions and the following disclaimer.
 | |
|  *
 | |
|  * 2. Redistributions in binary form must reproduce the above copyright notice,
 | |
|  * this list of conditions and the following disclaimer in the documentation
 | |
|  * and/or other materials provided with the distribution.
 | |
|  *
 | |
|  * 3. Neither the name of the copyright holder nor the names of its
 | |
|  * contributors may be used to endorse or promote products derived from
 | |
|  * this software without specific prior written permission.
 | |
|  *
 | |
|  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 | |
|  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 | |
|  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 | |
|  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 | |
|  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 | |
|  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 | |
|  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 | |
|  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 | |
|  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | |
|  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
|  *
 | |
|  **************************************************************************************************/
 | |
| 
 | |
| /*! \file
 | |
|     \brief Grouped FMHA kernel
 | |
| */
 | |
| 
 | |
| #pragma once
 | |
| 
 | |
| #include "cutlass/cutlass.h"
 | |
| #include "cutlass/fast_math.h"
 | |
| #include "cutlass/gemm/gemm.h"
 | |
| #include "cutlass/matrix_coord.h"
 | |
| #include "cutlass/complex.h"
 | |
| #include "cutlass/semaphore.h"
 | |
| 
 | |
| #include "cutlass/layout/matrix.h"
 | |
| #include "cutlass/trace.h"
 | |
| #include "cutlass/gemm/kernel/gemm_transpose_operands.h"
 | |
| 
 | |
| #include "fmha_grouped_problem_visitor.h"
 | |
| #include "gemm_kernel_utils.h"
 | |
| #include "gemm/mma_accum_lambda_iterator.h"
 | |
| #include "epilogue/epilogue_rescale_output.h"
 | |
| 
 | |
| 
 | |
| namespace {
 | |
|   static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
 | |
|   // source: https://stackoverflow.com/a/51549250
 | |
|   return (value >= 0)
 | |
|       ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
 | |
|       : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
 | |
| }
 | |
| }
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| namespace cutlass {
 | |
| namespace gemm {
 | |
| namespace kernel {
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| template <
 | |
|   typename MM0_,                           ///! Structure for computing P = Q @ K
 | |
|   typename MM1_,                           ///! Structure for computing O = P @ V
 | |
|   typename scalar_t_,
 | |
|   typename accum_t_,
 | |
|   typename output_t_,
 | |
|   typename output_accum_t_,
 | |
|   bool kKeepOutputInRF,                    ///! Whether the intermediate output from MM0_ should be kept in the register file
 | |
|   GroupScheduleMode GroupScheduleMode_     ///! Type of scheduling to perform
 | |
| >
 | |
| struct FMHAGrouped {
 | |
| public:
 | |
|   using MM0 = MM0_;
 | |
|   using MM1 = MM1_;
 | |
| 
 | |
|   using scalar_t = scalar_t_;
 | |
|   using accum_t = accum_t_;
 | |
|   using output_t = output_t_;
 | |
|   using output_accum_t = output_accum_t_;
 | |
| 
 | |
|   static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
 | |
| 
 | |
|   static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
 | |
|       !cutlass::platform::is_same<output_accum_t, output_t>::value;
 | |
| 
 | |
|   // Parameters to satisfy BaseGrouped
 | |
|   using ElementA = scalar_t;
 | |
|   using ElementB = scalar_t;
 | |
|   using ElementC = accum_t;
 | |
|   using LayoutA = typename MM0::LayoutA;
 | |
|   using LayoutB = typename MM0::ElementB;
 | |
|   using LayoutC = typename MM1::ElementC;
 | |
|   static ComplexTransform const kTransformA = ComplexTransform::kNone;
 | |
|   static ComplexTransform const kTransformB = ComplexTransform::kNone;
 | |
|   static int const kAlignmentA = MM0::kAlignmentA;
 | |
|   static int const kAlignmentB = MM0::kAlignmentB;
 | |
|   static int const kAlignmentC = 1;
 | |
|   using Mma = typename MM1::Mma;
 | |
|   using EpilogueOutputOp = typename MM1::EpilogueOutputOp;
 | |
|   using ThreadblockSwizzle = void;
 | |
|   using Operator = typename MM1::Operator;
 | |
|   using WarpShape = typename MM1::WarpShape;
 | |
|   using InstructionShape = typename MM1::InstructionShape;
 | |
| 
 | |
|   using ElementQ = scalar_t;
 | |
|   using ElementK = scalar_t;
 | |
|   using ElementP = accum_t;
 | |
|   using ElementV = scalar_t;
 | |
|   using ElementO = output_t;
 | |
|   using ElementOAccum = output_accum_t;
 | |
|   using ElementAccumulator = accum_t;
 | |
| 
 | |
|   using LayoutQ = typename MM0::LayoutA;
 | |
|   using LayoutK = typename MM0::LayoutB;
 | |
|   using LayoutP = typename MM0::LayoutC;
 | |
|   using LayoutV = typename MM1::LayoutB;
 | |
|   using LayoutO = typename MM1::LayoutC;
 | |
| 
 | |
|   static bool const kPreloadV = (MM1::Mma::ArchTag::kMinComputeCapability >= 80 &&
 | |
|                                  cutlass::sizeof_bits<ElementV>::value == 16);
 | |
| 
 | |
|   static int const kAlignmentQ = MM0::kAlignmentA;
 | |
|   static int const kAlignmentK = MM0::kAlignmentB;
 | |
|   static int const kAlignmentV = 1;
 | |
| 
 | |
|   using ThreadblockShape = typename MM0::ThreadblockShape;
 | |
| 
 | |
|   static int const kQueriesPerBlock = ThreadblockShape::kM;
 | |
|   static int const kKeysPerBlock = ThreadblockShape::kN;
 | |
| 
 | |
|   static constexpr bool kSupportsDropout = false;
 | |
|   static constexpr bool kSupportsBias = false;
 | |
| 
 | |
|   /// Warp count (concept: GemmShape)
 | |
|   using WarpCount = typename MM1::WarpCount;
 | |
|   static int const kThreadsPerWarp = 32;
 | |
|   static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount;
 | |
| 
 | |
|   static constexpr int kNumWarpsPerBlock =
 | |
|     kQueriesPerBlock * kKeysPerBlock / (kThreadsPerWarp * kThreadsPerWarp);
 | |
| 
 | |
|   using ProblemVisitor = FMHAGroupedProblemVisitor<
 | |
|                             ThreadblockShape,
 | |
|                             kGroupScheduleMode,
 | |
|                             kThreadCount,
 | |
|                             kThreadCount>;
 | |
| 
 | |
|   //
 | |
|   // Structures
 | |
|   //
 | |
| 
 | |
|   /// Argument structure
 | |
|   struct Arguments {
 | |
| 
 | |
|     //
 | |
|     // Data members
 | |
|     //
 | |
| 
 | |
|     GemmCoord *problem_sizes0;
 | |
|     GemmCoord *problem_sizes1;
 | |
| 
 | |
|     int problem_count;
 | |
|     int threadblock_count;
 | |
| 
 | |
|     ElementQ ** ptr_Q;
 | |
|     ElementK ** ptr_K;
 | |
|     ElementP ** ptr_P;
 | |
|     ElementV ** ptr_V;
 | |
|     ElementO ** ptr_O;
 | |
|     ElementOAccum ** ptr_O_accum;
 | |
| 
 | |
|     typename LayoutQ::Stride::LongIndex *ldq;
 | |
|     typename LayoutK::Stride::LongIndex *ldk;
 | |
|     typename LayoutP::Stride::LongIndex *ldv;
 | |
|     typename LayoutO::Stride::LongIndex *ldo;
 | |
| 
 | |
|     // Scale
 | |
|     ElementAccumulator scale;
 | |
| 
 | |
|     // Whether causal masking is to be performed
 | |
|     bool causal;
 | |
| 
 | |
|     // Only used by device-level operator
 | |
|     GemmCoord *host_problem_sizes;
 | |
| 
 | |
|     //
 | |
|     // Methods
 | |
|     //
 | |
| 
 | |
|     /// Default ctor
 | |
|     CUTLASS_HOST_DEVICE
 | |
|     Arguments():
 | |
|       problem_count(0),
 | |
|       threadblock_count(0),
 | |
|       ptr_Q(nullptr),
 | |
|       ptr_K(nullptr),
 | |
|       ptr_P(nullptr),
 | |
|       ptr_V(nullptr),
 | |
|       ptr_O(nullptr),
 | |
|       ptr_O_accum(nullptr),
 | |
|       ldq(nullptr),
 | |
|       ldk(nullptr),
 | |
|       ldv(nullptr),
 | |
|       ldo(nullptr),
 | |
|       scale(0),
 | |
|       causal(false),
 | |
|       host_problem_sizes(nullptr)
 | |
|     {
 | |
| 
 | |
|     }
 | |
| 
 | |
|     /// Ctor
 | |
|     CUTLASS_HOST_DEVICE
 | |
|     Arguments(
 | |
|       GemmCoord *problem_sizes0,
 | |
|       GemmCoord *problem_sizes1,
 | |
|       int problem_count,
 | |
|       int threadblock_count,
 | |
|       ElementQ ** ptr_Q,
 | |
|       ElementK ** ptr_K,
 | |
|       ElementP ** ptr_P,
 | |
|       ElementV ** ptr_V,
 | |
|       ElementO ** ptr_O,
 | |
|       ElementOAccum ** ptr_O_accum,
 | |
|       typename LayoutQ::Stride::LongIndex *ldq,
 | |
|       typename LayoutK::Stride::LongIndex *ldk,
 | |
|       typename LayoutP::Stride::LongIndex *ldp,
 | |
|       typename LayoutV::Stride::LongIndex *ldv,
 | |
|       typename LayoutO::Stride::LongIndex *ldo,
 | |
|       bool causal,
 | |
|       ElementAccumulator scale,
 | |
|       GemmCoord *host_problem_sizes=nullptr
 | |
|     ):
 | |
|       problem_sizes0(problem_sizes0),
 | |
|       problem_sizes1(problem_sizes1),
 | |
|       problem_count(problem_count),
 | |
|       threadblock_count(threadblock_count),
 | |
|       ptr_Q(ptr_Q),
 | |
|       ptr_K(ptr_K),
 | |
|       ptr_P(ptr_P),
 | |
|       ptr_V(ptr_V),
 | |
|       ptr_O(ptr_O),
 | |
|       ptr_O_accum(kNeedsOutputAccumulatorBuffer ? ptr_O_accum : (accum_t**)ptr_O),
 | |
|       ldq(ldq),
 | |
|       ldk(ldk),
 | |
|       ldv(ldv),
 | |
|       ldo(ldo),
 | |
|       causal(causal),
 | |
|       scale(scale),
 | |
|       host_problem_sizes(host_problem_sizes)
 | |
|     {
 | |
| 
 | |
|     }
 | |
| 
 | |
|     bool __host__ check_supported() {
 | |
|       CHECK_ALIGNED_PTR(ptr_Q, kAlignmentQ);
 | |
|       CHECK_ALIGNED_PTR(ptr_K, kAlignmentK);
 | |
|       CHECK_ALIGNED_PTR(ptr_V, kAlignmentV);
 | |
|       XFORMERS_CHECK(ldq % kAlignmentQ == 0, "query is not correctly aligned");
 | |
|       XFORMERS_CHECK(ldk % kAlignmentK == 0, "key is not correctly aligned");
 | |
|       XFORMERS_CHECK(ldv % kAlignmentV == 0, "value is not correctly aligned");
 | |
|       return true;
 | |
|     }
 | |
|   };
 | |
| 
 | |
|   //
 | |
|   // Structure for precomputing values in host memory and passing to kernels
 | |
|   //
 | |
| 
 | |
|   /// Parameters structure
 | |
|   struct Params {
 | |
| 
 | |
|     typename ProblemVisitor::Params problem_visitor;
 | |
|     int threadblock_count;
 | |
| 
 | |
|     ElementQ ** ptr_Q;
 | |
|     ElementK ** ptr_K;
 | |
|     ElementP ** ptr_P;
 | |
|     ElementV ** ptr_V;
 | |
|     ElementO ** ptr_O;
 | |
|     ElementOAccum ** ptr_O_accum;
 | |
| 
 | |
|     typename LayoutQ::Stride::LongIndex *ldq;
 | |
|     typename LayoutK::Stride::LongIndex *ldk;
 | |
|     typename LayoutP::Stride::LongIndex *ldv;
 | |
|     typename LayoutO::Stride::LongIndex *ldo;
 | |
| 
 | |
|     ElementAccumulator scale;
 | |
|     bool causal;
 | |
| 
 | |
|     //
 | |
|     // Methods
 | |
|     //
 | |
| 
 | |
|     CUTLASS_HOST_DEVICE
 | |
|     Params():
 | |
|       ptr_Q(nullptr),
 | |
|       ptr_K(nullptr),
 | |
|       ptr_P(nullptr),
 | |
|       ptr_V(nullptr),
 | |
|       ptr_O(nullptr),
 | |
|       ptr_O_accum(nullptr),
 | |
|       ldq(nullptr),
 | |
|       ldk(nullptr),
 | |
|       ldv(nullptr),
 | |
|       ldo(nullptr),
 | |
|       causal(false),
 | |
|       scale(0)
 | |
|     { }
 | |
| 
 | |
|     CUTLASS_HOST_DEVICE
 | |
|     Params(Arguments const &args,
 | |
|           void *workspace = nullptr,
 | |
|           int tile_count = 0):
 | |
|       problem_visitor(args.problem_sizes0, args.problem_sizes1, args.problem_count, workspace, tile_count),
 | |
|       threadblock_count(args.threadblock_count),
 | |
|       ptr_Q(args.ptr_Q),
 | |
|       ptr_K(args.ptr_K),
 | |
|       ptr_P(args.ptr_P),
 | |
|       ptr_V(args.ptr_V),
 | |
|       ptr_O(args.ptr_O),
 | |
|       ptr_O_accum(kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O),
 | |
|       ldq(args.ldq),
 | |
|       ldk(args.ldk),
 | |
|       ldv(args.ldv),
 | |
|       ldo(args.ldo),
 | |
|       causal(args.causal),
 | |
|       scale(args.scale)
 | |
|     { 
 | |
| 
 | |
|     }
 | |
| 
 | |
|     CUTLASS_HOST_DEVICE
 | |
|     void update(
 | |
|       Arguments const &args,
 | |
|       void *workspace = nullptr,
 | |
|       int tile_count = 0) {
 | |
| 
 | |
|       problem_visitor = typename ProblemVisitor::Params(args.problem_sizes0,
 | |
|                                                         args.problem_sizes1,
 | |
|                                                         args.problem_count,
 | |
|                                                         workspace, tile_count);
 | |
|       threadblock_count = args.threadblock_count;
 | |
|       ptr_Q = args.ptr_Q;
 | |
|       ptr_K = args.ptr_K;
 | |
|       ptr_P = args.ptr_P;
 | |
|       ptr_V = args.ptr_V;
 | |
|       ptr_O = args.ptr_O;
 | |
|       ptr_O_accum = kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O;
 | |
|       ldq = args.ldq;
 | |
|       ldk = args.ldk;
 | |
|       ldv = args.ldv;
 | |
|       ldo = args.ldo;
 | |
|       causal = args.causal;
 | |
|       scale = args.scale;
 | |
|     }
 | |
|   };
 | |
| 
 | |
|   // Shared storage - depends on kernel params
 | |
|   struct ScalingCoefs {
 | |
|     cutlass::Array<ElementAccumulator, kQueriesPerBlock> m_prime;
 | |
|     cutlass::Array<ElementAccumulator, kQueriesPerBlock> s_prime;
 | |
|     cutlass::Array<ElementAccumulator, kQueriesPerBlock> mi;
 | |
|     cutlass::Array<ElementAccumulator, kQueriesPerBlock> out_rescale;
 | |
|     cutlass::Array<ElementAccumulator, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
 | |
|         addition_storage;
 | |
|   };
 | |
| 
 | |
|   struct SharedStorageEpilogueAtEnd : ScalingCoefs {
 | |
|     struct SharedStorageAfterMM0 {
 | |
|       // Everything here might be overwritten during MM0
 | |
|       typename MM0::AccumulatorSharedStorage si;
 | |
|       typename MM1::Mma::SharedStorage mm1;
 | |
|     };
 | |
| 
 | |
|     union {
 | |
|       typename MM0::Mma::SharedStorage mm0;
 | |
|       SharedStorageAfterMM0 after_mm0;
 | |
|       typename MM1::DefaultEpilogue::SharedStorage epilogue;
 | |
|     };
 | |
| 
 | |
|     CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
 | |
|     epilogue_shared_storage() {
 | |
|       return epilogue;
 | |
|     }
 | |
| 
 | |
|     // ProblemVisitor shared storage can't be overlapped with others
 | |
|     typename ProblemVisitor::SharedStorage problem_visitor;
 | |
|   };
 | |
| 
 | |
|   struct SharedStorageEpilogueInLoop : ScalingCoefs {
 | |
|     struct SharedStorageAfterMM0 {
 | |
|       // Everything here might be overwritten during MM0
 | |
|       typename MM0::AccumulatorSharedStorage si;
 | |
|       typename MM1::Mma::SharedStorage mm1;
 | |
|       typename MM1::DefaultEpilogue::SharedStorage epilogue;
 | |
|     };
 | |
| 
 | |
|     union {
 | |
|       typename MM0::Mma::SharedStorage mm0;
 | |
|       SharedStorageAfterMM0 after_mm0;
 | |
|     };
 | |
| 
 | |
|     CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
 | |
|     epilogue_shared_storage() {
 | |
|       return after_mm0.epilogue;
 | |
|     }
 | |
| 
 | |
|     // ProblemVisitor shared storage can't be overlapped with others
 | |
|     typename ProblemVisitor::SharedStorage problem_visitor;
 | |
|   };
 | |
| 
 | |
|   using SharedStorage = typename cutlass::platform::conditional<
 | |
|       kKeepOutputInRF,
 | |
|       SharedStorageEpilogueAtEnd,
 | |
|       SharedStorageEpilogueInLoop>::type;
 | |
| 
 | |
| private:
 | |
| 
 | |
|   // Parameters to be used by an individual tile
 | |
|   struct TileParams {
 | |
| 
 | |
|     CUTLASS_HOST_DEVICE
 | |
|     static int query_start(int threadblock_idx) {
 | |
|       return threadblock_idx * kQueriesPerBlock;
 | |
|     }
 | |
| 
 | |
|     // Returns whether this threadblock computes within the number of queries,
 | |
|     // which is determined by the M dimension of problem 0
 | |
|     CUTLASS_HOST_DEVICE
 | |
|     static bool can_compute(int threadblock_idx, const GemmCoord& problem_size0) {
 | |
|       return query_start(threadblock_idx) < problem_size0.m();
 | |
|     }
 | |
| 
 | |
|     CUTLASS_HOST_DEVICE
 | |
|     static int num_queries(int threadblock_idx, const GemmCoord& problem_size0) {
 | |
|       return problem_size0.m() - query_start(threadblock_idx);
 | |
|     }
 | |
| 
 | |
|     CUTLASS_HOST_DEVICE
 | |
|     static int num_keys(int threadblock_idx, const GemmCoord& problem_size0, bool causal) {
 | |
|       int nk = problem_size0.n();
 | |
|       if (causal) {
 | |
|         nk = cutlass::fast_min(int32_t(query_start(threadblock_idx) + kQueriesPerBlock), nk);
 | |
|       }
 | |
|       return nk;
 | |
|     }
 | |
| 
 | |
|   };
 | |
| 
 | |
| public:
 | |
| 
 | |
|   //
 | |
|   // Methods
 | |
|   //
 | |
| 
 | |
|   CUTLASS_DEVICE
 | |
|   FMHAGrouped() { }
 | |
| 
 | |
|   /// Determines whether kernel satisfies alignment
 | |
|   static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
 | |
|     return Status::kSuccess;
 | |
|   }
 | |
| 
 | |
|   static Status can_implement(Arguments const &args) {
 | |
|     return Status::kSuccess;
 | |
|   }
 | |
| 
 | |
|   static CUTLASS_DEVICE int16_t thread_id() {
 | |
|     return threadIdx.x;
 | |
|   }
 | |
| 
 | |
|   static CUTLASS_DEVICE int8_t warp_id() {
 | |
|     return threadIdx.x / kThreadsPerWarp;
 | |
|   }
 | |
| 
 | |
|   static CUTLASS_DEVICE int8_t lane_id() {
 | |
|     return threadIdx.x % kThreadsPerWarp;
 | |
|   }
 | |
| 
 | |
|   /// Executes one GEMM
 | |
|   CUTLASS_DEVICE
 | |
|   void operator()(Params const ¶ms, SharedStorage &shared_storage) {
 | |
|     auto& m_prime = shared_storage.m_prime;
 | |
|     auto& s_prime = shared_storage.s_prime;
 | |
|     [[maybe_unused]] auto& si = shared_storage.after_mm0.si;
 | |
|     auto& mi = shared_storage.mi;
 | |
|     auto& out_rescale = shared_storage.out_rescale;
 | |
| 
 | |
|     ProblemVisitor problem_visitor(
 | |
|       params.problem_visitor,
 | |
|       shared_storage.problem_visitor,
 | |
|       blockIdx.x);
 | |
| 
 | |
|     // Outer 'persistent' loop to iterate over tiles
 | |
|     while (problem_visitor.next_tile()) {
 | |
| 
 | |
|       GemmCoord problem_size0 = problem_visitor.problem_size0();
 | |
|       GemmCoord problem_size1 = problem_visitor.problem_size1();
 | |
|       const int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
 | |
| 
 | |
|       if (!TileParams::can_compute(threadblock_idx, problem_size0)) {
 | |
|         problem_visitor.advance(gridDim.x);
 | |
|         continue;
 | |
|       }
 | |
| 
 | |
|       const int32_t problem_idx = problem_visitor.problem_index();
 | |
| 
 | |
|       if (thread_id() < kQueriesPerBlock) {
 | |
|         s_prime[thread_id()] = ElementAccumulator(0);
 | |
|         out_rescale[thread_id()] = accum_t(1.0);
 | |
|         m_prime[thread_id()] =
 | |
|             -cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
 | |
|         mi[thread_id()] = -cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
 | |
|       }
 | |
| 
 | |
|       ElementO *ptr_O = params.ptr_O[problem_idx]  + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx];
 | |
|       ElementOAccum *ptr_O_accum = params.ptr_O_accum[problem_idx]  + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx];
 | |
|       const int num_queries = TileParams::num_queries(threadblock_idx, problem_size0);
 | |
| 
 | |
|       auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
 | |
|         using OutputTileIterator = typename MM1::OutputTileIterator;
 | |
|         return OutputTileIterator(
 | |
|             typename OutputTileIterator::Params{(int32_t)params.ldo[problem_idx]},
 | |
|             ptr_O,
 | |
|             typename OutputTileIterator::TensorCoord{
 | |
|                 num_queries, problem_size1.n()},
 | |
|             thread_id(),
 | |
|             {0, col});
 | |
|       };
 | |
| 
 | |
|       auto createOutputAccumIter = [&](int col) ->
 | |
|         typename MM1::OutputTileIteratorAccum {
 | |
|           using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
 | |
|           return OutputTileIteratorAccum(
 | |
|               typename OutputTileIteratorAccum::Params{(int32_t)params.ldo[problem_idx]},
 | |
|               ptr_O_accum,
 | |
|               typename OutputTileIteratorAccum::TensorCoord{
 | |
|                   num_queries, problem_size1.n()},
 | |
|               thread_id(),
 | |
|               {0, col});
 | |
|         };
 | |
| 
 | |
|       typename MM1::Mma::FragmentC accum_o;
 | |
|       accum_o.clear();
 | |
| 
 | |
|       const int num_keys = TileParams::num_keys(threadblock_idx, problem_size0, params.causal);
 | |
| 
 | |
|       for (int32_t iter_key_start = 0; iter_key_start < num_keys;
 | |
|            iter_key_start += kKeysPerBlock) {
 | |
|         int32_t problem_size_0_m =
 | |
|             cutlass::fast_min((int32_t)kQueriesPerBlock, num_queries);
 | |
|         int32_t problem_size_0_n = cutlass::fast_min(
 | |
|             (int32_t)kKeysPerBlock, num_keys - iter_key_start);
 | |
|         int32_t const& problem_size_0_k = problem_size0.k();
 | |
|         int32_t const& problem_size_1_n = problem_size1.n();
 | |
|         int32_t const& problem_size_1_k = problem_size_0_n;
 | |
| 
 | |
|         auto prologueV = [&](int blockN) {
 | |
|           typename MM1::Mma::IteratorB iterator_V(
 | |
|               typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])},
 | |
|               params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx],
 | |
|               {problem_size_1_k, problem_size_1_n},
 | |
|               thread_id(),
 | |
|               cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
 | |
| 
 | |
|           MM1::Mma::prologue(
 | |
|               shared_storage.after_mm0.mm1,
 | |
|               iterator_V,
 | |
|               thread_id(),
 | |
|               problem_size_1_k);
 | |
|         };
 | |
| 
 | |
|         __syncthreads(); // Need to have shared memory initialized, and `m_prime`
 | |
|                          // updated from end of prev iter
 | |
| 
 | |
|         //
 | |
|         // MATMUL: Q.K_t
 | |
|         //
 | |
|         // Computes the block-matrix product of:
 | |
|         // (a) query[query_start:query_end, :]
 | |
|         // with
 | |
|         // (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
 | |
|         // and stores that into `shared_storage.si`
 | |
|         //
 | |
| 
 | |
|         ElementQ *ptr_Q = params.ptr_Q[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldq[problem_idx];
 | |
| 
 | |
|         // Construct iterators to A and B operands
 | |
|         typename MM0::IteratorA iterator_A(
 | |
|           typename MM0::IteratorA::Params(
 | |
|               typename MM0::MmaCore::LayoutA(params.ldq[problem_idx])),
 | |
|           ptr_Q,
 | |
|           {problem_size_0_m, problem_size_0_k},
 | |
|           thread_id(),
 | |
|           {0, 0});
 | |
| 
 | |
|         typename MM0::IteratorB iterator_B(
 | |
|             typename MM0::IteratorB::Params(
 | |
|                 typename MM0::MmaCore::LayoutB(params.ldk[problem_idx])),
 | |
|             params.ptr_K[problem_idx] + iter_key_start * params.ldk[problem_idx],
 | |
|             {problem_size_0_k, problem_size_0_n},
 | |
|             thread_id(),
 | |
|             {0, 0});
 | |
| 
 | |
|         // Construct thread-scoped matrix multiply
 | |
|         typename MM0::Mma mma(
 | |
|             shared_storage.mm0, thread_id(), warp_id(), lane_id());
 | |
| 
 | |
|         typename MM0::Mma::FragmentC accum;
 | |
| 
 | |
|         accum.clear();
 | |
| 
 | |
|         auto gemm_k_iterations =
 | |
|             (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
 | |
| 
 | |
|         // Compute threadblock-scoped matrix multiply-add
 | |
|         mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
 | |
|         __syncthreads();
 | |
| 
 | |
|         if (kPreloadV) {
 | |
|           prologueV(0);
 | |
|         } else {
 | |
|           MM1::Mma::drain_cp_asyncs();
 | |
|         }
 | |
| 
 | |
|         typename MM0::Mma::Operator::IteratorC::TensorCoord
 | |
|           iteratorC_tile_offset = {
 | |
|               (warp_id() % MM0::Mma::WarpCount::kM),
 | |
|               (warp_id() / MM0::Mma::WarpCount::kM)
 | |
|             };
 | |
| 
 | |
|         // Mask out last if causal
 | |
|         if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) {
 | |
|           auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
 | |
|               lane_id(), warp_id(), iteratorC_tile_offset);
 | |
|           int32_t last_col;
 | |
|           MM0::AccumLambdaIterator::iterateRows(
 | |
|               lane_offset,
 | |
|               [&](int accum_m) {
 | |
|                 last_col = TileParams::query_start(threadblock_idx) + accum_m - iter_key_start;
 | |
|               },
 | |
|               [&](int accum_m, int accum_n, int idx) {
 | |
|                 if (accum_n > last_col) {
 | |
|                   accum[idx] =
 | |
|                       -cutlass::platform::numeric_limits<accum_t>::infinity();
 | |
|                 }
 | |
|               },
 | |
|               [&](int accum_m) {});
 | |
|         }
 | |
|         // DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
 | |
|         //         DISPATCH_BOOL(
 | |
|         //             num_keys - iter_key_start >= kKeysPerBlock,
 | |
|         //             kFullColumns,
 | |
|         //             ([&] {
 | |
|         //               // Update `mi` from accum stored in registers
 | |
|         //               // Also does accum[i] <- exp(accum[i] - mi)
 | |
|         //               iterative_softmax<
 | |
|         //                   typename MM0::Mma::Operator::IteratorC,
 | |
|         //                   kFullColumns,
 | |
|         //                   kIsFirst>(
 | |
|         //                   accum_o,
 | |
|         //                   accum,
 | |
|         //                   mi,
 | |
|         //                   m_prime,
 | |
|         //                   s_prime,
 | |
|         //                   lane_id(),
 | |
|         //                   thread_id(),
 | |
|         //                   warp_id(),
 | |
|         //                   num_keys - iter_key_start,
 | |
|         //                   iteratorC_tile_offset,
 | |
|         //                   kSupportsBias ? 1.0f : params.scale);
 | |
|         //             }));
 | |
|         //       }));
 | |
| 
 | |
|         // Update `mi` from accum stored in registers
 | |
|         // Also does accum[i] <- exp(accum[i] - mi)
 | |
|         iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
 | |
|             accum_o,
 | |
|             accum,
 | |
|             mi,
 | |
|             m_prime,
 | |
|             s_prime,
 | |
|             out_rescale,
 | |
|             shared_storage.addition_storage,
 | |
|             lane_id(),
 | |
|             thread_id(),
 | |
|             warp_id(),
 | |
|             num_keys - iter_key_start,
 | |
|             iter_key_start == 0,
 | |
|             iteratorC_tile_offset,
 | |
|             kSupportsBias ? 1.0f : params.scale);
 | |
| 
 | |
|         // Output results to shared-memory
 | |
|         int warp_idx_mn_0 = warp_id() %
 | |
|             (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
 | |
|         auto output_tile_coords = cutlass::MatrixCoord{
 | |
|             warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
 | |
|             warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
 | |
| 
 | |
|         MM0::B2bGemm::accumToSmem(
 | |
|             shared_storage.after_mm0.si, accum, lane_id(), output_tile_coords);
 | |
| 
 | |
|         __syncthreads();
 | |
| 
 | |
|         //
 | |
|         // MATMUL: Attn . V
 | |
|         // Run the matmul `attn @ V` for a block of attn and V.
 | |
|         // `attn` is read from shared memory (in `shared_storage_si`)
 | |
|         // `V` is read from global memory (with iterator_B)
 | |
|         //
 | |
| 
 | |
|         const int64_t nBlockN = kKeepOutputInRF ? 1
 | |
|                                                 : ceil_div(
 | |
|                                                       (int64_t)problem_size_1_n,
 | |
|                                                       int64_t(MM1::ThreadblockShape::kN));
 | |
| 
 | |
|         // Iterate over the N dimension of GEMM1
 | |
|         for (int blockN = 0; blockN < nBlockN; ++blockN) {
 | |
|           int gemm_k_iterations =
 | |
|               (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
 | |
| 
 | |
|           // Compute threadblock-scoped matrix multiply-add and store it in accum
 | |
|           // (in registers)
 | |
|           if (!kPreloadV) {
 | |
|             __syncthreads(); // we share shmem between mma and epilogue
 | |
|           }
 | |
| 
 | |
|           typename MM1::Mma::IteratorB iterator_V(
 | |
|             typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])},
 | |
|             params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx],
 | |
|             {problem_size_1_k, problem_size_1_n},
 | |
|             thread_id(),
 | |
|             cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
 | |
| 
 | |
|           typename MM1::Mma mma_pv(
 | |
|             // operand A: Pij_dropped in shared memory
 | |
|             shared_storage.after_mm0.si.accum_ref(),
 | |
|             // operand B: shared memory staging area for Vj, which is loaded
 | |
|             // from global memory
 | |
|             shared_storage.after_mm0.mm1.operand_B_ref(),
 | |
|             (int)thread_id(),
 | |
|             (int)warp_id(),
 | |
|             (int)lane_id());
 | |
| 
 | |
|           mma_pv.set_prologue_done(kPreloadV);
 | |
|           if (!kKeepOutputInRF) {
 | |
|             accum_o.clear();
 | |
|           }
 | |
| 
 | |
|           mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
 | |
|           __syncthreads();
 | |
| 
 | |
|           if (kPreloadV && !kKeepOutputInRF && blockN + 1 < nBlockN) {
 | |
|             prologueV(blockN + 1);
 | |
|           }
 | |
| 
 | |
|           if (!kKeepOutputInRF) {
 | |
|             MM1::Mma::drain_cp_asyncs();
 | |
|             DISPATCH_BOOL(
 | |
|                 iter_key_start == 0, kIsFirst, ([&] {
 | |
|                   DISPATCH_BOOL(
 | |
|                       (iter_key_start + kKeysPerBlock) >= num_keys,
 | |
|                       kIsLast,
 | |
|                       ([&] {
 | |
|                         using DefaultEpilogue = typename MM1::DefaultEpilogue;
 | |
|                         using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
 | |
|                         using ElementCompute = typename DefaultOp::ElementCompute;
 | |
|                         using EpilogueOutputOp = typename cutlass::epilogue::
 | |
|                             thread::MemoryEfficientAttentionNormalize<
 | |
|                                 typename cutlass::platform::conditional<
 | |
|                                     kIsLast,
 | |
|                                     output_t,
 | |
|                                     output_accum_t>::type,
 | |
|                                 output_accum_t,
 | |
|                                 DefaultOp::kCount,
 | |
|                                 typename DefaultOp::ElementAccumulator,
 | |
|                                 output_accum_t,
 | |
|                                 kIsFirst,
 | |
|                                 kIsLast,
 | |
|                                 cutlass::Array<ElementCompute, kQueriesPerBlock>>;
 | |
|                         using Epilogue = typename cutlass::epilogue::threadblock::
 | |
|                             EpiloguePipelined<
 | |
|                                 typename DefaultEpilogue::Shape,
 | |
|                                 typename MM1::Mma::Operator,
 | |
|                                 DefaultEpilogue::kPartitionsK,
 | |
|                                 typename cutlass::platform::conditional<
 | |
|                                     kIsLast,
 | |
|                                     typename MM1::OutputTileIterator,
 | |
|                                     typename MM1::OutputTileIteratorAccum>::type,
 | |
|                                 typename DefaultEpilogue::
 | |
|                                     AccumulatorFragmentIterator,
 | |
|                                 typename DefaultEpilogue::WarpTileIterator,
 | |
|                                 typename DefaultEpilogue::SharedLoadIterator,
 | |
|                                 EpilogueOutputOp,
 | |
|                                 typename DefaultEpilogue::Padding,
 | |
|                                 DefaultEpilogue::kFragmentsPerIteration,
 | |
|                                 true, // IterationsUnroll
 | |
|                                 typename MM1::OutputTileIteratorAccum // Read
 | |
|                                                                       // iterator
 | |
|                                 >;
 | |
| 
 | |
|                         int col = blockN * MM1::Mma::Shape::kN;
 | |
|                         auto source_iter = createOutputAccumIter(col);
 | |
|                         auto dest_iter = gemm_kernel_utils::call_conditional<
 | |
|                             kIsLast,
 | |
|                             decltype(createOutputIter),
 | |
|                             decltype(createOutputAccumIter)>::
 | |
|                             apply(createOutputIter, createOutputAccumIter, col);
 | |
|                         EpilogueOutputOp rescale(s_prime, out_rescale);
 | |
|                         Epilogue epilogue(
 | |
|                             shared_storage.epilogue_shared_storage(),
 | |
|                             thread_id(),
 | |
|                             warp_id(),
 | |
|                             lane_id());
 | |
|                         epilogue(rescale, dest_iter, accum_o, source_iter);
 | |
|                       }));
 | |
|                 }));
 | |
|             if (!kKeepOutputInRF) {
 | |
|               __syncthreads();
 | |
|             }
 | |
|           }
 | |
|         }
 | |
|          __syncthreads(); // we modify `m_prime` after
 | |
|       }
 | |
| 
 | |
|       if (kKeepOutputInRF) {
 | |
|         const bool kIsFirst = true;
 | |
|         const bool kIsLast = true;
 | |
|         using DefaultEpilogue = typename MM1::DefaultEpilogue;
 | |
|         using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
 | |
|         using ElementCompute = typename DefaultOp::ElementCompute;
 | |
|         using EpilogueOutputOp =
 | |
|             typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
 | |
|                 output_t,       // output
 | |
|                 output_accum_t, // source
 | |
|                 DefaultOp::kCount,
 | |
|                 typename DefaultOp::ElementAccumulator, // accum
 | |
|                 output_accum_t, // compute
 | |
|                 kIsFirst,
 | |
|                 kIsLast,
 | |
|                 cutlass::Array<ElementCompute, kQueriesPerBlock>>;
 | |
|         using Epilogue =
 | |
|             typename cutlass::epilogue::threadblock::EpiloguePipelined<
 | |
|                 typename DefaultEpilogue::Shape,
 | |
|                 typename MM1::Mma::Operator,
 | |
|                 DefaultEpilogue::kPartitionsK,
 | |
|                 typename MM1::OutputTileIterator, // destination
 | |
|                 typename DefaultEpilogue::AccumulatorFragmentIterator,
 | |
|                 typename DefaultEpilogue::WarpTileIterator,
 | |
|                 typename DefaultEpilogue::SharedLoadIterator,
 | |
|                 EpilogueOutputOp,
 | |
|                 typename DefaultEpilogue::Padding,
 | |
|                 DefaultEpilogue::kFragmentsPerIteration,
 | |
|                 true, // IterationsUnroll
 | |
|                 typename MM1::OutputTileIteratorAccum // source tile
 | |
|                 >;
 | |
|         auto dest_iter = createOutputIter(0);
 | |
|         EpilogueOutputOp rescale(s_prime, out_rescale);
 | |
|         Epilogue epilogue(
 | |
|             shared_storage.epilogue_shared_storage(),
 | |
|             thread_id(),
 | |
|             warp_id(),
 | |
|             lane_id());
 | |
|         MM1::Mma::drain_cp_asyncs();
 | |
|         epilogue(rescale, dest_iter, accum_o);
 | |
|       }
 | |
| 
 | |
|       // Next tile
 | |
|       problem_visitor.advance(gridDim.x);
 | |
|       __syncthreads(); // Don't start the next iteration until all threads are done using shared memory.
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   template <typename WarpIteratorC>
 | |
|   CUTLASS_DEVICE static void iterative_softmax(
 | |
|       typename WarpIteratorC::Fragment& frag_o, // output so far
 | |
|       typename WarpIteratorC::Fragment& frag,
 | |
|       cutlass::Array<accum_t, kQueriesPerBlock>& mi,
 | |
|       cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
 | |
|       cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
 | |
|       cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
 | |
|       cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
 | |
|           addition_storage,
 | |
|       int8_t lane_id,
 | |
|       int8_t thread_id,
 | |
|       int8_t warp_id,
 | |
|       int max_col,
 | |
|       bool is_first,
 | |
|       typename WarpIteratorC::TensorCoord const& tile_offset,
 | |
|       float scaling) {
 | |
|     /* Iterates on the accumulator and corresponding position on result matrix
 | |
| 
 | |
|     (1) Update `mi[r]` to the max value of the row `r`
 | |
|     (2) In a second iteration do the following:
 | |
|         (a) accum   <- exp(accum - mi)
 | |
|         (b) m_prime <- exp(m_prime - mi)
 | |
|         (c) s_prime <- s_prime * m_prime + sum(accum)
 | |
| 
 | |
|     All of this is done on registers, before we store all of this
 | |
|     on shared memory for the next matmul with Value.
 | |
|     */
 | |
|     using Fragment = typename WarpIteratorC::Fragment;
 | |
|     using LambdaIterator = typename DefaultMmaAccumLambdaIterator<
 | |
|         WarpIteratorC,
 | |
|         accum_t,
 | |
|         kThreadsPerWarp>::Iterator;
 | |
|     // Convert to `accum_t` (rather than double)
 | |
|     constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
 | |
| 
 | |
|     static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
 | |
|     static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
 | |
| 
 | |
|     frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
 | |
| 
 | |
|     auto lane_offset =
 | |
|         LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
 | |
| 
 | |
|     // First update `mi` to the max per-row
 | |
|     {
 | |
|       accum_t max;
 | |
|       LambdaIterator::iterateRows(
 | |
|           lane_offset,
 | |
|           [&](int accum_m) {
 | |
|             max = -cutlass::platform::numeric_limits<accum_t>::infinity();
 | |
|           },
 | |
|           [&](int accum_m, int accum_n, int idx) {
 | |
|             if (accum_n < max_col) {
 | |
|               max = cutlass::fast_max(max, frag[idx]);
 | |
|             }
 | |
|           },
 | |
|           [&](int accum_m) {
 | |
|             // Having 4x atomicMax seems faster than reduce within warp
 | |
|             // first...
 | |
|             atomicMaxFloat(&mi[accum_m], max);
 | |
|           });
 | |
|     }
 | |
| 
 | |
|     // Make sure we all share the update values for `mi`
 | |
|     __syncthreads();
 | |
| 
 | |
|     // Doing this `exp` is quite expensive. Let's
 | |
|     // split it across the warps
 | |
|     bool restore_mi_to_minus_inf = false;
 | |
|     if (lane_id < kLinesPerWarp) {
 | |
|       int id = warp_id * kLinesPerWarp + lane_id;
 | |
|       auto m_prime_id = m_prime[id];
 | |
|       auto mi_id = mi[id];
 | |
|       bool changed = m_prime_id < mi_id; // `false` if both are -inf
 | |
|       if (changed) {
 | |
|         auto m_prime_exp = exp2f(m_prime_id - mi_id);
 | |
|         out_rescale[id] = m_prime_exp;
 | |
|         s_prime[id] *= m_prime_exp;
 | |
|       } else {
 | |
|         // Only when bias is enabled, it's possible that all the first values
 | |
|         // of attention are masked to `-inf`. In that case we want to avoid
 | |
|         // `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
 | |
|         if (kSupportsBias &&
 | |
|             mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
 | |
|           restore_mi_to_minus_inf = true;
 | |
|           mi[id] = 0.0f;
 | |
|         }
 | |
|         out_rescale[id] = 1.0f;
 | |
|       }
 | |
|     }
 | |
|     __syncthreads(); // Update output fragments
 | |
|     if (kKeepOutputInRF && !is_first) {
 | |
|       accum_t line_rescale;
 | |
|       LambdaIterator::iterateRows(
 | |
|           lane_offset,
 | |
|           [&](int accum_m) { line_rescale = out_rescale[accum_m]; },
 | |
|           [&](int accum_m, int accum_n, int idx) {
 | |
|             frag_o[idx] = frag_o[idx] * line_rescale;
 | |
|           },
 | |
|           [&](int accum_m) {});
 | |
|     }
 | |
|     // Update accum_m, accum_n, ...
 | |
|     {
 | |
|       accum_t mi_row, total_row;
 | |
|       LambdaIterator::iterateRows(
 | |
|           lane_offset,
 | |
|           [&](int accum_m) { mi_row = mi[accum_m]; },
 | |
|           [&](int accum_m, int accum_n, int idx) {
 | |
|             frag[idx] =
 | |
|                 (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
 | |
|           },
 | |
|           [&](int accum_m) {});
 | |
|       LambdaIterator::iterateRows(
 | |
|           lane_offset,
 | |
|           [&](int accum_m) { total_row = 0.0; },
 | |
|           [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
 | |
|           [&](int accum_m) {
 | |
|             if (LambdaIterator::reduceSameRow(
 | |
|                     lane_id, total_row, [](accum_t a, accum_t b) {
 | |
|                       return a + b;
 | |
|                     })) {
 | |
|               // NOTE: we could atomically add `total_row` to `s_prime`, but
 | |
|               // it's faster (and deterministic) to avoid atomics here
 | |
|               addition_storage
 | |
|                   [accum_m + kQueriesPerBlock * tile_offset.column()] =
 | |
|                       total_row;
 | |
|             }
 | |
|           });
 | |
|     }
 | |
| 
 | |
|     __syncthreads();
 | |
|     if (lane_id < kLinesPerWarp) {
 | |
|       int id = warp_id * kLinesPerWarp + lane_id;
 | |
|       accum_t total_row = s_prime[id];
 | |
|       if (restore_mi_to_minus_inf) {
 | |
|         // Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
 | |
|         mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
 | |
|       } else {
 | |
|         m_prime[id] = mi[id];
 | |
|       }
 | |
|       CUTLASS_PRAGMA_UNROLL
 | |
|       for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
 | |
|         total_row += addition_storage[id + kQueriesPerBlock * i];
 | |
|       }
 | |
|       s_prime[id] = total_row;
 | |
|     }
 | |
|   }
 | |
| };
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| } // namespace kernel
 | |
| } // namespace gemm
 | |
| } // namespace cutlass
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 |