diff --git a/include/cutlass/epilogue/thread/linear_combination_generic.h b/include/cutlass/epilogue/thread/linear_combination_generic.h index eeb67cc2..d43ce5c4 100644 --- a/include/cutlass/epilogue/thread/linear_combination_generic.h +++ b/include/cutlass/epilogue/thread/linear_combination_generic.h @@ -126,6 +126,7 @@ private: ElementCompute alpha_; ElementCompute beta_; + bool skip_elementwise_; public: @@ -135,6 +136,7 @@ public: alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + skip_elementwise_ = false; } /// Returns true if source is needed @@ -155,6 +157,10 @@ public: if (k_partition) { beta_ = ElementCompute(1); } + + if (k_partition != k_partition_count - 1) { + skip_elementwise_ = true; + } } /// Computes linear scaling: D = alpha * accumulator + beta * source @@ -188,7 +194,7 @@ public: intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X } - intermediate = activation(intermediate); + intermediate = skip_elementwise_ ? intermediate : activation(intermediate); // Convert to destination numeric type NumericArrayConverter destination_converter; @@ -219,7 +225,7 @@ public: intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum } - intermediate = activation(intermediate); + intermediate = skip_elementwise_ ? intermediate : activation(intermediate); // Convert to destination numeric type NumericArrayConverter destination_converter;