Update linear_combination_generic.h (#472)
add `skip_elementwise_` to support serial splitk in linear_combination_generic.h`
This commit is contained in:
parent
dae6b6893b
commit
e45e773436
@ -126,6 +126,7 @@ private:
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
bool skip_elementwise_;
|
||||
|
||||
public:
|
||||
|
||||
@ -135,6 +136,7 @@ public:
|
||||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
skip_elementwise_ = false;
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
@ -155,6 +157,10 @@ public:
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
if (k_partition != k_partition_count - 1) {
|
||||
skip_elementwise_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
@ -188,7 +194,7 @@ public:
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
}
|
||||
|
||||
intermediate = activation(intermediate);
|
||||
intermediate = skip_elementwise_ ? intermediate : activation(intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
@ -219,7 +225,7 @@ public:
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
}
|
||||
|
||||
intermediate = activation(intermediate);
|
||||
intermediate = skip_elementwise_ ? intermediate : activation(intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
Loading…
Reference in New Issue
Block a user