From c2ee13a0fe99241b0e798ce647acf98e237f1d0c Mon Sep 17 00:00:00 2001 From: masahi Date: Thu, 30 Dec 2021 12:53:40 +0900 Subject: [PATCH] Add epilogue functor for residual block fusion (#391) * Add epilogue functor for residual block fusion * Do not run split-k tests when ActivationOp is not Identity * explain TestSplitK param * return early --- include/cutlass/epilogue/thread/activation.h | 22 ++- .../linear_combination_residual_block.h | 163 ++++++++++++++++++ .../conv2d_fprop_with_broadcast_sm75.cu | 88 +++++++++- .../device/conv2d_with_broadcast_testbed.h | 54 ++++-- 4 files changed, 302 insertions(+), 25 deletions(-) create mode 100644 include/cutlass/epilogue/thread/linear_combination_residual_block.h diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 38597ccb..96a0ee40 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -68,8 +68,8 @@ struct ReLu { } CUTLASS_HOST_DEVICE T operator()(T value) const { - if (value < T()) { - value = T(); + if (value < T(0)) { + value = T(0); } return value; } @@ -91,6 +91,21 @@ struct ReLu> { } return result; } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &frag) const { + Array result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + T value = frag[i]; + if (value < T(0)) { + value = T(0); + } + result[i] = value; + } + return result; + } + }; // Sigmoid operator @@ -151,7 +166,8 @@ template struct SiLu { CUTLASS_HOST_DEVICE T operator()(T const &scalar) const { - return scalar * Sigmoid(scalar); + Sigmoid sigmoid; + return scalar * sigmoid(scalar); } }; diff --git a/include/cutlass/epilogue/thread/linear_combination_residual_block.h b/include/cutlass/epilogue/thread/linear_combination_residual_block.h new file mode 100644 index 00000000..263bdcc2 --- /dev/null +++ b/include/cutlass/epilogue/thread/linear_combination_residual_block.h @@ -0,0 +1,163 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Epilogue functor specialized for residual blocks in deep neural network. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +// /// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual)) +template class ActivationOp_, + template class BinaryOp_, + template class UnaryOp_> +class LinearCombinationResidualBlock { +public: + + using ElementOutput = ElementC_; + using ElementC = ElementC_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + + using UnaryOp = UnaryOp_>; + using BinaryOp = BinaryOp_>; + using ActivationOp = ActivationOp_>; + + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentC = Array; + using FragmentOutput = Array; + + using ElementZ = ElementOutput_; + using ElementT = ElementZ; + using FragmentZ = Array; + using FragmentT = Array; + + static bool const kIsHeavy = true; + static bool const kStoreZ = true; + static bool const kStoreT = false; + + /// Host-constructable parameters structure + struct Params { + + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales residual input + ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory + ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory + + CUTLASS_HOST_DEVICE + Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute alpha, ElementCompute beta) + : alpha(alpha), beta(beta) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr) + : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {} + }; + +private: + + ElementCompute alpha_; + ElementCompute beta_; + bool skip_elementwise_; + +public: + + /// Constructor from Params + CUTLASS_HOST_DEVICE + LinearCombinationResidualBlock(Params const ¶ms) { + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + skip_elementwise_ = false; + } + + /// The "source" tensor corresponds to the residual input + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return true; } + + /// Functionally required for serial reduction in the epilogue + /// IMPORTANT: Split-k is supported only when ActivationOp is Identity. + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + if (k_partition) { + beta_ = ElementCompute(1); + } + + if (k_partition != k_partition_count - 1) { + skip_elementwise_ = true; + } + } + + /// Applies the operation UnaryOp(BinaryOp(ActivationOp(AB + bias), residual)) + CUTLASS_HOST_DEVICE + void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB, + FragmentC const &residual, + FragmentCompute const &bias) const { + UnaryOp unary_op; + BinaryOp binary_op; + ActivationOp activation; + + FragmentCompute tmp_Accum = + NumericArrayConverter()(AB); + FragmentCompute tmp_residual = + NumericArrayConverter()(residual); + + FragmentCompute z = + binary_op(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual); + FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z); + + NumericArrayConverter convert_z; + frag_Z = convert_z(result_Z); + } + + /// Should never be called + CUTLASS_HOST_DEVICE + void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &, + FragmentCompute const &) const {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu b/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu index 64710edb..2d3505c6 100644 --- a/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu +++ b/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu @@ -28,15 +28,16 @@ #include "../../common/cutlass_unit_test.h" #include "cutlass/cutlass.h" - +#include "cutlass/array.h" #include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" -#include "cutlass/epilogue/thread/linear_combination_bias_relu.h" +#include "cutlass/epilogue/thread/linear_combination_residual_block.h" +#include "cutlass/epilogue/thread/activation.h" #include "cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h" #include "cutlass/conv/device/implicit_gemm_convolution.h" #include "conv2d_with_broadcast_testbed.h" - + #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) TEST(SM75_Device_Conv2d_Fprop_With_Broadcast_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, @@ -83,6 +84,87 @@ TEST(SM75_Device_Conv2d_Fprop_With_Broadcast_Analytic_ImplicitGemm_f16nhwc_f16nh EXPECT_TRUE(test::conv::device::TestAllConv2dWithBroadcast()); } +// Test residual block fusion: UnaryOp(BinaryOp(ActivationOp(Conv2d(X) + bias), residual)) +// LinearCombinationResidualBlock does not support the split-k mode unless ActivationOp is Identity. +// This is because the activation needs to be applied to the fully accumulated output of the Conv2d op, +// which only the last thread block would have an access to, before applying BinaryOp. +// The epilogue functor in the last thread block would have to be given three inputs, namely +// partial outputs, bias, and residual, but this is not supported in the current interface. +// Set TestSplitK = false to skip split-k tests with non-trivial ActivationOp. +template < + typename ElementAccumulator, + template class ActivationOp, + template class BinaryOp, + template class UnaryOp, + bool TestSplitK = true +> +void TestResidaulBlock() { + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementD = ElementC; + using ElementCompute = ElementAccumulator; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationResidualBlock< + ElementD, + ElementAccumulator, + ElementCompute, + ElementC, + 8, + ActivationOp, + BinaryOp, + UnaryOp + >; + + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithBroadcast< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + struct ReferenceOp { + using OutputOp = typename Conv2dFprop::EpilogueOutputOp; + using ElementZ = typename OutputOp::ElementZ; + + ActivationOp activation; + BinaryOp binary_op; + UnaryOp unary_op; + + void operator()(ElementZ &Z, ElementZ&, ElementCompute conv2d, ElementCompute residual) { + Z = ElementZ(unary_op(binary_op(activation(conv2d), residual))); + } + }; + + bool passed = test::conv::device::TestAllConv2dWithBroadcast(); + EXPECT_TRUE(passed); +} + +TEST(SM75_Device_Conv2d_Fprop_With_Residual_Block_Plus_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, + 128x128_32x2_64x64x32) { + // Resnet + TestResidaulBlock(); +} + +TEST(SM75_Device_Conv2d_Fprop_With_Residual_Block_Multiply_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, + 128x128_32x2_64x64x32) { + // EfficientNet V2 + // Do not run split-K tests since the activation op is not Identity. + TestResidaulBlock(); +} + //////////////////////////////////////////////////////////////////////////////// #endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED diff --git a/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/test/unit/conv/device/conv2d_with_broadcast_testbed.h index bd9596a7..eee8c162 100644 --- a/test/unit/conv/device/conv2d_with_broadcast_testbed.h +++ b/test/unit/conv/device/conv2d_with_broadcast_testbed.h @@ -95,7 +95,8 @@ struct Conv2dWithBroadcastReferenceOp { template < typename Conv2d, - typename ReferenceOp = Conv2dWithBroadcastReferenceOp + typename ReferenceOp, + bool AddBroadcastFirst = false > class TestbedConv2dWithBroadcast { public: @@ -113,7 +114,8 @@ public: using ElementT = typename EpilogueOutputOp::ElementT; static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; - + static const bool kAddBroadcastFirst = AddBroadcastFirst; + static const bool kStoreT = EpilogueOutputOp::kStoreT; public: /// Initialization @@ -270,7 +272,7 @@ public: cutlass::conv::Conv2dProblemSize const &problem_size, cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { + ElementCompute beta = ElementCompute(1)) { // Waive test if insufficient CUDA device if (!sufficient()) { @@ -300,7 +302,7 @@ public: {alpha, beta}, split_k_mode, tensor_Broadcast.device_data(), - tensor_T_computed.device_data(), + kStoreT ? tensor_T_computed.device_data() : nullptr, 0, // This must be zero implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c() ); @@ -338,7 +340,8 @@ public: // // Reference check // - + // When kAddBroadcastFirst is true, add bias on the host + ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta; #if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED cutlass::reference::device::Conv2d< @@ -358,7 +361,7 @@ public: tensor_C_reference.device_ref(), tensor_Y_reference.device_ref(), alpha, - beta); + beta_ref); // sync host (copy device data to host) for dumping error output in case of mismatches tensor_Y_reference.sync_host(); @@ -382,7 +385,7 @@ public: tensor_C_reference.host_ref(), tensor_Y_reference.host_ref(), alpha, - beta); + beta_ref); #endif ReferenceOp reference_op; @@ -395,9 +398,16 @@ public: ElementZ z; ElementT t; - - reference_op(z, t, tensor_Y_reference.at({n, p, q, k}), tensor_Broadcast.at({0, 0, 0, k})); - + ElementCompute accum = tensor_Y_reference.at({n, p, q, k}); + ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, k})); + + if (kAddBroadcastFirst) { + reference_op(z, t, accum + bias, + beta * ElementCompute(tensor_C_reference.at({n, p, q, k}))); + } else { + reference_op(z, t, accum, bias); + } + tensor_Z_reference.at({n, p, q, k}) = z; tensor_T_reference.at({n, p, q, k}) = t; } @@ -405,11 +415,11 @@ public: } } - passed = cutlass::reference::host::TensorEquals( - tensor_T_computed.host_view(), - tensor_T_reference.host_view()); - - EXPECT_TRUE(passed); + if (kStoreT) { + passed = cutlass::reference::host::TensorEquals( + tensor_T_computed.host_view(), tensor_T_reference.host_view()); + EXPECT_TRUE(passed); + } passed = cutlass::reference::host::TensorEquals( tensor_Z_computed.host_view(), @@ -479,10 +489,13 @@ public: // Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes // (conv_blacklist_sizes) ///////////////////////////////////////////////////////////////////////////////////////////////////////////// -template +template , + bool AddBroadcastFirst = false, + bool TestSplitK = true> bool TestAllConv2dWithBroadcast( - const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), - const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { + const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector()) { bool passed = true; @@ -490,7 +503,7 @@ bool TestAllConv2dWithBroadcast( // Testbed object // - TestbedConv2dWithBroadcast testbed; + TestbedConv2dWithBroadcast testbed; // // Get conv problem sizes to run conv operator @@ -597,6 +610,9 @@ bool TestAllConv2dWithBroadcast( return passed; } + if (!TestSplitK) + return passed; + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters // which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep