Merge pull request #228 from mani-ananth/master

Fix for issue#224 and issue#225
This commit is contained in:
Haicheng Wu 2021-04-08 10:08:13 -04:00 committed by GitHub
commit c805593ebe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 119 additions and 9 deletions

View File

@ -111,7 +111,8 @@ using Gemm = cutlass::gemm::device::SparseGemm<ElementInputA,
// Data type and layout of meta data matrix E can be inferred from template Gemm.
using ElementInputE = typename Gemm::ElementE;
using LayoutInputE = typename Gemm::LayoutE;
using LayoutInputE = cutlass::layout::RowMajor;
using ReorderedLayoutInputE = typename Gemm::LayoutE;
// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h
// 50% Sparsity on Ampere
@ -151,27 +152,27 @@ int run() {
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
// Same size as the above. The above one needs to be reordered and stored in this one.
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e_reordered(
cutlass::HostTensor<ElementInputE, ReorderedLayoutInputE> tensor_e_reordered(
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
// Fill input and output matrices on host using CUTLASS helper functions
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(1),
ElementInputA(-1),
ElementInputA(2),
ElementInputA(-2),
0); // <- Fill matrix A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(1),
ElementInputB(-1),
ElementInputB(2),
ElementInputB(-2),
0); // <- Fill matrix B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(1),
ElementOutput(-1),
ElementOutput(2),
ElementOutput(-2),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomSparseMeta(
tensor_e.host_view(),
@ -210,7 +211,7 @@ int run() {
tensor_b.device_ref(), // <- reference to matrix B on device
tensor_c.device_ref(), // <- reference to matrix C on device
tensor_d.device_ref(), // <- reference to matrix D on device
tensor_e.device_ref(), // <- reference to matrix E on device
tensor_e_reordered.device_ref(), // <- reference to matrix E on device
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices}; // <- k-dimension split factor

View File

@ -284,6 +284,115 @@ struct DefaultConv3dFprop <
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv3dFprop specialization for Optimized Iterator Algorithm
/// and 2 stage pipeline.
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag
>
struct DefaultConv3dFprop <
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
2,
MathOperatorTag,
IteratorAlgorithm::kOptimized
> {
// Define the core components from GEMM
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
2, MathOperatorTag>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using IteratorA =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
LayoutA,
ThreadMapA
>
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using IteratorB =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
LayoutB,
ThreadMapB
>
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
// Warp-level GEMM components
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
using MmaPolicy = typename MmaCore::MmaPolicy;
// Define the Mma
using Mma = threadblock::ImplicitGemmPipelined<
ThreadblockShape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
MmaPolicy
>;
// Define the epilogue
using Epilogue = typename detail::DefaultConvEpilogue<
ArchTag,
ThreadblockShape,
WarpMmaTensorOp,
1,
EpilogueOutputOp
>::Epilogue;
// Define the kernel
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
Mma,
Epilogue,
ThreadblockSwizzle,
conv::Operator::kFprop,
Conv3dProblemSize
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv3dFprop specialzation for Optimized IteratorAlgorithm and multistage
// pipeline.
template <