diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h index d744ae8d..9ffb05e7 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h @@ -107,7 +107,7 @@ struct DefaultConv2dFpropWithBroadcast { ImplicitGemmBase::Epilogue::kPartitionsK, ElementC, typename EpilogueOutputOp::ElementT, - ElementC, + typename EpilogueOutputOp::ElementVector, EpilogueOutputOp, ImplicitGemmBase::Epilogue::kElementsPerAccess >::Epilogue; diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h index 6892efb5..4bcc9df1 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h @@ -61,7 +61,8 @@ template < typename ElementT_, int ElementsPerAccess, typename ElementwiseOp_ = Identity, - typename BinaryOp_ = plus + typename BinaryOp_ = plus, + typename ElementVector_ = ElementC_ > class LinearCombinationBiasElementwise { public: @@ -72,6 +73,7 @@ public: using ElementCompute = ElementCompute_; using ElementZ = ElementZ_; using ElementT = ElementT_; + using ElementVector = ElementVector_; static int const kElementsPerAccess = ElementsPerAccess; static int const kCount = kElementsPerAccess; diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_relu.h b/include/cutlass/epilogue/thread/linear_combination_bias_relu.h index b095c91e..1f1a0179 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_relu.h @@ -204,7 +204,8 @@ template < typename ElementCompute_, typename ElementZ_, int ElementsPerAccess, - bool StoreT = true + bool StoreT = true, + typename ElementVector_ = ElementC_ > class LinearCombinationBiasRelu { public: @@ -214,6 +215,7 @@ public: using ElementAccumulator = ElementAccumulator_; using ElementCompute = ElementCompute_; using ElementZ = ElementZ_; + using ElementVector = ElementVector_; using ElementT = uint1b_t; diff --git a/include/cutlass/epilogue/thread/linear_combination_residual_block.h b/include/cutlass/epilogue/thread/linear_combination_residual_block.h index 7c47c24e..8aca8e32 100644 --- a/include/cutlass/epilogue/thread/linear_combination_residual_block.h +++ b/include/cutlass/epilogue/thread/linear_combination_residual_block.h @@ -59,7 +59,8 @@ template class ActivationOp_, template class BinaryOp1_, template class UnaryOp_, - template class BinaryOp2_ = detail::NoOp> + template class BinaryOp2_ = detail::NoOp, + typename ElementVector_ = ElementC_> class LinearCombinationResidualBlock { public: static bool const kIsSingleSource = false; @@ -68,6 +69,7 @@ public: using ElementC = ElementC_; using ElementAccumulator = ElementAccumulator_; using ElementCompute = ElementCompute_; + using ElementVector = ElementVector_; static int const kElementsPerAccess = ElementsPerAccess; static int const kCount = kElementsPerAccess; @@ -179,11 +181,12 @@ template class ActivationOp_, template class BinaryOp1_, - template class UnaryOp_> + template class UnaryOp_, + typename ElementVector_> class LinearCombinationResidualBlock { + detail::NoOp, ElementVector_> { public: static bool const kIsSingleSource = true; @@ -191,6 +194,7 @@ public: using ElementC = ElementC_; using ElementAccumulator = ElementAccumulator_; using ElementCompute = ElementCompute_; + using ElementVector = ElementVector_; static int const kElementsPerAccess = ElementsPerAccess; static int const kCount = kElementsPerAccess; diff --git a/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h b/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h index 1356b490..e3ef316b 100644 --- a/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h +++ b/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h @@ -121,7 +121,7 @@ struct DefaultGemmWithBroadcast { GemmBase::Epilogue::kPartitionsK, ElementC_, typename EpilogueOutputOp::ElementT, - ElementC_, + typename EpilogueOutputOp::ElementVector, EpilogueOutputOp, GemmBase::Epilogue::kElementsPerAccess >::Epilogue; @@ -221,7 +221,7 @@ struct DefaultGemmWithBroadcast< GemmBase::Epilogue::kPartitionsK, ElementC_, typename EpilogueOutputOp::ElementT, - ElementC_, + typename EpilogueOutputOp::ElementVector, EpilogueOutputOp, GemmBase::Epilogue::kElementsPerAccess >::Epilogue; diff --git a/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/test/unit/conv/device/conv2d_with_broadcast_testbed.h index 117fef0d..d678e3b5 100644 --- a/test/unit/conv/device/conv2d_with_broadcast_testbed.h +++ b/test/unit/conv/device/conv2d_with_broadcast_testbed.h @@ -120,6 +120,7 @@ public: using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; using ElementZ = typename EpilogueOutputOp::ElementZ; using ElementT = typename EpilogueOutputOp::ElementT; + using ElementVector = typename EpilogueOutputOp::ElementVector; static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; static const bool kAddBroadcastFirst = AddBroadcastFirst; @@ -142,7 +143,7 @@ public: cutlass::HostTensor tensor_T_computed; cutlass::HostTensor tensor_T_reference; cutlass::HostTensor tensor_Y_reference; - cutlass::HostTensor tensor_Broadcast; // Input Broadcast + cutlass::HostTensor tensor_Broadcast; // Input Broadcast public: diff --git a/test/unit/gemm/device/testbed_gemm_with_broadcast.h b/test/unit/gemm/device/testbed_gemm_with_broadcast.h index 10d5d3f0..c28fc8dd 100644 --- a/test/unit/gemm/device/testbed_gemm_with_broadcast.h +++ b/test/unit/gemm/device/testbed_gemm_with_broadcast.h @@ -105,7 +105,8 @@ struct TestbedGemmWithBroadcast { using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; using ElementC = typename Gemm::ElementC; using ElementAccumulator = typename Gemm::ElementAccumulator; - using ElementCOmpute = typename OutputOp::ElementCompute; + using ElementCompute = typename OutputOp::ElementCompute; + using ElementVector = typename OutputOp::ElementVector; using ElementZ = typename OutputOp::ElementZ; using ElementT = typename OutputOp::ElementT; @@ -118,7 +119,7 @@ struct TestbedGemmWithBroadcast { cutlass::HostTensor tensor_A; // Input A cutlass::HostTensor tensor_B; // Input B cutlass::HostTensor tensor_C; // Input C - cutlass::HostTensor tensor_Broadcast; // Input Broadcast + cutlass::HostTensor tensor_Broadcast; // Input Broadcast cutlass::HostTensor tensor_Z; cutlass::HostTensor tensor_T;