From 06eb90cc0daae633b1e25e80ace1ef81ac158baa Mon Sep 17 00:00:00 2001 From: seventh <43060508+Xseventh@users.noreply.github.com> Date: Thu, 10 Nov 2022 03:42:23 +0800 Subject: [PATCH] Fix identity sigmoid activation (#659) * activation support Identity * fix Sigmoid activation operator() with CUTLASS_HOST_DEVICE --- include/cutlass/epilogue/thread/activation.h | 45 ++++++++++++++------ 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 0162b6e4..d694ea8f 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -49,18 +49,6 @@ namespace cutlass { namespace epilogue { namespace thread { -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Identity { - static const bool kIsHeavy=false; - - CUTLASS_HOST_DEVICE - T operator()(T value) const { - return value; - } -}; - ///////////////////////////////////////////////////////////////////////////////////////////////// template struct LinearCombinationGenericParams { @@ -95,6 +83,39 @@ struct LinearCombinationGenericParams { ///////////////////////////////////////////////////////////////////////////////////////////////// +// Identity operator +template +struct Identity { + static const bool kIsHeavy=false; + + CUTLASS_HOST_DEVICE + T operator()(T value) const { + return value; + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + T operator()(T const &value, Params const ¶ms_) const { + return this->operator()(value); + } +}; + +template +struct Identity > { + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs) const { + return rhs; + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs, Params const ¶ms_) const { + return this->operator()(rhs); + } +}; + /// ReLu operator - propagates NaNs /// Always put threshold in the right hand side of max to propagate NaN. template