Add Sm90LinCombPerColBias (#1774)
Co-authored-by: Jiayu Sun <jiayus@s4124-0071.nvidia.com>
This commit is contained in:
parent
6c3044136b
commit
7369adcaca
@ -158,6 +158,23 @@ struct LinCombPerRowBiasEltAct
|
||||
static constexpr bool IsEltActSupported = true;
|
||||
};
|
||||
|
||||
// D = alpha * acc + beta * C + per-column bias
|
||||
template<
|
||||
class ElementOutput_,
|
||||
class ElementCompute_,
|
||||
class ElementBias_ = ElementOutput_,
|
||||
class ElementSource_ = ElementOutput_,
|
||||
class ElementScalar_ = ElementCompute_,
|
||||
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
|
||||
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
struct LinCombPerColBias
|
||||
: LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> {
|
||||
using ElementBias = ElementBias_;
|
||||
static constexpr int AlignmentBias = AlignmentBias_;
|
||||
static constexpr bool IsPerColBiasSupported = true;
|
||||
};
|
||||
|
||||
// D = activation(alpha * acc + beta * C + per-row bias)
|
||||
// aux = alpha * acc + beta * C + per-row bias
|
||||
template<
|
||||
|
||||
@ -333,6 +333,90 @@ struct FusionCallbacks<
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// D = alpha * acc + beta * C + per-column bias
|
||||
template<
|
||||
int StagesC,
|
||||
class CtaTileShapeMNK,
|
||||
class EpilogueTile,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementSource = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using Sm90LinCombPerColBias =
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
|
||||
Sm90ScalarBroadcast<ElementScalar>, // beta
|
||||
Sm90SrcFetch<ElementSource>, // C
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
|
||||
Sm90ScalarBroadcast<ElementScalar>, // alpha
|
||||
Sm90AccFetch, // acc
|
||||
Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_0,_1,int>, AlignmentBias> // bias
|
||||
>
|
||||
>;
|
||||
|
||||
template <
|
||||
int StagesC,
|
||||
int StagesD,
|
||||
int FragmentSize,
|
||||
bool ReuseSmemC,
|
||||
bool DelayTmaStore,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias,
|
||||
class ElementSource,
|
||||
class ElementScalar,
|
||||
int AlignmentBias,
|
||||
FloatRoundStyle RoundStyle,
|
||||
class CtaTileShapeMNK,
|
||||
class EpilogueTile
|
||||
>
|
||||
struct FusionCallbacks<
|
||||
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC, DelayTmaStore>,
|
||||
fusion::LinCombPerColBias<ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>,
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile
|
||||
> : Sm90LinCombPerColBias<
|
||||
StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> {
|
||||
using Impl = Sm90LinCombPerColBias<
|
||||
StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>;
|
||||
using Operation = fusion::LinCombPerColBias<
|
||||
ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>;
|
||||
|
||||
struct Arguments {
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
ElementScalar beta = ElementScalar(0);
|
||||
ElementScalar const* alpha_ptr = nullptr;
|
||||
ElementScalar const* beta_ptr = nullptr;
|
||||
|
||||
using StrideBias = Stride<_0,_1,int>;
|
||||
ElementBias const* bias_ptr = nullptr;
|
||||
StrideBias dBias = {};
|
||||
|
||||
operator typename Impl::Arguments() const {
|
||||
return
|
||||
{ // ternary op : beta * C + (alpha * acc + bias)
|
||||
{{beta}, {beta_ptr}}, // leaf args : beta
|
||||
{}, // leaf args : C
|
||||
{ // ternary op : alpha * acc + bias
|
||||
{{alpha}, {alpha_ptr}}, // leaf args : alpha
|
||||
{}, // leaf args : acc
|
||||
{bias_ptr, ElementBias(0), dBias}, // leaf args : bias
|
||||
{} // ternary args : multiply_add
|
||||
}, // end ternary op
|
||||
{} // ternary args : multiply_add
|
||||
}; // end ternary op
|
||||
}
|
||||
};
|
||||
|
||||
// Ctor inheritance
|
||||
using Impl::Impl;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// D = activation(alpha * acc + beta * C + per-row bias)
|
||||
template<
|
||||
class CtaTileShapeMNK,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user