epilogue leaky relu support ScaleType (#564)
Co-authored-by: xuweiqi <xuweiqi117@gmail.com>
This commit is contained in:
parent
8a766804ad
commit
fb379eaa5b
@ -55,6 +55,7 @@ template <
|
|||||||
int Count, ///< Number of elements computed per operation
|
int Count, ///< Number of elements computed per operation
|
||||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||||
|
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
||||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||||
>
|
>
|
||||||
class LinearCombinationLeakyRelu {
|
class LinearCombinationLeakyRelu {
|
||||||
@ -65,6 +66,7 @@ public:
|
|||||||
using ElementCompute = ElementCompute_;
|
using ElementCompute = ElementCompute_;
|
||||||
|
|
||||||
static int const kCount = Count;
|
static int const kCount = Count;
|
||||||
|
static const ScaleType::Kind kScale = Scale;
|
||||||
|
|
||||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||||
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
||||||
@ -123,6 +125,12 @@ public:
|
|||||||
/// Returns true if source is needed
|
/// Returns true if source is needed
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
bool is_source_needed() const {
|
bool is_source_needed() const {
|
||||||
|
if (Scale == ScaleType::NoBetaScaling) return true;
|
||||||
|
|
||||||
|
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
||||||
|
|
||||||
|
if (Scale == ScaleType::Nothing) return false;
|
||||||
|
|
||||||
return beta_bias_ != ElementCompute(0);
|
return beta_bias_ != ElementCompute(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,8 +169,15 @@ public:
|
|||||||
|
|
||||||
LeakyReLU<ComputeFragment> leakyrelu;
|
LeakyReLU<ComputeFragment> leakyrelu;
|
||||||
|
|
||||||
|
if (Scale == ScaleType::NoBetaScaling) {
|
||||||
|
intermediate = converted_source;
|
||||||
|
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||||
|
} else if (Scale == ScaleType::Nothing) {
|
||||||
|
intermediate = converted_accumulator;
|
||||||
|
} else {
|
||||||
intermediate = mul_add_source(beta_bias_, converted_source); // X = beta * C + uniform
|
intermediate = mul_add_source(beta_bias_, converted_source); // X = beta * C + uniform
|
||||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||||
|
}
|
||||||
// Compute threshold optionally
|
// Compute threshold optionally
|
||||||
intermediate = leakyrelu(intermediate, leaky_alpha_recip_);
|
intermediate = leakyrelu(intermediate, leaky_alpha_recip_);
|
||||||
|
|
||||||
@ -188,7 +203,11 @@ public:
|
|||||||
multiplies<ComputeFragment> mul_accumulator;
|
multiplies<ComputeFragment> mul_accumulator;
|
||||||
LeakyReLU<ComputeFragment> leakyrelu;
|
LeakyReLU<ComputeFragment> leakyrelu;
|
||||||
//printf("in doing with bias");
|
//printf("in doing with bias");
|
||||||
|
if (Scale == ScaleType::Nothing) {
|
||||||
|
intermediate = converted_accumulator;
|
||||||
|
} else {
|
||||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||||
|
}
|
||||||
|
|
||||||
// Compute threshold optionally
|
// Compute threshold optionally
|
||||||
intermediate = leakyrelu(intermediate, leaky_alpha_recip_);
|
intermediate = leakyrelu(intermediate, leaky_alpha_recip_);
|
||||||
|
Loading…
Reference in New Issue
Block a user