Add more GMMA shapes (#1630)
* Add more GMMA shapes * Add more shapes for BF16
This commit is contained in:
parent
be60a0b272
commit
5b283c872c
@ -556,18 +556,42 @@ ss_op_selector()
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 240 == 0) {
|
||||
return SM90_64x240x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 224 == 0) {
|
||||
return SM90_64x224x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 176 == 0) {
|
||||
return SM90_64x176x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 160 == 0) {
|
||||
return SM90_64x160x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 144 == 0) {
|
||||
return SM90_64x144x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 112 == 0) {
|
||||
return SM90_64x112x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 80 == 0) {
|
||||
return SM90_64x80x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 48 == 0) {
|
||||
return SM90_64x48x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
@ -590,18 +614,42 @@ ss_op_selector()
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 240 == 0) {
|
||||
return SM90_64x240x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 224 == 0) {
|
||||
return SM90_64x224x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 176 == 0) {
|
||||
return SM90_64x176x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 160 == 0) {
|
||||
return SM90_64x160x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 144 == 0) {
|
||||
return SM90_64x144x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 112 == 0) {
|
||||
return SM90_64x112x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 80 == 0) {
|
||||
return SM90_64x80x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 48 == 0) {
|
||||
return SM90_64x48x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
@ -1011,18 +1059,42 @@ rs_op_selector()
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 240 == 0) {
|
||||
return SM90_64x240x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 224 == 0) {
|
||||
return SM90_64x224x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 176 == 0) {
|
||||
return SM90_64x176x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 160 == 0) {
|
||||
return SM90_64x160x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 144 == 0) {
|
||||
return SM90_64x144x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 112 == 0) {
|
||||
return SM90_64x112x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 80 == 0) {
|
||||
return SM90_64x80x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 48 == 0) {
|
||||
return SM90_64x48x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
@ -1045,18 +1117,42 @@ rs_op_selector()
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 240 == 0) {
|
||||
return SM90_64x240x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 224 == 0) {
|
||||
return SM90_64x224x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 176 == 0) {
|
||||
return SM90_64x176x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 160 == 0) {
|
||||
return SM90_64x160x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 144 == 0) {
|
||||
return SM90_64x144x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 112 == 0) {
|
||||
return SM90_64x112x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 80 == 0) {
|
||||
return SM90_64x80x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 48 == 0) {
|
||||
return SM90_64x48x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -392,18 +392,42 @@ using CLayout_64x16 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _2>>,
|
||||
using CLayout_64x32 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _4>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x48 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _6>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x64 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _8>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x80 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _10>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x96 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _12>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x112 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, Int<14>>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x128 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _16>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x144 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, Int<18>>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x160 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, Int<20>>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x176 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, Int<22>>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x192 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _24>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x224 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, Int<28>>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x240 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, Int<30>>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x256 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _32>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
@ -898,6 +922,49 @@ struct MMA_Traits<SM90_64x32x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x48x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<48>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 48, 16>;
|
||||
using CLayout = GMMA::CLayout_64x48;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x48x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<48>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ALayout_64x16;
|
||||
using BLayout = GMMA::ABLayout< 48, 16>;
|
||||
using CLayout = GMMA::CLayout_64x48;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x64x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
@ -941,6 +1008,49 @@ struct MMA_Traits<SM90_64x64x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x80x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<80>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 80, 16>;
|
||||
using CLayout = GMMA::CLayout_64x80;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x80x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<80>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ALayout_64x16;
|
||||
using BLayout = GMMA::ABLayout< 80, 16>;
|
||||
using CLayout = GMMA::CLayout_64x80;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x96x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
@ -984,6 +1094,49 @@ struct MMA_Traits<SM90_64x96x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x112x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<112>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 112, 16>;
|
||||
using CLayout = GMMA::CLayout_64x112;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x112x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<112>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ALayout_64x16;
|
||||
using BLayout = GMMA::ABLayout< 112, 16>;
|
||||
using CLayout = GMMA::CLayout_64x112;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x128x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
@ -1027,6 +1180,135 @@ struct MMA_Traits<SM90_64x128x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x144x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<144>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 144, 16>;
|
||||
using CLayout = GMMA::CLayout_64x144;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x144x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<144>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ALayout_64x16;
|
||||
using BLayout = GMMA::ABLayout< 144, 16>;
|
||||
using CLayout = GMMA::CLayout_64x144;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x160x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<160>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 160, 16>;
|
||||
using CLayout = GMMA::CLayout_64x160;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x160x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<160>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ALayout_64x16;
|
||||
using BLayout = GMMA::ABLayout< 160, 16>;
|
||||
using CLayout = GMMA::CLayout_64x160;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x176x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<176>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 176, 16>;
|
||||
using CLayout = GMMA::CLayout_64x176;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x176x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<176>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ALayout_64x16;
|
||||
using BLayout = GMMA::ABLayout< 176, 16>;
|
||||
using CLayout = GMMA::CLayout_64x176;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x192x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
@ -1070,6 +1352,92 @@ struct MMA_Traits<SM90_64x192x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x224x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<224>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 224, 16>;
|
||||
using CLayout = GMMA::CLayout_64x224;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x224x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<224>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ALayout_64x16;
|
||||
using BLayout = GMMA::ABLayout< 224, 16>;
|
||||
using CLayout = GMMA::CLayout_64x224;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x240x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<240>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 240, 16>;
|
||||
using CLayout = GMMA::CLayout_64x240;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x240x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<240>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ALayout_64x16;
|
||||
using BLayout = GMMA::ABLayout< 240, 16>;
|
||||
using CLayout = GMMA::CLayout_64x240;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x256x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
@ -1242,6 +1610,49 @@ struct MMA_Traits<SM90_64x32x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x48x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<48>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 48, 16>;
|
||||
using CLayout = GMMA::CLayout_64x48;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x48x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<48>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ALayout_64x16;
|
||||
using BLayout = GMMA::ABLayout< 48, 16>;
|
||||
using CLayout = GMMA::CLayout_64x48;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x64x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
@ -1285,6 +1696,49 @@ struct MMA_Traits<SM90_64x64x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x80x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<80>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 80, 16>;
|
||||
using CLayout = GMMA::CLayout_64x80;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x80x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<80>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ALayout_64x16;
|
||||
using BLayout = GMMA::ABLayout< 80, 16>;
|
||||
using CLayout = GMMA::CLayout_64x80;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x96x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
@ -1328,6 +1782,49 @@ struct MMA_Traits<SM90_64x96x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x112x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<112>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 112, 16>;
|
||||
using CLayout = GMMA::CLayout_64x112;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x112x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<112>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ALayout_64x16;
|
||||
using BLayout = GMMA::ABLayout< 112, 16>;
|
||||
using CLayout = GMMA::CLayout_64x112;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x128x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
@ -1371,6 +1868,135 @@ struct MMA_Traits<SM90_64x128x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x144x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<144>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 144, 16>;
|
||||
using CLayout = GMMA::CLayout_64x144;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x144x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<144>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ALayout_64x16;
|
||||
using BLayout = GMMA::ABLayout< 144, 16>;
|
||||
using CLayout = GMMA::CLayout_64x144;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x160x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<160>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 160, 16>;
|
||||
using CLayout = GMMA::CLayout_64x160;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x160x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<160>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ALayout_64x16;
|
||||
using BLayout = GMMA::ABLayout< 160, 16>;
|
||||
using CLayout = GMMA::CLayout_64x160;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x176x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<176>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 176, 16>;
|
||||
using CLayout = GMMA::CLayout_64x176;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x176x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<176>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ALayout_64x16;
|
||||
using BLayout = GMMA::ABLayout< 176, 16>;
|
||||
using CLayout = GMMA::CLayout_64x176;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x192x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
@ -1414,6 +2040,92 @@ struct MMA_Traits<SM90_64x192x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x224x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<224>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 224, 16>;
|
||||
using CLayout = GMMA::CLayout_64x224;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x224x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<224>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ALayout_64x16;
|
||||
using BLayout = GMMA::ABLayout< 224, 16>;
|
||||
using CLayout = GMMA::CLayout_64x224;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x240x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<240>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 240, 16>;
|
||||
using CLayout = GMMA::CLayout_64x240;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x240x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<240>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ALayout_64x16;
|
||||
using BLayout = GMMA::ABLayout< 240, 16>;
|
||||
using CLayout = GMMA::CLayout_64x240;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x256x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
|
||||
Loading…
Reference in New Issue
Block a user