diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index ca1f72c2..9763f5fc 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -98,6 +98,43 @@ struct ReLu> { } }; +// Tanh operator +template +struct Tanh { + CUTLASS_HOST_DEVICE + T operator()(T const &scalar) const { + return fast_tanh(scalar); + } +}; + +template +struct Tanh > { + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs) const { + Array y; + Tanh tanh_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + y[i] = tanh_op(rhs[i]); + } + + return y; + } +}; + +template +struct Tanh> { + using T = half_t; + + CUTLASS_HOST_DEVICE + Array operator()(Array const& z) const { + fast_tanh_op> tanh; + return tanh(z); + + } +}; + // Leaky Relu operator template struct LeakyReLU {