diff --git a/include/cutlass/epilogue/thread/linear_combination_relu.h b/include/cutlass/epilogue/thread/linear_combination_relu.h index d545a78a..2f40cf18 100644 --- a/include/cutlass/epilogue/thread/linear_combination_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_relu.h @@ -357,9 +357,11 @@ public: ReLu relu; if (Scale == ScaleType::NoBetaScaling) - intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = converted_source; else - intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform + + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X // Compute threshold optionally intermediate = relu(threshold_, intermediate);