add 2stage fprop 3d into default file
This commit is contained in:
parent
d97214987a
commit
4839b6cb61
@ -284,6 +284,115 @@ struct DefaultConv3dFprop <
|
|||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Defines a kernel for Conv3dFprop specialization for Optimized Iterator Algorithm
|
||||||
|
/// and 2 stage pipeline.
|
||||||
|
template <
|
||||||
|
typename ElementA,
|
||||||
|
typename LayoutA,
|
||||||
|
typename ElementB,
|
||||||
|
typename LayoutB,
|
||||||
|
typename ElementC,
|
||||||
|
typename LayoutC,
|
||||||
|
typename ElementAccumulator,
|
||||||
|
typename ArchTag,
|
||||||
|
typename ThreadblockShape,
|
||||||
|
typename WarpShape,
|
||||||
|
typename InstructionShape,
|
||||||
|
typename EpilogueOutputOp,
|
||||||
|
typename ThreadblockSwizzle,
|
||||||
|
typename MathOperatorTag
|
||||||
|
>
|
||||||
|
struct DefaultConv3dFprop <
|
||||||
|
ElementA,
|
||||||
|
LayoutA,
|
||||||
|
ElementB,
|
||||||
|
LayoutB,
|
||||||
|
ElementC,
|
||||||
|
LayoutC,
|
||||||
|
ElementAccumulator,
|
||||||
|
arch::OpClassTensorOp,
|
||||||
|
ArchTag,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpShape,
|
||||||
|
InstructionShape,
|
||||||
|
EpilogueOutputOp,
|
||||||
|
ThreadblockSwizzle,
|
||||||
|
2,
|
||||||
|
MathOperatorTag,
|
||||||
|
IteratorAlgorithm::kOptimized
|
||||||
|
> {
|
||||||
|
|
||||||
|
// Define the core components from GEMM
|
||||||
|
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||||
|
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||||
|
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||||
|
2, MathOperatorTag>;
|
||||||
|
|
||||||
|
// Define iterators over tiles from the A operand
|
||||||
|
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||||
|
using IteratorA =
|
||||||
|
cutlass::conv::threadblock::TileIterator<
|
||||||
|
cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized<
|
||||||
|
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||||
|
ElementA,
|
||||||
|
LayoutA,
|
||||||
|
ThreadMapA
|
||||||
|
>
|
||||||
|
>;
|
||||||
|
|
||||||
|
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||||
|
|
||||||
|
// Define iterators over tiles from the B operand
|
||||||
|
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||||
|
using IteratorB =
|
||||||
|
cutlass::conv::threadblock::TileIterator<
|
||||||
|
cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized<
|
||||||
|
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||||
|
ElementB,
|
||||||
|
LayoutB,
|
||||||
|
ThreadMapB
|
||||||
|
>
|
||||||
|
>;
|
||||||
|
|
||||||
|
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||||
|
|
||||||
|
// Warp-level GEMM components
|
||||||
|
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||||
|
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||||
|
|
||||||
|
// Define the Mma
|
||||||
|
using Mma = threadblock::ImplicitGemmPipelined<
|
||||||
|
ThreadblockShape,
|
||||||
|
IteratorA,
|
||||||
|
SmemIteratorA,
|
||||||
|
IteratorB,
|
||||||
|
SmemIteratorB,
|
||||||
|
ElementC,
|
||||||
|
LayoutC,
|
||||||
|
MmaPolicy
|
||||||
|
>;
|
||||||
|
|
||||||
|
// Define the epilogue
|
||||||
|
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||||
|
ArchTag,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpMmaTensorOp,
|
||||||
|
1,
|
||||||
|
EpilogueOutputOp
|
||||||
|
>::Epilogue;
|
||||||
|
|
||||||
|
// Define the kernel
|
||||||
|
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||||
|
Mma,
|
||||||
|
Epilogue,
|
||||||
|
ThreadblockSwizzle,
|
||||||
|
conv::Operator::kFprop,
|
||||||
|
Conv3dProblemSize
|
||||||
|
>;
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/// Defines a kernel for Conv3dFprop specialzation for Optimized IteratorAlgorithm and multistage
|
/// Defines a kernel for Conv3dFprop specialzation for Optimized IteratorAlgorithm and multistage
|
||||||
// pipeline.
|
// pipeline.
|
||||||
template <
|
template <
|
||||||
|
Loading…
Reference in New Issue
Block a user