expose StoreT parameter for potential speed (#838)
* expose StoreT parameter for potential speed * add storeT to more elementwise --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
29801e348a
commit
86cae03cea
@ -62,6 +62,7 @@ template <
|
||||
int ElementsPerAccess,
|
||||
typename ElementwiseOp_ = Identity<ElementCompute_>,
|
||||
typename BinaryOp_ = plus<ElementCompute_>,
|
||||
bool StoreT_ = true,
|
||||
typename ElementVector_ = ElementC_
|
||||
>
|
||||
class LinearCombinationBiasElementwise {
|
||||
@ -97,7 +98,7 @@ public:
|
||||
static bool const kStoreZ = true;
|
||||
|
||||
/// If true, the 'T' tensor is stored
|
||||
static bool const kStoreT = true;
|
||||
static bool const kStoreT = StoreT_;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
@ -204,7 +204,7 @@ template <
|
||||
typename ElementCompute_,
|
||||
typename ElementZ_,
|
||||
int ElementsPerAccess,
|
||||
bool StoreT = true,
|
||||
bool StoreT_ = true,
|
||||
typename ElementVector_ = ElementC_
|
||||
>
|
||||
class LinearCombinationBiasRelu {
|
||||
@ -238,7 +238,7 @@ public:
|
||||
static bool const kStoreZ = true;
|
||||
|
||||
/// If true, the 'T' tensor is stored
|
||||
static bool const kStoreT = StoreT;
|
||||
static bool const kStoreT = StoreT_;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
@ -60,6 +60,7 @@ template <typename ElementOutput_, typename ElementAccumulator_,
|
||||
template <typename T> class BinaryOp1_,
|
||||
template <typename T> class UnaryOp_,
|
||||
template <typename T> class BinaryOp2_ = detail::NoOp,
|
||||
bool StoreT_ = false,
|
||||
typename ElementVector_ = ElementC_>
|
||||
class LinearCombinationResidualBlock {
|
||||
public:
|
||||
@ -90,7 +91,7 @@ public:
|
||||
|
||||
static bool const kIsHeavy = true;
|
||||
static bool const kStoreZ = true;
|
||||
static bool const kStoreT = false;
|
||||
static bool const kStoreT = StoreT_;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
@ -182,11 +183,12 @@ template <typename ElementOutput_, typename ElementAccumulator_,
|
||||
template <typename T> class ActivationOp_,
|
||||
template <typename T> class BinaryOp1_,
|
||||
template <typename T> class UnaryOp_,
|
||||
bool StoreT_,
|
||||
typename ElementVector_>
|
||||
class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,
|
||||
ElementCompute_, ElementC_, ElementsPerAccess,
|
||||
ActivationOp_, BinaryOp1_, UnaryOp_,
|
||||
detail::NoOp, ElementVector_> {
|
||||
detail::NoOp, StoreT_, ElementVector_> {
|
||||
public:
|
||||
static bool const kIsSingleSource = true;
|
||||
|
||||
@ -214,7 +216,7 @@ public:
|
||||
|
||||
static bool const kIsHeavy = true;
|
||||
static bool const kStoreZ = true;
|
||||
static bool const kStoreT = false;
|
||||
static bool const kStoreT = StoreT_;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
Loading…
Reference in New Issue
Block a user