Fix type bug in conv2d/gemm with broadcast (#796)
add ElementVector --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
		
							parent
							
								
									2e10404d26
								
							
						
					
					
						commit
						ce8597dc14
					
				| @ -107,7 +107,7 @@ struct DefaultConv2dFpropWithBroadcast { | |||||||
|     ImplicitGemmBase::Epilogue::kPartitionsK, |     ImplicitGemmBase::Epilogue::kPartitionsK, | ||||||
|     ElementC, |     ElementC, | ||||||
|     typename EpilogueOutputOp::ElementT, |     typename EpilogueOutputOp::ElementT, | ||||||
|     ElementC, |     typename EpilogueOutputOp::ElementVector, | ||||||
|     EpilogueOutputOp, |     EpilogueOutputOp, | ||||||
|     ImplicitGemmBase::Epilogue::kElementsPerAccess |     ImplicitGemmBase::Epilogue::kElementsPerAccess | ||||||
|   >::Epilogue; |   >::Epilogue; | ||||||
|  | |||||||
| @ -61,7 +61,8 @@ template < | |||||||
|   typename ElementT_, |   typename ElementT_, | ||||||
|   int ElementsPerAccess, |   int ElementsPerAccess, | ||||||
|   typename ElementwiseOp_ = Identity<ElementCompute_>, |   typename ElementwiseOp_ = Identity<ElementCompute_>, | ||||||
|   typename BinaryOp_ = plus<ElementCompute_> |   typename BinaryOp_ = plus<ElementCompute_>, | ||||||
|  |   typename ElementVector_ = ElementC_ | ||||||
| > | > | ||||||
| class LinearCombinationBiasElementwise { | class LinearCombinationBiasElementwise { | ||||||
| public: | public: | ||||||
| @ -72,6 +73,7 @@ public: | |||||||
|   using ElementCompute = ElementCompute_; |   using ElementCompute = ElementCompute_; | ||||||
|   using ElementZ = ElementZ_; |   using ElementZ = ElementZ_; | ||||||
|   using ElementT = ElementT_; |   using ElementT = ElementT_; | ||||||
|  |   using ElementVector = ElementVector_; | ||||||
|   static int const kElementsPerAccess = ElementsPerAccess; |   static int const kElementsPerAccess = ElementsPerAccess; | ||||||
|   static int const kCount = kElementsPerAccess; |   static int const kCount = kElementsPerAccess; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -204,7 +204,8 @@ template < | |||||||
|   typename ElementCompute_, |   typename ElementCompute_, | ||||||
|   typename ElementZ_, |   typename ElementZ_, | ||||||
|   int ElementsPerAccess, |   int ElementsPerAccess, | ||||||
|   bool StoreT = true |   bool StoreT = true, | ||||||
|  |   typename ElementVector_ = ElementC_ | ||||||
| > | > | ||||||
| class LinearCombinationBiasRelu { | class LinearCombinationBiasRelu { | ||||||
| public: | public: | ||||||
| @ -214,6 +215,7 @@ public: | |||||||
|   using ElementAccumulator = ElementAccumulator_; |   using ElementAccumulator = ElementAccumulator_; | ||||||
|   using ElementCompute = ElementCompute_; |   using ElementCompute = ElementCompute_; | ||||||
|   using ElementZ = ElementZ_; |   using ElementZ = ElementZ_; | ||||||
|  |   using ElementVector = ElementVector_; | ||||||
| 
 | 
 | ||||||
|   using ElementT = uint1b_t; |   using ElementT = uint1b_t; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -59,7 +59,8 @@ template <typename ElementOutput_, typename ElementAccumulator_, | |||||||
|           template <typename T> class ActivationOp_, |           template <typename T> class ActivationOp_, | ||||||
|           template <typename T> class BinaryOp1_, |           template <typename T> class BinaryOp1_, | ||||||
|           template <typename T> class UnaryOp_, |           template <typename T> class UnaryOp_, | ||||||
|           template <typename T> class BinaryOp2_ = detail::NoOp> |           template <typename T> class BinaryOp2_ = detail::NoOp, | ||||||
|  |           typename ElementVector_ = ElementC_> | ||||||
| class LinearCombinationResidualBlock { | class LinearCombinationResidualBlock { | ||||||
| public: | public: | ||||||
|   static bool const kIsSingleSource = false; |   static bool const kIsSingleSource = false; | ||||||
| @ -68,6 +69,7 @@ public: | |||||||
|   using ElementC = ElementC_; |   using ElementC = ElementC_; | ||||||
|   using ElementAccumulator = ElementAccumulator_; |   using ElementAccumulator = ElementAccumulator_; | ||||||
|   using ElementCompute = ElementCompute_; |   using ElementCompute = ElementCompute_; | ||||||
|  |   using ElementVector = ElementVector_; | ||||||
|   static int const kElementsPerAccess = ElementsPerAccess; |   static int const kElementsPerAccess = ElementsPerAccess; | ||||||
|   static int const kCount = kElementsPerAccess; |   static int const kCount = kElementsPerAccess; | ||||||
| 
 | 
 | ||||||
| @ -179,11 +181,12 @@ template <typename ElementOutput_, typename ElementAccumulator_, | |||||||
|           typename ElementCompute_, typename ElementC_, int ElementsPerAccess, |           typename ElementCompute_, typename ElementC_, int ElementsPerAccess, | ||||||
|           template <typename T> class ActivationOp_, |           template <typename T> class ActivationOp_, | ||||||
|           template <typename T> class BinaryOp1_, |           template <typename T> class BinaryOp1_, | ||||||
|           template <typename T> class UnaryOp_> |           template <typename T> class UnaryOp_, | ||||||
|  |           typename ElementVector_> | ||||||
| class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_, | class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_, | ||||||
|           ElementCompute_, ElementC_, ElementsPerAccess, |           ElementCompute_, ElementC_, ElementsPerAccess, | ||||||
|           ActivationOp_, BinaryOp1_, UnaryOp_, |           ActivationOp_, BinaryOp1_, UnaryOp_, | ||||||
|           detail::NoOp> { |           detail::NoOp, ElementVector_> { | ||||||
| public: | public: | ||||||
|   static bool const kIsSingleSource = true; |   static bool const kIsSingleSource = true; | ||||||
| 
 | 
 | ||||||
| @ -191,6 +194,7 @@ public: | |||||||
|   using ElementC = ElementC_; |   using ElementC = ElementC_; | ||||||
|   using ElementAccumulator = ElementAccumulator_; |   using ElementAccumulator = ElementAccumulator_; | ||||||
|   using ElementCompute = ElementCompute_; |   using ElementCompute = ElementCompute_; | ||||||
|  |   using ElementVector = ElementVector_; | ||||||
|   static int const kElementsPerAccess = ElementsPerAccess; |   static int const kElementsPerAccess = ElementsPerAccess; | ||||||
|   static int const kCount = kElementsPerAccess; |   static int const kCount = kElementsPerAccess; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -121,7 +121,7 @@ struct DefaultGemmWithBroadcast { | |||||||
|     GemmBase::Epilogue::kPartitionsK, |     GemmBase::Epilogue::kPartitionsK, | ||||||
|     ElementC_, |     ElementC_, | ||||||
|     typename EpilogueOutputOp::ElementT, |     typename EpilogueOutputOp::ElementT, | ||||||
|     ElementC_, |     typename EpilogueOutputOp::ElementVector, | ||||||
|     EpilogueOutputOp, |     EpilogueOutputOp, | ||||||
|     GemmBase::Epilogue::kElementsPerAccess |     GemmBase::Epilogue::kElementsPerAccess | ||||||
|   >::Epilogue; |   >::Epilogue; | ||||||
| @ -221,7 +221,7 @@ struct DefaultGemmWithBroadcast< | |||||||
|     GemmBase::Epilogue::kPartitionsK, |     GemmBase::Epilogue::kPartitionsK, | ||||||
|     ElementC_, |     ElementC_, | ||||||
|     typename EpilogueOutputOp::ElementT, |     typename EpilogueOutputOp::ElementT, | ||||||
|     ElementC_, |     typename EpilogueOutputOp::ElementVector, | ||||||
|     EpilogueOutputOp, |     EpilogueOutputOp, | ||||||
|     GemmBase::Epilogue::kElementsPerAccess |     GemmBase::Epilogue::kElementsPerAccess | ||||||
|   >::Epilogue; |   >::Epilogue; | ||||||
|  | |||||||
| @ -120,6 +120,7 @@ public: | |||||||
|   using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; |   using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; | ||||||
|   using ElementZ = typename EpilogueOutputOp::ElementZ; |   using ElementZ = typename EpilogueOutputOp::ElementZ; | ||||||
|   using ElementT = typename EpilogueOutputOp::ElementT; |   using ElementT = typename EpilogueOutputOp::ElementT; | ||||||
|  |   using ElementVector = typename EpilogueOutputOp::ElementVector; | ||||||
| 
 | 
 | ||||||
|   static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; |   static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; | ||||||
|   static const bool kAddBroadcastFirst = AddBroadcastFirst; |   static const bool kAddBroadcastFirst = AddBroadcastFirst; | ||||||
| @ -142,7 +143,7 @@ public: | |||||||
|   cutlass::HostTensor<ElementT, LayoutC> tensor_T_computed; |   cutlass::HostTensor<ElementT, LayoutC> tensor_T_computed; | ||||||
|   cutlass::HostTensor<ElementT, LayoutC> tensor_T_reference; |   cutlass::HostTensor<ElementT, LayoutC> tensor_T_reference; | ||||||
|   cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Y_reference; |   cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Y_reference; | ||||||
|   cutlass::HostTensor<ElementC, LayoutC> tensor_Broadcast;                 // Input Broadcast
 |   cutlass::HostTensor<ElementVector, LayoutC> tensor_Broadcast;            // Input Broadcast
 | ||||||
| 
 | 
 | ||||||
| public: | public: | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -105,7 +105,8 @@ struct TestbedGemmWithBroadcast { | |||||||
|   using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; |   using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; | ||||||
|   using ElementC = typename Gemm::ElementC; |   using ElementC = typename Gemm::ElementC; | ||||||
|   using ElementAccumulator = typename Gemm::ElementAccumulator; |   using ElementAccumulator = typename Gemm::ElementAccumulator; | ||||||
|   using ElementCOmpute = typename OutputOp::ElementCompute; |   using ElementCompute = typename OutputOp::ElementCompute; | ||||||
|  |   using ElementVector = typename OutputOp::ElementVector; | ||||||
|   using ElementZ = typename OutputOp::ElementZ; |   using ElementZ = typename OutputOp::ElementZ; | ||||||
|   using ElementT = typename OutputOp::ElementT; |   using ElementT = typename OutputOp::ElementT; | ||||||
| 
 | 
 | ||||||
| @ -118,7 +119,7 @@ struct TestbedGemmWithBroadcast { | |||||||
|   cutlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A;          // Input A
 |   cutlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A;          // Input A
 | ||||||
|   cutlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B;          // Input B
 |   cutlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B;          // Input B
 | ||||||
|   cutlass::HostTensor<ElementC, typename Gemm::LayoutC> tensor_C;                         // Input C
 |   cutlass::HostTensor<ElementC, typename Gemm::LayoutC> tensor_C;                         // Input C
 | ||||||
|   cutlass::HostTensor<ElementC, typename Gemm::LayoutC> tensor_Broadcast;                 // Input Broadcast
 |   cutlass::HostTensor<ElementVector, typename Gemm::LayoutC> tensor_Broadcast;            // Input Broadcast
 | ||||||
| 
 | 
 | ||||||
|   cutlass::HostTensor<ElementZ, typename Gemm::LayoutC> tensor_Z; |   cutlass::HostTensor<ElementZ, typename Gemm::LayoutC> tensor_Z; | ||||||
|   cutlass::HostTensor<ElementT, typename Gemm::LayoutC> tensor_T; |   cutlass::HostTensor<ElementT, typename Gemm::LayoutC> tensor_T; | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 Shuai Shao
						Shuai Shao