Update linear_combination_generic.h (#472)
add `skip_elementwise_` to support serial splitk in linear_combination_generic.h`
This commit is contained in:
parent
dae6b6893b
commit
e45e773436
@ -126,6 +126,7 @@ private:
|
|||||||
|
|
||||||
ElementCompute alpha_;
|
ElementCompute alpha_;
|
||||||
ElementCompute beta_;
|
ElementCompute beta_;
|
||||||
|
bool skip_elementwise_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
@ -135,6 +136,7 @@ public:
|
|||||||
|
|
||||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||||
|
skip_elementwise_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns true if source is needed
|
/// Returns true if source is needed
|
||||||
@ -155,6 +157,10 @@ public:
|
|||||||
if (k_partition) {
|
if (k_partition) {
|
||||||
beta_ = ElementCompute(1);
|
beta_ = ElementCompute(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (k_partition != k_partition_count - 1) {
|
||||||
|
skip_elementwise_ = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||||
@ -188,7 +194,7 @@ public:
|
|||||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||||
}
|
}
|
||||||
|
|
||||||
intermediate = activation(intermediate);
|
intermediate = skip_elementwise_ ? intermediate : activation(intermediate);
|
||||||
|
|
||||||
// Convert to destination numeric type
|
// Convert to destination numeric type
|
||||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||||
@ -219,7 +225,7 @@ public:
|
|||||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||||
}
|
}
|
||||||
|
|
||||||
intermediate = activation(intermediate);
|
intermediate = skip_elementwise_ ? intermediate : activation(intermediate);
|
||||||
|
|
||||||
// Convert to destination numeric type
|
// Convert to destination numeric type
|
||||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||||
|
Loading…
Reference in New Issue
Block a user