/*************************************************************************************************** * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Gemm kernel with an epilogue defined under the epilogue visitor concept */ #pragma once #include "cutlass/cutlass.h" #include "cutlass/gemm/kernel/gemm_universal.h" ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace gemm { namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// // Gemm that compute the epilogue visitor functor template < typename Mma, ///! Threadblock-scoped matrix multiply-accumulate typename Epilogue, ///! Epilogue typename ThreadblockSwizzle_ ///! Threadblock swizzling function > class GemmWithEpilogueVisitor: GemmUniversal { public: using ThreadblockSwizzle = ThreadblockSwizzle_; using Base = GemmUniversal; using Base::Base; using FusionCallbacks = typename Epilogue::FusionCallbacks; using ElementA = typename Base::ElementA; using LayoutA = typename Base::LayoutA; using ElementB = typename Base::ElementB; using LayoutB = typename Base::LayoutB; using ElementC = typename Base::ElementC; using LayoutC = typename Base::LayoutC; using ThreadblockShape = typename Mma::Shape; // // Structures // using SharedStorage = typename Base::SharedStorage; using Arguments = typename Base::Arguments; // // Structure for precomputing values in host memory and passing to kernels // /// Parameters structure struct Params : UniversalParamsBase< ThreadblockSwizzle, ThreadblockShape, ElementA, ElementB, ElementC, LayoutA, LayoutB> { using ParamsBase = UniversalParamsBase< ThreadblockSwizzle, ThreadblockShape, ElementA, ElementB, ElementC, LayoutA, LayoutB>; // // Data members // cute::Shape problem_shape; typename Mma::IteratorA::Params params_A; typename Mma::IteratorB::Params params_B; typename FusionCallbacks::Params output_op; void * ptr_A; void * ptr_B; int64_t batch_stride_A; int64_t batch_stride_B; int * ptr_gather_A_indices; int * ptr_gather_B_indices; // // Host dispatch API // /// Default constructor Params() = default; /// Constructor Params( Arguments const &args, /// GEMM application arguments int device_sms, /// Number of SMs on the device int sm_occupancy) /// Kernel SM occupancy (in thread blocks) : ParamsBase(args, device_sms, sm_occupancy), params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), output_op(FusionCallbacks::to_underlying_arguments(args.problem_size, args.epilogue, nullptr /*workspace*/)), problem_shape({args.problem_size.m(), args.problem_size.n(), args.batch_count}), ptr_A(const_cast(args.ptr_A)), ptr_B(const_cast(args.ptr_B)), batch_stride_A(args.batch_stride_A), batch_stride_B(args.batch_stride_B), ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)) { // Raise error on unsupported modes assert(args.mode != GemmUniversalMode::kGemmSplitKParallel && "Sm80 EVT does not support SplitKParallel."); assert(!(args.mode == GemmUniversalMode::kGemm && this->grid_tiled_shape.k() > 1 ) && "Sm80 EVT does not support SplitKSerial."); assert(args.mode != GemmUniversalMode::kArray && "Sm80 EVT does not support Array Gemm."); } /// Lightweight update given a subset of arguments. void update(Arguments const &args) { CUTLASS_TRACE_HOST("GemmUniversalwithVisitor::Params::update()"); // Update input pointers ptr_A = const_cast(args.ptr_A); ptr_B = const_cast(args.ptr_B); batch_stride_A = args.batch_stride_A; batch_stride_B = args.batch_stride_B; this->batch_stride_D = args.batch_stride_D; ptr_gather_A_indices = const_cast(args.ptr_gather_A_indices); ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); output_op = FusionCallbacks::to_underlying_arguments(args.problem_size, args.epilogue, nullptr /*workspace*/); problem_shape = make_shape(args.problem_size.m(), args.problem_size.n(), args.batch_count); } }; public: // // Device-only API // // Factory invocation CUTLASS_DEVICE static void invoke( Params const ¶ms, SharedStorage &shared_storage) { GemmWithEpilogueVisitor op; op(params, shared_storage); } /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { ThreadblockSwizzle threadblock_swizzle; run_with_swizzle(params, shared_storage, threadblock_swizzle); } /// Executes one GEMM with an externally-provided swizzling function CUTLASS_DEVICE void run_with_swizzle(Params const ¶ms, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) { cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { return; } int offset_k = 0; int problem_size_k = params.problem_size.k(); ElementA *ptr_A = static_cast(params.ptr_A); ElementB *ptr_B = static_cast(params.ptr_B); // // Fetch pointers based on mode. // if (params.mode == GemmUniversalMode::kGemm) { if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; } offset_k = threadblock_tile_offset.k() * params.gemm_k_size; } else if (params.mode == GemmUniversalMode::kBatched) { ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; } __syncthreads(); // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A{ threadblock_tile_offset.m() * Mma::Shape::kM, offset_k, }; cutlass::MatrixCoord tb_offset_B{ offset_k, threadblock_tile_offset.n() * Mma::Shape::kN }; // Compute position within threadblock int thread_idx = threadIdx.x; // Construct iterators to A and B operands typename Mma::IteratorA iterator_A( params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.ptr_gather_A_indices); typename Mma::IteratorB iterator_B( params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B, params.ptr_gather_B_indices); // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; // // Main loop // // Construct thread-scoped matrix multiply Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); typename Mma::FragmentC accumulators; accumulators.clear(); // Compute threadblock-scoped matrix multiply-add int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; // Compute threadblock-scoped matrix multiply-add mma( gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); // // Epilogue // threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); Epilogue epilogue( params.output_op, shared_storage.epilogue, thread_idx, warp_idx, lane_idx); // Execute the epilogue operator to update the destination tensor. epilogue(accumulators, threadblock_tile_offset, params.problem_shape, thread_idx); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace kernel } // namespace gemm } // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////