Add support for sparse GEMM with row broadcasted bias vector (#951)

This commit is contained in:
Aleksandar Samardžić 2023-05-24 16:25:05 +02:00 committed by GitHub
parent b4ab501767
commit d3e72719b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1847 additions and 7 deletions

View File

@ -0,0 +1,183 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/platform/platform.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_relu0.h"
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
#include "cutlass/epilogue/thread/linear_combination_hardswish.h"
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
#include "cutlass/epilogue/thread/conversion_op.h"
#include "cutlass/epilogue/thread/reduction_op.h"
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
#include "cutlass/layout/permute.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
template <
typename Shape_,
typename WarpMmaTensorOp_,
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess,
bool ScatterD = false,
typename PermuteDLayout = layout::NoPermute
>
struct DefaultEpilogueTensorOpRowBroadcast {
using Shape = Shape_;
using WarpMmaTensorOp = WarpMmaTensorOp_;
static int const kPartitionsK = PartitionsK;
using OutputOp = OutputOp_;
static int const kElementsPerAccess = ElementsPerAccess;
using ElementOutput = typename OutputOp::ElementOutput;
using LayoutC = typename WarpMmaTensorOp::LayoutC;
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
//
// Thread map
//
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
Shape,
typename WarpMmaTensorOp::Shape,
kPartitionsK,
ElementOutput,
kElementsPerAccess
>::Type;
static bool const UseCUDAStore = platform::is_same<ElementOutput, double>::value;
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorRowBroadcast<
OutputTileThreadMap,
ElementOutput,
ScatterD,
PermuteDLayout,
UseCUDAStore
>;
using AccumulatorFragmentIterator = typename platform::conditional<is_complex<ElementOutput>::value,
cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
typename WarpMmaTensorOp::Shape,
typename WarpMmaTensorOp::Policy::Operator::Shape,
typename WarpMmaTensorOp::Policy::Operator::ElementC,
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
LayoutC>,
cutlass::epilogue::warp::FragmentIteratorTensorOp<
typename WarpMmaTensorOp::Shape,
typename WarpMmaTensorOp::Policy::Operator::Shape,
typename WarpMmaTensorOp::Policy::Operator::ElementC,
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
LayoutC> >::type;
/// Support several implementations depending on structure of epilogue
using DefaultIterators = detail::DefaultIteratorsTensorOp<
ElementOutput,
ElementAccumulator,
kElementsPerAccess,
Shape,
typename WarpMmaTensorOp::Shape,
typename WarpMmaTensorOp::Policy::Operator::Shape,
typename OutputTileThreadMap::CompactedThreadMap
>;
using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
/// Hard-coded padding elements added
using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits<ElementAccumulator>::value * 4>;
static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1);
//
// Define the epilogue
//
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
Shape,
WarpMmaTensorOp,
kPartitionsK,
OutputTileIterator,
AccumulatorFragmentIterator,
WarpTileIterator,
SharedLoadIterator,
OutputOp,
Padding,
kFragmentsPerIteration
>;
};
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,519 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/permute.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Tile iterator used to load and store output tile from global memory in epilogue.
///
/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator
///
template <
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
typename Element_, ///< Element data type
bool ScatterD = false, ///< Scatter D operand or not
typename PermuteDLayout = layout::NoPermute, ///< Permute D operand or not
bool UseCUDAStore = false
>
class PredicatedTileIteratorRowBroadcast {
static_assert(!ScatterD);
static_assert(std::is_same<PermuteDLayout, layout::NoPermute>::value);
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;
using Element = Element_;
using Layout = layout::RowMajor;
using TensorRef = TensorRef<Element, Layout>;
using ConstTensorRef = typename TensorRef::ConstTensorRef;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = MatrixCoord;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
static int const kThreads = ThreadMap::kThreads;
static int const kIterations = ThreadMap::Count::kTile;
static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0");
static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0");
static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0");
static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0");
/// Fragment object
using Fragment = Array<
Element,
ThreadMap::Iterations::kColumn *
ThreadMap::Iterations::kRow *
ThreadMap::Iterations::kGroup *
ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
/// Memory access size
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
//
// Parameters struct
//
/// Uses a non-template class
struct Params : PredicatedTileIteratorParams {
using Base = PredicatedTileIteratorParams;
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(Layout const &layout):
PredicatedTileIteratorParams(
layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess,
make_OutputTileThreadMapDesc<ThreadMap>()
)
{ }
CUTLASS_HOST_DEVICE
Params(Base const &base) :
Base(base) { }
};
/// Mask object
struct Mask {
static int const kCount = ThreadMap::Iterations::kColumn;
/// Predicate state
bool predicates[kCount];
//
// Mask
//
CUTLASS_HOST_DEVICE
Mask() {
enable();
}
///< Efficiently disables all accesses guarded by mask
CUTLASS_HOST_DEVICE void clear() {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) {
predicates[i] = false;
}
}
///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask
CUTLASS_DEVICE void enable() {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) {
predicates[i] = true;
}
}
};
private:
//
// Data members
//
/// Parameters structure containing reference and precomputed state.
PredicatedTileIteratorParams params_;
/// Byte-level pointer.
uint8_t *byte_pointer_;
/// Byte-level pointer for store().
uint8_t *store_byte_pointer_;
/// Array of boolean values to contain steady-state predicates
Mask mask_;
/// Extent of the matrix tile in rows
Index extent_row_;
/// Extent of the matrix tile in rows
Index extent_column_;
/// A thread's starting row position (assuming steady-state predicates have been computed)
Index thread_start_row_;
/// A thread's starting column
Index thread_start_column_;
/// Internal state counter
int state_[3];
//
// Static asserts about internal strides
//
static_assert(sizeof(extent_row_) == 4, "Expected 32b extents");
static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents");
static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides");
private:
//
// Methods
//
public:
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
PredicatedTileIteratorRowBroadcast(
PredicatedTileIteratorParams const & params,
Element *pointer,
TensorCoord extent,
int thread_idx,
TensorCoord threadblock_offset = TensorCoord(),
int const *indices = nullptr
):
params_(params)
{
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
extent_row_ = extent.row();
extent_column_ = extent.column();
thread_start_row_ = thread_offset.row();
thread_start_column_ = thread_offset.column();
// Initialize predicates
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
mask_.predicates[c] = ((thread_offset.column()
+ ThreadMap::Delta::kColumn * c) < extent.column());
}
// Null pointer performs no accesses
if (!pointer) {
mask_.clear();
}
// Initialize byte_pointer_
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
// store_byte_pointer_ is set to be the same with byte_pointer_
store_byte_pointer_ = byte_pointer_;
// Initialize internal state counter
state_[0] = state_[1] = state_[2] = 0;
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
LongIndex(thread_offset.row()) * LongIndex(params_.stride);
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
store_byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const {
uint8_t *byte_pointer = byte_pointer_;
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow
+ group * ThreadMap::Delta::kGroup
+ cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
bool guard = row_guard && mask_.predicates[column];
/*
cutlass::arch::global_load<
AccessType,
sizeof(AccessType)
>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
column],
(void *)&memory_pointer[column * ThreadMap::Delta::kColumn /
kElementsPerAccess],
guard);
*/
if (guard) {
Element *bias = reinterpret_cast<Element*>(byte_pointer + byte_offset);
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column].fill(*bias);
}
}
if (row + 1 < ThreadMap::Iterations::kRow) {
byte_pointer += params_.increment_row;
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment &frag) const {
load_with_byte_offset(frag, 0);
}
/// Stores a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const {
uint8_t *byte_pointer = store_byte_pointer_;
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow
+ group * ThreadMap::Delta::kGroup
+ cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
bool guard = row_guard && mask_.predicates[column];
if (UseCUDAStore) {
if (guard) {
memory_pointer[0] =
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column];
}
} else {
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
(void *)&memory_pointer[0],
guard);
}
memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess);
}
if (row + 1 < ThreadMap::Iterations::kRow) {
byte_pointer += params_.increment_row;
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Stores a fragment to memory
CUTLASS_DEVICE
void store(Fragment const &frag) const {
store_with_byte_offset(frag, 0);
}
CUTLASS_DEVICE
MatrixCoord thread_start() const {
return MatrixCoord(thread_start_row_, thread_start_column_);
}
/// Need to get the thread start row from the tile iterator
CUTLASS_DEVICE
int32_t thread_start_row() const {
return thread_start_row_;
}
/// Need to get the thread start row from the tile iterator
CUTLASS_DEVICE
int32_t thread_start_column() const {
return thread_start_column_;
}
/// Extent of the matrix in rows
CUTLASS_DEVICE
Index extent_row() const {
return extent_row_;
}
/// Extent of the matrix in columns
CUTLASS_DEVICE
Index extent_column() const {
return extent_column_;
}
/// Advances to the next position to load or store
CUTLASS_HOST_DEVICE
PredicatedTileIteratorRowBroadcast &operator++() {
++state_[0];
store_byte_pointer_ += params_.advance_row;
byte_pointer_ += params_.advance_row;
thread_start_row_ += ThreadMap::Shape::kRow;
if (state_[0] == ThreadMap::Count::kRow) {
state_[0] = 0;
++state_[1];
byte_pointer_ += params_.advance_group;
store_byte_pointer_ += params_.advance_group;
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
if (state_[1] == ThreadMap::Count::kGroup) {
state_[1] = 0;
++state_[2];
byte_pointer_ += params_.advance_cluster;
store_byte_pointer_ += params_.advance_cluster;
thread_start_row_ += ThreadMap::Count::kGroup *
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
if (state_[2] == ThreadMap::Count::kCluster) {
state_[2] = 0;
byte_pointer_ += params_.advance_tile;
store_byte_pointer_ += params_.advance_tile;
thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow
* ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile;
}
}
}
return *this;
}
///< Efficiently disables all accesses guarded by mask
CUTLASS_DEVICE void clear_mask() {
mask_.clear();
}
///< Efficiently enables all accesses guarded by mask
CUTLASS_DEVICE void enable_mask() {
mask_.enable();
}
///< Sets the mask
CUTLASS_DEVICE void get_mask(Mask &mask) const {
mask = mask_;
}
///< Sets the mask
CUTLASS_DEVICE void set_mask(Mask const &mask) {
mask_ = mask;
}
};
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,514 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/kernel/sparse_gemm_row_broadcast.h"
#include "cutlass/gemm/kernel/default_gemm_sparse_row_broadcast.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may
be invoked from host code.
The contributions of this class are:
1. At compile time, it maps data types and high-level structural parameters onto
specific CUTLASS components.
2. At runtime, it maps logical arguments to GEMM problems to kernel parameters.
3. At runtime, it launches kernels on the device.
The intent is to provide a convenient mechanism for interacting with most plausible GEMM
configurations for each supported architecture. Consequently, not all parameters are exposed
to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy
are selected to tradeoff simplicity of the interface with flexibility. We expect
most configurations to be specified at this level. Applications with more exotic requirements
may construct their kernels of interest using CUTLASS components at the threadblock, warp,
and thread levels of abstraction.
CUTLASS exposes computations using the functor design pattern in which objects compose some
internal state with an overloaded function call operator. This enables decoupling of
initialization from execution, possibly reducing overhead during steady state phases of
application execution.
CUTLASS device-level operators expose an Arguments structure encompassing each logical
input to the computation. This is distinct from the kernel-level Params structure pattern
which contains application-specific precomputed state needed by the device code.
Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN
is as follows:
//
// Instantiate the CUTLASS GEMM operator.
//
cutlass::gemm::device::Gemm<
float,
cutlass::layout::ColumnMajor,
float,
cutlass::layout::ColumnMajor,
float,
cutlass::layout::ColumnMajor
> gemm_op;
//
// Launch the GEMM operation on the device
//
cutlass::Status status = gemm_op({
{m, n, k}, // GemmCoord problem_size,
{A, lda}, // TensorRef<float, layout::ColumnMajor> ref_A,
{B, ldb}, // TensorRef<float, layout::ColumnMajor> ref_B,
{C, ldc}, // TensorRef<float, layout::ColumnMajor> ref_C,
{D, ldd}, // TensorRef<float, layout::ColumnMajor> ref_D,
{alpha, beta} // EpilogueOutputOp::Params epilogue_op_params
});
A simplified view of the template is listed below.
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for. This is the minimum SM that
/// supports the intended feature. The device kernel can be built
/// targeting any SM larger than this number.
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages
>
class Gemm;
*/
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator_ = ElementC_,
/// Operator class tag
typename OperatorClass_ = arch::OpClassSimt,
/// Tag indicating architecture to tune for
typename ArchTag_ = arch::Sm70,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ =
typename threadblock::GemmIdentityThreadblockSwizzle<>,
/// Number of stages used in the pipelined mainloop
int Stages =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kStages,
/// Access granularity of A matrix in units of elements
int AlignmentA =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// If true, kernel supports split-K with serial reduction
bool SplitKSerial = false,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::Operator>
class SparseGemmRowBroadcast {
public:
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
using MathOperator = Operator;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp::kCount;
static bool const kSplitKSerial = SplitKSerial;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Define the kernel
using GemmKernel = typename kernel::DefaultSparseGemmRowBroadcast<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementC,
LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
kStages,
kSplitKSerial,
Operator
>::GemmKernel;
using ElementE = typename GemmKernel::ElementE;
using LayoutE = typename GemmKernel::LayoutE;
static int const kAlignmentE = 128 / sizeof_bits<ElementE>::value;
static int const kSparse = GemmKernel::kSparse;
static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits;
static int const kElementsPerElementE = GemmKernel::kElementsPerElementE;
/// Argument structure
struct Arguments {
//
// Data members
//
GemmCoord problem_size;
TensorRef<ElementA const, LayoutA> ref_A;
TensorRef<ElementB const, LayoutB> ref_B;
TensorRef<ElementC const, LayoutC> ref_C;
TensorRef<ElementC, LayoutC> ref_D;
TensorRef<ElementE const, LayoutE> ref_E;
typename EpilogueOutputOp::Params epilogue;
int split_k_slices;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments(): problem_size(0, 0, 0), split_k_slices(1) {
}
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord problem_size_,
TensorRef<ElementA const, LayoutA> ref_A_,
TensorRef<ElementB const, LayoutB> ref_B_,
TensorRef<ElementC const, LayoutC> ref_C_,
TensorRef<ElementC, LayoutC> ref_D_,
TensorRef<ElementE, LayoutE> ref_E_,
typename EpilogueOutputOp::Params epilogue_ =
typename EpilogueOutputOp::Params(),
int split_k_slices = 1
):
problem_size(problem_size_),
ref_A(ref_A_),
ref_B(ref_B_),
ref_C(ref_C_),
ref_D(ref_D_),
ref_E(ref_E_),
epilogue(epilogue_),
split_k_slices(split_k_slices) {
}
};
private:
/// Kernel parameters object
typename GemmKernel::Params params_;
public:
/// Constructs the GEMM.
SparseGemmRowBroadcast() { }
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
if (!kSplitKSerial && args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
Status status = GemmKernel::can_implement(
args.problem_size,
args.ref_A.non_const_ref(),
args.ref_B.non_const_ref(),
args.ref_C.non_const_ref(),
args.ref_D,
args.ref_E.non_const_ref()
);
if (status != Status::kSuccess) {
return status;
}
return Status::kSuccess;
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
size_t bytes = 0;
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.split_k_slices);
if (kSplitKSerial && args.split_k_slices > 1) {
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
}
return bytes;
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.split_k_slices);
if (kSplitKSerial) {
if (args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
size_t bytes = get_workspace_size(args);
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
}
else {
if (args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
}
// Initialize the Params structure
params_ = typename GemmKernel::Params{
args.problem_size,
grid_shape,
args.ref_A.non_const_ref(),
args.ref_B.non_const_ref(),
args.ref_C.non_const_ref(),
args.ref_D,
args.ref_E.non_const_ref(),
args.epilogue,
static_cast<int *>(workspace)
};
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
cudaError_t result = cudaFuncSetAttribute(Kernel<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
if (kSplitKSerial && args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
}
params_.ref_A.reset(args.ref_A.non_const_ref().data());
params_.ref_B.reset(args.ref_B.non_const_ref().data());
params_.ref_C.reset(args.ref_C.non_const_ref().data());
params_.ref_D.reset(args.ref_D.data());
params_.ref_E.reset(args.ref_E.non_const_ref().data());
params_.output_op = args.epilogue;
params_.semaphore = static_cast<int *>(workspace);
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(GemmKernel::kThreadCount, 1, 1);
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
cudaError_t result = cudaGetLastError();
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
} // namespace device
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,191 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/wmma.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm.h"
#include "cutlass/gemm/kernel/sparse_gemm_row_broadcast.h"
#include "cutlass/gemm/kernel/gemm_pipelined.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h"
#include "cutlass/gemm/threadblock/default_sparse_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op_row_broadcast.h"
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h"
#endif //CUTLASS_ARCH_WMMA_ENABLED
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultSparseGemmRowBroadcast;
////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Ampere Architecture
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of A matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultSparseGemmRowBroadcast<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial,
Operator> {
/// Define the threadblock-scoped matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultSparseMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape, WarpShape, InstructionShape, Stages,
Operator>::ThreadblockMma;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpRowBroadcast<
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
EpilogueOutputOp::kCount>::Epilogue;
/// Define the kernel-level GEMM operator.
using GemmKernel = kernel::SparseGemmRowBroadcast<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
};
////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,400 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
>
struct SparseGemmRowBroadcast {
using Mma = Mma_;
using Epilogue = Epilogue_;
using OutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
static int const kSparse = Mma::kSparse;
static int const kMetaSizeInBits = Mma::kMetaSizeInBits;
static int const kMaxID2 = Mma::kMaxID2;
static int const kElementsPerElementE = Mma::kElementsPerElementE;
using ElementE = typename Mma::ElementE;
using LayoutE = typename Mma::LayoutE;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::Params params_B;
typename Mma::IteratorB::TensorRef ref_B;
typename Epilogue::OutputTileIterator::Params params_C;
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::Params params_D;
typename Epilogue::OutputTileIterator::TensorRef ref_D;
typename Mma::IteratorE::Params params_E;
typename Mma::IteratorE::TensorRef ref_E;
typename OutputOp::Params output_op;
int *semaphore;
int gemm_k_iterations;
int gemm_k_size;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { }
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmCoord const & problem_size,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D,
typename Mma::IteratorE::TensorRef ref_E,
typename OutputOp::Params output_op = typename OutputOp::Params(),
int *workspace = nullptr
):
problem_size(problem_size),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
params_A(ref_A.layout()),
ref_A(ref_A),
params_B(ref_B.layout()),
ref_B(ref_B),
params_C(ref_C.layout()),
ref_C(ref_C),
params_D(ref_D.layout()),
ref_D(ref_D),
params_E(ref_E.layout()),
ref_E(ref_E),
output_op(output_op) {
int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
semaphore = workspace;
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
//
// Methods
//
CUTLASS_HOST_DEVICE
SparseGemmRowBroadcast() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D,
typename Mma::IteratorE::TensorRef ref_E) {
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
static int const kAlignmentE = Mma::IteratorE::AccessType::kElements;
if (!TensorRef_aligned(ref_A, kAlignmentA)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_B, kAlignmentB)) {
return Status::kErrorMisalignedOperand;
}
// if (!TensorRef_aligned(ref_C, kAlignmentC)) {
// return Status::kErrorMisalignedOperand;
// }
if (!TensorRef_aligned(ref_D, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_E, kAlignmentE)) {
return Status::kErrorMisalignedOperand;
}
if ((problem_size.m() % kAlignmentA) || ((problem_size.k() / kSparse) % kAlignmentA) ||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC) ||
(problem_size.m() % kAlignmentE) || ((problem_size.k() / kSparse) % kAlignmentE)) {
return Status::kErrorMisalignedOperand;
}
// The k dimension has to be the multiple of the Threadblock k because out
// of bound meta data would be initialized to 0 by acync.zfill but 0 is not
// a valid meta data.
if (problem_size.k() % Mma::Shape::kK) {
return Status::kErrorMisalignedOperand;
}
// M dimension has to be multiple of 32 (sparse float) or 16 (sparse int)
// because of the row reordering of operand E
static int const kAlignmentM = (sizeof(ElementE) == 2) ? 32 : 16;
if (problem_size.m() % kAlignmentM) {
return Status::kErrorMisalignedOperand;
}
return Status::kSuccess;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size / kSparse,
};
cutlass::MatrixCoord tb_offset_B{
threadblock_tile_offset.k() * params.gemm_k_size,
threadblock_tile_offset.n() * Mma::Shape::kN
};
cutlass::MatrixCoord tb_offset_E{
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size / kSparse,
};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k = min(
params.problem_size.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A, B, and E operands
typename Mma::IteratorA iterator_A(
params.params_A,
params.ref_A.data(),
{params.problem_size.m(), problem_size_k / kSparse},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B,
params.ref_B.data(),
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B);
typename Mma::IteratorE iterator_E(
params.params_E, params.ref_E.data(),
{params.problem_size.m(),
problem_size_k / kSparse / kElementsPerElementE},
thread_idx, tb_offset_E);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = canonical_warp_idx();
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
if (!kSplitKSerial || gemm_k_iterations > 0) {
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators);
}
//
// Epilogue
//
OutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
//assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params.params_C,
params.ref_C.data(),
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
params.ref_D.data(),
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
__threadfence();
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
__threadfence();
semaphore.release(lock);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -37,6 +37,7 @@
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_sparse.h"
#include "cutlass/gemm/device/gemm_sparse_row_broadcast.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
@ -267,6 +268,24 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x128_32x32x128)
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_Row_Broadcast_f16n_f16n_f16t_tensor_op_f32, 64x64x128_32x32x128) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::SparseGemmRowBroadcast<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>(true));
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED)

