Add Sm90LinCombPerColBias (#1774)

Co-authored-by: Jiayu Sun <jiayus@s4124-0071.nvidia.com>
This commit is contained in:
JiayuSun 2024-09-05 03:11:24 +08:00 committed by GitHub
parent 6c3044136b
commit 7369adcaca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 101 additions and 0 deletions

View File

@ -158,6 +158,23 @@ struct LinCombPerRowBiasEltAct
static constexpr bool IsEltActSupported = true;
};
// D = alpha * acc + beta * C + per-column bias
template<
class ElementOutput_,
class ElementCompute_,
class ElementBias_ = ElementOutput_,
class ElementSource_ = ElementOutput_,
class ElementScalar_ = ElementCompute_,
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
>
struct LinCombPerColBias
: LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> {
using ElementBias = ElementBias_;
static constexpr int AlignmentBias = AlignmentBias_;
static constexpr bool IsPerColBiasSupported = true;
};
// D = activation(alpha * acc + beta * C + per-row bias)
// aux = alpha * acc + beta * C + per-row bias
template<

View File

@ -333,6 +333,90 @@ struct FusionCallbacks<
/////////////////////////////////////////////////////////////////////////////////////////////////
// D = alpha * acc + beta * C + per-column bias
template<
int StagesC,
class CtaTileShapeMNK,
class EpilogueTile,
class ElementOutput,
class ElementCompute,
class ElementBias = ElementOutput,
class ElementSource = ElementOutput,
class ElementScalar = ElementCompute,
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
>
using Sm90LinCombPerColBias =
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
Sm90ScalarBroadcast<ElementScalar>, // beta
Sm90SrcFetch<ElementSource>, // C
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
Sm90ScalarBroadcast<ElementScalar>, // alpha
Sm90AccFetch, // acc
Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_0,_1,int>, AlignmentBias> // bias
>
>;
template <
int StagesC,
int StagesD,
int FragmentSize,
bool ReuseSmemC,
bool DelayTmaStore,
class ElementOutput,
class ElementCompute,
class ElementBias,
class ElementSource,
class ElementScalar,
int AlignmentBias,
FloatRoundStyle RoundStyle,
class CtaTileShapeMNK,
class EpilogueTile
>
struct FusionCallbacks<
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC, DelayTmaStore>,
fusion::LinCombPerColBias<ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>,
CtaTileShapeMNK,
EpilogueTile
> : Sm90LinCombPerColBias<
StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> {
using Impl = Sm90LinCombPerColBias<
StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>;
using Operation = fusion::LinCombPerColBias<
ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>;
struct Arguments {
ElementScalar alpha = ElementScalar(1);
ElementScalar beta = ElementScalar(0);
ElementScalar const* alpha_ptr = nullptr;
ElementScalar const* beta_ptr = nullptr;
using StrideBias = Stride<_0,_1,int>;
ElementBias const* bias_ptr = nullptr;
StrideBias dBias = {};
operator typename Impl::Arguments() const {
return
{ // ternary op : beta * C + (alpha * acc + bias)
{{beta}, {beta_ptr}}, // leaf args : beta
{}, // leaf args : C
{ // ternary op : alpha * acc + bias
{{alpha}, {alpha_ptr}}, // leaf args : alpha
{}, // leaf args : acc
{bias_ptr, ElementBias(0), dBias}, // leaf args : bias
{} // ternary args : multiply_add
}, // end ternary op
{} // ternary args : multiply_add
}; // end ternary op
}
};
// Ctor inheritance
using Impl::Impl;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// D = activation(alpha * acc + beta * C + per-row bias)
template<
class CtaTileShapeMNK,