Fix type bug in conv2d/gemm with broadcast (#796)
add ElementVector --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
2e10404d26
commit
ce8597dc14
@ -107,7 +107,7 @@ struct DefaultConv2dFpropWithBroadcast {
|
|||||||
ImplicitGemmBase::Epilogue::kPartitionsK,
|
ImplicitGemmBase::Epilogue::kPartitionsK,
|
||||||
ElementC,
|
ElementC,
|
||||||
typename EpilogueOutputOp::ElementT,
|
typename EpilogueOutputOp::ElementT,
|
||||||
ElementC,
|
typename EpilogueOutputOp::ElementVector,
|
||||||
EpilogueOutputOp,
|
EpilogueOutputOp,
|
||||||
ImplicitGemmBase::Epilogue::kElementsPerAccess
|
ImplicitGemmBase::Epilogue::kElementsPerAccess
|
||||||
>::Epilogue;
|
>::Epilogue;
|
||||||
|
@ -61,7 +61,8 @@ template <
|
|||||||
typename ElementT_,
|
typename ElementT_,
|
||||||
int ElementsPerAccess,
|
int ElementsPerAccess,
|
||||||
typename ElementwiseOp_ = Identity<ElementCompute_>,
|
typename ElementwiseOp_ = Identity<ElementCompute_>,
|
||||||
typename BinaryOp_ = plus<ElementCompute_>
|
typename BinaryOp_ = plus<ElementCompute_>,
|
||||||
|
typename ElementVector_ = ElementC_
|
||||||
>
|
>
|
||||||
class LinearCombinationBiasElementwise {
|
class LinearCombinationBiasElementwise {
|
||||||
public:
|
public:
|
||||||
@ -72,6 +73,7 @@ public:
|
|||||||
using ElementCompute = ElementCompute_;
|
using ElementCompute = ElementCompute_;
|
||||||
using ElementZ = ElementZ_;
|
using ElementZ = ElementZ_;
|
||||||
using ElementT = ElementT_;
|
using ElementT = ElementT_;
|
||||||
|
using ElementVector = ElementVector_;
|
||||||
static int const kElementsPerAccess = ElementsPerAccess;
|
static int const kElementsPerAccess = ElementsPerAccess;
|
||||||
static int const kCount = kElementsPerAccess;
|
static int const kCount = kElementsPerAccess;
|
||||||
|
|
||||||
|
@ -204,7 +204,8 @@ template <
|
|||||||
typename ElementCompute_,
|
typename ElementCompute_,
|
||||||
typename ElementZ_,
|
typename ElementZ_,
|
||||||
int ElementsPerAccess,
|
int ElementsPerAccess,
|
||||||
bool StoreT = true
|
bool StoreT = true,
|
||||||
|
typename ElementVector_ = ElementC_
|
||||||
>
|
>
|
||||||
class LinearCombinationBiasRelu {
|
class LinearCombinationBiasRelu {
|
||||||
public:
|
public:
|
||||||
@ -214,6 +215,7 @@ public:
|
|||||||
using ElementAccumulator = ElementAccumulator_;
|
using ElementAccumulator = ElementAccumulator_;
|
||||||
using ElementCompute = ElementCompute_;
|
using ElementCompute = ElementCompute_;
|
||||||
using ElementZ = ElementZ_;
|
using ElementZ = ElementZ_;
|
||||||
|
using ElementVector = ElementVector_;
|
||||||
|
|
||||||
using ElementT = uint1b_t;
|
using ElementT = uint1b_t;
|
||||||
|
|
||||||
|
@ -59,7 +59,8 @@ template <typename ElementOutput_, typename ElementAccumulator_,
|
|||||||
template <typename T> class ActivationOp_,
|
template <typename T> class ActivationOp_,
|
||||||
template <typename T> class BinaryOp1_,
|
template <typename T> class BinaryOp1_,
|
||||||
template <typename T> class UnaryOp_,
|
template <typename T> class UnaryOp_,
|
||||||
template <typename T> class BinaryOp2_ = detail::NoOp>
|
template <typename T> class BinaryOp2_ = detail::NoOp,
|
||||||
|
typename ElementVector_ = ElementC_>
|
||||||
class LinearCombinationResidualBlock {
|
class LinearCombinationResidualBlock {
|
||||||
public:
|
public:
|
||||||
static bool const kIsSingleSource = false;
|
static bool const kIsSingleSource = false;
|
||||||
@ -68,6 +69,7 @@ public:
|
|||||||
using ElementC = ElementC_;
|
using ElementC = ElementC_;
|
||||||
using ElementAccumulator = ElementAccumulator_;
|
using ElementAccumulator = ElementAccumulator_;
|
||||||
using ElementCompute = ElementCompute_;
|
using ElementCompute = ElementCompute_;
|
||||||
|
using ElementVector = ElementVector_;
|
||||||
static int const kElementsPerAccess = ElementsPerAccess;
|
static int const kElementsPerAccess = ElementsPerAccess;
|
||||||
static int const kCount = kElementsPerAccess;
|
static int const kCount = kElementsPerAccess;
|
||||||
|
|
||||||
@ -179,11 +181,12 @@ template <typename ElementOutput_, typename ElementAccumulator_,
|
|||||||
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
|
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
|
||||||
template <typename T> class ActivationOp_,
|
template <typename T> class ActivationOp_,
|
||||||
template <typename T> class BinaryOp1_,
|
template <typename T> class BinaryOp1_,
|
||||||
template <typename T> class UnaryOp_>
|
template <typename T> class UnaryOp_,
|
||||||
|
typename ElementVector_>
|
||||||
class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,
|
class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,
|
||||||
ElementCompute_, ElementC_, ElementsPerAccess,
|
ElementCompute_, ElementC_, ElementsPerAccess,
|
||||||
ActivationOp_, BinaryOp1_, UnaryOp_,
|
ActivationOp_, BinaryOp1_, UnaryOp_,
|
||||||
detail::NoOp> {
|
detail::NoOp, ElementVector_> {
|
||||||
public:
|
public:
|
||||||
static bool const kIsSingleSource = true;
|
static bool const kIsSingleSource = true;
|
||||||
|
|
||||||
@ -191,6 +194,7 @@ public:
|
|||||||
using ElementC = ElementC_;
|
using ElementC = ElementC_;
|
||||||
using ElementAccumulator = ElementAccumulator_;
|
using ElementAccumulator = ElementAccumulator_;
|
||||||
using ElementCompute = ElementCompute_;
|
using ElementCompute = ElementCompute_;
|
||||||
|
using ElementVector = ElementVector_;
|
||||||
static int const kElementsPerAccess = ElementsPerAccess;
|
static int const kElementsPerAccess = ElementsPerAccess;
|
||||||
static int const kCount = kElementsPerAccess;
|
static int const kCount = kElementsPerAccess;
|
||||||
|
|
||||||
|
@ -121,7 +121,7 @@ struct DefaultGemmWithBroadcast {
|
|||||||
GemmBase::Epilogue::kPartitionsK,
|
GemmBase::Epilogue::kPartitionsK,
|
||||||
ElementC_,
|
ElementC_,
|
||||||
typename EpilogueOutputOp::ElementT,
|
typename EpilogueOutputOp::ElementT,
|
||||||
ElementC_,
|
typename EpilogueOutputOp::ElementVector,
|
||||||
EpilogueOutputOp,
|
EpilogueOutputOp,
|
||||||
GemmBase::Epilogue::kElementsPerAccess
|
GemmBase::Epilogue::kElementsPerAccess
|
||||||
>::Epilogue;
|
>::Epilogue;
|
||||||
@ -221,7 +221,7 @@ struct DefaultGemmWithBroadcast<
|
|||||||
GemmBase::Epilogue::kPartitionsK,
|
GemmBase::Epilogue::kPartitionsK,
|
||||||
ElementC_,
|
ElementC_,
|
||||||
typename EpilogueOutputOp::ElementT,
|
typename EpilogueOutputOp::ElementT,
|
||||||
ElementC_,
|
typename EpilogueOutputOp::ElementVector,
|
||||||
EpilogueOutputOp,
|
EpilogueOutputOp,
|
||||||
GemmBase::Epilogue::kElementsPerAccess
|
GemmBase::Epilogue::kElementsPerAccess
|
||||||
>::Epilogue;
|
>::Epilogue;
|
||||||
|
@ -120,6 +120,7 @@ public:
|
|||||||
using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
|
using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
|
||||||
using ElementZ = typename EpilogueOutputOp::ElementZ;
|
using ElementZ = typename EpilogueOutputOp::ElementZ;
|
||||||
using ElementT = typename EpilogueOutputOp::ElementT;
|
using ElementT = typename EpilogueOutputOp::ElementT;
|
||||||
|
using ElementVector = typename EpilogueOutputOp::ElementVector;
|
||||||
|
|
||||||
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
|
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
|
||||||
static const bool kAddBroadcastFirst = AddBroadcastFirst;
|
static const bool kAddBroadcastFirst = AddBroadcastFirst;
|
||||||
@ -142,7 +143,7 @@ public:
|
|||||||
cutlass::HostTensor<ElementT, LayoutC> tensor_T_computed;
|
cutlass::HostTensor<ElementT, LayoutC> tensor_T_computed;
|
||||||
cutlass::HostTensor<ElementT, LayoutC> tensor_T_reference;
|
cutlass::HostTensor<ElementT, LayoutC> tensor_T_reference;
|
||||||
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Y_reference;
|
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Y_reference;
|
||||||
cutlass::HostTensor<ElementC, LayoutC> tensor_Broadcast; // Input Broadcast
|
cutlass::HostTensor<ElementVector, LayoutC> tensor_Broadcast; // Input Broadcast
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
@ -105,7 +105,8 @@ struct TestbedGemmWithBroadcast {
|
|||||||
using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp;
|
using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp;
|
||||||
using ElementC = typename Gemm::ElementC;
|
using ElementC = typename Gemm::ElementC;
|
||||||
using ElementAccumulator = typename Gemm::ElementAccumulator;
|
using ElementAccumulator = typename Gemm::ElementAccumulator;
|
||||||
using ElementCOmpute = typename OutputOp::ElementCompute;
|
using ElementCompute = typename OutputOp::ElementCompute;
|
||||||
|
using ElementVector = typename OutputOp::ElementVector;
|
||||||
using ElementZ = typename OutputOp::ElementZ;
|
using ElementZ = typename OutputOp::ElementZ;
|
||||||
using ElementT = typename OutputOp::ElementT;
|
using ElementT = typename OutputOp::ElementT;
|
||||||
|
|
||||||
@ -118,7 +119,7 @@ struct TestbedGemmWithBroadcast {
|
|||||||
cutlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A; // Input A
|
cutlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A; // Input A
|
||||||
cutlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B; // Input B
|
cutlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B; // Input B
|
||||||
cutlass::HostTensor<ElementC, typename Gemm::LayoutC> tensor_C; // Input C
|
cutlass::HostTensor<ElementC, typename Gemm::LayoutC> tensor_C; // Input C
|
||||||
cutlass::HostTensor<ElementC, typename Gemm::LayoutC> tensor_Broadcast; // Input Broadcast
|
cutlass::HostTensor<ElementVector, typename Gemm::LayoutC> tensor_Broadcast; // Input Broadcast
|
||||||
|
|
||||||
cutlass::HostTensor<ElementZ, typename Gemm::LayoutC> tensor_Z;
|
cutlass::HostTensor<ElementZ, typename Gemm::LayoutC> tensor_Z;
|
||||||
cutlass::HostTensor<ElementT, typename Gemm::LayoutC> tensor_T;
|
cutlass::HostTensor<ElementT, typename Gemm::LayoutC> tensor_T;
|
||||||
|
Loading…
Reference in New Issue
Block a user