View File

@ -164,14 +164,19 @@ struct SparseTestbed {
}
/// Initializes data structures
void initialize(cutlass::gemm::GemmCoord problem_size) {
void initialize(cutlass::gemm::GemmCoord problem_size, bool tensor_C_row_broadcast = false) {
//
// Allocate the GEMM workspace
//
tensor_A.resize(cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse));
tensor_A_uncompressed.resize(problem_size.mk());
tensor_B.resize(problem_size.kn());
tensor_C.resize(problem_size.mn());
if (tensor_C_row_broadcast) {
tensor_C.resize({problem_size.m(), 1});
}
else {
tensor_C.resize(problem_size.mn());
}
tensor_D.resize(problem_size.mn());
reference_D.resize(problem_size.mn(), false);
tensor_E.resize(cutlass::make_Coord(
@ -206,7 +211,14 @@ struct SparseTestbed {
tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1);
tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1);
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
if (tensor_C_row_broadcast) {
for (int i = 0; i < problem_size.m(); ++i)
for (int j = 0; j < problem_size.n(); ++j)
reference_D.host_view().at({i, j}) = tensor_C.host_view().at({i, 0});
}
else {
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
}
tensor_A.sync_device();
tensor_B.sync_device();
@ -338,7 +350,8 @@ struct SparseTestbed {
cutlass::gemm::GemmCoord problem_size,
int split_k_slices = 1,
ElementCompute alpha = ElementCompute(1),
ElementCompute beta = ElementCompute(0)) {
ElementCompute beta = ElementCompute(0),
bool tensor_C_row_broadcast = false) {
// Waive test if insufficient CUDA device
if (!sufficient()) {
@ -348,7 +361,7 @@ struct SparseTestbed {
return true;
}
this->initialize(problem_size);
this->initialize(problem_size, tensor_C_row_broadcast);
//
// Initialize the GEMM operator
@ -403,7 +416,7 @@ struct SparseTestbed {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Gemm>
bool TestAllSparseGemm() {
bool TestAllSparseGemm(bool tensor_C_row_broadcast = false) {
bool passed = true;
int const kMinimumOperandElementSize =
@ -463,7 +476,8 @@ bool TestAllSparseGemm() {
problem_size,
split_k,
cutlass::from_real<ElementCompute>(alpha),
cutlass::from_real<ElementCompute>(beta)
cutlass::from_real<ElementCompute>(beta),
tensor_C_row_broadcast
);
if (!passed) {