/*************************************************************************************************** * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief GEMM kernel to support the epilogue visitor model for customized softmax partial reduction epilogue fusion. This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once its usage has been stabilized. For now, it is included in this example to demonstrate some basic output fusion options. */ #pragma once #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "cutlass/gemm/gemm.h" #include "cutlass/matrix_coord.h" #include "cutlass/complex.h" #include "cutlass/semaphore.h" #include "cutlass/trace.h" ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace gemm { namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// template < typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate typename Epilogue_, ///! Epilogue typename ThreadblockSwizzle_ ///! Threadblock swizzling function > struct GemmWithEpilogueVisitor { public: using Mma = Mma_; using Epilogue = Epilogue_; using EpilogueVisitor = typename Epilogue::Visitor; using ThreadblockSwizzle = ThreadblockSwizzle_; using ElementA = typename Mma::IteratorA::Element; using LayoutA = typename Mma::IteratorA::Layout; using TensorRefA = TensorRef; using ElementB = typename Mma::IteratorB::Element; using LayoutB = typename Mma::IteratorB::Layout; using TensorRefB = TensorRef; using ElementC = typename EpilogueVisitor::ElementOutput; using LayoutC = typename Epilogue::Layout; using TensorRefC = TensorRef; static ComplexTransform const kTransformA = Mma::kTransformA; static ComplexTransform const kTransformB = Mma::kTransformB; using Operator = typename Mma::Operator; using OperatorClass = typename Mma::Operator::OperatorClass; using ThreadblockShape = typename Mma::Shape; using WarpShape = typename Mma::Operator::Shape; using InstructionShape = typename Mma::Policy::Operator::InstructionShape; using ArchTag = typename Mma::ArchTag; using ElementNorm = typename EpilogueVisitor::ElementNorm; using ElementSum = typename EpilogueVisitor::ElementSum; static int const kStages = Mma::kStages; static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; /// Warp count (concept: GemmShape) using WarpCount = typename Mma::WarpCount; static int const kThreadCount = 32 * WarpCount::kCount; /// Split-K preserves splits that are 128b aligned static int const kSplitKAlignment = const_max( 128 / sizeof_bits::value, 128 / sizeof_bits::value ); // // Structures // /// Argument structure struct Arguments { // // Data members // GemmUniversalMode mode; GemmCoord problem_size; int batch_count; TensorRefA ref_A; TensorRefB ref_B; TensorRefC ref_C; TensorRefC ref_D; ElementNorm *ptr_Max; ElementSum *ptr_Sum; int64_t batch_stride_A; int64_t batch_stride_B; typename EpilogueVisitor::Arguments epilogue_visitor; // // Methods // Arguments(): mode(GemmUniversalMode::kGemm), batch_count(1) { } /// constructs an arguments structure Arguments( GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_, TensorRefB ref_B_, TensorRefC ref_C_, TensorRefC ref_D_, ElementNorm *ptr_Max_, ElementSum *ptr_Sum_, int64_t batch_stride_A_, int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_ ): mode(mode_), problem_size(problem_size_), batch_count(batch_count_), ref_A(ref_A_), ref_B(ref_B_), ref_C(ref_C_), ref_D(ref_D_), ptr_Max(ptr_Max_), ptr_Sum(ptr_Sum_), batch_stride_A(batch_stride_A_), batch_stride_B(batch_stride_B_), epilogue_visitor(epilogue_visitor_) { } }; // // Structure for precomputing values in host memory and passing to kernels // /// Parameters structure struct Params { cutlass::gemm::GemmCoord problem_size; cutlass::gemm::GemmCoord grid_tiled_shape; int swizzle_log_tile; typename Mma::IteratorA::Params params_A; typename Mma::IteratorB::Params params_B; typename EpilogueVisitor::OutputTileIterator::Params params_C; typename EpilogueVisitor::OutputTileIterator::Params params_D; GemmUniversalMode mode; int batch_count; int gemm_k_size; void * ptr_A; void * ptr_B; ElementC * ptr_C; ElementC * ptr_D; ElementNorm * ptr_Max; ElementSum * ptr_Sum; int64_t batch_stride_A; int64_t batch_stride_B; typename EpilogueVisitor::Params epilogue_visitor; // // Methods // CUTLASS_HOST_DEVICE Params(): swizzle_log_tile(0), params_A(0), params_B(0), params_C(0), params_D(0), batch_count(0), gemm_k_size(0), mode(cutlass::gemm::GemmUniversalMode::kGemm), ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), ptr_Max(nullptr), ptr_Sum(nullptr), batch_stride_A(0), batch_stride_B(0) { } Params( Arguments const &args ): problem_size(args.problem_size), swizzle_log_tile(0), params_A(args.ref_A.layout()), params_B(args.ref_B.layout()), params_C(args.ref_C.layout()), params_D(args.ref_D.layout()), mode(args.mode), batch_count(args.batch_count), gemm_k_size(args.problem_size.k()), ptr_A(args.ref_A.data()), ptr_B(args.ref_B.data()), ptr_C(args.ref_C.data()), ptr_D(args.ref_D.data()), ptr_Max(args.ptr_Max), ptr_Sum(args.ptr_Sum), batch_stride_A(args.batch_stride_A), batch_stride_B(args.batch_stride_B), epilogue_visitor(args.epilogue_visitor) { ThreadblockSwizzle threadblock_swizzle; grid_tiled_shape = threadblock_swizzle.get_tiled_shape( args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); if (gemm_k_size) { grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); } } swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); } }; /// Shared memory storage structure union SharedStorage { typename Mma::SharedStorage main_loop; struct { typename Epilogue::SharedStorage epilogue; typename EpilogueVisitor::SharedStorage visitor; } epilogue; }; public: // // Methods // CUTLASS_DEVICE GemmWithEpilogueVisitor() { } /// Determines whether kernel satisfies alignment static Status can_implement( cutlass::gemm::GemmCoord const & problem_size) { CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; bool isAMisaligned = false; bool isBMisaligned = false; bool isCMisaligned = false; if (platform::is_same::value) { isAMisaligned = problem_size.k() % kAlignmentA; } else if (platform::is_same::value) { isAMisaligned = problem_size.m() % kAlignmentA; } else if (platform::is_same>::value || platform::is_same>::value) { isAMisaligned = problem_size.k() % kAlignmentA; } if (platform::is_same::value) { isBMisaligned = problem_size.n() % kAlignmentB; } else if (platform::is_same::value) { isBMisaligned = problem_size.k() % kAlignmentB; } else if (platform::is_same>::value || platform::is_same>::value) { isBMisaligned = problem_size.k() % kAlignmentB; } if (platform::is_same::value) { isCMisaligned = problem_size.n() % kAlignmentC; } else if (platform::is_same::value) { isCMisaligned = problem_size.m() % kAlignmentC; } else if (platform::is_same>::value || platform::is_same>::value) { isCMisaligned = problem_size.n() % kAlignmentC; } if (isAMisaligned) { CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); return Status::kErrorMisalignedOperand; } if (isBMisaligned) { CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); return Status::kErrorMisalignedOperand; } if (isCMisaligned) { CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); return Status::kErrorMisalignedOperand; } CUTLASS_TRACE_HOST(" returning kSuccess"); return Status::kSuccess; } static Status can_implement(Arguments const &args) { return can_implement(args.problem_size); } #define SPLIT_K_ENABLED 1 /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { return; } int offset_k = 0; int problem_size_k = params.problem_size.k(); ElementA *ptr_A = static_cast(params.ptr_A); ElementB *ptr_B = static_cast(params.ptr_B); #if SPLIT_K_ENABLED // // Fetch pointers based on mode. // if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) { if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; } offset_k = threadblock_tile_offset.k() * params.gemm_k_size; } else if (params.mode == GemmUniversalMode::kBatched) { ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; } else if (params.mode == GemmUniversalMode::kArray) { ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; } #endif // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A{ threadblock_tile_offset.m() * Mma::Shape::kM, offset_k, }; cutlass::MatrixCoord tb_offset_B{ offset_k, threadblock_tile_offset.n() * Mma::Shape::kN }; // Compute position within threadblock int thread_idx = threadIdx.x; // Construct iterators to A and B operands typename Mma::IteratorA iterator_A( params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); typename Mma::IteratorB iterator_B( params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); int lane_idx = threadIdx.x % 32; // // Main loop // // Construct thread-scoped matrix multiply Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); typename Mma::FragmentC accumulators; accumulators.clear(); // Compute threadblock-scoped matrix multiply-add int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; // Compute threadblock-scoped matrix multiply-add mma( gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); // // Masked tile iterators constructed from members // threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); //assume identity swizzle MatrixCoord threadblock_offset( threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN ); int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); // // Construct the epilogue visitor // EpilogueVisitor epilogue_visitor( params.epilogue_visitor, shared_storage.epilogue.visitor, params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_C, params.params_D, params.ptr_C, params.ptr_D, params.ptr_Max, params.ptr_Sum, threadblock_offset, blockIdx.y *params.problem_size.m() ); if (params.mode == GemmUniversalMode::kGemm) { // Indicate which position in a serial reduction the output operator is currently updating epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); } else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); } // Construct the epilogue Epilogue epilogue( shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); // Execute the epilogue operator to update the destination tensor. epilogue(epilogue_visitor, accumulators); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace kernel } // namespace gemm } // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////