Add Sm90LinCombPerColBias (#1774)
Co-authored-by: Jiayu Sun <jiayus@s4124-0071.nvidia.com>
This commit is contained in:
parent
6c3044136b
commit
7369adcaca
@ -158,6 +158,23 @@ struct LinCombPerRowBiasEltAct
|
|||||||
static constexpr bool IsEltActSupported = true;
|
static constexpr bool IsEltActSupported = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// D = alpha * acc + beta * C + per-column bias
|
||||||
|
template<
|
||||||
|
class ElementOutput_,
|
||||||
|
class ElementCompute_,
|
||||||
|
class ElementBias_ = ElementOutput_,
|
||||||
|
class ElementSource_ = ElementOutput_,
|
||||||
|
class ElementScalar_ = ElementCompute_,
|
||||||
|
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
|
||||||
|
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
|
||||||
|
>
|
||||||
|
struct LinCombPerColBias
|
||||||
|
: LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> {
|
||||||
|
using ElementBias = ElementBias_;
|
||||||
|
static constexpr int AlignmentBias = AlignmentBias_;
|
||||||
|
static constexpr bool IsPerColBiasSupported = true;
|
||||||
|
};
|
||||||
|
|
||||||
// D = activation(alpha * acc + beta * C + per-row bias)
|
// D = activation(alpha * acc + beta * C + per-row bias)
|
||||||
// aux = alpha * acc + beta * C + per-row bias
|
// aux = alpha * acc + beta * C + per-row bias
|
||||||
template<
|
template<
|
||||||
|
|||||||
@ -333,6 +333,90 @@ struct FusionCallbacks<
|
|||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// D = alpha * acc + beta * C + per-column bias
|
||||||
|
template<
|
||||||
|
int StagesC,
|
||||||
|
class CtaTileShapeMNK,
|
||||||
|
class EpilogueTile,
|
||||||
|
class ElementOutput,
|
||||||
|
class ElementCompute,
|
||||||
|
class ElementBias = ElementOutput,
|
||||||
|
class ElementSource = ElementOutput,
|
||||||
|
class ElementScalar = ElementCompute,
|
||||||
|
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
|
||||||
|
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||||
|
>
|
||||||
|
using Sm90LinCombPerColBias =
|
||||||
|
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
|
||||||
|
Sm90ScalarBroadcast<ElementScalar>, // beta
|
||||||
|
Sm90SrcFetch<ElementSource>, // C
|
||||||
|
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
|
||||||
|
Sm90ScalarBroadcast<ElementScalar>, // alpha
|
||||||
|
Sm90AccFetch, // acc
|
||||||
|
Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_0,_1,int>, AlignmentBias> // bias
|
||||||
|
>
|
||||||
|
>;
|
||||||
|
|
||||||
|
template <
|
||||||
|
int StagesC,
|
||||||
|
int StagesD,
|
||||||
|
int FragmentSize,
|
||||||
|
bool ReuseSmemC,
|
||||||
|
bool DelayTmaStore,
|
||||||
|
class ElementOutput,
|
||||||
|
class ElementCompute,
|
||||||
|
class ElementBias,
|
||||||
|
class ElementSource,
|
||||||
|
class ElementScalar,
|
||||||
|
int AlignmentBias,
|
||||||
|
FloatRoundStyle RoundStyle,
|
||||||
|
class CtaTileShapeMNK,
|
||||||
|
class EpilogueTile
|
||||||
|
>
|
||||||
|
struct FusionCallbacks<
|
||||||
|
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC, DelayTmaStore>,
|
||||||
|
fusion::LinCombPerColBias<ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>,
|
||||||
|
CtaTileShapeMNK,
|
||||||
|
EpilogueTile
|
||||||
|
> : Sm90LinCombPerColBias<
|
||||||
|
StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> {
|
||||||
|
using Impl = Sm90LinCombPerColBias<
|
||||||
|
StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>;
|
||||||
|
using Operation = fusion::LinCombPerColBias<
|
||||||
|
ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>;
|
||||||
|
|
||||||
|
struct Arguments {
|
||||||
|
ElementScalar alpha = ElementScalar(1);
|
||||||
|
ElementScalar beta = ElementScalar(0);
|
||||||
|
ElementScalar const* alpha_ptr = nullptr;
|
||||||
|
ElementScalar const* beta_ptr = nullptr;
|
||||||
|
|
||||||
|
using StrideBias = Stride<_0,_1,int>;
|
||||||
|
ElementBias const* bias_ptr = nullptr;
|
||||||
|
StrideBias dBias = {};
|
||||||
|
|
||||||
|
operator typename Impl::Arguments() const {
|
||||||
|
return
|
||||||
|
{ // ternary op : beta * C + (alpha * acc + bias)
|
||||||
|
{{beta}, {beta_ptr}}, // leaf args : beta
|
||||||
|
{}, // leaf args : C
|
||||||
|
{ // ternary op : alpha * acc + bias
|
||||||
|
{{alpha}, {alpha_ptr}}, // leaf args : alpha
|
||||||
|
{}, // leaf args : acc
|
||||||
|
{bias_ptr, ElementBias(0), dBias}, // leaf args : bias
|
||||||
|
{} // ternary args : multiply_add
|
||||||
|
}, // end ternary op
|
||||||
|
{} // ternary args : multiply_add
|
||||||
|
}; // end ternary op
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Ctor inheritance
|
||||||
|
using Impl::Impl;
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
// D = activation(alpha * acc + beta * C + per-row bias)
|
// D = activation(alpha * acc + beta * C + per-row bias)
|
||||||
template<
|
template<
|
||||||
class CtaTileShapeMNK,
|
class CtaTileShapeMNK,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user