diff --git a/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt b/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt index 3932234c..a20fa2b1 100644 --- a/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt +++ b/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt @@ -33,3 +33,8 @@ cutlass_example_add_executable( ampere_sparse_tensorop_gemm.cu ) +cutlass_example_add_executable( + 15_ampere_sparse_tensorop_gemm_with_visitor + ampere_sparse_tensorop_gemm_with_visitor.cu + ) + diff --git a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu new file mode 100644 index 00000000..972f8b69 --- /dev/null +++ b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu @@ -0,0 +1,379 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** +Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere +architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4. + +Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of +meta data is different for every data types. CUTLASS templates can automatically infer it based on +input A and B. Check code below. + +Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers +efficiently. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse_with_visitor.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" + +#include "helper.h" + +// The code section below describes datatype for input, output matrices and computation between +// elements in input matrices. +using ElementAccumulator = int32_t; // <- data type of accumulator +using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations +using ElementInputA = int8_t; // <- data type of elements in input matrix A +using ElementInputB = int8_t; // <- data type of elements in input matrix B +using ElementOutput = int32_t; // <- data type of elements in output matrix D + +// The code section below describes matrix layout of input and output matrices. Row Major for +// Matrix A, Column Major for Matrix B and Row Major for Matrix C +using LayoutInputA = cutlass::layout::RowMajor; +using LayoutInputB = cutlass::layout::ColumnMajor; +using LayoutOutput = cutlass::layout::RowMajor; + +// The number of elements per vectorized memory access. +constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits::value; +constexpr int AlignmentInputB = 128 / cutlass::sizeof_bits::value; +constexpr int AlignmentComputeEpilogue = 128 / cutlass::sizeof_bits::value; +constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassTensorOp; + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm80; + +// This code section describes the tile size a thread block will compute +using ShapeMMAThreadBlock = + cutlass::gemm::GemmShape<128, 128, 128>; // <- threadblock tile M = 128, N = 128, K = 128 +// This code section describes tile size a warp will compute +using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 128>; // <- warp tile M = 64, N = 64, K = 128 +// This code section describes the size of MMA op +using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 64>; // <- MMA Op tile M = 16, N = 8, K = 64 + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + +using Operator = cutlass::arch::OpMultiplyAdd; + +// Number of pipelines you want to use +constexpr int NumStages = 3; + +constexpr auto NumEVTEpilogueStages = 1; + +using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + +using BiasTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + ShapeMMAThreadBlock, + ShapeMMAWarp, + ElementComputeEpilogue, + AlignmentComputeEpilogue, + NumEVTEpilogueStages>; + +using Bias = cutlass::epilogue::threadblock::VisitorAuxLoad< + BiasTileThreadMap, + ElementComputeEpilogue, + cute::Stride>; + +using ApplyBias = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::plus, ElementComputeEpilogue, ElementComputeEpilogue, + cutlass::FloatRoundStyle::round_to_nearest>; + +using EVTApplyBias = cutlass::epilogue::threadblock::Sm80EVT< + ApplyBias, + Accum, + Bias>; + +using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + ShapeMMAThreadBlock, + ShapeMMAWarp, + ElementOutput, + AlignmentOutput, + NumEVTEpilogueStages>; + +using Output = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementOutput, + cutlass::FloatRoundStyle::round_to_nearest, + cute::Stride>; + +using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT< + Output, + EVTApplyBias>; + +// Use element type in EVT with the smallest bitwidth as ElementC. +using ElementC = ElementComputeEpilogue; +using LayoutC = LayoutOutput; + +using Gemm = + typename cutlass::gemm::device::SparseGemmWithVisitor< + ElementInputA, LayoutInputA, + ElementInputB, LayoutInputB, + ElementC, LayoutC, + ElementAccumulator, + MMAOp, + SmArch, + ShapeMMAThreadBlock, + ShapeMMAWarp, + ShapeMMAOp, + EVTOutput, + SwizzleThreadBlock, + NumStages, + AlignmentInputA, + AlignmentInputB, + Operator, + NumEVTEpilogueStages>; + +// Data type and layout of meta data matrix E can be inferred from template Gemm. +using ElementInputE = typename Gemm::GemmKernel::ElementE; +using LayoutInputE = cutlass::layout::RowMajor; +using ReorderedLayoutInputE = typename Gemm::GemmKernel::LayoutE; + +// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h +// 50% Sparsity on Ampere +constexpr int kSparse = Gemm::kSparse; +// How many elements of A are covered per ElementE +constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; +// The size of individual meta data +constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + +int run() { + + const int length_m = 512; + const int length_n = 512; + const int length_k = 1024; + + // Create a tuple of problem size for matrix multiplication + cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); + + // Initialize tensors using CUTLASS helper functions + cutlass::HostTensor tensor_a( + cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); // <- Create matrix A with dimensions M x (K / 2) + cutlass::HostTensor tensor_a_uncompressed( + problem_size.mk()); // <- Create uncompressed matrix A with dimensions M x K for reference computing + + cutlass::HostTensor tensor_b( + problem_size.kn()); // <- Create matrix B with dimensions K x N + cutlass::HostTensor tensor_c( + problem_size.mn()); // <- Create matrix C with dimensions M x N + cutlass::HostTensor tensor_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // CUTLASS kernel + cutlass::HostTensor tensor_ref_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // reference kernel + + // Create matrix E with dimensions M x (K / 2 / kElementsPerElementE). This one is used by reference computing. + cutlass::HostTensor tensor_e( + cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + // Same size as the above. The above one needs to be reordered and stored in this one. + cutlass::HostTensor tensor_e_reordered( + cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + + // Fill input and output matrices on host using CUTLASS helper functions + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + ElementInputA(8), + ElementInputA(-8), + 0); // <- Fill matrix A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_b.host_view(), + 1, + ElementInputB(8), + ElementInputB(-8), + 0); // <- Fill matrix B on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(8), + ElementOutput(-8), + 0); // <- Fill matrix C on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomSparseMeta( + tensor_e.host_view(), + 1, + kMetaSizeInBits); // <- Fill matrix E on host with uniform-distribution random meta data + cutlass::reference::host::TensorFill( + tensor_d.host_view()); // <- fill matrix D on host with zeros + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros + + // Reorder the meta data matrix so that we can use ldmatrix to load them to tensor core + // instructions. + cutlass::reorder_meta(tensor_e_reordered.host_ref(), tensor_e.host_ref(), + {problem_size.m(), problem_size.n(), + problem_size.k() / kSparse / kElementsPerElementE}); + + // Copy data from host to GPU + tensor_a.sync_device(); + tensor_b.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + tensor_e_reordered.sync_device(); + tensor_ref_d.sync_device(); + + // Initialize alpha and beta for dot product computation + ElementComputeEpilogue alpha = ElementComputeEpilogue(1); + ElementComputeEpilogue beta = ElementComputeEpilogue(1); + + typename Bias::Arguments bias_arguments{ + tensor_c.device_data(), + ElementComputeEpilogue(0), + {problem_size.n(), cute::_1{}, problem_size.mn().product()} + }; + typename Output::Arguments output_arguments{ + tensor_d.device_data(), + {problem_size.n(), cute::_1{}, problem_size.mn().product()} + }; + typename EVTOutput::Arguments callback_arguments{ + { + {}, // Accum + bias_arguments, // Bias + {} // ApplyBias + }, // EVTApplyBias + output_arguments // Output + }; // EVTOutput + + // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch + // instantiated CUTLASS kernel + typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication + tensor_a.device_ref(), // <- reference to matrix A on device + tensor_b.device_ref(), // <- reference to matrix B on device + tensor_e_reordered.device_ref(), // <- reference to matrix E on device + callback_arguments}; // <- epilogue arguments + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + + // Launch initialized CUTLASS kernel + status = gemm_op(); + CUTLASS_CHECK(status); + + // uncompress tensor_a based on meta data tensor_e. We need it for reference computing. + cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(), + tensor_e.host_ref(), problem_size.m(), problem_size.k()); + + // Create instantiation for host reference gemm kernel + cutlass::reference::host::Gemm + gemm_host; + + // Launch host reference gemm kernel + gemm_host(problem_size, + alpha, + tensor_a_uncompressed.host_ref(), + tensor_b.host_ref(), + beta, + tensor_c.host_ref(), + tensor_ref_d.host_ref()); + + // Copy output data from CUTLASS host for comparison + tensor_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); + + std::cout << (passed ? "Passed" : "Failed") << std::endl; + + return (passed ? 0 : -1); +} + +int main() { + + bool notSupported = false; + + // Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available + // in CUDA 11.1. + // + // CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples. + + if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl; + notSupported = true; + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major * 10 + props.minor < 80) { + std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." + << std::endl; + notSupported = true; + } + + if (notSupported) { + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + return run(); +} diff --git a/include/cutlass/gemm/device/gemm_sparse_with_visitor.h b/include/cutlass/gemm/device/gemm_sparse_with_visitor.h new file mode 100644 index 00000000..942ba1f5 --- /dev/null +++ b/include/cutlass/gemm/device/gemm_sparse_with_visitor.h @@ -0,0 +1,342 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/sparse_gemm.h" + +#include "cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! Sparse GEMM with visitor + */ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename FusionCallbacks_ = + typename cutlass::epilogue::threadblock::detail::EmptyCallbacks, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Number of stages used in the pipelined epilogue + int EpilogueStages = 1> +class SparseGemmWithVisitor { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using FusionCallbacks = FusionCallbacks_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + using MathOperator = Operator; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + + /// Define the kernel + using GemmKernel = typename kernel::DefaultSparseGemmWithVisitor< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + FusionCallbacks, + ThreadblockSwizzle, + kStages, + Operator, + EpilogueStages + >::GemmKernel; + + using ElementE = typename GemmKernel::ElementE; + + using LayoutE = typename GemmKernel::LayoutE; + + static int const kAlignmentE = 128 / sizeof_bits::value; + + static int const kSparse = GemmKernel::kSparse; + static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits; + static int const kElementsPerElementE = GemmKernel::kElementsPerElementE; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_E; + typename FusionCallbacks::Arguments epilogue; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0) { + + } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_E_, + typename FusionCallbacks::Arguments epilogue_ = + typename FusionCallbacks::Arguments() + ): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_E(ref_E_), + epilogue(epilogue_) { + + } + }; + +private: + + /// Kernel parameters object + typename GemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + SparseGemmWithVisitor() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + Status status = GemmKernel::can_implement( + args.problem_size, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + cutlass::TensorRef(), // It only matters that it's empty. + cutlass::TensorRef(), // Same as above. + args.ref_E.non_const_ref() + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + constexpr int SplitKSlices = 1; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + SplitKSlices); + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_E.non_const_ref(), + args.epilogue + }; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_E.reset(args.ref_E.non_const_ref().data()); + params_.output_op = args.epilogue; + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h b/include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h new file mode 100644 index 00000000..a0d602da --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h @@ -0,0 +1,198 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Default sparse GEMM with visitor. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm.h" +#include "cutlass/gemm/kernel/default_gemm_sparse.h" +#include "cutlass/gemm/kernel/sparse_gemm_with_visitor.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h" +#include "cutlass/gemm/threadblock/default_sparse_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" +#endif //CUTLASS_ARCH_WMMA_ENABLED + + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename FusionCallbacks, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Number of stages used in the pipelined epilogue + int EpilogueStages = 1> +struct DefaultSparseGemmWithVisitor; + +//////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename FusionCallbacks, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Number of stages used in the pipelined epilogue + int EpilogueStages> +struct DefaultSparseGemmWithVisitor { + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultSparseMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator>::ThreadblockMma; + + static constexpr int kAlignmentC = 128 / sizeof_bits::value;; + using ElementEpilogue = ElementAccumulator; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + using EpilogueOutputOp = + typename epilogue::thread::LinearCombination< + ElementC, kAlignmentC, + ElementAccumulator, ElementEpilogue>; + using BaseEpilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, + EpilogueOutputOp, EpilogueOutputOp::kCount>::Epilogue; + + // Define epilogue + using Epilogue = cutlass::epilogue::threadblock::EpilogueWithVisitorCallbacks< + BaseEpilogue, + FusionCallbacks, + EpilogueStages>; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::SparseGemmWithEpilogueVisitor; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + diff --git a/include/cutlass/gemm/kernel/params_sparse_base.h b/include/cutlass/gemm/kernel/params_sparse_base.h new file mode 100644 index 00000000..743954da --- /dev/null +++ b/include/cutlass/gemm/kernel/params_sparse_base.h @@ -0,0 +1,117 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Base functionality for common types of sparse GEMM kernel parameters +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure +template < + typename ThreadblockSwizzle, + typename ParamsA, + typename TensorRefA, + typename ParamsB, + typename TensorRefB, + typename ParamsE, + typename TensorRefE> +struct SparseParamsBase +{ + // + // Data members + // + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + ParamsA params_A; + TensorRefA ref_A; + ParamsB params_B; + TensorRefB ref_B; + ParamsE params_E; + TensorRefE ref_E; + int gemm_k_iterations; + int gemm_k_size; + + // + // Host dispatch API + // + + /// Default constructor + CUTLASS_HOST_DEVICE + SparseParamsBase() : swizzle_log_tile(0), gemm_k_iterations(0), gemm_k_size(0) { } + + + /// Constructor + CUTLASS_HOST_DEVICE + SparseParamsBase( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + TensorRefA ref_A, + TensorRefB ref_B, + TensorRefE ref_E, + int const mma_shape_k) + : + problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(ref_A.layout()), + ref_A(ref_A), + params_B(ref_B.layout()), + ref_B(ref_B), + params_E(ref_E.layout()), + ref_E(ref_E) + { + int total_gemm_k_iterations = (problem_size.k() + mma_shape_k - 1) / mma_shape_k; + int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * mma_shape_k; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/sparse_gemm.h b/include/cutlass/gemm/kernel/sparse_gemm.h index 1964fba8..c87f2098 100644 --- a/include/cutlass/gemm/kernel/sparse_gemm.h +++ b/include/cutlass/gemm/kernel/sparse_gemm.h @@ -37,6 +37,7 @@ #include "cutlass/cutlass.h" #include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/params_sparse_base.h" #include "cutlass/matrix_coord.h" #include "cutlass/semaphore.h" @@ -74,66 +75,58 @@ struct SparseGemm { using WarpCount = typename Mma::WarpCount; static int const kThreadCount = 32 * WarpCount::kCount; + using ParamsA = typename Mma::IteratorA::Params; + using TensorRefA = typename Mma::IteratorA::TensorRef; + using ParamsB = typename Mma::IteratorB::Params; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using ParamsE = typename Mma::IteratorE::Params; + using TensorRefE = typename Mma::IteratorE::TensorRef; + /// Parameters structure - struct Params { - cutlass::gemm::GemmCoord problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorA::TensorRef ref_A; - typename Mma::IteratorB::Params params_B; - typename Mma::IteratorB::TensorRef ref_B; + struct Params : public SparseParamsBase< + ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB, + ParamsE, TensorRefE> { + + using Base = SparseParamsBase< + ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB, + ParamsE, TensorRefE>; + + // + // Data members + // typename Epilogue::OutputTileIterator::Params params_C; typename Epilogue::OutputTileIterator::TensorRef ref_C; typename Epilogue::OutputTileIterator::Params params_D; typename Epilogue::OutputTileIterator::TensorRef ref_D; - typename Mma::IteratorE::Params params_E; - typename Mma::IteratorE::TensorRef ref_E; typename OutputOp::Params output_op; int *semaphore; - int gemm_k_iterations; - int gemm_k_size; // // Methods // CUTLASS_HOST_DEVICE - Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { } + Params() { } CUTLASS_HOST_DEVICE Params( cutlass::gemm::GemmCoord const & problem_size, cutlass::gemm::GemmCoord const & grid_tiled_shape, - typename Mma::IteratorA::TensorRef ref_A, - typename Mma::IteratorB::TensorRef ref_B, + TensorRefA ref_A, + TensorRefB ref_B, typename Epilogue::OutputTileIterator::TensorRef ref_C, typename Epilogue::OutputTileIterator::TensorRef ref_D, - typename Mma::IteratorE::TensorRef ref_E, + TensorRefE ref_E, typename OutputOp::Params output_op = typename OutputOp::Params(), int *workspace = nullptr ): - problem_size(problem_size), - grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), - params_A(ref_A.layout()), - ref_A(ref_A), - params_B(ref_B.layout()), - ref_B(ref_B), + Base(problem_size, grid_tiled_shape, ref_A, ref_B, ref_E, Mma::Shape::kK), params_C(ref_C.layout()), ref_C(ref_C), params_D(ref_D.layout()), ref_D(ref_D), - params_E(ref_E.layout()), - ref_E(ref_E), - output_op(output_op) { - - int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); - - gemm_k_size = gemm_k_iterations * Mma::Shape::kK; - - semaphore = workspace; + output_op(output_op), + semaphore(workspace) { } }; diff --git a/include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h b/include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h new file mode 100644 index 00000000..990a6c36 --- /dev/null +++ b/include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h @@ -0,0 +1,237 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Sparse GEMM with visitor. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/kernel/sparse_gemm.h" +#include "cutlass/gemm/kernel/params_sparse_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Sparse Gemm that compute the epilogue visitor functor +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct SparseGemmWithEpilogueVisitor : public SparseGemm { + + using Base = SparseGemm; + + using Mma = Mma_; + using Epilogue = Epilogue_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using FusionCallbacks = typename Epilogue::FusionCallbacks; + + using ParamsA = typename Mma::IteratorA::Params; + using TensorRefA = typename Mma::IteratorA::TensorRef; + using ParamsB = typename Mma::IteratorB::Params; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using ParamsE = typename Mma::IteratorE::Params; + using TensorRefE = typename Mma::IteratorE::TensorRef; + + static int const kSparse = Base::kSparse; + static int const kElementsPerElementE = Base::kElementsPerElementE; + using SharedStorage = typename Base::SharedStorage; + + /// Parameters structure + struct Params : public SparseParamsBase< + ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB, + ParamsE, TensorRefE> { + + using Base = SparseParamsBase< + ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB, + ParamsE, TensorRefE>; + + // + // Data members + // + + typename FusionCallbacks::Params output_op; + cute::Shape problem_shape; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorE::TensorRef ref_E, + typename FusionCallbacks::Arguments output_op = typename FusionCallbacks::Arguments() + ): + Base(problem_size, grid_tiled_shape, ref_A, ref_B, ref_E, Mma::Shape::kK), + output_op(FusionCallbacks::to_underlying_arguments(problem_size, output_op, nullptr /*workspace*/)), + problem_shape(problem_size.m(), problem_size.n(), 1) { + } + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + SparseGemmWithEpilogueVisitor() { } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size / kSparse, + }; + + cutlass::MatrixCoord tb_offset_B{ + threadblock_tile_offset.k() * params.gemm_k_size, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + cutlass::MatrixCoord tb_offset_E{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size / kSparse, + }; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min( + params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A, B, and E operands + typename Mma::IteratorA iterator_A( + params.params_A, + params.ref_A.data(), + {params.problem_size.m(), problem_size_k / kSparse}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + params.ref_B.data(), + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + typename Mma::IteratorE iterator_E( + params.params_E, params.ref_E.data(), + {params.problem_size.m(), + problem_size_k / kSparse / kElementsPerElementE}, + thread_idx, tb_offset_E); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators); + } + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Epilogue + // + + Epilogue epilogue( + params.output_op, + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(accumulators, threadblock_tile_offset, params.problem_shape, thread_idx); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass