diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 0162b6e4..d694ea8f 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -49,18 +49,6 @@ namespace cutlass { namespace epilogue { namespace thread { -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Identity { - static const bool kIsHeavy=false; - - CUTLASS_HOST_DEVICE - T operator()(T value) const { - return value; - } -}; - ///////////////////////////////////////////////////////////////////////////////////////////////// template struct LinearCombinationGenericParams { @@ -95,6 +83,39 @@ struct LinearCombinationGenericParams { ///////////////////////////////////////////////////////////////////////////////////////////////// +// Identity operator +template +struct Identity { + static const bool kIsHeavy=false; + + CUTLASS_HOST_DEVICE + T operator()(T value) const { + return value; + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + T operator()(T const &value, Params const ¶ms_) const { + return this->operator()(value); + } +}; + +template +struct Identity > { + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs) const { + return rhs; + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs, Params const ¶ms_) const { + return this->operator()(rhs); + } +}; + /// ReLu operator - propagates NaNs /// Always put threshold in the right hand side of max to propagate NaN. template