From fb379eaa5bf05fe599f03db79dfe488aee478a38 Mon Sep 17 00:00:00 2001 From: seventh <43060508+Xseventh@users.noreply.github.com> Date: Tue, 12 Jul 2022 05:30:55 +0800 Subject: [PATCH] epilogue leaky relu support ScaleType (#564) Co-authored-by: xuweiqi --- .../thread/linear_combination_leaky_relu.h | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h b/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h index 95150ef4..111fc53e 100644 --- a/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h @@ -55,6 +55,7 @@ template < int Count, ///< Number of elements computed per operation typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling FloatRoundStyle Round = FloatRoundStyle::round_to_nearest > class LinearCombinationLeakyRelu { @@ -65,6 +66,7 @@ public: using ElementCompute = ElementCompute_; static int const kCount = Count; + static const ScaleType::Kind kScale = Scale; using FragmentOutput = Array; using FragmentAccumulator = Array; @@ -123,6 +125,12 @@ public: /// Returns true if source is needed CUTLASS_HOST_DEVICE bool is_source_needed() const { + if (Scale == ScaleType::NoBetaScaling) return true; + + if (Scale == ScaleType::OnlyAlphaScaling) return false; + + if (Scale == ScaleType::Nothing) return false; + return beta_bias_ != ElementCompute(0); } @@ -161,8 +169,15 @@ public: LeakyReLU leakyrelu; - intermediate = mul_add_source(beta_bias_, converted_source); // X = beta * C + uniform - intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + if (Scale == ScaleType::NoBetaScaling) { + intermediate = converted_source; + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } else if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = mul_add_source(beta_bias_, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + } // Compute threshold optionally intermediate = leakyrelu(intermediate, leaky_alpha_recip_); @@ -188,7 +203,11 @@ public: multiplies mul_accumulator; LeakyReLU leakyrelu; //printf("in doing with bias"); - intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + if (Scale == ScaleType::Nothing) { + intermediate = converted_accumulator; + } else { + intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + } // Compute threshold optionally intermediate = leakyrelu(intermediate, leaky_alpha_recip_);