Fixes to Gelu for half and fusion
This commit is contained in:
parent
c77a524459
commit
41a31b404b
@ -139,7 +139,7 @@ struct GELU {
|
|||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
T operator()(T const &scalar) const {
|
T operator()(T const &scalar) const {
|
||||||
return T(cutlass::constants::half<T>() * scalar *
|
return T(cutlass::constants::half<T>() * scalar *
|
||||||
(cutlass::constants::one<T>() + erff( scalar / cutlass::constants::root_two<T>() )));
|
(cutlass::constants::one<T>() + (T)erff((float)(scalar / cutlass::constants::root_two<T>()))));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -152,6 +152,15 @@ struct GELU<float> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct GELU<double> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
float operator()(double const &scalar) const {
|
||||||
|
return cutlass::constants::half<double>() * scalar *
|
||||||
|
(cutlass::constants::one<double>() + erf( scalar / cutlass::constants::root_two<double>() ));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, int N>
|
template <typename T, int N>
|
||||||
struct GELU<Array<T, N> > {
|
struct GELU<Array<T, N> > {
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
|
@ -133,7 +133,7 @@ public:
|
|||||||
|
|
||||||
/// Functionally required for serial reduction in the epilogue
|
/// Functionally required for serial reduction in the epilogue
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void set_k_partition(int k_partition) {
|
void set_k_partition(int k_partition, int k_partition_count) {
|
||||||
if (k_partition) {
|
if (k_partition) {
|
||||||
beta_ = ElementCompute(1);
|
beta_ = ElementCompute(1);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user