parent
04a9777b87
commit
1eb6355182
@ -98,6 +98,43 @@ struct ReLu<Array<T, N>> {
|
||||
}
|
||||
};
|
||||
|
||||
// Tanh operator
|
||||
template <typename T>
|
||||
struct Tanh {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &scalar) const {
|
||||
return fast_tanh(scalar);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct Tanh<Array<T, N> > {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||
Array<T, N> y;
|
||||
Tanh<T> tanh_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
y[i] = tanh_op(rhs[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct Tanh<Array<half_t, N>> {
|
||||
using T = half_t;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const& z) const {
|
||||
fast_tanh_op<Array<T, N>> tanh;
|
||||
return tanh(z);
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
// Leaky Relu operator
|
||||
template <typename T>
|
||||
struct LeakyReLU {
|
||||
|
Loading…
Reference in New Issue
Block a user