fix call of GELU_Taylor in LinearCombinationGeneric (#634)
This commit is contained in:
parent
a821280dc7
commit
9f2e3faa69
@ -575,6 +575,11 @@ struct GELU_taylor {
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &scalar, Params const ¶ms_) const {
|
||||
return this->operator()(scalar);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <int N>
|
||||
@ -603,6 +608,11 @@ struct GELU_taylor<Array<half_t, N> > {
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<half_t>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<half_t, N> operator()(Array<half_t, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
@ -622,6 +632,11 @@ struct GELU_taylor<Array<T, N> > {
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
/// Computes backwards pass for GELU operator assuming d_t is the layer gradient and
|
||||
|
Loading…
Reference in New Issue
Block a user