From 86cae03cea301817232290d231092ef2c18961d7 Mon Sep 17 00:00:00 2001 From: Edward Rees <53308037+erees1@users.noreply.github.com> Date: Fri, 10 Mar 2023 17:58:17 +0000 Subject: [PATCH] expose StoreT parameter for potential speed (#838) * expose StoreT parameter for potential speed * add storeT to more elementwise --------- Co-authored-by: Haicheng Wu --- .../epilogue/thread/linear_combination_bias_elementwise.h | 3 ++- .../epilogue/thread/linear_combination_bias_relu.h | 4 ++-- .../epilogue/thread/linear_combination_residual_block.h | 8 +++++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h index 4bcc9df1..d145bfa7 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h @@ -62,6 +62,7 @@ template < int ElementsPerAccess, typename ElementwiseOp_ = Identity, typename BinaryOp_ = plus, + bool StoreT_ = true, typename ElementVector_ = ElementC_ > class LinearCombinationBiasElementwise { @@ -97,7 +98,7 @@ public: static bool const kStoreZ = true; /// If true, the 'T' tensor is stored - static bool const kStoreT = true; + static bool const kStoreT = StoreT_; /// Host-constructable parameters structure struct Params { diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_relu.h b/include/cutlass/epilogue/thread/linear_combination_bias_relu.h index 1f1a0179..0549d753 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_relu.h @@ -204,7 +204,7 @@ template < typename ElementCompute_, typename ElementZ_, int ElementsPerAccess, - bool StoreT = true, + bool StoreT_ = true, typename ElementVector_ = ElementC_ > class LinearCombinationBiasRelu { @@ -238,7 +238,7 @@ public: static bool const kStoreZ = true; /// If true, the 'T' tensor is stored - static bool const kStoreT = StoreT; + static bool const kStoreT = StoreT_; /// Host-constructable parameters structure struct Params { diff --git a/include/cutlass/epilogue/thread/linear_combination_residual_block.h b/include/cutlass/epilogue/thread/linear_combination_residual_block.h index 8aca8e32..42d14662 100644 --- a/include/cutlass/epilogue/thread/linear_combination_residual_block.h +++ b/include/cutlass/epilogue/thread/linear_combination_residual_block.h @@ -60,6 +60,7 @@ template class BinaryOp1_, template class UnaryOp_, template class BinaryOp2_ = detail::NoOp, + bool StoreT_ = false, typename ElementVector_ = ElementC_> class LinearCombinationResidualBlock { public: @@ -90,7 +91,7 @@ public: static bool const kIsHeavy = true; static bool const kStoreZ = true; - static bool const kStoreT = false; + static bool const kStoreT = StoreT_; /// Host-constructable parameters structure struct Params { @@ -182,11 +183,12 @@ template class ActivationOp_, template class BinaryOp1_, template class UnaryOp_, + bool StoreT_, typename ElementVector_> class LinearCombinationResidualBlock { + detail::NoOp, StoreT_, ElementVector_> { public: static bool const kIsSingleSource = true; @@ -214,7 +216,7 @@ public: static bool const kIsHeavy = true; static bool const kStoreZ = true; - static bool const kStoreT = false; + static bool const kStoreT = StoreT_; /// Host-constructable parameters structure struct Params {