Fix identity sigmoid activation (#659)
* activation support Identity * fix Sigmoid activation operator() with CUTLASS_HOST_DEVICE
This commit is contained in:
parent
168ea8b0e1
commit
06eb90cc0d
@ -49,18 +49,6 @@ namespace cutlass {
|
|||||||
namespace epilogue {
|
namespace epilogue {
|
||||||
namespace thread {
|
namespace thread {
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct Identity {
|
|
||||||
static const bool kIsHeavy=false;
|
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
|
||||||
T operator()(T value) const {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct LinearCombinationGenericParams {
|
struct LinearCombinationGenericParams {
|
||||||
@ -95,6 +83,39 @@ struct LinearCombinationGenericParams {
|
|||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Identity operator
|
||||||
|
template <typename T>
|
||||||
|
struct Identity {
|
||||||
|
static const bool kIsHeavy=false;
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
T operator()(T value) const {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
using Params = LinearCombinationGenericParams<T>;
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
T operator()(T const &value, Params const ¶ms_) const {
|
||||||
|
return this->operator()(value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int N>
|
||||||
|
struct Identity<Array<T, N> > {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||||
|
return rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
using Params = LinearCombinationGenericParams<T>;
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||||
|
return this->operator()(rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// ReLu operator - propagates NaNs
|
/// ReLu operator - propagates NaNs
|
||||||
/// Always put threshold in the right hand side of max to propagate NaN.
|
/// Always put threshold in the right hand side of max to propagate NaN.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
Loading…
Reference in New Issue
Block a user