Fixes to Gelu for half and fusion
This commit is contained in:
parent
c77a524459
commit
41a31b404b
@ -139,7 +139,7 @@ struct GELU {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &scalar) const {
|
||||
return T(cutlass::constants::half<T>() * scalar *
|
||||
(cutlass::constants::one<T>() + erff( scalar / cutlass::constants::root_two<T>() )));
|
||||
(cutlass::constants::one<T>() + (T)erff((float)(scalar / cutlass::constants::root_two<T>()))));
|
||||
}
|
||||
};
|
||||
|
||||
@ -152,6 +152,15 @@ struct GELU<float> {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GELU<double> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
float operator()(double const &scalar) const {
|
||||
return cutlass::constants::half<double>() * scalar *
|
||||
(cutlass::constants::one<double>() + erf( scalar / cutlass::constants::root_two<double>() ));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct GELU<Array<T, N> > {
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
@ -133,7 +133,7 @@ public:
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition) {
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user