Add support for sparse GEMM with visitor epilogue (#1189)

* Add support for sparse GEMM with visitor epilogue

* Refactor changes at the kernel level
This commit is contained in:
Aleksandar Samardžić 2024-01-04 18:38:11 +01:00 committed by GitHub
parent 8236f30675
commit 5c756eb774
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1304 additions and 33 deletions

View File

@ -33,3 +33,8 @@ cutlass_example_add_executable(
ampere_sparse_tensorop_gemm.cu
)
cutlass_example_add_executable(
15_ampere_sparse_tensorop_gemm_with_visitor
ampere_sparse_tensorop_gemm_with_visitor.cu
)

View File

@ -0,0 +1,379 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere
architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4.
Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of
meta data is different for every data types. CUTLASS templates can automatically infer it based on
input A and B. Check code below.
Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers
efficiently.
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_sparse_with_visitor.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/host_reorder.h"
#include "cutlass/util/host_uncompress.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "helper.h"
// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = int32_t; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = int8_t; // <- data type of elements in input matrix A
using ElementInputB = int8_t; // <- data type of elements in input matrix B
using ElementOutput = int32_t; // <- data type of elements in output matrix D
// The code section below describes matrix layout of input and output matrices. Row Major for
// Matrix A, Column Major for Matrix B and Row Major for Matrix C
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// The number of elements per vectorized memory access.
constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
constexpr int AlignmentInputB = 128 / cutlass::sizeof_bits<ElementInputB>::value;
constexpr int AlignmentComputeEpilogue = 128 / cutlass::sizeof_bits<ElementComputeEpilogue>::value;
constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassTensorOp;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm80;
// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock =
cutlass::gemm::GemmShape<128, 128, 128>; // <- threadblock tile M = 128, N = 128, K = 128
// This code section describes tile size a warp will compute
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 128>; // <- warp tile M = 64, N = 64, K = 128
// This code section describes the size of MMA op
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 64>; // <- MMA Op tile M = 16, N = 8, K = 64
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
using Operator = cutlass::arch::OpMultiplyAdd;
// Number of pipelines you want to use
constexpr int NumStages = 3;
constexpr auto NumEVTEpilogueStages = 1;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using BiasTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ShapeMMAThreadBlock,
ShapeMMAWarp,
ElementComputeEpilogue,
AlignmentComputeEpilogue,
NumEVTEpilogueStages>;
using Bias = cutlass::epilogue::threadblock::VisitorAuxLoad<
BiasTileThreadMap,
ElementComputeEpilogue,
cute::Stride<int64_t, cute::_1, int64_t>>;
using ApplyBias = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::plus, ElementComputeEpilogue, ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTApplyBias = cutlass::epilogue::threadblock::Sm80EVT<
ApplyBias,
Accum,
Bias>;
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ShapeMMAThreadBlock,
ShapeMMAWarp,
ElementOutput,
AlignmentOutput,
NumEVTEpilogueStages>;
using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementOutput,
cutlass::FloatRoundStyle::round_to_nearest,
cute::Stride<int64_t, cute::_1, int64_t>>;
using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
Output,
EVTApplyBias>;
// Use element type in EVT with the smallest bitwidth as ElementC.
using ElementC = ElementComputeEpilogue;
using LayoutC = LayoutOutput;
using Gemm =
typename cutlass::gemm::device::SparseGemmWithVisitor<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementC, LayoutC,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EVTOutput,
SwizzleThreadBlock,
NumStages,
AlignmentInputA,
AlignmentInputB,
Operator,
NumEVTEpilogueStages>;
// Data type and layout of meta data matrix E can be inferred from template Gemm.
using ElementInputE = typename Gemm::GemmKernel::ElementE;
using LayoutInputE = cutlass::layout::RowMajor;
using ReorderedLayoutInputE = typename Gemm::GemmKernel::LayoutE;
// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h
// 50% Sparsity on Ampere
constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
int run() {
const int length_m = 512;
const int length_n = 512;
const int length_k = 1024;
// Create a tuple of problem size for matrix multiplication
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
// Initialize tensors using CUTLASS helper functions
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); // <- Create matrix A with dimensions M x (K / 2)
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a_uncompressed(
problem_size.mk()); // <- Create uncompressed matrix A with dimensions M x K for reference computing
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
problem_size.kn()); // <- Create matrix B with dimensions K x N
cutlass::HostTensor<ElementComputeEpilogue, LayoutOutput> tensor_c(
problem_size.mn()); // <- Create matrix C with dimensions M x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// CUTLASS kernel
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// reference kernel
// Create matrix E with dimensions M x (K / 2 / kElementsPerElementE). This one is used by reference computing.
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
// Same size as the above. The above one needs to be reordered and stored in this one.
cutlass::HostTensor<ElementInputE, ReorderedLayoutInputE> tensor_e_reordered(
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
// Fill input and output matrices on host using CUTLASS helper functions
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(8),
ElementInputA(-8),
0); // <- Fill matrix A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(8),
ElementInputB(-8),
0); // <- Fill matrix B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(8),
ElementOutput(-8),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomSparseMeta(
tensor_e.host_view(),
1,
kMetaSizeInBits); // <- Fill matrix E on host with uniform-distribution random meta data
cutlass::reference::host::TensorFill(
tensor_d.host_view()); // <- fill matrix D on host with zeros
cutlass::reference::host::TensorFill(
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
// Reorder the meta data matrix so that we can use ldmatrix to load them to tensor core
// instructions.
cutlass::reorder_meta(tensor_e_reordered.host_ref(), tensor_e.host_ref(),
{problem_size.m(), problem_size.n(),
problem_size.k() / kSparse / kElementsPerElementE});
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
tensor_e_reordered.sync_device();
tensor_ref_d.sync_device();
// Initialize alpha and beta for dot product computation
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(1);
typename Bias::Arguments bias_arguments{
tensor_c.device_data(),
ElementComputeEpilogue(0),
{problem_size.n(), cute::_1{}, problem_size.mn().product()}
};
typename Output::Arguments output_arguments{
tensor_d.device_data(),
{problem_size.n(), cute::_1{}, problem_size.mn().product()}
};
typename EVTOutput::Arguments callback_arguments{
{
{}, // Accum
bias_arguments, // Bias
{} // ApplyBias
}, // EVTApplyBias
output_arguments // Output
}; // EVTOutput
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
// instantiated CUTLASS kernel
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
tensor_a.device_ref(), // <- reference to matrix A on device
tensor_b.device_ref(), // <- reference to matrix B on device
tensor_e_reordered.device_ref(), // <- reference to matrix E on device
callback_arguments}; // <- epilogue arguments
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm_op;
// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(status);
// Launch initialized CUTLASS kernel
status = gemm_op();
CUTLASS_CHECK(status);
// uncompress tensor_a based on meta data tensor_e. We need it for reference computing.
cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(),
tensor_e.host_ref(), problem_size.m(), problem_size.k());
// Create instantiation for host reference gemm kernel
cutlass::reference::host::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementComputeEpilogue,
typename Gemm::Operator>
gemm_host;
// Launch host reference gemm kernel
gemm_host(problem_size,
alpha,
tensor_a_uncompressed.host_ref(),
tensor_b.host_ref(),
beta,
tensor_c.host_ref(),
tensor_ref_d.host_ref());
// Copy output data from CUTLASS host for comparison
tensor_d.sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::host::TensorEquals(
tensor_d.host_view(),
tensor_ref_d.host_view());
std::cout << (passed ? "Passed" : "Failed") << std::endl;
return (passed ? 0 : -1);
}
int main() {
bool notSupported = false;
// Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available
// in CUDA 11.1.
//
// CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) {
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl;
notSupported = true;
}
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (props.major * 10 + props.minor < 80) {
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
<< std::endl;
notSupported = true;
}
if (notSupported) {
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
return run();
}

