diff --git a/include/cutlass/conv/kernel/default_conv2d.h b/include/cutlass/conv/kernel/default_conv2d.h index 5605162c..5a83586e 100644 --- a/include/cutlass/conv/kernel/default_conv2d.h +++ b/include/cutlass/conv/kernel/default_conv2d.h @@ -37,6 +37,10 @@ #include "cutlass/epilogue/threadblock/default_epilogue_simt.h" #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" + #include "cutlass/conv/convolution.h" #include "cutlass/conv/threadblock/conv2d_tile_iterator.h" #include "cutlass/conv/threadblock/implicit_gemm_pipelined.h" @@ -96,6 +100,122 @@ struct DefaultConvEpilogue< OutputOp::kCount >::Epilogue; }; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ArchTag, + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithBroadcastTensorOp { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + ElementTensor, + ElementVector, + OutputOp, + ElementsPerAccess + >::Epilogue; +}; + +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithBroadcastTensorOp< + arch::Sm70, + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + ElementTensor, + ElementVector, + OutputOp, + ElementsPerAccess + > { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + ElementTensor, + ElementVector, + OutputOp, + ElementsPerAccess + >::Epilogue; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ArchTag, + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename OutputOp, + typename ReductionOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithReductionTensorOp { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + ElementsPerAccess + >::Epilogue; +}; + +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename OutputOp, + typename ReductionOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithReductionTensorOp< + arch::Sm70, + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + ElementsPerAccess + > { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithReductionVoltaTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + ElementsPerAccess + >::Epilogue; +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// // Defaults for strided Dgrad diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h index dfbc98e7..52c403a5 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h @@ -94,7 +94,8 @@ struct DefaultConv2dFpropWithBroadcast { >::Kernel; // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOp< + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< + ArchTag, typename ImplicitGemmBase::Epilogue::Shape, typename ImplicitGemmBase::Epilogue::WarpMmaOperator, ImplicitGemmBase::Epilogue::kPartitionsK, diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h index 24553a6d..b092b1a3 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h @@ -95,7 +95,8 @@ struct DefaultConv2dFpropWithReduction { >::Kernel; // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithReductionTensorOp< + ArchTag, typename ImplicitGemmBase::Epilogue::Shape, typename ImplicitGemmBase::Epilogue::WarpMmaOperator, ImplicitGemmBase::Epilogue::kPartitionsK, diff --git a/test/unit/conv/device/CMakeLists.txt b/test/unit/conv/device/CMakeLists.txt index a59b1873..13805670 100644 --- a/test/unit/conv/device/CMakeLists.txt +++ b/test/unit/conv/device/CMakeLists.txt @@ -129,6 +129,7 @@ endif() cutlass_test_unit_add_executable( cutlass_test_unit_conv_device_tensorop_f32_sm70 + conv2d_fprop_with_broadcast_sm70.cu conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu index 4f347f58..e2e4ece9 100644 --- a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu @@ -76,6 +76,49 @@ TEST(SM70_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tens EXPECT_TRUE(test::conv::device::TestAllConv2d()); } + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM70_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, + 128x128_32x2_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); +} + + //////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// #endif // CUTLASS_ARCH_MMA_SM70_SUPPORTED diff --git a/test/unit/conv/device/conv2d_fprop_with_broadcast_sm70.cu b/test/unit/conv/device/conv2d_fprop_with_broadcast_sm70.cu new file mode 100644 index 00000000..511b331c --- /dev/null +++ b/test/unit/conv/device/conv2d_fprop_with_broadcast_sm70.cu @@ -0,0 +1,121 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/epilogue/thread/linear_combination_residual_block.h" +#include "cutlass/epilogue/thread/activation.h" + +#include "cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv2d_with_broadcast_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) + +// Test residual block fusion: UnaryOp(BinaryOp(ActivationOp(Conv2d(X) + bias), residual)) +// LinearCombinationResidualBlock does not support the split-k mode unless ActivationOp is Identity. +// This is because the activation needs to be applied to the fully accumulated output of the Conv2d op, +// which only the last thread block would have an access to, before applying BinaryOp. +// The epilogue functor in the last thread block would have to be given three inputs, namely +// partial outputs, bias, and residual, but this is not supported in the current interface. +// Set TestSplitK = false to skip split-k tests with non-trivial ActivationOp. +template < + typename ElementAccumulator, + template class ActivationOp, + template class BinaryOp, + template class UnaryOp, + bool TestSplitK = false +> +void TestResidaulBlock() { + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementD = ElementC; + using ElementCompute = ElementAccumulator; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationResidualBlock< + ElementD, + ElementAccumulator, + ElementCompute, + ElementC, + 8, + ActivationOp, + BinaryOp, + UnaryOp + >; + + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithBroadcast< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + struct ReferenceOp { + using OutputOp = typename Conv2dFprop::EpilogueOutputOp; + using ElementZ = typename OutputOp::ElementZ; + + ActivationOp activation; + BinaryOp binary_op; + UnaryOp unary_op; + + void operator()(ElementZ &Z, ElementZ&, ElementCompute conv2d, ElementCompute residual) { + Z = ElementZ(unary_op(binary_op(activation(conv2d), residual))); + } + }; + + bool passed = test::conv::device::TestAllConv2dWithBroadcast(); + EXPECT_TRUE(passed); +} + +TEST(SM70_Device_Conv2d_Fprop_With_Residual_Block_Plus_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, + 128x128_32x2_64x64x32) { + // Resnet + TestResidaulBlock(); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_SM70_SUPPORTED + +////////////////////////////////////////////////////////////////////////////////