epilogue leaky relu support ScaleType (#564)

Co-authored-by: xuweiqi <xuweiqi117@gmail.com>
This commit is contained in:
seventh 2022-07-12 05:30:55 +08:00 committed by GitHub
parent 8a766804ad
commit fb379eaa5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -55,6 +55,7 @@ template <
int Count, ///< Number of elements computed per operation
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
class LinearCombinationLeakyRelu {
@ -65,6 +66,7 @@ public:
using ElementCompute = ElementCompute_;
static int const kCount = Count;
static const ScaleType::Kind kScale = Scale;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
@ -123,6 +125,12 @@ public:
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
if (Scale == ScaleType::NoBetaScaling) return true;
if (Scale == ScaleType::OnlyAlphaScaling) return false;
if (Scale == ScaleType::Nothing) return false;
return beta_bias_ != ElementCompute(0);
}
@ -161,8 +169,15 @@ public:
LeakyReLU<ComputeFragment> leakyrelu;
intermediate = mul_add_source(beta_bias_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
if (Scale == ScaleType::NoBetaScaling) {
intermediate = converted_source;
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
} else if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_add_source(beta_bias_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
}
// Compute threshold optionally
intermediate = leakyrelu(intermediate, leaky_alpha_recip_);
@ -188,7 +203,11 @@ public:
multiplies<ComputeFragment> mul_accumulator;
LeakyReLU<ComputeFragment> leakyrelu;
//printf("in doing with bias");
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
}
// Compute threshold optionally
intermediate = leakyrelu(intermediate, leaky_alpha_recip_);