Fix type bug in conv2d/gemm with broadcast (#796)

add ElementVector

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Shuai Shao 2023-02-09 17:53:25 -08:00 committed by GitHub
parent 2e10404d26
commit ce8597dc14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 21 additions and 11 deletions

View File

@ -107,7 +107,7 @@ struct DefaultConv2dFpropWithBroadcast {
ImplicitGemmBase::Epilogue::kPartitionsK, ImplicitGemmBase::Epilogue::kPartitionsK,
ElementC, ElementC,
typename EpilogueOutputOp::ElementT, typename EpilogueOutputOp::ElementT,
ElementC, typename EpilogueOutputOp::ElementVector,
EpilogueOutputOp, EpilogueOutputOp,
ImplicitGemmBase::Epilogue::kElementsPerAccess ImplicitGemmBase::Epilogue::kElementsPerAccess
>::Epilogue; >::Epilogue;

View File

@ -61,7 +61,8 @@ template <
typename ElementT_, typename ElementT_,
int ElementsPerAccess, int ElementsPerAccess,
typename ElementwiseOp_ = Identity<ElementCompute_>, typename ElementwiseOp_ = Identity<ElementCompute_>,
typename BinaryOp_ = plus<ElementCompute_> typename BinaryOp_ = plus<ElementCompute_>,
typename ElementVector_ = ElementC_
> >
class LinearCombinationBiasElementwise { class LinearCombinationBiasElementwise {
public: public:
@ -72,6 +73,7 @@ public:
using ElementCompute = ElementCompute_; using ElementCompute = ElementCompute_;
using ElementZ = ElementZ_; using ElementZ = ElementZ_;
using ElementT = ElementT_; using ElementT = ElementT_;
using ElementVector = ElementVector_;
static int const kElementsPerAccess = ElementsPerAccess; static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess; static int const kCount = kElementsPerAccess;

View File

@ -204,7 +204,8 @@ template <
typename ElementCompute_, typename ElementCompute_,
typename ElementZ_, typename ElementZ_,
int ElementsPerAccess, int ElementsPerAccess,
bool StoreT = true bool StoreT = true,
typename ElementVector_ = ElementC_
> >
class LinearCombinationBiasRelu { class LinearCombinationBiasRelu {
public: public:
@ -214,6 +215,7 @@ public:
using ElementAccumulator = ElementAccumulator_; using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_; using ElementCompute = ElementCompute_;
using ElementZ = ElementZ_; using ElementZ = ElementZ_;
using ElementVector = ElementVector_;
using ElementT = uint1b_t; using ElementT = uint1b_t;

View File

@ -59,7 +59,8 @@ template <typename ElementOutput_, typename ElementAccumulator_,
template <typename T> class ActivationOp_, template <typename T> class ActivationOp_,
template <typename T> class BinaryOp1_, template <typename T> class BinaryOp1_,
template <typename T> class UnaryOp_, template <typename T> class UnaryOp_,
template <typename T> class BinaryOp2_ = detail::NoOp> template <typename T> class BinaryOp2_ = detail::NoOp,
typename ElementVector_ = ElementC_>
class LinearCombinationResidualBlock { class LinearCombinationResidualBlock {
public: public:
static bool const kIsSingleSource = false; static bool const kIsSingleSource = false;
@ -68,6 +69,7 @@ public:
using ElementC = ElementC_; using ElementC = ElementC_;
using ElementAccumulator = ElementAccumulator_; using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_; using ElementCompute = ElementCompute_;
using ElementVector = ElementVector_;
static int const kElementsPerAccess = ElementsPerAccess; static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess; static int const kCount = kElementsPerAccess;
@ -179,11 +181,12 @@ template <typename ElementOutput_, typename ElementAccumulator_,
typename ElementCompute_, typename ElementC_, int ElementsPerAccess, typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
template <typename T> class ActivationOp_, template <typename T> class ActivationOp_,
template <typename T> class BinaryOp1_, template <typename T> class BinaryOp1_,
template <typename T> class UnaryOp_> template <typename T> class UnaryOp_,
typename ElementVector_>
class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_, class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,
ElementCompute_, ElementC_, ElementsPerAccess, ElementCompute_, ElementC_, ElementsPerAccess,
ActivationOp_, BinaryOp1_, UnaryOp_, ActivationOp_, BinaryOp1_, UnaryOp_,
detail::NoOp> { detail::NoOp, ElementVector_> {
public: public:
static bool const kIsSingleSource = true; static bool const kIsSingleSource = true;
@ -191,6 +194,7 @@ public:
using ElementC = ElementC_; using ElementC = ElementC_;
using ElementAccumulator = ElementAccumulator_; using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_; using ElementCompute = ElementCompute_;
using ElementVector = ElementVector_;
static int const kElementsPerAccess = ElementsPerAccess; static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess; static int const kCount = kElementsPerAccess;

View File

@ -121,7 +121,7 @@ struct DefaultGemmWithBroadcast {
GemmBase::Epilogue::kPartitionsK, GemmBase::Epilogue::kPartitionsK,
ElementC_, ElementC_,
typename EpilogueOutputOp::ElementT, typename EpilogueOutputOp::ElementT,
ElementC_, typename EpilogueOutputOp::ElementVector,
EpilogueOutputOp, EpilogueOutputOp,
GemmBase::Epilogue::kElementsPerAccess GemmBase::Epilogue::kElementsPerAccess
>::Epilogue; >::Epilogue;
@ -221,7 +221,7 @@ struct DefaultGemmWithBroadcast<
GemmBase::Epilogue::kPartitionsK, GemmBase::Epilogue::kPartitionsK,
ElementC_, ElementC_,
typename EpilogueOutputOp::ElementT, typename EpilogueOutputOp::ElementT,
ElementC_, typename EpilogueOutputOp::ElementVector,
EpilogueOutputOp, EpilogueOutputOp,
GemmBase::Epilogue::kElementsPerAccess GemmBase::Epilogue::kElementsPerAccess
>::Epilogue; >::Epilogue;

View File

@ -120,6 +120,7 @@ public:
using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
using ElementZ = typename EpilogueOutputOp::ElementZ; using ElementZ = typename EpilogueOutputOp::ElementZ;
using ElementT = typename EpilogueOutputOp::ElementT; using ElementT = typename EpilogueOutputOp::ElementT;
using ElementVector = typename EpilogueOutputOp::ElementVector;
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
static const bool kAddBroadcastFirst = AddBroadcastFirst; static const bool kAddBroadcastFirst = AddBroadcastFirst;
@ -142,7 +143,7 @@ public:
cutlass::HostTensor<ElementT, LayoutC> tensor_T_computed; cutlass::HostTensor<ElementT, LayoutC> tensor_T_computed;
cutlass::HostTensor<ElementT, LayoutC> tensor_T_reference; cutlass::HostTensor<ElementT, LayoutC> tensor_T_reference;
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Y_reference; cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Y_reference;
cutlass::HostTensor<ElementC, LayoutC> tensor_Broadcast; // Input Broadcast cutlass::HostTensor<ElementVector, LayoutC> tensor_Broadcast; // Input Broadcast
public: public:

View File

@ -105,7 +105,8 @@ struct TestbedGemmWithBroadcast {
using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp;
using ElementC = typename Gemm::ElementC; using ElementC = typename Gemm::ElementC;
using ElementAccumulator = typename Gemm::ElementAccumulator; using ElementAccumulator = typename Gemm::ElementAccumulator;
using ElementCOmpute = typename OutputOp::ElementCompute; using ElementCompute = typename OutputOp::ElementCompute;
using ElementVector = typename OutputOp::ElementVector;
using ElementZ = typename OutputOp::ElementZ; using ElementZ = typename OutputOp::ElementZ;
using ElementT = typename OutputOp::ElementT; using ElementT = typename OutputOp::ElementT;
@ -118,7 +119,7 @@ struct TestbedGemmWithBroadcast {
cutlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A; // Input A cutlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A; // Input A
cutlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B; // Input B cutlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B; // Input B
cutlass::HostTensor<ElementC, typename Gemm::LayoutC> tensor_C; // Input C cutlass::HostTensor<ElementC, typename Gemm::LayoutC> tensor_C; // Input C
cutlass::HostTensor<ElementC, typename Gemm::LayoutC> tensor_Broadcast; // Input Broadcast cutlass::HostTensor<ElementVector, typename Gemm::LayoutC> tensor_Broadcast; // Input Broadcast
cutlass::HostTensor<ElementZ, typename Gemm::LayoutC> tensor_Z; cutlass::HostTensor<ElementZ, typename Gemm::LayoutC> tensor_Z;
cutlass::HostTensor<ElementT, typename Gemm::LayoutC> tensor_T; cutlass::HostTensor<ElementT, typename Gemm::LayoutC> tensor_T;