Merge pull request #228 from mani-ananth/master
Fix for issue#224 and issue#225
This commit is contained in:
commit
c805593ebe
@ -111,7 +111,8 @@ using Gemm = cutlass::gemm::device::SparseGemm<ElementInputA,
|
|||||||
|
|
||||||
// Data type and layout of meta data matrix E can be inferred from template Gemm.
|
// Data type and layout of meta data matrix E can be inferred from template Gemm.
|
||||||
using ElementInputE = typename Gemm::ElementE;
|
using ElementInputE = typename Gemm::ElementE;
|
||||||
using LayoutInputE = typename Gemm::LayoutE;
|
using LayoutInputE = cutlass::layout::RowMajor;
|
||||||
|
using ReorderedLayoutInputE = typename Gemm::LayoutE;
|
||||||
|
|
||||||
// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h
|
// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h
|
||||||
// 50% Sparsity on Ampere
|
// 50% Sparsity on Ampere
|
||||||
@ -151,27 +152,27 @@ int run() {
|
|||||||
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
|
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
|
||||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||||
// Same size as the above. The above one needs to be reordered and stored in this one.
|
// Same size as the above. The above one needs to be reordered and stored in this one.
|
||||||
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e_reordered(
|
cutlass::HostTensor<ElementInputE, ReorderedLayoutInputE> tensor_e_reordered(
|
||||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||||
|
|
||||||
// Fill input and output matrices on host using CUTLASS helper functions
|
// Fill input and output matrices on host using CUTLASS helper functions
|
||||||
cutlass::reference::host::TensorFillRandomUniform(
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
tensor_a.host_view(),
|
tensor_a.host_view(),
|
||||||
1,
|
1,
|
||||||
ElementInputA(1),
|
ElementInputA(2),
|
||||||
ElementInputA(-1),
|
ElementInputA(-2),
|
||||||
0); // <- Fill matrix A on host with uniform-distribution random data
|
0); // <- Fill matrix A on host with uniform-distribution random data
|
||||||
cutlass::reference::host::TensorFillRandomUniform(
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
tensor_b.host_view(),
|
tensor_b.host_view(),
|
||||||
1,
|
1,
|
||||||
ElementInputB(1),
|
ElementInputB(2),
|
||||||
ElementInputB(-1),
|
ElementInputB(-2),
|
||||||
0); // <- Fill matrix B on host with uniform-distribution random data
|
0); // <- Fill matrix B on host with uniform-distribution random data
|
||||||
cutlass::reference::host::TensorFillRandomUniform(
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
tensor_c.host_view(),
|
tensor_c.host_view(),
|
||||||
1,
|
1,
|
||||||
ElementOutput(1),
|
ElementOutput(2),
|
||||||
ElementOutput(-1),
|
ElementOutput(-2),
|
||||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||||
cutlass::reference::host::TensorFillRandomSparseMeta(
|
cutlass::reference::host::TensorFillRandomSparseMeta(
|
||||||
tensor_e.host_view(),
|
tensor_e.host_view(),
|
||||||
@ -210,7 +211,7 @@ int run() {
|
|||||||
tensor_b.device_ref(), // <- reference to matrix B on device
|
tensor_b.device_ref(), // <- reference to matrix B on device
|
||||||
tensor_c.device_ref(), // <- reference to matrix C on device
|
tensor_c.device_ref(), // <- reference to matrix C on device
|
||||||
tensor_d.device_ref(), // <- reference to matrix D on device
|
tensor_d.device_ref(), // <- reference to matrix D on device
|
||||||
tensor_e.device_ref(), // <- reference to matrix E on device
|
tensor_e_reordered.device_ref(), // <- reference to matrix E on device
|
||||||
{alpha, beta}, // <- tuple of alpha and beta
|
{alpha, beta}, // <- tuple of alpha and beta
|
||||||
split_k_slices}; // <- k-dimension split factor
|
split_k_slices}; // <- k-dimension split factor
|
||||||
|
|
||||||
|
@ -284,6 +284,115 @@ struct DefaultConv3dFprop <
|
|||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Defines a kernel for Conv3dFprop specialization for Optimized Iterator Algorithm
|
||||||
|
/// and 2 stage pipeline.
|
||||||
|
template <
|
||||||
|
typename ElementA,
|
||||||
|
typename LayoutA,
|
||||||
|
typename ElementB,
|
||||||
|
typename LayoutB,
|
||||||
|
typename ElementC,
|
||||||
|
typename LayoutC,
|
||||||
|
typename ElementAccumulator,
|
||||||
|
typename ArchTag,
|
||||||
|
typename ThreadblockShape,
|
||||||
|
typename WarpShape,
|
||||||
|
typename InstructionShape,
|
||||||
|
typename EpilogueOutputOp,
|
||||||
|
typename ThreadblockSwizzle,
|
||||||
|
typename MathOperatorTag
|
||||||
|
>
|
||||||
|
struct DefaultConv3dFprop <
|
||||||
|
ElementA,
|
||||||
|
LayoutA,
|
||||||
|
ElementB,
|
||||||
|
LayoutB,
|
||||||
|
ElementC,
|
||||||
|
LayoutC,
|
||||||
|
ElementAccumulator,
|
||||||
|
arch::OpClassTensorOp,
|
||||||
|
ArchTag,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpShape,
|
||||||
|
InstructionShape,
|
||||||
|
EpilogueOutputOp,
|
||||||
|
ThreadblockSwizzle,
|
||||||
|
2,
|
||||||
|
MathOperatorTag,
|
||||||
|
IteratorAlgorithm::kOptimized
|
||||||
|
> {
|
||||||
|
|
||||||
|
// Define the core components from GEMM
|
||||||
|
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||||
|
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||||
|
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||||
|
2, MathOperatorTag>;
|
||||||
|
|
||||||
|
// Define iterators over tiles from the A operand
|
||||||
|
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||||
|
using IteratorA =
|
||||||
|
cutlass::conv::threadblock::TileIterator<
|
||||||
|
cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized<
|
||||||
|
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||||
|
ElementA,
|
||||||
|
LayoutA,
|
||||||
|
ThreadMapA
|
||||||
|
>
|
||||||
|
>;
|
||||||
|
|
||||||
|
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||||
|
|
||||||
|
// Define iterators over tiles from the B operand
|
||||||
|
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||||
|
using IteratorB =
|
||||||
|
cutlass::conv::threadblock::TileIterator<
|
||||||
|
cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized<
|
||||||
|
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||||
|
ElementB,
|
||||||
|
LayoutB,
|
||||||
|
ThreadMapB
|
||||||
|
>
|
||||||
|
>;
|
||||||
|
|
||||||
|
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||||
|
|
||||||
|
// Warp-level GEMM components
|
||||||
|
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||||
|
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||||
|
|
||||||
|
// Define the Mma
|
||||||
|
using Mma = threadblock::ImplicitGemmPipelined<
|
||||||
|
ThreadblockShape,
|
||||||
|
IteratorA,
|
||||||
|
SmemIteratorA,
|
||||||
|
IteratorB,
|
||||||
|
SmemIteratorB,
|
||||||
|
ElementC,
|
||||||
|
LayoutC,
|
||||||
|
MmaPolicy
|
||||||
|
>;
|
||||||
|
|
||||||
|
// Define the epilogue
|
||||||
|
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||||
|
ArchTag,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpMmaTensorOp,
|
||||||
|
1,
|
||||||
|
EpilogueOutputOp
|
||||||
|
>::Epilogue;
|
||||||
|
|
||||||
|
// Define the kernel
|
||||||
|
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||||
|
Mma,
|
||||||
|
Epilogue,
|
||||||
|
ThreadblockSwizzle,
|
||||||
|
conv::Operator::kFprop,
|
||||||
|
Conv3dProblemSize
|
||||||
|
>;
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/// Defines a kernel for Conv3dFprop specialzation for Optimized IteratorAlgorithm and multistage
|
/// Defines a kernel for Conv3dFprop specialzation for Optimized IteratorAlgorithm and multistage
|
||||||
// pipeline.
|
// pipeline.
|
||||||
template <
|
template <
|
||||||
|
Loading…
Reference in New Issue
Block a user