Add support for sparse GEMM with visitor epilogue (#1189)
* Add support for sparse GEMM with visitor epilogue * Refactor changes at the kernel level
This commit is contained in:
parent
8236f30675
commit
5c756eb774
@ -33,3 +33,8 @@ cutlass_example_add_executable(
|
|||||||
ampere_sparse_tensorop_gemm.cu
|
ampere_sparse_tensorop_gemm.cu
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cutlass_example_add_executable(
|
||||||
|
15_ampere_sparse_tensorop_gemm_with_visitor
|
||||||
|
ampere_sparse_tensorop_gemm_with_visitor.cu
|
||||||
|
)
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,379 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
/**
|
||||||
|
Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere
|
||||||
|
architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4.
|
||||||
|
|
||||||
|
Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of
|
||||||
|
meta data is different for every data types. CUTLASS templates can automatically infer it based on
|
||||||
|
input A and B. Check code below.
|
||||||
|
|
||||||
|
Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers
|
||||||
|
efficiently.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/gemm/device/gemm_sparse_with_visitor.h"
|
||||||
|
|
||||||
|
#include "cutlass/util/host_tensor.h"
|
||||||
|
#include "cutlass/util/reference/host/gemm.h"
|
||||||
|
#include "cutlass/util/host_reorder.h"
|
||||||
|
#include "cutlass/util/host_uncompress.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||||
|
#include "cutlass/util/tensor_view_io.h"
|
||||||
|
|
||||||
|
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||||
|
|
||||||
|
#include "helper.h"
|
||||||
|
|
||||||
|
// The code section below describes datatype for input, output matrices and computation between
|
||||||
|
// elements in input matrices.
|
||||||
|
using ElementAccumulator = int32_t; // <- data type of accumulator
|
||||||
|
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
|
||||||
|
using ElementInputA = int8_t; // <- data type of elements in input matrix A
|
||||||
|
using ElementInputB = int8_t; // <- data type of elements in input matrix B
|
||||||
|
using ElementOutput = int32_t; // <- data type of elements in output matrix D
|
||||||
|
|
||||||
|
// The code section below describes matrix layout of input and output matrices. Row Major for
|
||||||
|
// Matrix A, Column Major for Matrix B and Row Major for Matrix C
|
||||||
|
using LayoutInputA = cutlass::layout::RowMajor;
|
||||||
|
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||||
|
using LayoutOutput = cutlass::layout::RowMajor;
|
||||||
|
|
||||||
|
// The number of elements per vectorized memory access.
|
||||||
|
constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
|
||||||
|
constexpr int AlignmentInputB = 128 / cutlass::sizeof_bits<ElementInputB>::value;
|
||||||
|
constexpr int AlignmentComputeEpilogue = 128 / cutlass::sizeof_bits<ElementComputeEpilogue>::value;
|
||||||
|
constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||||
|
|
||||||
|
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||||
|
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||||
|
|
||||||
|
// This code section describes CUDA SM architecture number
|
||||||
|
using SmArch = cutlass::arch::Sm80;
|
||||||
|
|
||||||
|
// This code section describes the tile size a thread block will compute
|
||||||
|
using ShapeMMAThreadBlock =
|
||||||
|
cutlass::gemm::GemmShape<128, 128, 128>; // <- threadblock tile M = 128, N = 128, K = 128
|
||||||
|
// This code section describes tile size a warp will compute
|
||||||
|
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 128>; // <- warp tile M = 64, N = 64, K = 128
|
||||||
|
// This code section describes the size of MMA op
|
||||||
|
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 64>; // <- MMA Op tile M = 16, N = 8, K = 64
|
||||||
|
|
||||||
|
// This code section describes how threadblocks are scheduled on GPU
|
||||||
|
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||||
|
|
||||||
|
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||||
|
|
||||||
|
// Number of pipelines you want to use
|
||||||
|
constexpr int NumStages = 3;
|
||||||
|
|
||||||
|
constexpr auto NumEVTEpilogueStages = 1;
|
||||||
|
|
||||||
|
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||||
|
|
||||||
|
using BiasTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||||
|
ShapeMMAThreadBlock,
|
||||||
|
ShapeMMAWarp,
|
||||||
|
ElementComputeEpilogue,
|
||||||
|
AlignmentComputeEpilogue,
|
||||||
|
NumEVTEpilogueStages>;
|
||||||
|
|
||||||
|
using Bias = cutlass::epilogue::threadblock::VisitorAuxLoad<
|
||||||
|
BiasTileThreadMap,
|
||||||
|
ElementComputeEpilogue,
|
||||||
|
cute::Stride<int64_t, cute::_1, int64_t>>;
|
||||||
|
|
||||||
|
using ApplyBias = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::plus, ElementComputeEpilogue, ElementComputeEpilogue,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTApplyBias = cutlass::epilogue::threadblock::Sm80EVT<
|
||||||
|
ApplyBias,
|
||||||
|
Accum,
|
||||||
|
Bias>;
|
||||||
|
|
||||||
|
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||||
|
ShapeMMAThreadBlock,
|
||||||
|
ShapeMMAWarp,
|
||||||
|
ElementOutput,
|
||||||
|
AlignmentOutput,
|
||||||
|
NumEVTEpilogueStages>;
|
||||||
|
|
||||||
|
using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||||
|
OutputTileThreadMap, ElementOutput,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest,
|
||||||
|
cute::Stride<int64_t, cute::_1, int64_t>>;
|
||||||
|
|
||||||
|
using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
|
||||||
|
Output,
|
||||||
|
EVTApplyBias>;
|
||||||
|
|
||||||
|
// Use element type in EVT with the smallest bitwidth as ElementC.
|
||||||
|
using ElementC = ElementComputeEpilogue;
|
||||||
|
using LayoutC = LayoutOutput;
|
||||||
|
|
||||||
|
using Gemm =
|
||||||
|
typename cutlass::gemm::device::SparseGemmWithVisitor<
|
||||||
|
ElementInputA, LayoutInputA,
|
||||||
|
ElementInputB, LayoutInputB,
|
||||||
|
ElementC, LayoutC,
|
||||||
|
ElementAccumulator,
|
||||||
|
MMAOp,
|
||||||
|
SmArch,
|
||||||
|
ShapeMMAThreadBlock,
|
||||||
|
ShapeMMAWarp,
|
||||||
|
ShapeMMAOp,
|
||||||
|
EVTOutput,
|
||||||
|
SwizzleThreadBlock,
|
||||||
|
NumStages,
|
||||||
|
AlignmentInputA,
|
||||||
|
AlignmentInputB,
|
||||||
|
Operator,
|
||||||
|
NumEVTEpilogueStages>;
|
||||||
|
|
||||||
|
// Data type and layout of meta data matrix E can be inferred from template Gemm.
|
||||||
|
using ElementInputE = typename Gemm::GemmKernel::ElementE;
|
||||||
|
using LayoutInputE = cutlass::layout::RowMajor;
|
||||||
|
using ReorderedLayoutInputE = typename Gemm::GemmKernel::LayoutE;
|
||||||
|
|
||||||
|
// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h
|
||||||
|
// 50% Sparsity on Ampere
|
||||||
|
constexpr int kSparse = Gemm::kSparse;
|
||||||
|
// How many elements of A are covered per ElementE
|
||||||
|
constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
|
||||||
|
// The size of individual meta data
|
||||||
|
constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
|
||||||
|
|
||||||
|
int run() {
|
||||||
|
|
||||||
|
const int length_m = 512;
|
||||||
|
const int length_n = 512;
|
||||||
|
const int length_k = 1024;
|
||||||
|
|
||||||
|
// Create a tuple of problem size for matrix multiplication
|
||||||
|
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
|
||||||
|
|
||||||
|
// Initialize tensors using CUTLASS helper functions
|
||||||
|
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||||
|
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); // <- Create matrix A with dimensions M x (K / 2)
|
||||||
|
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a_uncompressed(
|
||||||
|
problem_size.mk()); // <- Create uncompressed matrix A with dimensions M x K for reference computing
|
||||||
|
|
||||||
|
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
||||||
|
problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||||
|
cutlass::HostTensor<ElementComputeEpilogue, LayoutOutput> tensor_c(
|
||||||
|
problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||||
|
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
|
||||||
|
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||||
|
// CUTLASS kernel
|
||||||
|
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
|
||||||
|
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||||
|
// reference kernel
|
||||||
|
|
||||||
|
// Create matrix E with dimensions M x (K / 2 / kElementsPerElementE). This one is used by reference computing.
|
||||||
|
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
|
||||||
|
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||||
|
// Same size as the above. The above one needs to be reordered and stored in this one.
|
||||||
|
cutlass::HostTensor<ElementInputE, ReorderedLayoutInputE> tensor_e_reordered(
|
||||||
|
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||||
|
|
||||||
|
// Fill input and output matrices on host using CUTLASS helper functions
|
||||||
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
|
tensor_a.host_view(),
|
||||||
|
1,
|
||||||
|
ElementInputA(8),
|
||||||
|
ElementInputA(-8),
|
||||||
|
0); // <- Fill matrix A on host with uniform-distribution random data
|
||||||
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
|
tensor_b.host_view(),
|
||||||
|
1,
|
||||||
|
ElementInputB(8),
|
||||||
|
ElementInputB(-8),
|
||||||
|
0); // <- Fill matrix B on host with uniform-distribution random data
|
||||||
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
|
tensor_c.host_view(),
|
||||||
|
1,
|
||||||
|
ElementOutput(8),
|
||||||
|
ElementOutput(-8),
|
||||||
|
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||||
|
cutlass::reference::host::TensorFillRandomSparseMeta(
|
||||||
|
tensor_e.host_view(),
|
||||||
|
1,
|
||||||
|
kMetaSizeInBits); // <- Fill matrix E on host with uniform-distribution random meta data
|
||||||
|
cutlass::reference::host::TensorFill(
|
||||||
|
tensor_d.host_view()); // <- fill matrix D on host with zeros
|
||||||
|
cutlass::reference::host::TensorFill(
|
||||||
|
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
|
||||||
|
|
||||||
|
// Reorder the meta data matrix so that we can use ldmatrix to load them to tensor core
|
||||||
|
// instructions.
|
||||||
|
cutlass::reorder_meta(tensor_e_reordered.host_ref(), tensor_e.host_ref(),
|
||||||
|
{problem_size.m(), problem_size.n(),
|
||||||
|
problem_size.k() / kSparse / kElementsPerElementE});
|
||||||
|
|
||||||
|
// Copy data from host to GPU
|
||||||
|
tensor_a.sync_device();
|
||||||
|
tensor_b.sync_device();
|
||||||
|
tensor_c.sync_device();
|
||||||
|
tensor_d.sync_device();
|
||||||
|
tensor_e_reordered.sync_device();
|
||||||
|
tensor_ref_d.sync_device();
|
||||||
|
|
||||||
|
// Initialize alpha and beta for dot product computation
|
||||||
|
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
|
||||||
|
ElementComputeEpilogue beta = ElementComputeEpilogue(1);
|
||||||
|
|
||||||
|
typename Bias::Arguments bias_arguments{
|
||||||
|
tensor_c.device_data(),
|
||||||
|
ElementComputeEpilogue(0),
|
||||||
|
{problem_size.n(), cute::_1{}, problem_size.mn().product()}
|
||||||
|
};
|
||||||
|
typename Output::Arguments output_arguments{
|
||||||
|
tensor_d.device_data(),
|
||||||
|
{problem_size.n(), cute::_1{}, problem_size.mn().product()}
|
||||||
|
};
|
||||||
|
typename EVTOutput::Arguments callback_arguments{
|
||||||
|
{
|
||||||
|
{}, // Accum
|
||||||
|
bias_arguments, // Bias
|
||||||
|
{} // ApplyBias
|
||||||
|
}, // EVTApplyBias
|
||||||
|
output_arguments // Output
|
||||||
|
}; // EVTOutput
|
||||||
|
|
||||||
|
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||||
|
// instantiated CUTLASS kernel
|
||||||
|
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
||||||
|
tensor_a.device_ref(), // <- reference to matrix A on device
|
||||||
|
tensor_b.device_ref(), // <- reference to matrix B on device
|
||||||
|
tensor_e_reordered.device_ref(), // <- reference to matrix E on device
|
||||||
|
callback_arguments}; // <- epilogue arguments
|
||||||
|
|
||||||
|
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||||
|
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||||
|
|
||||||
|
// Allocate workspace memory
|
||||||
|
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||||
|
|
||||||
|
// Instantiate CUTLASS kernel depending on templates
|
||||||
|
Gemm gemm_op;
|
||||||
|
|
||||||
|
// Check the problem size is supported or not
|
||||||
|
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||||
|
CUTLASS_CHECK(status);
|
||||||
|
|
||||||
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||||
|
status = gemm_op.initialize(arguments, workspace.get());
|
||||||
|
CUTLASS_CHECK(status);
|
||||||
|
|
||||||
|
// Launch initialized CUTLASS kernel
|
||||||
|
status = gemm_op();
|
||||||
|
CUTLASS_CHECK(status);
|
||||||
|
|
||||||
|
// uncompress tensor_a based on meta data tensor_e. We need it for reference computing.
|
||||||
|
cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(),
|
||||||
|
tensor_e.host_ref(), problem_size.m(), problem_size.k());
|
||||||
|
|
||||||
|
// Create instantiation for host reference gemm kernel
|
||||||
|
cutlass::reference::host::Gemm<ElementInputA,
|
||||||
|
LayoutInputA,
|
||||||
|
ElementInputB,
|
||||||
|
LayoutInputB,
|
||||||
|
ElementOutput,
|
||||||
|
LayoutOutput,
|
||||||
|
ElementComputeEpilogue,
|
||||||
|
ElementComputeEpilogue,
|
||||||
|
typename Gemm::Operator>
|
||||||
|
gemm_host;
|
||||||
|
|
||||||
|
// Launch host reference gemm kernel
|
||||||
|
gemm_host(problem_size,
|
||||||
|
alpha,
|
||||||
|
tensor_a_uncompressed.host_ref(),
|
||||||
|
tensor_b.host_ref(),
|
||||||
|
beta,
|
||||||
|
tensor_c.host_ref(),
|
||||||
|
tensor_ref_d.host_ref());
|
||||||
|
|
||||||
|
// Copy output data from CUTLASS host for comparison
|
||||||
|
tensor_d.sync_host();
|
||||||
|
|
||||||
|
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||||
|
bool passed = cutlass::reference::host::TensorEquals(
|
||||||
|
tensor_d.host_view(),
|
||||||
|
tensor_ref_d.host_view());
|
||||||
|
|
||||||
|
std::cout << (passed ? "Passed" : "Failed") << std::endl;
|
||||||
|
|
||||||
|
return (passed ? 0 : -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
|
||||||
|
bool notSupported = false;
|
||||||
|
|
||||||
|
// Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||||
|
// in CUDA 11.1.
|
||||||
|
//
|
||||||
|
// CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples.
|
||||||
|
|
||||||
|
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||||
|
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl;
|
||||||
|
notSupported = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaDeviceProp props;
|
||||||
|
|
||||||
|
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||||
|
if (error != cudaSuccess) {
|
||||||
|
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (props.major * 10 + props.minor < 80) {
|
||||||
|
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||||
|
<< std::endl;
|
||||||
|
notSupported = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (notSupported) {
|
||||||
|
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return run();
|
||||||
|
}
|
||||||
342
include/cutlass/gemm/device/gemm_sparse_with_visitor.h
Normal file
342
include/cutlass/gemm/device/gemm_sparse_with_visitor.h
Normal file
@ -0,0 +1,342 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/*! \file
|
||||||
|
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/numeric_types.h"
|
||||||
|
#include "cutlass/arch/arch.h"
|
||||||
|
#include "cutlass/device_kernel.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||||
|
#include "cutlass/gemm/kernel/sparse_gemm.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h"
|
||||||
|
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||||
|
|
||||||
|
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace gemm {
|
||||||
|
namespace device {
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/*! Sparse GEMM with visitor
|
||||||
|
*/
|
||||||
|
template <
|
||||||
|
/// Element type for A matrix operand
|
||||||
|
typename ElementA_,
|
||||||
|
/// Layout type for A matrix operand
|
||||||
|
typename LayoutA_,
|
||||||
|
/// Element type for B matrix operand
|
||||||
|
typename ElementB_,
|
||||||
|
/// Layout type for B matrix operand
|
||||||
|
typename LayoutB_,
|
||||||
|
/// Element type for C and D matrix operands
|
||||||
|
typename ElementC_,
|
||||||
|
/// Layout type for C and D matrix operands
|
||||||
|
typename LayoutC_,
|
||||||
|
/// Element type for internal accumulation
|
||||||
|
typename ElementAccumulator_ = ElementC_,
|
||||||
|
/// Operator class tag
|
||||||
|
typename OperatorClass_ = arch::OpClassSimt,
|
||||||
|
/// Tag indicating architecture to tune for
|
||||||
|
typename ArchTag_ = arch::Sm70,
|
||||||
|
/// Threadblock-level tile size (concept: GemmShape)
|
||||||
|
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
|
||||||
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||||
|
ElementAccumulator_>::ThreadblockShape,
|
||||||
|
/// Warp-level tile size (concept: GemmShape)
|
||||||
|
typename WarpShape_ = typename DefaultGemmConfiguration<
|
||||||
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||||
|
ElementAccumulator_>::WarpShape,
|
||||||
|
/// Instruction-level tile size (concept: GemmShape)
|
||||||
|
typename InstructionShape_ = typename DefaultGemmConfiguration<
|
||||||
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||||
|
ElementAccumulator_>::InstructionShape,
|
||||||
|
/// Epilogue output operator
|
||||||
|
typename FusionCallbacks_ =
|
||||||
|
typename cutlass::epilogue::threadblock::detail::EmptyCallbacks,
|
||||||
|
/// Threadblock-level swizzling operator
|
||||||
|
typename ThreadblockSwizzle_ =
|
||||||
|
typename threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||||
|
/// Number of stages used in the pipelined mainloop
|
||||||
|
int Stages =
|
||||||
|
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||||
|
ElementC_, ElementAccumulator_>::kStages,
|
||||||
|
/// Access granularity of A matrix in units of elements
|
||||||
|
int AlignmentA =
|
||||||
|
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||||
|
ElementC_, ElementAccumulator_>::kAlignmentA,
|
||||||
|
/// Access granularity of B matrix in units of elements
|
||||||
|
int AlignmentB =
|
||||||
|
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||||
|
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||||||
|
/// Operation performed by GEMM
|
||||||
|
typename Operator_ = typename DefaultGemmConfiguration<
|
||||||
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||||
|
ElementAccumulator_>::Operator,
|
||||||
|
/// Number of stages used in the pipelined epilogue
|
||||||
|
int EpilogueStages = 1>
|
||||||
|
class SparseGemmWithVisitor {
|
||||||
|
public:
|
||||||
|
|
||||||
|
using ElementA = ElementA_;
|
||||||
|
using LayoutA = LayoutA_;
|
||||||
|
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||||
|
using ElementB = ElementB_;
|
||||||
|
using LayoutB = LayoutB_;
|
||||||
|
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||||
|
using ElementC = ElementC_;
|
||||||
|
using LayoutC = LayoutC_;
|
||||||
|
using ElementAccumulator = ElementAccumulator_;
|
||||||
|
using OperatorClass = OperatorClass_;
|
||||||
|
using ArchTag = ArchTag_;
|
||||||
|
using ThreadblockShape = ThreadblockShape_;
|
||||||
|
using WarpShape = WarpShape_;
|
||||||
|
using InstructionShape = InstructionShape_;
|
||||||
|
using FusionCallbacks = FusionCallbacks_;
|
||||||
|
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||||
|
using Operator = Operator_;
|
||||||
|
using MathOperator = Operator;
|
||||||
|
static int const kStages = Stages;
|
||||||
|
static int const kAlignmentA = AlignmentA;
|
||||||
|
static int const kAlignmentB = AlignmentB;
|
||||||
|
|
||||||
|
/// Define the kernel
|
||||||
|
using GemmKernel = typename kernel::DefaultSparseGemmWithVisitor<
|
||||||
|
ElementA,
|
||||||
|
LayoutA,
|
||||||
|
kAlignmentA,
|
||||||
|
ElementB,
|
||||||
|
LayoutB,
|
||||||
|
kAlignmentB,
|
||||||
|
ElementC,
|
||||||
|
LayoutC,
|
||||||
|
ElementAccumulator,
|
||||||
|
OperatorClass,
|
||||||
|
ArchTag,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpShape,
|
||||||
|
InstructionShape,
|
||||||
|
FusionCallbacks,
|
||||||
|
ThreadblockSwizzle,
|
||||||
|
kStages,
|
||||||
|
Operator,
|
||||||
|
EpilogueStages
|
||||||
|
>::GemmKernel;
|
||||||
|
|
||||||
|
using ElementE = typename GemmKernel::ElementE;
|
||||||
|
|
||||||
|
using LayoutE = typename GemmKernel::LayoutE;
|
||||||
|
|
||||||
|
static int const kAlignmentE = 128 / sizeof_bits<ElementE>::value;
|
||||||
|
|
||||||
|
static int const kSparse = GemmKernel::kSparse;
|
||||||
|
static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits;
|
||||||
|
static int const kElementsPerElementE = GemmKernel::kElementsPerElementE;
|
||||||
|
|
||||||
|
/// Argument structure
|
||||||
|
struct Arguments {
|
||||||
|
|
||||||
|
//
|
||||||
|
// Data members
|
||||||
|
//
|
||||||
|
|
||||||
|
GemmCoord problem_size;
|
||||||
|
TensorRef<ElementA const, LayoutA> ref_A;
|
||||||
|
TensorRef<ElementB const, LayoutB> ref_B;
|
||||||
|
TensorRef<ElementE const, LayoutE> ref_E;
|
||||||
|
typename FusionCallbacks::Arguments epilogue;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Default ctor
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Arguments(): problem_size(0, 0, 0) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Constructs an Arguments structure
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Arguments(
|
||||||
|
GemmCoord problem_size_,
|
||||||
|
TensorRef<ElementA const, LayoutA> ref_A_,
|
||||||
|
TensorRef<ElementB const, LayoutB> ref_B_,
|
||||||
|
TensorRef<ElementE, LayoutE> ref_E_,
|
||||||
|
typename FusionCallbacks::Arguments epilogue_ =
|
||||||
|
typename FusionCallbacks::Arguments()
|
||||||
|
):
|
||||||
|
problem_size(problem_size_),
|
||||||
|
ref_A(ref_A_),
|
||||||
|
ref_B(ref_B_),
|
||||||
|
ref_E(ref_E_),
|
||||||
|
epilogue(epilogue_) {
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
/// Kernel parameters object
|
||||||
|
typename GemmKernel::Params params_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
/// Constructs the GEMM.
|
||||||
|
SparseGemmWithVisitor() { }
|
||||||
|
|
||||||
|
/// Determines whether the GEMM can execute the given problem.
|
||||||
|
static Status can_implement(Arguments const &args) {
|
||||||
|
|
||||||
|
Status status = GemmKernel::can_implement(
|
||||||
|
args.problem_size,
|
||||||
|
args.ref_A.non_const_ref(),
|
||||||
|
args.ref_B.non_const_ref(),
|
||||||
|
cutlass::TensorRef<ElementC, LayoutC>(), // It only matters that it's empty.
|
||||||
|
cutlass::TensorRef<ElementC, LayoutC>(), // Same as above.
|
||||||
|
args.ref_E.non_const_ref()
|
||||||
|
);
|
||||||
|
|
||||||
|
if (status != Status::kSuccess) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gets the workspace size
|
||||||
|
static size_t get_workspace_size(Arguments const &args) {
|
||||||
|
|
||||||
|
size_t bytes = 0;
|
||||||
|
|
||||||
|
return bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initializes GEMM state from arguments.
|
||||||
|
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||||
|
|
||||||
|
constexpr int SplitKSlices = 1;
|
||||||
|
|
||||||
|
// Determine grid shape
|
||||||
|
ThreadblockSwizzle threadblock_swizzle;
|
||||||
|
|
||||||
|
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||||||
|
args.problem_size,
|
||||||
|
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||||
|
SplitKSlices);
|
||||||
|
|
||||||
|
// Initialize the Params structure
|
||||||
|
params_ = typename GemmKernel::Params{
|
||||||
|
args.problem_size,
|
||||||
|
grid_shape,
|
||||||
|
args.ref_A.non_const_ref(),
|
||||||
|
args.ref_B.non_const_ref(),
|
||||||
|
args.ref_E.non_const_ref(),
|
||||||
|
args.epilogue
|
||||||
|
};
|
||||||
|
|
||||||
|
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||||
|
if (smem_size >= (48 << 10)) {
|
||||||
|
cudaError_t result = cudaFuncSetAttribute(Kernel<GemmKernel>,
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
smem_size);
|
||||||
|
|
||||||
|
if (result != cudaSuccess) {
|
||||||
|
return Status::kErrorInternal;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Lightweight update given a subset of arguments
|
||||||
|
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||||
|
|
||||||
|
params_.ref_A.reset(args.ref_A.non_const_ref().data());
|
||||||
|
params_.ref_B.reset(args.ref_B.non_const_ref().data());
|
||||||
|
params_.ref_E.reset(args.ref_E.non_const_ref().data());
|
||||||
|
params_.output_op = args.epilogue;
|
||||||
|
|
||||||
|
return Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runs the kernel using initialized state.
|
||||||
|
Status run(cudaStream_t stream = nullptr) {
|
||||||
|
|
||||||
|
ThreadblockSwizzle threadblock_swizzle;
|
||||||
|
|
||||||
|
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||||
|
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||||
|
|
||||||
|
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||||
|
|
||||||
|
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||||
|
|
||||||
|
cudaError_t result = cudaGetLastError();
|
||||||
|
|
||||||
|
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runs the kernel using initialized state.
|
||||||
|
Status operator()(cudaStream_t stream = nullptr) {
|
||||||
|
return run(stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runs the kernel using initialized state.
|
||||||
|
Status operator()(
|
||||||
|
Arguments const &args,
|
||||||
|
void *workspace = nullptr,
|
||||||
|
cudaStream_t stream = nullptr) {
|
||||||
|
|
||||||
|
Status status = initialize(args, workspace, stream);
|
||||||
|
|
||||||
|
if (status == Status::kSuccess) {
|
||||||
|
status = run(stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace device
|
||||||
|
} // namespace gemm
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
198
include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h
Normal file
198
include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/*! \file
|
||||||
|
\brief Default sparse GEMM with visitor.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
#include "cutlass/layout/matrix.h"
|
||||||
|
#include "cutlass/numeric_types.h"
|
||||||
|
#include "cutlass/arch/wmma.h"
|
||||||
|
|
||||||
|
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||||
|
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/gemm.h"
|
||||||
|
#include "cutlass/gemm/kernel/gemm.h"
|
||||||
|
#include "cutlass/gemm/kernel/default_gemm_sparse.h"
|
||||||
|
#include "cutlass/gemm/kernel/sparse_gemm_with_visitor.h"
|
||||||
|
#include "cutlass/gemm/kernel/gemm_pipelined.h"
|
||||||
|
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||||
|
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||||
|
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||||
|
#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h"
|
||||||
|
#include "cutlass/gemm/threadblock/default_sparse_mma.h"
|
||||||
|
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
||||||
|
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||||
|
|
||||||
|
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||||
|
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||||
|
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||||
|
#include "cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h"
|
||||||
|
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||||
|
|
||||||
|
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||||
|
#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h"
|
||||||
|
#endif //CUTLASS_ARCH_WMMA_ENABLED
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace gemm {
|
||||||
|
namespace kernel {
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
/// Element type for A matrix operand
|
||||||
|
typename ElementA,
|
||||||
|
/// Layout type for A matrix operand
|
||||||
|
typename LayoutA,
|
||||||
|
/// Access granularity of A matrix in units of elements
|
||||||
|
int kAlignmentA,
|
||||||
|
/// Element type for B matrix operand
|
||||||
|
typename ElementB,
|
||||||
|
/// Layout type for B matrix operand
|
||||||
|
typename LayoutB,
|
||||||
|
/// Access granularity of B matrix in units of elements
|
||||||
|
int kAlignmentB,
|
||||||
|
/// Element type for C and D matrix operands
|
||||||
|
typename ElementC,
|
||||||
|
/// Layout type for C and D matrix operands
|
||||||
|
typename LayoutC,
|
||||||
|
/// Element type for internal accumulation
|
||||||
|
typename ElementAccumulator,
|
||||||
|
/// Operator class tag
|
||||||
|
typename OperatorClass,
|
||||||
|
/// Tag indicating architecture to tune for
|
||||||
|
typename ArchTag,
|
||||||
|
/// Threadblock-level tile size (concept: GemmShape)
|
||||||
|
typename ThreadblockShape,
|
||||||
|
/// Warp-level tile size (concept: GemmShape)
|
||||||
|
typename WarpShape,
|
||||||
|
/// Warp-level tile size (concept: GemmShape)
|
||||||
|
typename InstructionShape,
|
||||||
|
/// Epilogue output operator
|
||||||
|
typename FusionCallbacks,
|
||||||
|
/// Threadblock-level swizzling operator
|
||||||
|
typename ThreadblockSwizzle,
|
||||||
|
/// Number of stages used in the pipelined mainloop
|
||||||
|
int Stages,
|
||||||
|
/// Operation performed by GEMM
|
||||||
|
typename Operator,
|
||||||
|
/// Number of stages used in the pipelined epilogue
|
||||||
|
int EpilogueStages = 1>
|
||||||
|
struct DefaultSparseGemmWithVisitor;
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Partial specialization for Ampere Architecture
|
||||||
|
template <
|
||||||
|
/// Element type for A matrix operand
|
||||||
|
typename ElementA,
|
||||||
|
/// Layout type for A matrix operand
|
||||||
|
typename LayoutA,
|
||||||
|
/// Access granularity of A matrix in units of elements
|
||||||
|
int kAlignmentA,
|
||||||
|
/// Element type for B matrix operand
|
||||||
|
typename ElementB,
|
||||||
|
/// Layout type for B matrix operand
|
||||||
|
typename LayoutB,
|
||||||
|
/// Access granularity of A matrix in units of elements
|
||||||
|
int kAlignmentB,
|
||||||
|
/// Element type for C and D matrix operands
|
||||||
|
typename ElementC,
|
||||||
|
/// Layout type for C and D matrix operands
|
||||||
|
typename LayoutC,
|
||||||
|
/// Element type for internal accumulation
|
||||||
|
typename ElementAccumulator,
|
||||||
|
/// Threadblock-level tile size (concept: GemmShape)
|
||||||
|
typename ThreadblockShape,
|
||||||
|
/// Warp-level tile size (concept: GemmShape)
|
||||||
|
typename WarpShape,
|
||||||
|
/// Warp-level tile size (concept: GemmShape)
|
||||||
|
typename InstructionShape,
|
||||||
|
/// Epilogue output operator
|
||||||
|
typename FusionCallbacks,
|
||||||
|
/// Threadblock-level swizzling operator
|
||||||
|
typename ThreadblockSwizzle,
|
||||||
|
/// Number of stages used in the pipelined mainloop
|
||||||
|
int Stages,
|
||||||
|
/// Operation performed by GEMM
|
||||||
|
typename Operator,
|
||||||
|
/// Number of stages used in the pipelined epilogue
|
||||||
|
int EpilogueStages>
|
||||||
|
struct DefaultSparseGemmWithVisitor<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||||
|
ElementC, LayoutC, ElementAccumulator, arch::OpClassTensorOp,
|
||||||
|
arch::Sm80, ThreadblockShape, WarpShape, InstructionShape,
|
||||||
|
FusionCallbacks, ThreadblockSwizzle, Stages, Operator,
|
||||||
|
EpilogueStages> {
|
||||||
|
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||||
|
using Mma = typename cutlass::gemm::threadblock::DefaultSparseMma<
|
||||||
|
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||||
|
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||||
|
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||||
|
Operator>::ThreadblockMma;
|
||||||
|
|
||||||
|
static constexpr int kAlignmentC = 128 / sizeof_bits<ElementC>::value;;
|
||||||
|
using ElementEpilogue = ElementAccumulator;
|
||||||
|
|
||||||
|
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||||
|
using EpilogueOutputOp =
|
||||||
|
typename epilogue::thread::LinearCombination<
|
||||||
|
ElementC, kAlignmentC,
|
||||||
|
ElementAccumulator, ElementEpilogue>;
|
||||||
|
using BaseEpilogue =
|
||||||
|
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||||
|
ThreadblockShape, typename Mma::Operator, kPartitionsK,
|
||||||
|
EpilogueOutputOp, EpilogueOutputOp::kCount>::Epilogue;
|
||||||
|
|
||||||
|
// Define epilogue
|
||||||
|
using Epilogue = cutlass::epilogue::threadblock::EpilogueWithVisitorCallbacks<
|
||||||
|
BaseEpilogue,
|
||||||
|
FusionCallbacks,
|
||||||
|
EpilogueStages>;
|
||||||
|
|
||||||
|
/// Define the kernel-level GEMM operator.
|
||||||
|
using GemmKernel = kernel::SparseGemmWithEpilogueVisitor<Mma, Epilogue, ThreadblockSwizzle>;
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace gemm
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
117
include/cutlass/gemm/kernel/params_sparse_base.h
Normal file
117
include/cutlass/gemm/kernel/params_sparse_base.h
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
/*! \file
|
||||||
|
\brief Base functionality for common types of sparse GEMM kernel parameters
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace gemm {
|
||||||
|
namespace kernel {
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Parameters structure
|
||||||
|
template <
|
||||||
|
typename ThreadblockSwizzle,
|
||||||
|
typename ParamsA,
|
||||||
|
typename TensorRefA,
|
||||||
|
typename ParamsB,
|
||||||
|
typename TensorRefB,
|
||||||
|
typename ParamsE,
|
||||||
|
typename TensorRefE>
|
||||||
|
struct SparseParamsBase
|
||||||
|
{
|
||||||
|
//
|
||||||
|
// Data members
|
||||||
|
//
|
||||||
|
|
||||||
|
cutlass::gemm::GemmCoord problem_size;
|
||||||
|
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||||
|
int swizzle_log_tile;
|
||||||
|
ParamsA params_A;
|
||||||
|
TensorRefA ref_A;
|
||||||
|
ParamsB params_B;
|
||||||
|
TensorRefB ref_B;
|
||||||
|
ParamsE params_E;
|
||||||
|
TensorRefE ref_E;
|
||||||
|
int gemm_k_iterations;
|
||||||
|
int gemm_k_size;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Host dispatch API
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Default constructor
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
SparseParamsBase() : swizzle_log_tile(0), gemm_k_iterations(0), gemm_k_size(0) { }
|
||||||
|
|
||||||
|
|
||||||
|
/// Constructor
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
SparseParamsBase(
|
||||||
|
cutlass::gemm::GemmCoord const & problem_size,
|
||||||
|
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||||
|
TensorRefA ref_A,
|
||||||
|
TensorRefB ref_B,
|
||||||
|
TensorRefE ref_E,
|
||||||
|
int const mma_shape_k)
|
||||||
|
:
|
||||||
|
problem_size(problem_size),
|
||||||
|
grid_tiled_shape(grid_tiled_shape),
|
||||||
|
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||||
|
params_A(ref_A.layout()),
|
||||||
|
ref_A(ref_A),
|
||||||
|
params_B(ref_B.layout()),
|
||||||
|
ref_B(ref_B),
|
||||||
|
params_E(ref_E.layout()),
|
||||||
|
ref_E(ref_E)
|
||||||
|
{
|
||||||
|
int total_gemm_k_iterations = (problem_size.k() + mma_shape_k - 1) / mma_shape_k;
|
||||||
|
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
|
||||||
|
|
||||||
|
gemm_k_size = gemm_k_iterations * mma_shape_k;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace gemm
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -37,6 +37,7 @@
|
|||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
#include "cutlass/gemm/gemm.h"
|
#include "cutlass/gemm/gemm.h"
|
||||||
|
#include "cutlass/gemm/kernel/params_sparse_base.h"
|
||||||
#include "cutlass/matrix_coord.h"
|
#include "cutlass/matrix_coord.h"
|
||||||
#include "cutlass/semaphore.h"
|
#include "cutlass/semaphore.h"
|
||||||
|
|
||||||
@ -74,66 +75,58 @@ struct SparseGemm {
|
|||||||
using WarpCount = typename Mma::WarpCount;
|
using WarpCount = typename Mma::WarpCount;
|
||||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||||
|
|
||||||
|
using ParamsA = typename Mma::IteratorA::Params;
|
||||||
|
using TensorRefA = typename Mma::IteratorA::TensorRef;
|
||||||
|
using ParamsB = typename Mma::IteratorB::Params;
|
||||||
|
using TensorRefB = typename Mma::IteratorB::TensorRef;
|
||||||
|
using ParamsE = typename Mma::IteratorE::Params;
|
||||||
|
using TensorRefE = typename Mma::IteratorE::TensorRef;
|
||||||
|
|
||||||
/// Parameters structure
|
/// Parameters structure
|
||||||
struct Params {
|
struct Params : public SparseParamsBase<
|
||||||
cutlass::gemm::GemmCoord problem_size;
|
ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB,
|
||||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
ParamsE, TensorRefE> {
|
||||||
int swizzle_log_tile;
|
|
||||||
typename Mma::IteratorA::Params params_A;
|
using Base = SparseParamsBase<
|
||||||
typename Mma::IteratorA::TensorRef ref_A;
|
ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB,
|
||||||
typename Mma::IteratorB::Params params_B;
|
ParamsE, TensorRefE>;
|
||||||
typename Mma::IteratorB::TensorRef ref_B;
|
|
||||||
|
//
|
||||||
|
// Data members
|
||||||
|
//
|
||||||
typename Epilogue::OutputTileIterator::Params params_C;
|
typename Epilogue::OutputTileIterator::Params params_C;
|
||||||
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
||||||
typename Epilogue::OutputTileIterator::Params params_D;
|
typename Epilogue::OutputTileIterator::Params params_D;
|
||||||
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
||||||
typename Mma::IteratorE::Params params_E;
|
|
||||||
typename Mma::IteratorE::TensorRef ref_E;
|
|
||||||
typename OutputOp::Params output_op;
|
typename OutputOp::Params output_op;
|
||||||
int *semaphore;
|
int *semaphore;
|
||||||
int gemm_k_iterations;
|
|
||||||
int gemm_k_size;
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Methods
|
// Methods
|
||||||
//
|
//
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { }
|
Params() { }
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Params(
|
Params(
|
||||||
cutlass::gemm::GemmCoord const & problem_size,
|
cutlass::gemm::GemmCoord const & problem_size,
|
||||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||||
typename Mma::IteratorA::TensorRef ref_A,
|
TensorRefA ref_A,
|
||||||
typename Mma::IteratorB::TensorRef ref_B,
|
TensorRefB ref_B,
|
||||||
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
||||||
typename Epilogue::OutputTileIterator::TensorRef ref_D,
|
typename Epilogue::OutputTileIterator::TensorRef ref_D,
|
||||||
typename Mma::IteratorE::TensorRef ref_E,
|
TensorRefE ref_E,
|
||||||
typename OutputOp::Params output_op = typename OutputOp::Params(),
|
typename OutputOp::Params output_op = typename OutputOp::Params(),
|
||||||
int *workspace = nullptr
|
int *workspace = nullptr
|
||||||
):
|
):
|
||||||
problem_size(problem_size),
|
Base(problem_size, grid_tiled_shape, ref_A, ref_B, ref_E, Mma::Shape::kK),
|
||||||
grid_tiled_shape(grid_tiled_shape),
|
|
||||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
|
||||||
params_A(ref_A.layout()),
|
|
||||||
ref_A(ref_A),
|
|
||||||
params_B(ref_B.layout()),
|
|
||||||
ref_B(ref_B),
|
|
||||||
params_C(ref_C.layout()),
|
params_C(ref_C.layout()),
|
||||||
ref_C(ref_C),
|
ref_C(ref_C),
|
||||||
params_D(ref_D.layout()),
|
params_D(ref_D.layout()),
|
||||||
ref_D(ref_D),
|
ref_D(ref_D),
|
||||||
params_E(ref_E.layout()),
|
output_op(output_op),
|
||||||
ref_E(ref_E),
|
semaphore(workspace) {
|
||||||
output_op(output_op) {
|
|
||||||
|
|
||||||
int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
|
||||||
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
|
|
||||||
|
|
||||||
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
|
|
||||||
|
|
||||||
semaphore = workspace;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
237
include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h
Normal file
237
include/cutlass/gemm/kernel/sparse_gemm_with_visitor.h
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/*! \file
|
||||||
|
\brief Sparse GEMM with visitor.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/kernel/sparse_gemm.h"
|
||||||
|
#include "cutlass/gemm/kernel/params_sparse_base.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace gemm {
|
||||||
|
namespace kernel {
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Sparse Gemm that compute the epilogue visitor functor
|
||||||
|
template <
|
||||||
|
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||||
|
typename Epilogue_, ///! Epilogue
|
||||||
|
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||||
|
>
|
||||||
|
struct SparseGemmWithEpilogueVisitor : public SparseGemm<Mma_, Epilogue_, ThreadblockSwizzle_, false> {
|
||||||
|
|
||||||
|
using Base = SparseGemm<Mma_, Epilogue_, ThreadblockSwizzle_, false>;
|
||||||
|
|
||||||
|
using Mma = Mma_;
|
||||||
|
using Epilogue = Epilogue_;
|
||||||
|
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||||
|
|
||||||
|
using FusionCallbacks = typename Epilogue::FusionCallbacks;
|
||||||
|
|
||||||
|
using ParamsA = typename Mma::IteratorA::Params;
|
||||||
|
using TensorRefA = typename Mma::IteratorA::TensorRef;
|
||||||
|
using ParamsB = typename Mma::IteratorB::Params;
|
||||||
|
using TensorRefB = typename Mma::IteratorB::TensorRef;
|
||||||
|
using ParamsE = typename Mma::IteratorE::Params;
|
||||||
|
using TensorRefE = typename Mma::IteratorE::TensorRef;
|
||||||
|
|
||||||
|
static int const kSparse = Base::kSparse;
|
||||||
|
static int const kElementsPerElementE = Base::kElementsPerElementE;
|
||||||
|
using SharedStorage = typename Base::SharedStorage;
|
||||||
|
|
||||||
|
/// Parameters structure
|
||||||
|
struct Params : public SparseParamsBase<
|
||||||
|
ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB,
|
||||||
|
ParamsE, TensorRefE> {
|
||||||
|
|
||||||
|
using Base = SparseParamsBase<
|
||||||
|
ThreadblockSwizzle, ParamsA, TensorRefA, ParamsB, TensorRefB,
|
||||||
|
ParamsE, TensorRefE>;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Data members
|
||||||
|
//
|
||||||
|
|
||||||
|
typename FusionCallbacks::Params output_op;
|
||||||
|
cute::Shape<int32_t,int32_t,int32_t> problem_shape;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods
|
||||||
|
//
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Params() { }
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Params(
|
||||||
|
cutlass::gemm::GemmCoord const & problem_size,
|
||||||
|
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||||
|
typename Mma::IteratorA::TensorRef ref_A,
|
||||||
|
typename Mma::IteratorB::TensorRef ref_B,
|
||||||
|
typename Mma::IteratorE::TensorRef ref_E,
|
||||||
|
typename FusionCallbacks::Arguments output_op = typename FusionCallbacks::Arguments()
|
||||||
|
):
|
||||||
|
Base(problem_size, grid_tiled_shape, ref_A, ref_B, ref_E, Mma::Shape::kK),
|
||||||
|
output_op(FusionCallbacks::to_underlying_arguments(problem_size, output_op, nullptr /*workspace*/)),
|
||||||
|
problem_shape(problem_size.m(), problem_size.n(), 1) {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods
|
||||||
|
//
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
SparseGemmWithEpilogueVisitor() { }
|
||||||
|
|
||||||
|
/// Executes one GEMM
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||||
|
|
||||||
|
// Compute threadblock location
|
||||||
|
ThreadblockSwizzle threadblock_swizzle;
|
||||||
|
|
||||||
|
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||||
|
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||||
|
|
||||||
|
// Early exit if CTA is out of range
|
||||||
|
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||||
|
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute initial location in logical coordinates
|
||||||
|
cutlass::MatrixCoord tb_offset_A{
|
||||||
|
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||||
|
threadblock_tile_offset.k() * params.gemm_k_size / kSparse,
|
||||||
|
};
|
||||||
|
|
||||||
|
cutlass::MatrixCoord tb_offset_B{
|
||||||
|
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||||
|
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||||
|
};
|
||||||
|
|
||||||
|
cutlass::MatrixCoord tb_offset_E{
|
||||||
|
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||||
|
threadblock_tile_offset.k() * params.gemm_k_size / kSparse,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Problem size is a function of threadblock index in the K dimension
|
||||||
|
int problem_size_k = min(
|
||||||
|
params.problem_size.k(),
|
||||||
|
(threadblock_tile_offset.k() + 1) * params.gemm_k_size);
|
||||||
|
|
||||||
|
// Compute threadblock-scoped matrix multiply-add
|
||||||
|
int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||||
|
|
||||||
|
// Compute position within threadblock
|
||||||
|
int thread_idx = threadIdx.x;
|
||||||
|
|
||||||
|
// Construct iterators to A, B, and E operands
|
||||||
|
typename Mma::IteratorA iterator_A(
|
||||||
|
params.params_A,
|
||||||
|
params.ref_A.data(),
|
||||||
|
{params.problem_size.m(), problem_size_k / kSparse},
|
||||||
|
thread_idx,
|
||||||
|
tb_offset_A);
|
||||||
|
|
||||||
|
typename Mma::IteratorB iterator_B(
|
||||||
|
params.params_B,
|
||||||
|
params.ref_B.data(),
|
||||||
|
{problem_size_k, params.problem_size.n()},
|
||||||
|
thread_idx,
|
||||||
|
tb_offset_B);
|
||||||
|
|
||||||
|
typename Mma::IteratorE iterator_E(
|
||||||
|
params.params_E, params.ref_E.data(),
|
||||||
|
{params.problem_size.m(),
|
||||||
|
problem_size_k / kSparse / kElementsPerElementE},
|
||||||
|
thread_idx, tb_offset_E);
|
||||||
|
|
||||||
|
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||||
|
// is compiled as warp-uniform.
|
||||||
|
int warp_idx = canonical_warp_idx_sync();
|
||||||
|
int lane_idx = threadIdx.x % 32;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Main loop
|
||||||
|
//
|
||||||
|
|
||||||
|
// Construct thread-scoped matrix multiply
|
||||||
|
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||||
|
|
||||||
|
typename Mma::FragmentC accumulators;
|
||||||
|
|
||||||
|
accumulators.clear();
|
||||||
|
|
||||||
|
if (gemm_k_iterations > 0) {
|
||||||
|
// Compute threadblock-scoped matrix multiply-add
|
||||||
|
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators);
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Masked tile iterators constructed from members
|
||||||
|
//
|
||||||
|
|
||||||
|
threadblock_tile_offset =
|
||||||
|
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||||
|
|
||||||
|
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||||
|
|
||||||
|
//
|
||||||
|
// Epilogue
|
||||||
|
//
|
||||||
|
|
||||||
|
Epilogue epilogue(
|
||||||
|
params.output_op,
|
||||||
|
shared_storage.epilogue,
|
||||||
|
thread_idx,
|
||||||
|
warp_idx,
|
||||||
|
lane_idx);
|
||||||
|
|
||||||
|
// Execute the epilogue operator to update the destination tensor.
|
||||||
|
epilogue(accumulators, threadblock_tile_offset, params.problem_shape, thread_idx);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace gemm
|
||||||
|
} // namespace cutlass
|
||||||
Loading…
Reference in New Issue
Block a user