fix call of GELU_Taylor in LinearCombinationGeneric (#634)

This commit is contained in:
Tianqi Zhang (张天启) 2022-09-21 09:00:55 +08:00 committed by GitHub
parent a821280dc7
commit 9f2e3faa69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -575,6 +575,11 @@ struct GELU_taylor {
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
T operator()(T const &scalar, Params const &params_) const {
return this->operator()(scalar);
}
};
template <int N>
@ -603,6 +608,11 @@ struct GELU_taylor<Array<half_t, N> > {
}
using Params = LinearCombinationGenericParams<half_t>;
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(Array<half_t, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
}
};
template <typename T, int N>
@ -622,6 +632,11 @@ struct GELU_taylor<Array<T, N> > {
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
}
};
/// Computes backwards pass for GELU operator assuming d_t is the layer gradient and