Fix type bug in conv2d/gemm with broadcast (#796)
add ElementVector --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
2e10404d26
commit
ce8597dc14
@ -107,7 +107,7 @@ struct DefaultConv2dFpropWithBroadcast {
|
||||
ImplicitGemmBase::Epilogue::kPartitionsK,
|
||||
ElementC,
|
||||
typename EpilogueOutputOp::ElementT,
|
||||
ElementC,
|
||||
typename EpilogueOutputOp::ElementVector,
|
||||
EpilogueOutputOp,
|
||||
ImplicitGemmBase::Epilogue::kElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
@ -61,7 +61,8 @@ template <
|
||||
typename ElementT_,
|
||||
int ElementsPerAccess,
|
||||
typename ElementwiseOp_ = Identity<ElementCompute_>,
|
||||
typename BinaryOp_ = plus<ElementCompute_>
|
||||
typename BinaryOp_ = plus<ElementCompute_>,
|
||||
typename ElementVector_ = ElementC_
|
||||
>
|
||||
class LinearCombinationBiasElementwise {
|
||||
public:
|
||||
@ -72,6 +73,7 @@ public:
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementZ = ElementZ_;
|
||||
using ElementT = ElementT_;
|
||||
using ElementVector = ElementVector_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kCount = kElementsPerAccess;
|
||||
|
||||
|
@ -204,7 +204,8 @@ template <
|
||||
typename ElementCompute_,
|
||||
typename ElementZ_,
|
||||
int ElementsPerAccess,
|
||||
bool StoreT = true
|
||||
bool StoreT = true,
|
||||
typename ElementVector_ = ElementC_
|
||||
>
|
||||
class LinearCombinationBiasRelu {
|
||||
public:
|
||||
@ -214,6 +215,7 @@ public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementZ = ElementZ_;
|
||||
using ElementVector = ElementVector_;
|
||||
|
||||
using ElementT = uint1b_t;
|
||||
|
||||
|
@ -59,7 +59,8 @@ template <typename ElementOutput_, typename ElementAccumulator_,
|
||||
template <typename T> class ActivationOp_,
|
||||
template <typename T> class BinaryOp1_,
|
||||
template <typename T> class UnaryOp_,
|
||||
template <typename T> class BinaryOp2_ = detail::NoOp>
|
||||
template <typename T> class BinaryOp2_ = detail::NoOp,
|
||||
typename ElementVector_ = ElementC_>
|
||||
class LinearCombinationResidualBlock {
|
||||
public:
|
||||
static bool const kIsSingleSource = false;
|
||||
@ -68,6 +69,7 @@ public:
|
||||
using ElementC = ElementC_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementVector = ElementVector_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kCount = kElementsPerAccess;
|
||||
|
||||
@ -179,11 +181,12 @@ template <typename ElementOutput_, typename ElementAccumulator_,
|
||||
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
|
||||
template <typename T> class ActivationOp_,
|
||||
template <typename T> class BinaryOp1_,
|
||||
template <typename T> class UnaryOp_>
|
||||
template <typename T> class UnaryOp_,
|
||||
typename ElementVector_>
|
||||
class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,
|
||||
ElementCompute_, ElementC_, ElementsPerAccess,
|
||||
ActivationOp_, BinaryOp1_, UnaryOp_,
|
||||
detail::NoOp> {
|
||||
detail::NoOp, ElementVector_> {
|
||||
public:
|
||||
static bool const kIsSingleSource = true;
|
||||
|
||||
@ -191,6 +194,7 @@ public:
|
||||
using ElementC = ElementC_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementVector = ElementVector_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kCount = kElementsPerAccess;
|
||||
|
||||
|
@ -121,7 +121,7 @@ struct DefaultGemmWithBroadcast {
|
||||
GemmBase::Epilogue::kPartitionsK,
|
||||
ElementC_,
|
||||
typename EpilogueOutputOp::ElementT,
|
||||
ElementC_,
|
||||
typename EpilogueOutputOp::ElementVector,
|
||||
EpilogueOutputOp,
|
||||
GemmBase::Epilogue::kElementsPerAccess
|
||||
>::Epilogue;
|
||||
@ -221,7 +221,7 @@ struct DefaultGemmWithBroadcast<
|
||||
GemmBase::Epilogue::kPartitionsK,
|
||||
ElementC_,
|
||||
typename EpilogueOutputOp::ElementT,
|
||||
ElementC_,
|
||||
typename EpilogueOutputOp::ElementVector,
|
||||
EpilogueOutputOp,
|
||||
GemmBase::Epilogue::kElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
@ -120,6 +120,7 @@ public:
|
||||
using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
|
||||
using ElementZ = typename EpilogueOutputOp::ElementZ;
|
||||
using ElementT = typename EpilogueOutputOp::ElementT;
|
||||
using ElementVector = typename EpilogueOutputOp::ElementVector;
|
||||
|
||||
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
|
||||
static const bool kAddBroadcastFirst = AddBroadcastFirst;
|
||||
@ -142,7 +143,7 @@ public:
|
||||
cutlass::HostTensor<ElementT, LayoutC> tensor_T_computed;
|
||||
cutlass::HostTensor<ElementT, LayoutC> tensor_T_reference;
|
||||
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Y_reference;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_Broadcast; // Input Broadcast
|
||||
cutlass::HostTensor<ElementVector, LayoutC> tensor_Broadcast; // Input Broadcast
|
||||
|
||||
public:
|
||||
|
||||
|
@ -105,7 +105,8 @@ struct TestbedGemmWithBroadcast {
|
||||
using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp;
|
||||
using ElementC = typename Gemm::ElementC;
|
||||
using ElementAccumulator = typename Gemm::ElementAccumulator;
|
||||
using ElementCOmpute = typename OutputOp::ElementCompute;
|
||||
using ElementCompute = typename OutputOp::ElementCompute;
|
||||
using ElementVector = typename OutputOp::ElementVector;
|
||||
using ElementZ = typename OutputOp::ElementZ;
|
||||
using ElementT = typename OutputOp::ElementT;
|
||||
|
||||
@ -118,7 +119,7 @@ struct TestbedGemmWithBroadcast {
|
||||
cutlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A; // Input A
|
||||
cutlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B; // Input B
|
||||
cutlass::HostTensor<ElementC, typename Gemm::LayoutC> tensor_C; // Input C
|
||||
cutlass::HostTensor<ElementC, typename Gemm::LayoutC> tensor_Broadcast; // Input Broadcast
|
||||
cutlass::HostTensor<ElementVector, typename Gemm::LayoutC> tensor_Broadcast; // Input Broadcast
|
||||
|
||||
cutlass::HostTensor<ElementZ, typename Gemm::LayoutC> tensor_Z;
|
||||
cutlass::HostTensor<ElementT, typename Gemm::LayoutC> tensor_T;
|
||||
|
Loading…
Reference in New Issue
Block a user