Update linear_combination_generic.h (#472)

add `skip_elementwise_` to support serial splitk in linear_combination_generic.h`
This commit is contained in:
Haicheng Wu 2022-06-28 07:29:38 -04:00 committed by GitHub
parent dae6b6893b
commit e45e773436
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -126,6 +126,7 @@ private:
ElementCompute alpha_;
ElementCompute beta_;
bool skip_elementwise_;
public:
@ -135,6 +136,7 @@ public:
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
skip_elementwise_ = false;
}
/// Returns true if source is needed
@ -155,6 +157,10 @@ public:
if (k_partition) {
beta_ = ElementCompute(1);
}
if (k_partition != k_partition_count - 1) {
skip_elementwise_ = true;
}
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
@ -188,7 +194,7 @@ public:
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
}
intermediate = activation(intermediate);
intermediate = skip_elementwise_ ? intermediate : activation(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
@ -219,7 +225,7 @@ public:
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
}
intermediate = activation(intermediate);
intermediate = skip_elementwise_ ? intermediate : activation(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;