Add epilogue functor for residual block fusion (#391)
* Add epilogue functor for residual block fusion * Do not run split-k tests when ActivationOp is not Identity * explain TestSplitK param * return early
This commit is contained in:
parent
f78994bb40
commit
c2ee13a0fe
@ -68,8 +68,8 @@ struct ReLu {
|
|||||||
}
|
}
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
T operator()(T value) const {
|
T operator()(T value) const {
|
||||||
if (value < T()) {
|
if (value < T(0)) {
|
||||||
value = T();
|
value = T(0);
|
||||||
}
|
}
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
@ -91,6 +91,21 @@ struct ReLu<Array<T, N>> {
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||||
|
Array<T, N> result;
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
T value = frag[i];
|
||||||
|
if (value < T(0)) {
|
||||||
|
value = T(0);
|
||||||
|
}
|
||||||
|
result[i] = value;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Sigmoid operator
|
// Sigmoid operator
|
||||||
@ -151,7 +166,8 @@ template <typename T>
|
|||||||
struct SiLu {
|
struct SiLu {
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
T operator()(T const &scalar) const {
|
T operator()(T const &scalar) const {
|
||||||
return scalar * Sigmoid<T>(scalar);
|
Sigmoid<T> sigmoid;
|
||||||
|
return scalar * sigmoid(scalar);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -0,0 +1,163 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||||
|
* provided that the following conditions are met:
|
||||||
|
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||||
|
* conditions and the following disclaimer.
|
||||||
|
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||||
|
* conditions and the following disclaimer in the documentation and/or other materials
|
||||||
|
* provided with the distribution.
|
||||||
|
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||||
|
* to endorse or promote products derived from this software without specific prior written
|
||||||
|
* permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||||
|
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||||
|
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||||
|
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||||
|
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
/*! \file
|
||||||
|
\brief Epilogue functor specialized for residual blocks in deep neural network.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/array.h"
|
||||||
|
#include "cutlass/functional.h"
|
||||||
|
#include "cutlass/numeric_conversion.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace epilogue {
|
||||||
|
namespace thread {
|
||||||
|
|
||||||
|
// /// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual))
|
||||||
|
template <typename ElementOutput_, typename ElementAccumulator_,
|
||||||
|
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
|
||||||
|
template <typename T> class ActivationOp_,
|
||||||
|
template <typename T> class BinaryOp_,
|
||||||
|
template <typename T> class UnaryOp_>
|
||||||
|
class LinearCombinationResidualBlock {
|
||||||
|
public:
|
||||||
|
|
||||||
|
using ElementOutput = ElementC_;
|
||||||
|
using ElementC = ElementC_;
|
||||||
|
using ElementAccumulator = ElementAccumulator_;
|
||||||
|
using ElementCompute = ElementCompute_;
|
||||||
|
static int const kElementsPerAccess = ElementsPerAccess;
|
||||||
|
static int const kCount = kElementsPerAccess;
|
||||||
|
|
||||||
|
using UnaryOp = UnaryOp_<Array<ElementCompute, kCount>>;
|
||||||
|
using BinaryOp = BinaryOp_<Array<ElementCompute, kCount>>;
|
||||||
|
using ActivationOp = ActivationOp_<Array<ElementCompute, kCount>>;
|
||||||
|
|
||||||
|
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
||||||
|
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
||||||
|
using FragmentC = Array<ElementC, kElementsPerAccess>;
|
||||||
|
using FragmentOutput = Array<ElementOutput, kElementsPerAccess>;
|
||||||
|
|
||||||
|
using ElementZ = ElementOutput_;
|
||||||
|
using ElementT = ElementZ;
|
||||||
|
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
||||||
|
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
||||||
|
|
||||||
|
static bool const kIsHeavy = true;
|
||||||
|
static bool const kStoreZ = true;
|
||||||
|
static bool const kStoreT = false;
|
||||||
|
|
||||||
|
/// Host-constructable parameters structure
|
||||||
|
struct Params {
|
||||||
|
|
||||||
|
ElementCompute alpha; ///< scales accumulators
|
||||||
|
ElementCompute beta; ///< scales residual input
|
||||||
|
ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||||
|
ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {}
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Params(ElementCompute alpha, ElementCompute beta)
|
||||||
|
: alpha(alpha), beta(beta) {}
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
|
||||||
|
: alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
ElementCompute alpha_;
|
||||||
|
ElementCompute beta_;
|
||||||
|
bool skip_elementwise_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
/// Constructor from Params
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
LinearCombinationResidualBlock(Params const ¶ms) {
|
||||||
|
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||||
|
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||||
|
skip_elementwise_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The "source" tensor corresponds to the residual input
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
bool is_source_needed() const { return true; }
|
||||||
|
|
||||||
|
/// Functionally required for serial reduction in the epilogue
|
||||||
|
/// IMPORTANT: Split-k is supported only when ActivationOp is Identity.
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
void set_k_partition(int k_partition, int k_partition_count) {
|
||||||
|
if (k_partition) {
|
||||||
|
beta_ = ElementCompute(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (k_partition != k_partition_count - 1) {
|
||||||
|
skip_elementwise_ = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies the operation UnaryOp(BinaryOp(ActivationOp(AB + bias), residual))
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
|
||||||
|
FragmentC const &residual,
|
||||||
|
FragmentCompute const &bias) const {
|
||||||
|
UnaryOp unary_op;
|
||||||
|
BinaryOp binary_op;
|
||||||
|
ActivationOp activation;
|
||||||
|
|
||||||
|
FragmentCompute tmp_Accum =
|
||||||
|
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
||||||
|
FragmentCompute tmp_residual =
|
||||||
|
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual);
|
||||||
|
|
||||||
|
FragmentCompute z =
|
||||||
|
binary_op(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual);
|
||||||
|
FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
|
||||||
|
|
||||||
|
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
|
||||||
|
frag_Z = convert_z(result_Z);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Should never be called
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &,
|
||||||
|
FragmentCompute const &) const {}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace thread
|
||||||
|
} // namespace epilogue
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
@ -28,15 +28,16 @@
|
|||||||
|
|
||||||
#include "../../common/cutlass_unit_test.h"
|
#include "../../common/cutlass_unit_test.h"
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/array.h"
|
||||||
#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
|
#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
|
||||||
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
|
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
|
||||||
|
#include "cutlass/epilogue/thread/activation.h"
|
||||||
|
|
||||||
#include "cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h"
|
#include "cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h"
|
||||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||||
|
|
||||||
#include "conv2d_with_broadcast_testbed.h"
|
#include "conv2d_with_broadcast_testbed.h"
|
||||||
|
|
||||||
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||||
|
|
||||||
TEST(SM75_Device_Conv2d_Fprop_With_Broadcast_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32,
|
TEST(SM75_Device_Conv2d_Fprop_With_Broadcast_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32,
|
||||||
@ -83,6 +84,87 @@ TEST(SM75_Device_Conv2d_Fprop_With_Broadcast_Analytic_ImplicitGemm_f16nhwc_f16nh
|
|||||||
EXPECT_TRUE(test::conv::device::TestAllConv2dWithBroadcast<Conv2dFprop>());
|
EXPECT_TRUE(test::conv::device::TestAllConv2dWithBroadcast<Conv2dFprop>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test residual block fusion: UnaryOp(BinaryOp(ActivationOp(Conv2d(X) + bias), residual))
|
||||||
|
// LinearCombinationResidualBlock does not support the split-k mode unless ActivationOp is Identity.
|
||||||
|
// This is because the activation needs to be applied to the fully accumulated output of the Conv2d op,
|
||||||
|
// which only the last thread block would have an access to, before applying BinaryOp.
|
||||||
|
// The epilogue functor in the last thread block would have to be given three inputs, namely
|
||||||
|
// partial outputs, bias, and residual, but this is not supported in the current interface.
|
||||||
|
// Set TestSplitK = false to skip split-k tests with non-trivial ActivationOp.
|
||||||
|
template <
|
||||||
|
typename ElementAccumulator,
|
||||||
|
template<typename T> class ActivationOp,
|
||||||
|
template<typename T> class BinaryOp,
|
||||||
|
template<typename T> class UnaryOp,
|
||||||
|
bool TestSplitK = true
|
||||||
|
>
|
||||||
|
void TestResidaulBlock() {
|
||||||
|
using ElementA = cutlass::half_t;
|
||||||
|
using ElementB = cutlass::half_t;
|
||||||
|
using ElementC = cutlass::half_t;
|
||||||
|
using ElementD = ElementC;
|
||||||
|
using ElementCompute = ElementAccumulator;
|
||||||
|
|
||||||
|
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationResidualBlock<
|
||||||
|
ElementD,
|
||||||
|
ElementAccumulator,
|
||||||
|
ElementCompute,
|
||||||
|
ElementC,
|
||||||
|
8,
|
||||||
|
ActivationOp,
|
||||||
|
BinaryOp,
|
||||||
|
UnaryOp
|
||||||
|
>;
|
||||||
|
|
||||||
|
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithBroadcast<
|
||||||
|
ElementA, cutlass::layout::TensorNHWC,
|
||||||
|
ElementB, cutlass::layout::TensorNHWC,
|
||||||
|
ElementC, cutlass::layout::TensorNHWC,
|
||||||
|
ElementAccumulator,
|
||||||
|
cutlass::arch::OpClassTensorOp,
|
||||||
|
cutlass::arch::Sm75,
|
||||||
|
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||||
|
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||||
|
EpilogueOutputOp,
|
||||||
|
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||||
|
2,
|
||||||
|
cutlass::arch::OpMultiplyAdd,
|
||||||
|
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||||
|
>::Kernel;
|
||||||
|
|
||||||
|
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||||
|
|
||||||
|
struct ReferenceOp {
|
||||||
|
using OutputOp = typename Conv2dFprop::EpilogueOutputOp;
|
||||||
|
using ElementZ = typename OutputOp::ElementZ;
|
||||||
|
|
||||||
|
ActivationOp<ElementCompute> activation;
|
||||||
|
BinaryOp<ElementCompute> binary_op;
|
||||||
|
UnaryOp<ElementCompute> unary_op;
|
||||||
|
|
||||||
|
void operator()(ElementZ &Z, ElementZ&, ElementCompute conv2d, ElementCompute residual) {
|
||||||
|
Z = ElementZ(unary_op(binary_op(activation(conv2d), residual)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
bool passed = test::conv::device::TestAllConv2dWithBroadcast<Conv2dFprop, ReferenceOp, true, TestSplitK>();
|
||||||
|
EXPECT_TRUE(passed);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SM75_Device_Conv2d_Fprop_With_Residual_Block_Plus_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32,
|
||||||
|
128x128_32x2_64x64x32) {
|
||||||
|
// Resnet
|
||||||
|
TestResidaulBlock<cutlass::half_t, cutlass::epilogue::thread::Identity, cutlass::plus, cutlass::epilogue::thread::ReLu>();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SM75_Device_Conv2d_Fprop_With_Residual_Block_Multiply_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32,
|
||||||
|
128x128_32x2_64x64x32) {
|
||||||
|
// EfficientNet V2
|
||||||
|
// Do not run split-K tests since the activation op is not Identity.
|
||||||
|
TestResidaulBlock<float, cutlass::epilogue::thread::Sigmoid, cutlass::multiplies, cutlass::epilogue::thread::Identity, false>();
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED
|
#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED
|
||||||
|
@ -95,7 +95,8 @@ struct Conv2dWithBroadcastReferenceOp {
|
|||||||
|
|
||||||
template <
|
template <
|
||||||
typename Conv2d,
|
typename Conv2d,
|
||||||
typename ReferenceOp = Conv2dWithBroadcastReferenceOp<Conv2d>
|
typename ReferenceOp,
|
||||||
|
bool AddBroadcastFirst = false
|
||||||
>
|
>
|
||||||
class TestbedConv2dWithBroadcast {
|
class TestbedConv2dWithBroadcast {
|
||||||
public:
|
public:
|
||||||
@ -113,7 +114,8 @@ public:
|
|||||||
using ElementT = typename EpilogueOutputOp::ElementT;
|
using ElementT = typename EpilogueOutputOp::ElementT;
|
||||||
|
|
||||||
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
|
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
|
||||||
|
static const bool kAddBroadcastFirst = AddBroadcastFirst;
|
||||||
|
static const bool kStoreT = EpilogueOutputOp::kStoreT;
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/// Initialization
|
/// Initialization
|
||||||
@ -270,7 +272,7 @@ public:
|
|||||||
cutlass::conv::Conv2dProblemSize const &problem_size,
|
cutlass::conv::Conv2dProblemSize const &problem_size,
|
||||||
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
||||||
ElementCompute alpha = ElementCompute(1),
|
ElementCompute alpha = ElementCompute(1),
|
||||||
ElementCompute beta = ElementCompute(0)) {
|
ElementCompute beta = ElementCompute(1)) {
|
||||||
|
|
||||||
// Waive test if insufficient CUDA device
|
// Waive test if insufficient CUDA device
|
||||||
if (!sufficient()) {
|
if (!sufficient()) {
|
||||||
@ -300,7 +302,7 @@ public:
|
|||||||
{alpha, beta},
|
{alpha, beta},
|
||||||
split_k_mode,
|
split_k_mode,
|
||||||
tensor_Broadcast.device_data(),
|
tensor_Broadcast.device_data(),
|
||||||
tensor_T_computed.device_data(),
|
kStoreT ? tensor_T_computed.device_data() : nullptr,
|
||||||
0, // This must be zero
|
0, // This must be zero
|
||||||
implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c()
|
implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c()
|
||||||
);
|
);
|
||||||
@ -338,7 +340,8 @@ public:
|
|||||||
//
|
//
|
||||||
// Reference check
|
// Reference check
|
||||||
//
|
//
|
||||||
|
// When kAddBroadcastFirst is true, add bias on the host
|
||||||
|
ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta;
|
||||||
#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
|
#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
|
||||||
|
|
||||||
cutlass::reference::device::Conv2d<
|
cutlass::reference::device::Conv2d<
|
||||||
@ -358,7 +361,7 @@ public:
|
|||||||
tensor_C_reference.device_ref(),
|
tensor_C_reference.device_ref(),
|
||||||
tensor_Y_reference.device_ref(),
|
tensor_Y_reference.device_ref(),
|
||||||
alpha,
|
alpha,
|
||||||
beta);
|
beta_ref);
|
||||||
|
|
||||||
// sync host (copy device data to host) for dumping error output in case of mismatches
|
// sync host (copy device data to host) for dumping error output in case of mismatches
|
||||||
tensor_Y_reference.sync_host();
|
tensor_Y_reference.sync_host();
|
||||||
@ -382,7 +385,7 @@ public:
|
|||||||
tensor_C_reference.host_ref(),
|
tensor_C_reference.host_ref(),
|
||||||
tensor_Y_reference.host_ref(),
|
tensor_Y_reference.host_ref(),
|
||||||
alpha,
|
alpha,
|
||||||
beta);
|
beta_ref);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
ReferenceOp reference_op;
|
ReferenceOp reference_op;
|
||||||
@ -395,9 +398,16 @@ public:
|
|||||||
|
|
||||||
ElementZ z;
|
ElementZ z;
|
||||||
ElementT t;
|
ElementT t;
|
||||||
|
ElementCompute accum = tensor_Y_reference.at({n, p, q, k});
|
||||||
reference_op(z, t, tensor_Y_reference.at({n, p, q, k}), tensor_Broadcast.at({0, 0, 0, k}));
|
ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, k}));
|
||||||
|
|
||||||
|
if (kAddBroadcastFirst) {
|
||||||
|
reference_op(z, t, accum + bias,
|
||||||
|
beta * ElementCompute(tensor_C_reference.at({n, p, q, k})));
|
||||||
|
} else {
|
||||||
|
reference_op(z, t, accum, bias);
|
||||||
|
}
|
||||||
|
|
||||||
tensor_Z_reference.at({n, p, q, k}) = z;
|
tensor_Z_reference.at({n, p, q, k}) = z;
|
||||||
tensor_T_reference.at({n, p, q, k}) = t;
|
tensor_T_reference.at({n, p, q, k}) = t;
|
||||||
}
|
}
|
||||||
@ -405,11 +415,11 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
passed = cutlass::reference::host::TensorEquals(
|
if (kStoreT) {
|
||||||
tensor_T_computed.host_view(),
|
passed = cutlass::reference::host::TensorEquals(
|
||||||
tensor_T_reference.host_view());
|
tensor_T_computed.host_view(), tensor_T_reference.host_view());
|
||||||
|
EXPECT_TRUE(passed);
|
||||||
EXPECT_TRUE(passed);
|
}
|
||||||
|
|
||||||
passed = cutlass::reference::host::TensorEquals(
|
passed = cutlass::reference::host::TensorEquals(
|
||||||
tensor_Z_computed.host_view(),
|
tensor_Z_computed.host_view(),
|
||||||
@ -479,10 +489,13 @@ public:
|
|||||||
// Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
|
// Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
|
||||||
// (conv_blacklist_sizes)
|
// (conv_blacklist_sizes)
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
template <typename ImplicitGemm>
|
template <typename ImplicitGemm,
|
||||||
|
typename ReferenceOp = Conv2dWithBroadcastReferenceOp<ImplicitGemm>,
|
||||||
|
bool AddBroadcastFirst = false,
|
||||||
|
bool TestSplitK = true>
|
||||||
bool TestAllConv2dWithBroadcast(
|
bool TestAllConv2dWithBroadcast(
|
||||||
const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(),
|
const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector(),
|
||||||
const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) {
|
const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector()) {
|
||||||
|
|
||||||
bool passed = true;
|
bool passed = true;
|
||||||
|
|
||||||
@ -490,7 +503,7 @@ bool TestAllConv2dWithBroadcast(
|
|||||||
// Testbed object
|
// Testbed object
|
||||||
//
|
//
|
||||||
|
|
||||||
TestbedConv2dWithBroadcast<ImplicitGemm> testbed;
|
TestbedConv2dWithBroadcast<ImplicitGemm, ReferenceOp, AddBroadcastFirst> testbed;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Get conv problem sizes to run conv operator
|
// Get conv problem sizes to run conv operator
|
||||||
@ -597,6 +610,9 @@ bool TestAllConv2dWithBroadcast(
|
|||||||
return passed;
|
return passed;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!TestSplitK)
|
||||||
|
return passed;
|
||||||
|
|
||||||
// Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
|
// Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
|
||||||
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
||||||
// which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep
|
// which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep
|
||||||
|
Loading…
Reference in New Issue
Block a user