fix call of GELU_Taylor in LinearCombinationGeneric (#634)
This commit is contained in:
parent
a821280dc7
commit
9f2e3faa69
@ -575,6 +575,11 @@ struct GELU_taylor {
|
|||||||
|
|
||||||
using Params = LinearCombinationGenericParams<T>;
|
using Params = LinearCombinationGenericParams<T>;
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
T operator()(T const &scalar, Params const ¶ms_) const {
|
||||||
|
return this->operator()(scalar);
|
||||||
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int N>
|
template <int N>
|
||||||
@ -603,6 +608,11 @@ struct GELU_taylor<Array<half_t, N> > {
|
|||||||
}
|
}
|
||||||
|
|
||||||
using Params = LinearCombinationGenericParams<half_t>;
|
using Params = LinearCombinationGenericParams<half_t>;
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Array<half_t, N> operator()(Array<half_t, N> const &rhs, Params const ¶ms_) const {
|
||||||
|
return this->operator()(rhs);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, int N>
|
template <typename T, int N>
|
||||||
@ -622,6 +632,11 @@ struct GELU_taylor<Array<T, N> > {
|
|||||||
}
|
}
|
||||||
|
|
||||||
using Params = LinearCombinationGenericParams<T>;
|
using Params = LinearCombinationGenericParams<T>;
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||||
|
return this->operator()(rhs);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Computes backwards pass for GELU operator assuming d_t is the layer gradient and
|
/// Computes backwards pass for GELU operator assuming d_t is the layer gradient and
|
||||||
|
Loading…
Reference in New Issue
Block a user