Add support for sparse GEMM with visitor epilogue (#1189)
* Add support for sparse GEMM with visitor epilogue * Refactor changes at the kernel level
This commit is contained in:
parent
8236f30675
commit
5c756eb774
@ -33,3 +33,8 @@ cutlass_example_add_executable(
|
||||
ampere_sparse_tensorop_gemm.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
15_ampere_sparse_tensorop_gemm_with_visitor
|
||||
ampere_sparse_tensorop_gemm_with_visitor.cu
|
||||
)
|
||||
|
||||
|
||||
@ -0,0 +1,379 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere
|
||||
architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4.
|
||||
|
||||
Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of
|
||||
meta data is different for every data types. CUTLASS templates can automatically infer it based on
|
||||
input A and B. Check code below.
|
||||
|
||||
Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers
|
||||
efficiently.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_sparse_with_visitor.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
#include "cutlass/util/host_uncompress.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output matrices and computation between
|
||||
// elements in input matrices.
|
||||
using ElementAccumulator = int32_t; // <- data type of accumulator
|
||||
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
|
||||
using ElementInputA = int8_t; // <- data type of elements in input matrix A
|
||||
using ElementInputB = int8_t; // <- data type of elements in input matrix B
|
||||
using ElementOutput = int32_t; // <- data type of elements in output matrix D
|
||||
|
||||
// The code section below describes matrix layout of input and output matrices. Row Major for
|
||||
// Matrix A, Column Major for Matrix B and Row Major for Matrix C
|
||||
using LayoutInputA = cutlass::layout::RowMajor;
|
||||
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
// The number of elements per vectorized memory access.
|
||||
constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
|
||||
constexpr int AlignmentInputB = 128 / cutlass::sizeof_bits<ElementInputB>::value;
|
||||
constexpr int AlignmentComputeEpilogue = 128 / cutlass::sizeof_bits<ElementComputeEpilogue>::value;
|
||||
constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ShapeMMAThreadBlock =
|
||||
cutlass::gemm::GemmShape<128, 128, 128>; // <- threadblock tile M = 128, N = 128, K = 128
|
||||
// This code section describes tile size a warp will compute
|
||||
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 128>; // <- warp tile M = 64, N = 64, K = 128
|
||||
// This code section describes the size of MMA op
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 64>; // <- MMA Op tile M = 16, N = 8, K = 64
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 3;
|
||||
|
||||
constexpr auto NumEVTEpilogueStages = 1;
|
||||
|
||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
using BiasTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||
ShapeMMAThreadBlock,
|
||||
ShapeMMAWarp,
|
||||
ElementComputeEpilogue,
|
||||
AlignmentComputeEpilogue,
|
||||
NumEVTEpilogueStages>;
|
||||
|
||||
using Bias = cutlass::epilogue::threadblock::VisitorAuxLoad<
|
||||
BiasTileThreadMap,
|
||||
ElementComputeEpilogue,
|
||||
cute::Stride<int64_t, cute::_1, int64_t>>;
|
||||
|
||||
using ApplyBias = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::plus, ElementComputeEpilogue, ElementComputeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTApplyBias = cutlass::epilogue::threadblock::Sm80EVT<
|
||||
ApplyBias,
|
||||
Accum,
|
||||
Bias>;
|
||||
|
||||
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||
ShapeMMAThreadBlock,
|
||||
ShapeMMAWarp,
|
||||
ElementOutput,
|
||||
AlignmentOutput,
|
||||
NumEVTEpilogueStages>;
|
||||
|
||||
using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||
OutputTileThreadMap, ElementOutput,
|
||||
cutlass::FloatRoundStyle::round_to_nearest,
|
||||
cute::Stride<int64_t, cute::_1, int64_t>>;
|
||||
|
||||
using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
|
||||
Output,
|
||||
EVTApplyBias>;
|
||||
|
||||
// Use element type in EVT with the smallest bitwidth as ElementC.
|
||||
using ElementC = ElementComputeEpilogue;
|
||||
using LayoutC = LayoutOutput;
|
||||
|
||||
using Gemm =
|
||||
typename cutlass::gemm::device::SparseGemmWithVisitor<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
ElementC, LayoutC,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ShapeMMAThreadBlock,
|
||||
ShapeMMAWarp,
|
||||
ShapeMMAOp,
|
||||
EVTOutput,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
AlignmentInputA,
|
||||
AlignmentInputB,
|
||||
Operator,
|
||||
NumEVTEpilogueStages>;
|
||||
|
||||
// Data type and layout of meta data matrix E can be inferred from template Gemm.
|
||||
using ElementInputE = typename Gemm::GemmKernel::ElementE;
|
||||
using LayoutInputE = cutlass::layout::RowMajor;
|
||||
using ReorderedLayoutInputE = typename Gemm::GemmKernel::LayoutE;
|
||||
|
||||
// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h
|
||||
// 50% Sparsity on Ampere
|
||||
constexpr int kSparse = Gemm::kSparse;
|
||||
// How many elements of A are covered per ElementE
|
||||
constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
|
||||
// The size of individual meta data
|
||||
constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
|
||||
|
||||
int run() {
|
||||
|
||||
const int length_m = 512;
|
||||
const int length_n = 512;
|
||||
const int length_k = 1024;
|
||||
|
||||
// Create a tuple of problem size for matrix multiplication
|
||||
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
|
||||
|
||||
// Initialize tensors using CUTLASS helper functions
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); // <- Create matrix A with dimensions M x (K / 2)
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a_uncompressed(
|
||||
problem_size.mk()); // <- Create uncompressed matrix A with dimensions M x K for reference computing
|
||||
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
||||
problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<ElementComputeEpilogue, LayoutOutput> tensor_c(
|
||||
problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// CUTLASS kernel
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// reference kernel
|
||||
|
||||
// Create matrix E with dimensions M x (K / 2 / kElementsPerElementE). This one is used by reference computing.
|
||||
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||
// Same size as the above. The above one needs to be reordered and stored in this one.
|
||||
cutlass::HostTensor<ElementInputE, ReorderedLayoutInputE> tensor_e_reordered(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(8),
|
||||
ElementInputA(-8),
|
||||
0); // <- Fill matrix A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(8),
|
||||
ElementInputB(-8),
|
||||
0); // <- Fill matrix B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(8),
|
||||
ElementOutput(-8),
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomSparseMeta(
|
||||
tensor_e.host_view(),
|
||||
1,
|
||||
kMetaSizeInBits); // <- Fill matrix E on host with uniform-distribution random meta data
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d.host_view()); // <- fill matrix D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
|
||||
|
||||
// Reorder the meta data matrix so that we can use ldmatrix to load them to tensor core
|
||||
// instructions.
|
||||
cutlass::reorder_meta(tensor_e_reordered.host_ref(), tensor_e.host_ref(),
|
||||
{problem_size.m(), problem_size.n(),
|
||||
problem_size.k() / kSparse / kElementsPerElementE});
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_e_reordered.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
// Initialize alpha and beta for dot product computation
|
||||
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
|
||||
ElementComputeEpilogue beta = ElementComputeEpilogue(1);
|
||||
|
||||
typename Bias::Arguments bias_arguments{
|
||||
tensor_c.device_data(),
|
||||
ElementComputeEpilogue(0),
|
||||
{problem_size.n(), cute::_1{}, problem_size.mn().product()}
|
||||
};
|
||||
typename Output::Arguments output_arguments{
|
||||
tensor_d.device_data(),
|
||||
{problem_size.n(), cute::_1{}, problem_size.mn().product()}
|
||||
};
|
||||
typename EVTOutput::Arguments callback_arguments{
|
||||
{
|
||||
{}, // Accum
|
||||
bias_arguments, // Bias
|
||||
{} // ApplyBias
|
||||
}, // EVTApplyBias
|
||||
output_arguments // Output
|
||||
}; // EVTOutput
|
||||
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
||||
tensor_a.device_ref(), // <- reference to matrix A on device
|
||||
tensor_b.device_ref(), // <- reference to matrix B on device
|
||||
tensor_e_reordered.device_ref(), // <- reference to matrix E on device
|
||||
callback_arguments}; // <- epilogue arguments
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm_op;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
status = gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// uncompress tensor_a based on meta data tensor_e. We need it for reference computing.
|
||||
cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(),
|
||||
tensor_e.host_ref(), problem_size.m(), problem_size.k());
|
||||
|
||||
// Create instantiation for host reference gemm kernel
|
||||
cutlass::reference::host::Gemm<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue,
|
||||
typename Gemm::Operator>
|
||||
gemm_host;
|
||||
|
||||
// Launch host reference gemm kernel
|
||||
gemm_host(problem_size,
|
||||
alpha,
|
||||
tensor_a_uncompressed.host_ref(),
|
||||
tensor_b.host_ref(),
|
||||
beta,
|
||||
tensor_c.host_ref(),
|
||||
tensor_ref_d.host_ref());
|
||||
|
||||
// Copy output data from CUTLASS host for comparison
|
||||
tensor_d.sync_host();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
std::cout << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
return (passed ? 0 : -1);
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.1.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples.
|
||||
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (props.major * 10 + props.minor < 80) {
|
||||
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
return run();
|
||||
}
|
||||
342
include/cutlass/gemm/device/gemm_sparse_with_visitor.h
Normal file
342
include/cutlass/gemm/device/gemm_sparse_with_visitor.h
Normal file
@ -0,0 +1,342 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/kernel/sparse_gemm.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*! Sparse GEMM with visitor
|
||||
*/
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_ = ElementC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_ = arch::OpClassSimt,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_ = arch::Sm70,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename FusionCallbacks_ =
|
||||
typename cutlass::epilogue::threadblock::detail::EmptyCallbacks,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_ =
|
||||
typename threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kStages,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentA,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::Operator,
|
||||
/// Number of stages used in the pipelined epilogue
|
||||
int EpilogueStages = 1>
|
||||
class SparseGemmWithVisitor {
|
||||
public:
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = LayoutA_;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = LayoutB_;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using FusionCallbacks = FusionCallbacks_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
using Operator = Operator_;
|
||||
using MathOperator = Operator;
|
||||
static int const kStages = Stages;
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
|
||||
/// Define the kernel
|
||||
using GemmKernel = typename kernel::DefaultSparseGemmWithVisitor<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
FusionCallbacks,
|
||||
ThreadblockSwizzle,
|
||||
kStages,
|
||||
Operator,
|
||||
EpilogueStages
|
||||
>::GemmKernel;
|
||||
|
||||
using ElementE = typename GemmKernel::ElementE;
|
||||
|
||||
using LayoutE = typename GemmKernel::LayoutE;
|
||||
|
||||
static int const kAlignmentE = 128 / sizeof_bits<ElementE>::value;
|
||||
|
||||
static int const kSparse = GemmKernel::kSparse;
|
||||
static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits;
|
||||
static int const kElementsPerElementE = GemmKernel::kElementsPerElementE;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord problem_size;
|
||||
TensorRef<ElementA const, LayoutA> ref_A;
|
||||
TensorRef<ElementB const, LayoutB> ref_B;
|
||||
TensorRef<ElementE const, LayoutE> ref_E;
|
||||
typename FusionCallbacks::Arguments epilogue;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): problem_size(0, 0, 0) {
|
||||
|
||||
}
|
||||
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmCoord problem_size_,
|
||||
TensorRef<ElementA const, LayoutA> ref_A_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B_,
|
||||
TensorRef<ElementE, LayoutE> ref_E_,
|
||||
typename FusionCallbacks::Arguments epilogue_ =
|
||||
typename FusionCallbacks::Arguments()
|
||||
):
|
||||
problem_size(problem_size_),
|
||||
ref_A(ref_A_),
|
||||
ref_B(ref_B_),
|
||||
ref_E(ref_E_),
|
||||
epilogue(epilogue_) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel parameters object
|
||||
typename GemmKernel::Params params_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the GEMM.
|
||||
SparseGemmWithVisitor() { }
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
Status status = GemmKernel::can_implement(
|
||||
args.problem_size,
|
||||
args.ref_A.non_const_ref(),
|
||||
args.ref_B.non_const_ref(),
|
||||
cutlass::TensorRef<ElementC, LayoutC>(), // It only matters that it's empty.
|
||||
cutlass::TensorRef<ElementC, LayoutC>(), // Same as above.
|
||||
args.ref_E.non_const_ref()
|
||||
);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
size_t bytes = 0;
|
||||
|
||||
return bytes;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
constexpr int SplitKSlices = 1;
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
SplitKSlices);
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename GemmKernel::Params{
|
||||
args.problem_size,
|
||||
grid_shape,
|
||||
args.ref_A.non_const_ref(),
|
||||
args.ref_B.non_const_ref(),
|
||||
args.ref_E.non_const_ref(),
|
||||
args.epilogue
|
||||
};
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
if (smem_size >= (48 << 10)) {
|
||||
cudaError_t result = cudaFuncSetAttribute(Kernel<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
params_.ref_A.reset(args.ref_A.non_const_ref().data());
|
||||
params_.ref_B.reset(args.ref_B.non_const_ref().data());
|
||||
params_.ref_E.reset(args.ref_E.non_const_ref().data());
|
||||
params_.output_op = args.epilogue;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
198
include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h
Normal file
198
include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h
Normal file
@ -0,0 +1,198 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Default sparse GEMM with visitor.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/wmma.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_sparse.h"
|
||||
#include "cutlass/gemm/kernel/sparse_gemm_with_visitor.h"
|
||||
#include "cutlass/gemm/kernel/gemm_pipelined.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h"
|
||||
#include "cutlass/gemm/threadblock/default_sparse_mma.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h"
|
||||
#endif //CUTLASS_ARCH_WMMA_ENABLED
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename FusionCallbacks,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Number of stages used in the pipelined epilogue
|
||||
int EpilogueStages = 1>
|
||||
struct DefaultSparseGemmWithVisitor;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename FusionCallbacks,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Number of stages used in the pipelined epilogue
|
||||
int EpilogueStages>
|
||||
struct DefaultSparseGemmWithVisitor<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementC, LayoutC, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape, WarpShape, InstructionShape,
|
||||
FusionCallbacks, ThreadblockSwizzle, Stages, Operator,
|
||||
EpilogueStages> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultSparseMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||
Operator>::ThreadblockMma;
|
||||
|
||||
static constexpr int kAlignmentC = 128 / sizeof_bits<ElementC>::value;;
|
||||
using ElementEpilogue = ElementAccumulator;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
using EpilogueOutputOp =
|
||||
typename epilogue::thread::LinearCombination<
|
||||
ElementC, kAlignmentC,
|
||||
ElementAccumulator, ElementEpilogue>;
|
||||
using BaseEpilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape, typename Mma::Operator, kPartitionsK,
|
||||
EpilogueOutputOp, EpilogueOutputOp::kCount>::Epilogue;
|
||||
|
||||
// Define epilogue
|
||||
using Epilogue = cutlass::epilogue::threadblock::EpilogueWithVisitorCallbacks<
|
||||
BaseEpilogue,
|
||||
FusionCallbacks,
|
||||
EpilogueStages>;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using GemmKernel = kernel::SparseGemmWithEpilogueVisitor<Mma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
117
include/cutlass/gemm/kernel/params_sparse_base.h
Normal file
117
include/cutlass/gemm/kernel/params_sparse_base.h
Normal file
@ -0,0 +1,117 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Base functionality for common types of sparse GEMM kernel parameters
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parameters structure
|
||||
template <
|
||||
typename ThreadblockSwizzle,
|
||||
typename ParamsA,
|
||||
typename TensorRefA,
|
||||
typename ParamsB,
|
||||
typename TensorRefB,
|
||||
typename ParamsE,
|
||||
typename TensorRefE>
|
||||
struct SparseParamsBase
|
||||
{
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
ParamsA params_A;
|
||||
TensorRefA ref_A;
|
||||
ParamsB params_B;
|
||||
TensorRefB ref_B;
|
||||
ParamsE params_E;
|
||||
TensorRefE ref_E;
|
||||
int gemm_k_iterations;
|
||||
int gemm_k_size;
|
||||
|
||||
//
|
||||
// Host dispatch API
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
SparseParamsBase() : swizzle_log_tile(0), gemm_k_iterations(0), gemm_k_size(0) { }
|
||||
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
SparseParamsBase(
|
||||
cutlass::gemm::GemmCoord const & problem_size,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
TensorRefA ref_A,
|
||||
TensorRefB ref_B,
|
||||
TensorRefE ref_E,
|
||||
int const mma_shape_k)
|
||||
:
|
||||
problem_size(problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
params_A(ref_A.layout()),
|
||||
ref_A(ref_A),
|
||||
params_B(ref_B.layout()),
|
||||
ref_B(ref_B),
|
||||
params_E(ref_E.layout()),
|
||||
ref_E(ref_E)
|
||||
{
|
||||
int total_gemm_k_iterations = (problem_size.k() + mma_shape_k - 1) / mma_shape_k;
|
||||
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
|
||||
|
||||
gemm_k_size = gemm_k_iterations * mma_shape_k;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -37,6 +37,7 @@
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/params_sparse_base.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
@ -74,66 +75,58 @@ struct SparseGemm {
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using ParamsA = typename Mma::IteratorA::Params;
|
||||
using TensorRefA = typename Mma::IteratorA::TensorRef;
|
||||
using ParamsB = typename Mma::IteratorB::Params;
|
||||
using TensorRefB = typename Mma::IteratorB::TensorRef;
|
||||
using ParamsE = typename Mma::IteratorE::Params;
|
||||
using TensorRefE = typename Mma::IteratorE::TensorRef;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorA::TensorRef ref_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename Mma::IteratorB::TensorRef ref_B;
|
||||
struct Params : public SparseParamsBase<
|
||||
ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB,
|
||||
ParamsE, TensorRefE> {
|
||||
|
||||
using Base = SparseParamsBase<
|
||||
ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB,
|
||||
ParamsE, TensorRefE>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
typename Epilogue::OutputTileIterator::Params params_C;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
||||
typename Epilogue::OutputTileIterator::Params params_D;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
||||
typename Mma::IteratorE::Params params_E;
|
||||
typename Mma::IteratorE::TensorRef ref_E;
|
||||
typename OutputOp::Params output_op;
|
||||
int *semaphore;
|
||||
int gemm_k_iterations;
|
||||
int gemm_k_size;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { }
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
cutlass::gemm::GemmCoord const & problem_size,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
typename Mma::IteratorA::TensorRef ref_A,
|
||||
typename Mma::IteratorB::TensorRef ref_B,
|
||||
TensorRefA ref_A,
|
||||
TensorRefB ref_B,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D,
|
||||
typename Mma::IteratorE::TensorRef ref_E,
|
||||
TensorRefE ref_E,
|
||||
typename OutputOp::Params output_op = typename OutputOp::Params(),
|
||||
int *workspace = nullptr
|
||||
):
|
||||
problem_size(problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
params_A(ref_A.layout()),
|
||||
ref_A(ref_A),
|
||||
params_B(ref_B.layout()),
|
||||
ref_B(ref_B),
|
||||
Base(problem_size, grid_tiled_shape, ref_A, ref_B, ref_E, Mma::Shape::kK),
|
||||
params_C(ref_C.layout()),
|
||||
ref_C(ref_C),
|
||||
params_D(ref_D.layout()),
|
||||
ref_D(ref_D),
|
||||
params_E(ref_E.layout()),
|
||||
ref_E(ref_E),
|
||||
output_op(output_op) {
|
||||
|
||||
int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
|
||||
|
||||
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
|
||||
|
||||
semaphore = workspace;
|
||||
output_op(output_op),
|
||||
semaphore(workspace) {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
237
include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h
Normal file
237
include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h
Normal file
@ -0,0 +1,237 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Sparse GEMM with visitor.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/sparse_gemm.h"
|
||||
#include "cutlass/gemm/kernel/params_sparse_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Sparse Gemm that compute the epilogue visitor functor
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct SparseGemmWithEpilogueVisitor : public SparseGemm<Mma_, Epilogue_, ThreadblockSwizzle_, false> {
|
||||
|
||||
using Base = SparseGemm<Mma_, Epilogue_, ThreadblockSwizzle_, false>;
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using FusionCallbacks = typename Epilogue::FusionCallbacks;
|
||||
|
||||
using ParamsA = typename Mma::IteratorA::Params;
|
||||
using TensorRefA = typename Mma::IteratorA::TensorRef;
|
||||
using ParamsB = typename Mma::IteratorB::Params;
|
||||
using TensorRefB = typename Mma::IteratorB::TensorRef;
|
||||
using ParamsE = typename Mma::IteratorE::Params;
|
||||
using TensorRefE = typename Mma::IteratorE::TensorRef;
|
||||
|
||||
static int const kSparse = Base::kSparse;
|
||||
static int const kElementsPerElementE = Base::kElementsPerElementE;
|
||||
using SharedStorage = typename Base::SharedStorage;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params : public SparseParamsBase<
|
||||
ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB,
|
||||
ParamsE, TensorRefE> {
|
||||
|
||||
using Base = SparseParamsBase<
|
||||
ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB,
|
||||
ParamsE, TensorRefE>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
typename FusionCallbacks::Params output_op;
|
||||
cute::Shape<int32_t,int32_t,int32_t> problem_shape;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
cutlass::gemm::GemmCoord const & problem_size,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
typename Mma::IteratorA::TensorRef ref_A,
|
||||
typename Mma::IteratorB::TensorRef ref_B,
|
||||
typename Mma::IteratorE::TensorRef ref_E,
|
||||
typename FusionCallbacks::Arguments output_op = typename FusionCallbacks::Arguments()
|
||||
):
|
||||
Base(problem_size, grid_tiled_shape, ref_A, ref_B, ref_E, Mma::Shape::kK),
|
||||
output_op(FusionCallbacks::to_underlying_arguments(problem_size, output_op, nullptr /*workspace*/)),
|
||||
problem_shape(problem_size.m(), problem_size.n(), 1) {
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SparseGemmWithEpilogueVisitor() { }
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size / kSparse,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_E{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size / kSparse,
|
||||
};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k = min(
|
||||
params.problem_size.k(),
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A, B, and E operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A,
|
||||
params.ref_A.data(),
|
||||
{params.problem_size.m(), problem_size_k / kSparse},
|
||||
thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B,
|
||||
params.ref_B.data(),
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
typename Mma::IteratorE iterator_E(
|
||||
params.params_E, params.ref_E.data(),
|
||||
{params.problem_size.m(),
|
||||
problem_size_k / kSparse / kElementsPerElementE},
|
||||
thread_idx, tb_offset_E);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
if (gemm_k_iterations > 0) {
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators);
|
||||
}
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
Epilogue epilogue(
|
||||
params.output_op,
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(accumulators, threadblock_tile_offset, params.problem_shape, thread_idx);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
Loading…
Reference in New Issue
Block a user