expose StoreT parameter for potential speed (#838)
* expose StoreT parameter for potential speed * add storeT to more elementwise --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
29801e348a
commit
86cae03cea
@ -62,6 +62,7 @@ template <
|
|||||||
int ElementsPerAccess,
|
int ElementsPerAccess,
|
||||||
typename ElementwiseOp_ = Identity<ElementCompute_>,
|
typename ElementwiseOp_ = Identity<ElementCompute_>,
|
||||||
typename BinaryOp_ = plus<ElementCompute_>,
|
typename BinaryOp_ = plus<ElementCompute_>,
|
||||||
|
bool StoreT_ = true,
|
||||||
typename ElementVector_ = ElementC_
|
typename ElementVector_ = ElementC_
|
||||||
>
|
>
|
||||||
class LinearCombinationBiasElementwise {
|
class LinearCombinationBiasElementwise {
|
||||||
@ -97,7 +98,7 @@ public:
|
|||||||
static bool const kStoreZ = true;
|
static bool const kStoreZ = true;
|
||||||
|
|
||||||
/// If true, the 'T' tensor is stored
|
/// If true, the 'T' tensor is stored
|
||||||
static bool const kStoreT = true;
|
static bool const kStoreT = StoreT_;
|
||||||
|
|
||||||
/// Host-constructable parameters structure
|
/// Host-constructable parameters structure
|
||||||
struct Params {
|
struct Params {
|
||||||
|
@ -204,7 +204,7 @@ template <
|
|||||||
typename ElementCompute_,
|
typename ElementCompute_,
|
||||||
typename ElementZ_,
|
typename ElementZ_,
|
||||||
int ElementsPerAccess,
|
int ElementsPerAccess,
|
||||||
bool StoreT = true,
|
bool StoreT_ = true,
|
||||||
typename ElementVector_ = ElementC_
|
typename ElementVector_ = ElementC_
|
||||||
>
|
>
|
||||||
class LinearCombinationBiasRelu {
|
class LinearCombinationBiasRelu {
|
||||||
@ -238,7 +238,7 @@ public:
|
|||||||
static bool const kStoreZ = true;
|
static bool const kStoreZ = true;
|
||||||
|
|
||||||
/// If true, the 'T' tensor is stored
|
/// If true, the 'T' tensor is stored
|
||||||
static bool const kStoreT = StoreT;
|
static bool const kStoreT = StoreT_;
|
||||||
|
|
||||||
/// Host-constructable parameters structure
|
/// Host-constructable parameters structure
|
||||||
struct Params {
|
struct Params {
|
||||||
|
@ -60,6 +60,7 @@ template <typename ElementOutput_, typename ElementAccumulator_,
|
|||||||
template <typename T> class BinaryOp1_,
|
template <typename T> class BinaryOp1_,
|
||||||
template <typename T> class UnaryOp_,
|
template <typename T> class UnaryOp_,
|
||||||
template <typename T> class BinaryOp2_ = detail::NoOp,
|
template <typename T> class BinaryOp2_ = detail::NoOp,
|
||||||
|
bool StoreT_ = false,
|
||||||
typename ElementVector_ = ElementC_>
|
typename ElementVector_ = ElementC_>
|
||||||
class LinearCombinationResidualBlock {
|
class LinearCombinationResidualBlock {
|
||||||
public:
|
public:
|
||||||
@ -90,7 +91,7 @@ public:
|
|||||||
|
|
||||||
static bool const kIsHeavy = true;
|
static bool const kIsHeavy = true;
|
||||||
static bool const kStoreZ = true;
|
static bool const kStoreZ = true;
|
||||||
static bool const kStoreT = false;
|
static bool const kStoreT = StoreT_;
|
||||||
|
|
||||||
/// Host-constructable parameters structure
|
/// Host-constructable parameters structure
|
||||||
struct Params {
|
struct Params {
|
||||||
@ -182,11 +183,12 @@ template <typename ElementOutput_, typename ElementAccumulator_,
|
|||||||
template <typename T> class ActivationOp_,
|
template <typename T> class ActivationOp_,
|
||||||
template <typename T> class BinaryOp1_,
|
template <typename T> class BinaryOp1_,
|
||||||
template <typename T> class UnaryOp_,
|
template <typename T> class UnaryOp_,
|
||||||
|
bool StoreT_,
|
||||||
typename ElementVector_>
|
typename ElementVector_>
|
||||||
class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,
|
class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,
|
||||||
ElementCompute_, ElementC_, ElementsPerAccess,
|
ElementCompute_, ElementC_, ElementsPerAccess,
|
||||||
ActivationOp_, BinaryOp1_, UnaryOp_,
|
ActivationOp_, BinaryOp1_, UnaryOp_,
|
||||||
detail::NoOp, ElementVector_> {
|
detail::NoOp, StoreT_, ElementVector_> {
|
||||||
public:
|
public:
|
||||||
static bool const kIsSingleSource = true;
|
static bool const kIsSingleSource = true;
|
||||||
|
|
||||||
@ -214,7 +216,7 @@ public:
|
|||||||
|
|
||||||
static bool const kIsHeavy = true;
|
static bool const kIsHeavy = true;
|
||||||
static bool const kStoreZ = true;
|
static bool const kStoreZ = true;
|
||||||
static bool const kStoreT = false;
|
static bool const kStoreT = StoreT_;
|
||||||
|
|
||||||
/// Host-constructable parameters structure
|
/// Host-constructable parameters structure
|
||||||
struct Params {
|
struct Params {
|
||||||
|
Loading…
Reference in New Issue
Block a user