parent
04a9777b87
commit
1eb6355182
@ -98,6 +98,43 @@ struct ReLu<Array<T, N>> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Tanh operator
|
||||||
|
template <typename T>
|
||||||
|
struct Tanh {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
T operator()(T const &scalar) const {
|
||||||
|
return fast_tanh(scalar);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int N>
|
||||||
|
struct Tanh<Array<T, N> > {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||||
|
Array<T, N> y;
|
||||||
|
Tanh<T> tanh_op;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
y[i] = tanh_op(rhs[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
struct Tanh<Array<half_t, N>> {
|
||||||
|
using T = half_t;
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Array<T, N> operator()(Array<T, N> const& z) const {
|
||||||
|
fast_tanh_op<Array<T, N>> tanh;
|
||||||
|
return tanh(z);
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Leaky Relu operator
|
// Leaky Relu operator
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct LeakyReLU {
|
struct LeakyReLU {
|
||||||
|
Loading…
Reference in New Issue
Block a user