bug fixes and enharcement to gemm reductionK fusion (#682)

* add two missing files

* fix bunch of bugs of gemm-reducek fusion and add a device interface

* small changes

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Haicheng Wu 2022-11-03 11:07:50 -04:00 committed by GitHub
parent cc85b64cf6
commit 012c62c748
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 445 additions and 21 deletions

View File

@ -45,7 +45,7 @@ epilogue/threadblock/epilogue_gemm_k_reduction.h
#include <sstream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/device/gemm_with_k_reduction.h"
#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h"
#include "cutlass/reduction/device/reduce_split_k.h"
#include "cutlass/reduction/kernel/reduce_split_k.h"
@ -101,6 +101,12 @@ constexpr int NumStages = 4;
// Reduce A or B operand along the K dimension
constexpr bool ReduceKForA = true;
// Alignment of A operand
constexpr int AlignmentA = 8;
// Alignment of B operand
constexpr int AlignmentB = 8;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
@ -110,9 +116,9 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue>;
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction<
ElementInputA, LayoutInputA, cutlass::ComplexTransform::kNone, 8,
ElementInputB, LayoutInputB, cutlass::ComplexTransform::kNone, 8,
using Gemm = typename cutlass::gemm::device::GemmWithKReduction<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
@ -124,10 +130,12 @@ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction<
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd
>::GemmKernel;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
AlignmentA,
AlignmentB,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
// Below is the reduction kernel used in the case of parallel split-k
using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;;
@ -368,21 +376,21 @@ Result profile(Options const &options) {
// Fill input and output matrices on host using CUTLASS helper functions
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
1997,
ElementInputA(2),
ElementInputA(-2),
0); // <- Fill tensor A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
2003,
ElementInputB(2),
ElementInputB(-2),
0); // <- Fill tensor B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
2017,
ElementOutput(2),
ElementOutput(-2),
0); // <- Fill matrix C on host with uniform-distribution random data
@ -561,7 +569,7 @@ Result profile(Options const &options) {
tensor_reduction.sync_host();
// ReduceK in host code
// Reduce K in host code
if (ReduceKForA) {
for (int m = 0; m < options.problem_size.m(); ++m) {
for (int k = 0; k < options.problem_size.k(); ++k) {
@ -581,7 +589,7 @@ Result profile(Options const &options) {
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool pass = cutlass::reference::host::TensorEquals(tensor_d.host_view(),
tensor_ref_d.host_view());
pass &= cutlass::reference::host::TensorEquals(tensor_ref_reduction.host_view(),
tensor_reduction.host_view());

View File

@ -149,13 +149,12 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations / 4; ++i) {
ElementOutput tmp;
ElementOutput *source_ptr = reinterpret_cast<ElementOutput *>(&source);
cutlass::arch::global_load<ElementOutput, sizeof(ElementOutput)>(
tmp,
source_ptr[i],
(void *)(pointer_ + i * 32),
guard[i] && LoadForSerialSplitK);
source[i] = tmp;
}
FragmentAccumulator sum = gemm_k_with_reduction_accumulation;

View File

@ -0,0 +1,414 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/kernel/gemm_with_k_reduction.h"
#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/device/gemm_universal_base.h"
#include "cutlass/layout/permute.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
/*!
The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
batched array variants.
*/
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator_ = ElementC_,
/// Operator class tag
typename OperatorClass_ = arch::OpClassSimt,
/// Reduce A or B operand along the K dimension
bool ReduceKForA_ = true,
/// Tag indicating architecture to tune for. This is the minimum SM that
/// supports the intended feature. The device kernel can be built
/// targeting any SM larger than this number.
typename ArchTag_ = arch::Sm70,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
/// Number of stages used in the pipelined mainloop
int Stages =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kStages,
/// Access granularity of A matrix in units of elements
int AlignmentA =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::Operator,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA = ComplexTransform::kNone,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB = ComplexTransform::kNone,
/// Gather operand A by using an index array
bool GatherA = false,
/// Gather operand B by using an index array
bool GatherB = false,
/// Scatter result D by using an index array
bool ScatterD = false,
/// Permute result D
typename PermuteDLayout = layout::NoPermute
>
class GemmWithKReduction :
public GemmUniversalBase<
typename kernel::DefaultGemmWithKReduction<
ElementA_,
LayoutA_,
TransformA,
AlignmentA,
ElementB_,
LayoutB_,
TransformB,
AlignmentB,
ElementC_,
LayoutC_,
ElementAccumulator_,
OperatorClass_,
ReduceKForA_,
ArchTag_,
ThreadblockShape_,
WarpShape_,
InstructionShape_,
EpilogueOutputOp_,
ThreadblockSwizzle_,
Stages,
Operator_,
SharedMemoryClearOption::kNone
>::GemmKernel
> {
public:
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
static constexpr int kStages = Stages;
static constexpr int kAlignmentA = AlignmentA;
static constexpr int kAlignmentB = AlignmentB;
static constexpr int kAlignmentC = EpilogueOutputOp::kCount;
static constexpr ComplexTransform kTransformA = TransformA;
static constexpr ComplexTransform kTransformB = TransformB;
using Base = GemmUniversalBase<
typename kernel::DefaultGemmWithKReduction<
ElementA_,
LayoutA_,
TransformA,
AlignmentA,
ElementB_,
LayoutB_,
TransformB,
AlignmentB,
ElementC_,
LayoutC_,
ElementAccumulator_,
OperatorClass_,
ReduceKForA_,
ArchTag_,
ThreadblockShape_,
WarpShape_,
InstructionShape_,
EpilogueOutputOp_,
ThreadblockSwizzle_,
Stages,
Operator_,
SharedMemoryClearOption::kNone
>::GemmKernel
>;
using Arguments = typename Base::Arguments;
using GemmKernel = typename Base::GemmKernel;
};
////////////////////////////////////////////////////////////////////////////////
/// Parital specialization for column-major output exchanges problem size and operand.
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Operator class tag
typename OperatorClass_,
/// Reduce A or B operand along the K dimension
bool ReduceKForA_,
/// Tag indicating architecture to tune for. This is the minimum SM that
/// supports the intended feature. The device kernel can be built
/// targeting any SM larger than this number.
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Epilogue output operator
typename EpilogueOutputOp_,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Access granularity of A matrix in units of elements
int AlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB,
/// Operation performed by GEMM
typename Operator_,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Gather operand A by using an index array
bool GatherA,
/// Gather operand B by using an index array
bool GatherB,
/// Scatter result D by using an index array
bool ScatterD,
/// Permute result D
typename PermuteDLayout
>
class GemmWithKReduction<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
layout::ColumnMajor, // partially specialized on LayoutC
ElementAccumulator_, OperatorClass_, ReduceKForA_, ArchTag_, ThreadblockShape_,
WarpShape_, InstructionShape_, EpilogueOutputOp_,
ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB,
Operator_, TransformA, TransformB, GatherA, GatherB, ScatterD, PermuteDLayout> {
public:
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = layout::ColumnMajor;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using UnderlyingOperator = typename GemmWithKReduction<
ElementB,
typename layout::LayoutTranspose<LayoutB>::type,
ElementA,
typename layout::LayoutTranspose<LayoutA>::type,
ElementC,
layout::RowMajor,
ElementAccumulator,
OperatorClass,
!ReduceKForA_,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
kAlignmentB,
kAlignmentA,
Operator,
kTransformB,
kTransformA,
GatherB,
GatherA,
ScatterD,
PermuteDLayout
>::Base;
using GemmKernel = typename UnderlyingOperator::GemmKernel;
static int const kAlignmentC = EpilogueOutputOp::kCount;
/// Argument structure
using Arguments = typename UnderlyingOperator::Arguments;
private:
UnderlyingOperator underlying_operator_;
public:
/// Constructs the GEMM.
GemmWithKReduction() = default;
/// Helper to construct a transposed equivalent for the underying GEMM operator
static Arguments to_underlying_arguments(Arguments const &args) {
return args.transposed_problem();
}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const &args) {
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
return UnderlyingOperator::maximum_active_blocks(smem_capacity);
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
return underlying_operator_.update(to_underlying_arguments(args), workspace);
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
return underlying_operator_.run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -91,7 +91,7 @@ template <
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
///
/// Reduce A or B along the K dimension
bool ReduceKForA_,
/// Tag indicating architecture to tune for
typename ArchTag,

View File

@ -41,6 +41,7 @@
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/trace.h"

View File

@ -90,7 +90,7 @@ template <
typename LayoutC,
/// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp)
typename OperatorClass,
///
/// Reduce operand A or B along K dimension
bool ReduceKForA_,
/// Number of stages
int Stages = 2,

View File

@ -61,7 +61,9 @@ template <
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Operator describing the tensor operation
typename Operator_ = arch::OpMultiplyAdd,
typename Operator_,
/// Reduce operand A or B along K dimension
bool ReduceKForA_,
/// Number of partitions along K dimension
int PartitionsK = 1,
/// Store the accumulators in row major or column major. Row major is used
@ -78,7 +80,7 @@ struct DefaultMmaWithReductionTensorOp {
// Define the warp-level tensor op
using Type = cutlass::gemm::warp::MmaWithReductionTensorOp<
WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
Policy, PartitionsK, AccumulatorsInRowMajor>;
Policy, ReduceKForA_, PartitionsK, AccumulatorsInRowMajor>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -81,7 +81,7 @@ template <
typename LayoutC_,
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
typename Policy_,
///
/// Reduce operand A or B along K dimension
bool ReduceKForA_,
/// Number of partitions along K dimension
int PartitionsK_ = 1,