expose StoreT parameter for potential speed (#838)

* expose StoreT parameter for potential speed

* add storeT to more elementwise

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Edward Rees 2023-03-10 17:58:17 +00:00 committed by GitHub
parent 29801e348a
commit 86cae03cea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 6 deletions

View File

@ -62,6 +62,7 @@ template <
int ElementsPerAccess, int ElementsPerAccess,
typename ElementwiseOp_ = Identity<ElementCompute_>, typename ElementwiseOp_ = Identity<ElementCompute_>,
typename BinaryOp_ = plus<ElementCompute_>, typename BinaryOp_ = plus<ElementCompute_>,
bool StoreT_ = true,
typename ElementVector_ = ElementC_ typename ElementVector_ = ElementC_
> >
class LinearCombinationBiasElementwise { class LinearCombinationBiasElementwise {
@ -97,7 +98,7 @@ public:
static bool const kStoreZ = true; static bool const kStoreZ = true;
/// If true, the 'T' tensor is stored /// If true, the 'T' tensor is stored
static bool const kStoreT = true; static bool const kStoreT = StoreT_;
/// Host-constructable parameters structure /// Host-constructable parameters structure
struct Params { struct Params {

View File

@ -204,7 +204,7 @@ template <
typename ElementCompute_, typename ElementCompute_,
typename ElementZ_, typename ElementZ_,
int ElementsPerAccess, int ElementsPerAccess,
bool StoreT = true, bool StoreT_ = true,
typename ElementVector_ = ElementC_ typename ElementVector_ = ElementC_
> >
class LinearCombinationBiasRelu { class LinearCombinationBiasRelu {
@ -238,7 +238,7 @@ public:
static bool const kStoreZ = true; static bool const kStoreZ = true;
/// If true, the 'T' tensor is stored /// If true, the 'T' tensor is stored
static bool const kStoreT = StoreT; static bool const kStoreT = StoreT_;
/// Host-constructable parameters structure /// Host-constructable parameters structure
struct Params { struct Params {

View File

@ -60,6 +60,7 @@ template <typename ElementOutput_, typename ElementAccumulator_,
template <typename T> class BinaryOp1_, template <typename T> class BinaryOp1_,
template <typename T> class UnaryOp_, template <typename T> class UnaryOp_,
template <typename T> class BinaryOp2_ = detail::NoOp, template <typename T> class BinaryOp2_ = detail::NoOp,
bool StoreT_ = false,
typename ElementVector_ = ElementC_> typename ElementVector_ = ElementC_>
class LinearCombinationResidualBlock { class LinearCombinationResidualBlock {
public: public:
@ -90,7 +91,7 @@ public:
static bool const kIsHeavy = true; static bool const kIsHeavy = true;
static bool const kStoreZ = true; static bool const kStoreZ = true;
static bool const kStoreT = false; static bool const kStoreT = StoreT_;
/// Host-constructable parameters structure /// Host-constructable parameters structure
struct Params { struct Params {
@ -182,11 +183,12 @@ template <typename ElementOutput_, typename ElementAccumulator_,
template <typename T> class ActivationOp_, template <typename T> class ActivationOp_,
template <typename T> class BinaryOp1_, template <typename T> class BinaryOp1_,
template <typename T> class UnaryOp_, template <typename T> class UnaryOp_,
bool StoreT_,
typename ElementVector_> typename ElementVector_>
class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_, class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,
ElementCompute_, ElementC_, ElementsPerAccess, ElementCompute_, ElementC_, ElementsPerAccess,
ActivationOp_, BinaryOp1_, UnaryOp_, ActivationOp_, BinaryOp1_, UnaryOp_,
detail::NoOp, ElementVector_> { detail::NoOp, StoreT_, ElementVector_> {
public: public:
static bool const kIsSingleSource = true; static bool const kIsSingleSource = true;
@ -214,7 +216,7 @@ public:
static bool const kIsHeavy = true; static bool const kIsHeavy = true;
static bool const kStoreZ = true; static bool const kStoreZ = true;
static bool const kStoreT = false; static bool const kStoreT = StoreT_;
/// Host-constructable parameters structure /// Host-constructable parameters structure
struct Params { struct Params {