Fix identity sigmoid activation (#659)

* activation support Identity

* fix Sigmoid activation operator() with CUTLASS_HOST_DEVICE
This commit is contained in:
seventh 2022-11-10 03:42:23 +08:00 committed by GitHub
parent 168ea8b0e1
commit 06eb90cc0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -49,18 +49,6 @@ namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct Identity {
static const bool kIsHeavy=false;
CUTLASS_HOST_DEVICE
T operator()(T value) const {
return value;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct LinearCombinationGenericParams {
@ -95,6 +83,39 @@ struct LinearCombinationGenericParams {
/////////////////////////////////////////////////////////////////////////////////////////////////
// Identity operator
template <typename T>
struct Identity {
static const bool kIsHeavy=false;
CUTLASS_HOST_DEVICE
T operator()(T value) const {
return value;
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
T operator()(T const &value, Params const &params_) const {
return this->operator()(value);
}
};
template <typename T, int N>
struct Identity<Array<T, N> > {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
return rhs;
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
}
};
/// ReLu operator - propagates NaNs
/// Always put threshold in the right hand side of max to propagate NaN.
template <typename T>