From e45e77343693e261f9285ee05be4f0498848e5a5 Mon Sep 17 00:00:00 2001 From: Haicheng Wu <57973641+hwu36@users.noreply.github.com> Date: Tue, 28 Jun 2022 07:29:38 -0400 Subject: [PATCH] Update linear_combination_generic.h (#472) add `skip_elementwise_` to support serial splitk in linear_combination_generic.h` --- .../epilogue/thread/linear_combination_generic.h | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/include/cutlass/epilogue/thread/linear_combination_generic.h b/include/cutlass/epilogue/thread/linear_combination_generic.h index eeb67cc2..d43ce5c4 100644 --- a/include/cutlass/epilogue/thread/linear_combination_generic.h +++ b/include/cutlass/epilogue/thread/linear_combination_generic.h @@ -126,6 +126,7 @@ private: ElementCompute alpha_; ElementCompute beta_; + bool skip_elementwise_; public: @@ -135,6 +136,7 @@ public: alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + skip_elementwise_ = false; } /// Returns true if source is needed @@ -155,6 +157,10 @@ public: if (k_partition) { beta_ = ElementCompute(1); } + + if (k_partition != k_partition_count - 1) { + skip_elementwise_ = true; + } } /// Computes linear scaling: D = alpha * accumulator + beta * source @@ -188,7 +194,7 @@ public: intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X } - intermediate = activation(intermediate); + intermediate = skip_elementwise_ ? intermediate : activation(intermediate); // Convert to destination numeric type NumericArrayConverter destination_converter; @@ -219,7 +225,7 @@ public: intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum } - intermediate = activation(intermediate); + intermediate = skip_elementwise_ ? intermediate : activation(intermediate); // Convert to destination numeric type NumericArrayConverter destination_converter;