add 2stage fprop 3d into default file

This commit is contained in:
Manikandan Ananth 2021-04-07 13:29:32 -07:00
parent d97214987a
commit 4839b6cb61

View File

@ -284,6 +284,115 @@ struct DefaultConv3dFprop <
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv3dFprop specialization for Optimized Iterator Algorithm
/// and 2 stage pipeline.
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag
>
struct DefaultConv3dFprop <
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
2,
MathOperatorTag,
IteratorAlgorithm::kOptimized
> {
// Define the core components from GEMM
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
2, MathOperatorTag>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using IteratorA =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
LayoutA,
ThreadMapA
>
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using IteratorB =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
LayoutB,
ThreadMapB
>
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
// Warp-level GEMM components
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
using MmaPolicy = typename MmaCore::MmaPolicy;
// Define the Mma
using Mma = threadblock::ImplicitGemmPipelined<
ThreadblockShape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
MmaPolicy
>;
// Define the epilogue
using Epilogue = typename detail::DefaultConvEpilogue<
ArchTag,
ThreadblockShape,
WarpMmaTensorOp,
1,
EpilogueOutputOp
>::Epilogue;
// Define the kernel
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
Mma,
Epilogue,
ThreadblockSwizzle,
conv::Operator::kFprop,
Conv3dProblemSize
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv3dFprop specialzation for Optimized IteratorAlgorithm and multistage
// pipeline.
template <