diff --git a/include/cutlass/conv/kernel/default_conv3d_fprop.h b/include/cutlass/conv/kernel/default_conv3d_fprop.h index 56604588..c928de33 100644 --- a/include/cutlass/conv/kernel/default_conv3d_fprop.h +++ b/include/cutlass/conv/kernel/default_conv3d_fprop.h @@ -284,6 +284,115 @@ struct DefaultConv3dFprop < ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dFprop specialization for Optimized Iterator Algorithm +/// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Defines a kernel for Conv3dFprop specialzation for Optimized IteratorAlgorithm and multistage // pipeline. template <