View File

@ -0,0 +1,342 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/kernel/sparse_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
/*! Sparse GEMM with visitor
*/
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator_ = ElementC_,
/// Operator class tag
typename OperatorClass_ = arch::OpClassSimt,
/// Tag indicating architecture to tune for
typename ArchTag_ = arch::Sm70,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::InstructionShape,
/// Epilogue output operator
typename FusionCallbacks_ =
typename cutlass::epilogue::threadblock::detail::EmptyCallbacks,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ =
typename threadblock::GemmIdentityThreadblockSwizzle<>,
/// Number of stages used in the pipelined mainloop
int Stages =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kStages,
/// Access granularity of A matrix in units of elements
int AlignmentA =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::Operator,
/// Number of stages used in the pipelined epilogue
int EpilogueStages = 1>
class SparseGemmWithVisitor {
public:
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using FusionCallbacks = FusionCallbacks_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
using MathOperator = Operator;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
/// Define the kernel
using GemmKernel = typename kernel::DefaultSparseGemmWithVisitor<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementC,
LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
FusionCallbacks,
ThreadblockSwizzle,
kStages,
Operator,
EpilogueStages
>::GemmKernel;
using ElementE = typename GemmKernel::ElementE;
using LayoutE = typename GemmKernel::LayoutE;
static int const kAlignmentE = 128 / sizeof_bits<ElementE>::value;
static int const kSparse = GemmKernel::kSparse;
static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits;
static int const kElementsPerElementE = GemmKernel::kElementsPerElementE;
/// Argument structure
struct Arguments {
//
// Data members
//
GemmCoord problem_size;
TensorRef<ElementA const, LayoutA> ref_A;
TensorRef<ElementB const, LayoutB> ref_B;
TensorRef<ElementE const, LayoutE> ref_E;
typename FusionCallbacks::Arguments epilogue;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments(): problem_size(0, 0, 0) {
}
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord problem_size_,
TensorRef<ElementA const, LayoutA> ref_A_,
TensorRef<ElementB const, LayoutB> ref_B_,
TensorRef<ElementE, LayoutE> ref_E_,
typename FusionCallbacks::Arguments epilogue_ =
typename FusionCallbacks::Arguments()
):
problem_size(problem_size_),
ref_A(ref_A_),
ref_B(ref_B_),
ref_E(ref_E_),
epilogue(epilogue_) {
}
};
private:
/// Kernel parameters object
typename GemmKernel::Params params_;
public:
/// Constructs the GEMM.
SparseGemmWithVisitor() { }
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
Status status = GemmKernel::can_implement(
args.problem_size,
args.ref_A.non_const_ref(),
args.ref_B.non_const_ref(),
cutlass::TensorRef<ElementC, LayoutC>(), // It only matters that it's empty.
cutlass::TensorRef<ElementC, LayoutC>(), // Same as above.
args.ref_E.non_const_ref()
);
if (status != Status::kSuccess) {
return status;
}
return Status::kSuccess;
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
size_t bytes = 0;
return bytes;
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
constexpr int SplitKSlices = 1;
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
SplitKSlices);
// Initialize the Params structure
params_ = typename GemmKernel::Params{
args.problem_size,
grid_shape,
args.ref_A.non_const_ref(),
args.ref_B.non_const_ref(),
args.ref_E.non_const_ref(),
args.epilogue
};
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
cudaError_t result = cudaFuncSetAttribute(Kernel<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
params_.ref_A.reset(args.ref_A.non_const_ref().data());
params_.ref_B.reset(args.ref_B.non_const_ref().data());
params_.ref_E.reset(args.ref_E.non_const_ref().data());
params_.output_op = args.epilogue;
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(GemmKernel::kThreadCount, 1, 1);
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
cudaError_t result = cudaGetLastError();
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
} // namespace device
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,198 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Default sparse GEMM with visitor.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/wmma.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm.h"
#include "cutlass/gemm/kernel/default_gemm_sparse.h"
#include "cutlass/gemm/kernel/sparse_gemm_with_visitor.h"
#include "cutlass/gemm/kernel/gemm_pipelined.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h"
#include "cutlass/gemm/threadblock/default_sparse_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
#include "cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h"
#endif //CUTLASS_ARCH_WMMA_ENABLED
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename FusionCallbacks,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator,
/// Number of stages used in the pipelined epilogue
int EpilogueStages = 1>
struct DefaultSparseGemmWithVisitor;
////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Ampere Architecture
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of A matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename FusionCallbacks,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator,
/// Number of stages used in the pipelined epilogue
int EpilogueStages>
struct DefaultSparseGemmWithVisitor<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementC, LayoutC, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape, WarpShape, InstructionShape,
FusionCallbacks, ThreadblockSwizzle, Stages, Operator,
EpilogueStages> {
/// Define the threadblock-scoped matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultSparseMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape, WarpShape, InstructionShape, Stages,
Operator>::ThreadblockMma;
static constexpr int kAlignmentC = 128 / sizeof_bits<ElementC>::value;;
using ElementEpilogue = ElementAccumulator;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
using EpilogueOutputOp =
typename epilogue::thread::LinearCombination<
ElementC, kAlignmentC,
ElementAccumulator, ElementEpilogue>;
using BaseEpilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape, typename Mma::Operator, kPartitionsK,
EpilogueOutputOp, EpilogueOutputOp::kCount>::Epilogue;
// Define epilogue
using Epilogue = cutlass::epilogue::threadblock::EpilogueWithVisitorCallbacks<
BaseEpilogue,
FusionCallbacks,
EpilogueStages>;
/// Define the kernel-level GEMM operator.
using GemmKernel = kernel::SparseGemmWithEpilogueVisitor<Mma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,117 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Base functionality for common types of sparse GEMM kernel parameters
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Parameters structure
template <
typename ThreadblockSwizzle,
typename ParamsA,
typename TensorRefA,
typename ParamsB,
typename TensorRefB,
typename ParamsE,
typename TensorRefE>
struct SparseParamsBase
{
//
// Data members
//
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
ParamsA params_A;
TensorRefA ref_A;
ParamsB params_B;
TensorRefB ref_B;
ParamsE params_E;
TensorRefE ref_E;
int gemm_k_iterations;
int gemm_k_size;
//
// Host dispatch API
//
/// Default constructor
CUTLASS_HOST_DEVICE
SparseParamsBase() : swizzle_log_tile(0), gemm_k_iterations(0), gemm_k_size(0) { }
/// Constructor
CUTLASS_HOST_DEVICE
SparseParamsBase(
cutlass::gemm::GemmCoord const & problem_size,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
TensorRefA ref_A,
TensorRefB ref_B,
TensorRefE ref_E,
int const mma_shape_k)
:
problem_size(problem_size),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
params_A(ref_A.layout()),
ref_A(ref_A),
params_B(ref_B.layout()),
ref_B(ref_B),
params_E(ref_E.layout()),
ref_E(ref_E)
{
int total_gemm_k_iterations = (problem_size.k() + mma_shape_k - 1) / mma_shape_k;
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
gemm_k_size = gemm_k_iterations * mma_shape_k;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -37,6 +37,7 @@
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/params_sparse_base.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
@ -74,66 +75,58 @@ struct SparseGemm {
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
using ParamsA = typename Mma::IteratorA::Params;
using TensorRefA = typename Mma::IteratorA::TensorRef;
using ParamsB = typename Mma::IteratorB::Params;
using TensorRefB = typename Mma::IteratorB::TensorRef;
using ParamsE = typename Mma::IteratorE::Params;
using TensorRefE = typename Mma::IteratorE::TensorRef;
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::Params params_B;
typename Mma::IteratorB::TensorRef ref_B;
struct Params : public SparseParamsBase<
ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB,
ParamsE, TensorRefE> {
using Base = SparseParamsBase<
ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB,
ParamsE, TensorRefE>;
//
// Data members
//
typename Epilogue::OutputTileIterator::Params params_C;
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::Params params_D;
typename Epilogue::OutputTileIterator::TensorRef ref_D;
typename Mma::IteratorE::Params params_E;
typename Mma::IteratorE::TensorRef ref_E;
typename OutputOp::Params output_op;
int *semaphore;
int gemm_k_iterations;
int gemm_k_size;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { }
Params() { }
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmCoord const & problem_size,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
TensorRefA ref_A,
TensorRefB ref_B,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D,
typename Mma::IteratorE::TensorRef ref_E,
TensorRefE ref_E,
typename OutputOp::Params output_op = typename OutputOp::Params(),
int *workspace = nullptr
):
problem_size(problem_size),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
params_A(ref_A.layout()),
ref_A(ref_A),
params_B(ref_B.layout()),
ref_B(ref_B),
Base(problem_size, grid_tiled_shape, ref_A, ref_B, ref_E, Mma::Shape::kK),
params_C(ref_C.layout()),
ref_C(ref_C),
params_D(ref_D.layout()),
ref_D(ref_D),
params_E(ref_E.layout()),
ref_E(ref_E),
output_op(output_op) {
int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
semaphore = workspace;
output_op(output_op),
semaphore(workspace) {
}
};

View File

@ -0,0 +1,237 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Sparse GEMM with visitor.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/kernel/sparse_gemm.h"
#include "cutlass/gemm/kernel/params_sparse_base.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
// Sparse Gemm that compute the epilogue visitor functor
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct SparseGemmWithEpilogueVisitor : public SparseGemm<Mma_, Epilogue_, ThreadblockSwizzle_, false> {
using Base = SparseGemm<Mma_, Epilogue_, ThreadblockSwizzle_, false>;
using Mma = Mma_;
using Epilogue = Epilogue_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using FusionCallbacks = typename Epilogue::FusionCallbacks;
using ParamsA = typename Mma::IteratorA::Params;
using TensorRefA = typename Mma::IteratorA::TensorRef;
using ParamsB = typename Mma::IteratorB::Params;
using TensorRefB = typename Mma::IteratorB::TensorRef;
using ParamsE = typename Mma::IteratorE::Params;
using TensorRefE = typename Mma::IteratorE::TensorRef;
static int const kSparse = Base::kSparse;
static int const kElementsPerElementE = Base::kElementsPerElementE;
using SharedStorage = typename Base::SharedStorage;
/// Parameters structure
struct Params : public SparseParamsBase<
ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB,
ParamsE, TensorRefE> {
using Base = SparseParamsBase<
ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB,
ParamsE, TensorRefE>;
//
// Data members
//
typename FusionCallbacks::Params output_op;
cute::Shape<int32_t,int32_t,int32_t> problem_shape;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmCoord const & problem_size,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
typename Mma::IteratorE::TensorRef ref_E,
typename FusionCallbacks::Arguments output_op = typename FusionCallbacks::Arguments()
):
Base(problem_size, grid_tiled_shape, ref_A, ref_B, ref_E, Mma::Shape::kK),
output_op(FusionCallbacks::to_underlying_arguments(problem_size, output_op, nullptr /*workspace*/)),
problem_shape(problem_size.m(), problem_size.n(), 1) {
}
};
//
// Methods
//
CUTLASS_HOST_DEVICE
SparseGemmWithEpilogueVisitor() { }
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size / kSparse,
};
cutlass::MatrixCoord tb_offset_B{
threadblock_tile_offset.k() * params.gemm_k_size,
threadblock_tile_offset.n() * Mma::Shape::kN
};
cutlass::MatrixCoord tb_offset_E{
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size / kSparse,
};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k = min(
params.problem_size.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A, B, and E operands
typename Mma::IteratorA iterator_A(
params.params_A,
params.ref_A.data(),
{params.problem_size.m(), problem_size_k / kSparse},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B,
params.ref_B.data(),
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B);
typename Mma::IteratorE iterator_E(
params.params_E, params.ref_E.data(),
{params.problem_size.m(),
problem_size_k / kSparse / kElementsPerElementE},
thread_idx, tb_offset_E);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = canonical_warp_idx_sync();
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
if (gemm_k_iterations > 0) {
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators);
}
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
//
// Epilogue
//
Epilogue epilogue(
params.output_op,
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(accumulators, threadblock_tile_offset, params.problem_shape, thread_idx);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass