bug fixes and enharcement to gemm reductionK fusion (#682)
* add two missing files * fix bunch of bugs of gemm-reducek fusion and add a device interface * small changes Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
cc85b64cf6
commit
012c62c748
@ -45,7 +45,7 @@ epilogue/threadblock/epilogue_gemm_k_reduction.h
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/device/gemm_with_k_reduction.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h"
|
||||
#include "cutlass/reduction/device/reduce_split_k.h"
|
||||
#include "cutlass/reduction/kernel/reduce_split_k.h"
|
||||
@ -101,6 +101,12 @@ constexpr int NumStages = 4;
|
||||
// Reduce A or B operand along the K dimension
|
||||
constexpr bool ReduceKForA = true;
|
||||
|
||||
// Alignment of A operand
|
||||
constexpr int AlignmentA = 8;
|
||||
|
||||
// Alignment of B operand
|
||||
constexpr int AlignmentB = 8;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
@ -110,9 +116,9 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue>;
|
||||
|
||||
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction<
|
||||
ElementInputA, LayoutInputA, cutlass::ComplexTransform::kNone, 8,
|
||||
ElementInputB, LayoutInputB, cutlass::ComplexTransform::kNone, 8,
|
||||
using Gemm = typename cutlass::gemm::device::GemmWithKReduction<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
@ -124,10 +130,12 @@ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction<
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd
|
||||
>::GemmKernel;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
AlignmentA,
|
||||
AlignmentB,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::ComplexTransform::kNone
|
||||
>;
|
||||
|
||||
// Below is the reduction kernel used in the case of parallel split-k
|
||||
using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;;
|
||||
@ -368,21 +376,21 @@ Result profile(Options const &options) {
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
1997,
|
||||
ElementInputA(2),
|
||||
ElementInputA(-2),
|
||||
0); // <- Fill tensor A on host with uniform-distribution random data
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
2003,
|
||||
ElementInputB(2),
|
||||
ElementInputB(-2),
|
||||
0); // <- Fill tensor B on host with uniform-distribution random data
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
2017,
|
||||
ElementOutput(2),
|
||||
ElementOutput(-2),
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
@ -561,7 +569,7 @@ Result profile(Options const &options) {
|
||||
|
||||
tensor_reduction.sync_host();
|
||||
|
||||
// ReduceK in host code
|
||||
// Reduce K in host code
|
||||
if (ReduceKForA) {
|
||||
for (int m = 0; m < options.problem_size.m(); ++m) {
|
||||
for (int k = 0; k < options.problem_size.k(); ++k) {
|
||||
@ -581,7 +589,7 @@ Result profile(Options const &options) {
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool pass = cutlass::reference::host::TensorEquals(tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
|
||||
pass &= cutlass::reference::host::TensorEquals(tensor_ref_reduction.host_view(),
|
||||
tensor_reduction.host_view());
|
||||
|
||||
|
@ -149,13 +149,12 @@ public:
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations / 4; ++i) {
|
||||
ElementOutput tmp;
|
||||
ElementOutput *source_ptr = reinterpret_cast<ElementOutput *>(&source);
|
||||
cutlass::arch::global_load<ElementOutput, sizeof(ElementOutput)>(
|
||||
tmp,
|
||||
source_ptr[i],
|
||||
(void *)(pointer_ + i * 32),
|
||||
guard[i] && LoadForSerialSplitK);
|
||||
|
||||
source[i] = tmp;
|
||||
}
|
||||
|
||||
FragmentAccumulator sum = gemm_k_with_reduction_accumulation;
|
||||
|
414
include/cutlass/gemm/device/gemm_with_k_reduction.h
Normal file
414
include/cutlass/gemm/device/gemm_with_k_reduction.h
Normal file
@ -0,0 +1,414 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/kernel/gemm_with_k_reduction.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_base.h"
|
||||
|
||||
#include "cutlass/layout/permute.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!
|
||||
The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
|
||||
batched array variants.
|
||||
*/
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_ = ElementC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_ = arch::OpClassSimt,
|
||||
/// Reduce A or B operand along the K dimension
|
||||
bool ReduceKForA_ = true,
|
||||
/// Tag indicating architecture to tune for. This is the minimum SM that
|
||||
/// supports the intended feature. The device kernel can be built
|
||||
/// targeting any SM larger than this number.
|
||||
typename ArchTag_ = arch::Sm70,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kStages,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentA,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::Operator,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA = ComplexTransform::kNone,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB = ComplexTransform::kNone,
|
||||
/// Gather operand A by using an index array
|
||||
bool GatherA = false,
|
||||
/// Gather operand B by using an index array
|
||||
bool GatherB = false,
|
||||
/// Scatter result D by using an index array
|
||||
bool ScatterD = false,
|
||||
/// Permute result D
|
||||
typename PermuteDLayout = layout::NoPermute
|
||||
>
|
||||
class GemmWithKReduction :
|
||||
public GemmUniversalBase<
|
||||
typename kernel::DefaultGemmWithKReduction<
|
||||
ElementA_,
|
||||
LayoutA_,
|
||||
TransformA,
|
||||
AlignmentA,
|
||||
ElementB_,
|
||||
LayoutB_,
|
||||
TransformB,
|
||||
AlignmentB,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
ElementAccumulator_,
|
||||
OperatorClass_,
|
||||
ReduceKForA_,
|
||||
ArchTag_,
|
||||
ThreadblockShape_,
|
||||
WarpShape_,
|
||||
InstructionShape_,
|
||||
EpilogueOutputOp_,
|
||||
ThreadblockSwizzle_,
|
||||
Stages,
|
||||
Operator_,
|
||||
SharedMemoryClearOption::kNone
|
||||
>::GemmKernel
|
||||
> {
|
||||
|
||||
public:
|
||||
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
using Operator = Operator_;
|
||||
static constexpr int kStages = Stages;
|
||||
static constexpr int kAlignmentA = AlignmentA;
|
||||
static constexpr int kAlignmentB = AlignmentB;
|
||||
static constexpr int kAlignmentC = EpilogueOutputOp::kCount;
|
||||
static constexpr ComplexTransform kTransformA = TransformA;
|
||||
static constexpr ComplexTransform kTransformB = TransformB;
|
||||
|
||||
using Base = GemmUniversalBase<
|
||||
typename kernel::DefaultGemmWithKReduction<
|
||||
ElementA_,
|
||||
LayoutA_,
|
||||
TransformA,
|
||||
AlignmentA,
|
||||
ElementB_,
|
||||
LayoutB_,
|
||||
TransformB,
|
||||
AlignmentB,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
ElementAccumulator_,
|
||||
OperatorClass_,
|
||||
ReduceKForA_,
|
||||
ArchTag_,
|
||||
ThreadblockShape_,
|
||||
WarpShape_,
|
||||
InstructionShape_,
|
||||
EpilogueOutputOp_,
|
||||
ThreadblockSwizzle_,
|
||||
Stages,
|
||||
Operator_,
|
||||
SharedMemoryClearOption::kNone
|
||||
>::GemmKernel
|
||||
>;
|
||||
|
||||
using Arguments = typename Base::Arguments;
|
||||
using GemmKernel = typename Base::GemmKernel;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parital specialization for column-major output exchanges problem size and operand.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Reduce A or B operand along the K dimension
|
||||
bool ReduceKForA_,
|
||||
/// Tag indicating architecture to tune for. This is the minimum SM that
|
||||
/// supports the intended feature. The device kernel can be built
|
||||
/// targeting any SM larger than this number.
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp_,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Gather operand A by using an index array
|
||||
bool GatherA,
|
||||
/// Gather operand B by using an index array
|
||||
bool GatherB,
|
||||
/// Scatter result D by using an index array
|
||||
bool ScatterD,
|
||||
/// Permute result D
|
||||
typename PermuteDLayout
|
||||
>
|
||||
class GemmWithKReduction<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
|
||||
layout::ColumnMajor, // partially specialized on LayoutC
|
||||
ElementAccumulator_, OperatorClass_, ReduceKForA_, ArchTag_, ThreadblockShape_,
|
||||
WarpShape_, InstructionShape_, EpilogueOutputOp_,
|
||||
ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB,
|
||||
Operator_, TransformA, TransformB, GatherA, GatherB, ScatterD, PermuteDLayout> {
|
||||
public:
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = LayoutA_;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = LayoutB_;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = layout::ColumnMajor;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
using Operator = Operator_;
|
||||
static int const kStages = Stages;
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
using UnderlyingOperator = typename GemmWithKReduction<
|
||||
ElementB,
|
||||
typename layout::LayoutTranspose<LayoutB>::type,
|
||||
ElementA,
|
||||
typename layout::LayoutTranspose<LayoutA>::type,
|
||||
ElementC,
|
||||
layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
!ReduceKForA_,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
kAlignmentB,
|
||||
kAlignmentA,
|
||||
Operator,
|
||||
kTransformB,
|
||||
kTransformA,
|
||||
GatherB,
|
||||
GatherA,
|
||||
ScatterD,
|
||||
PermuteDLayout
|
||||
>::Base;
|
||||
|
||||
using GemmKernel = typename UnderlyingOperator::GemmKernel;
|
||||
static int const kAlignmentC = EpilogueOutputOp::kCount;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename UnderlyingOperator::Arguments;
|
||||
|
||||
private:
|
||||
|
||||
UnderlyingOperator underlying_operator_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the GEMM.
|
||||
GemmWithKReduction() = default;
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
static Arguments to_underlying_arguments(Arguments const &args) {
|
||||
return args.transposed_problem();
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const &args) {
|
||||
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1) {
|
||||
return UnderlyingOperator::maximum_active_blocks(smem_capacity);
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
return underlying_operator_.update(to_underlying_arguments(args), workspace);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
return underlying_operator_.run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
@ -91,7 +91,7 @@ template <
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
///
|
||||
/// Reduce A or B along the K dimension
|
||||
bool ReduceKForA_,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
|
@ -41,6 +41,7 @@
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
|
@ -90,7 +90,7 @@ template <
|
||||
typename LayoutC,
|
||||
/// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp)
|
||||
typename OperatorClass,
|
||||
///
|
||||
/// Reduce operand A or B along K dimension
|
||||
bool ReduceKForA_,
|
||||
/// Number of stages
|
||||
int Stages = 2,
|
||||
|
@ -61,7 +61,9 @@ template <
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC,
|
||||
/// Operator describing the tensor operation
|
||||
typename Operator_ = arch::OpMultiplyAdd,
|
||||
typename Operator_,
|
||||
/// Reduce operand A or B along K dimension
|
||||
bool ReduceKForA_,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK = 1,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
@ -78,7 +80,7 @@ struct DefaultMmaWithReductionTensorOp {
|
||||
// Define the warp-level tensor op
|
||||
using Type = cutlass::gemm::warp::MmaWithReductionTensorOp<
|
||||
WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
||||
Policy, PartitionsK, AccumulatorsInRowMajor>;
|
||||
Policy, ReduceKForA_, PartitionsK, AccumulatorsInRowMajor>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -81,7 +81,7 @@ template <
|
||||
typename LayoutC_,
|
||||
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
|
||||
typename Policy_,
|
||||
///
|
||||
/// Reduce operand A or B along K dimension
|
||||
bool ReduceKForA_,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK_ = 1,
|
||||
|
Loading…
Reference in New Issue
Block a user