From 9f2e3faa69920d9c8ea81465f5a22bd2ffc01b51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tianqi=20Zhang=20=28=E5=BC=A0=E5=A4=A9=E5=90=AF=29?= Date: Wed, 21 Sep 2022 09:00:55 +0800 Subject: [PATCH] fix call of GELU_Taylor in LinearCombinationGeneric (#634) --- include/cutlass/epilogue/thread/activation.h | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 4854f3ad..5f8ba63e 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -575,6 +575,11 @@ struct GELU_taylor { using Params = LinearCombinationGenericParams; + CUTLASS_HOST_DEVICE + T operator()(T const &scalar, Params const ¶ms_) const { + return this->operator()(scalar); + } + }; template @@ -603,6 +608,11 @@ struct GELU_taylor > { } using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs, Params const ¶ms_) const { + return this->operator()(rhs); + } }; template @@ -622,6 +632,11 @@ struct GELU_taylor > { } using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs, Params const ¶ms_) const { + return this->operator()(rhs); + } }; /// Computes backwards pass for GELU operator assuming d_t is the layer gradient and