expose StoreT parameter for potential speed (#838)

* expose StoreT parameter for potential speed

* add storeT to more elementwise

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Edward Rees 2023-03-10 17:58:17 +00:00 committed by GitHub
parent 29801e348a
commit 86cae03cea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 6 deletions

View File

@ -62,6 +62,7 @@ template <
int ElementsPerAccess,
typename ElementwiseOp_ = Identity<ElementCompute_>,
typename BinaryOp_ = plus<ElementCompute_>,
bool StoreT_ = true,
typename ElementVector_ = ElementC_
>
class LinearCombinationBiasElementwise {
@ -97,7 +98,7 @@ public:
static bool const kStoreZ = true;
/// If true, the 'T' tensor is stored
static bool const kStoreT = true;
static bool const kStoreT = StoreT_;
/// Host-constructable parameters structure
struct Params {

View File

@ -204,7 +204,7 @@ template <
typename ElementCompute_,
typename ElementZ_,
int ElementsPerAccess,
bool StoreT = true,
bool StoreT_ = true,
typename ElementVector_ = ElementC_
>
class LinearCombinationBiasRelu {
@ -238,7 +238,7 @@ public:
static bool const kStoreZ = true;
/// If true, the 'T' tensor is stored
static bool const kStoreT = StoreT;
static bool const kStoreT = StoreT_;
/// Host-constructable parameters structure
struct Params {

View File

@ -60,6 +60,7 @@ template <typename ElementOutput_, typename ElementAccumulator_,
template <typename T> class BinaryOp1_,
template <typename T> class UnaryOp_,
template <typename T> class BinaryOp2_ = detail::NoOp,
bool StoreT_ = false,
typename ElementVector_ = ElementC_>
class LinearCombinationResidualBlock {
public:
@ -90,7 +91,7 @@ public:
static bool const kIsHeavy = true;
static bool const kStoreZ = true;
static bool const kStoreT = false;
static bool const kStoreT = StoreT_;
/// Host-constructable parameters structure
struct Params {
@ -182,11 +183,12 @@ template <typename ElementOutput_, typename ElementAccumulator_,
template <typename T> class ActivationOp_,
template <typename T> class BinaryOp1_,
template <typename T> class UnaryOp_,
bool StoreT_,
typename ElementVector_>
class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,
ElementCompute_, ElementC_, ElementsPerAccess,
ActivationOp_, BinaryOp1_, UnaryOp_,
detail::NoOp, ElementVector_> {
detail::NoOp, StoreT_, ElementVector_> {
public:
static bool const kIsSingleSource = true;
@ -214,7 +216,7 @@ public:
static bool const kIsHeavy = true;
static bool const kStoreZ = true;
static bool const kStoreT = false;
static bool const kStoreT = StoreT_;
/// Host-constructable parameters structure
struct Params {