Add more GMMA shapes (#1630)

* Add more GMMA shapes

* Add more shapes for BF16
This commit is contained in:
Tri Dao 2024-07-29 16:09:51 -07:00 committed by GitHub
parent be60a0b272
commit 5b283c872c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 3616 additions and 0 deletions

View File

@ -556,18 +556,42 @@ ss_op_selector()
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 240 == 0) {
return SM90_64x240x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 224 == 0) {
return SM90_64x224x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 176 == 0) {
return SM90_64x176x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 160 == 0) {
return SM90_64x160x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 144 == 0) {
return SM90_64x144x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 112 == 0) {
return SM90_64x112x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 80 == 0) {
return SM90_64x80x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 48 == 0) {
return SM90_64x48x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
@ -590,18 +614,42 @@ ss_op_selector()
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 240 == 0) {
return SM90_64x240x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 224 == 0) {
return SM90_64x224x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 176 == 0) {
return SM90_64x176x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 160 == 0) {
return SM90_64x160x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 144 == 0) {
return SM90_64x144x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 112 == 0) {
return SM90_64x112x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 80 == 0) {
return SM90_64x80x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 48 == 0) {
return SM90_64x48x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
@ -1011,18 +1059,42 @@ rs_op_selector()
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 240 == 0) {
return SM90_64x240x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 224 == 0) {
return SM90_64x224x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 176 == 0) {
return SM90_64x176x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 160 == 0) {
return SM90_64x160x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 144 == 0) {
return SM90_64x144x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 112 == 0) {
return SM90_64x112x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 80 == 0) {
return SM90_64x80x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 48 == 0) {
return SM90_64x48x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
@ -1045,18 +1117,42 @@ rs_op_selector()
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 240 == 0) {
return SM90_64x240x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 224 == 0) {
return SM90_64x224x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 176 == 0) {
return SM90_64x176x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 160 == 0) {
return SM90_64x160x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 144 == 0) {
return SM90_64x144x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 112 == 0) {
return SM90_64x112x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 80 == 0) {
return SM90_64x80x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 48 == 0) {
return SM90_64x48x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}

File diff suppressed because it is too large Load Diff

View File

@ -392,18 +392,42 @@ using CLayout_64x16 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _2>>,
using CLayout_64x32 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _4>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x48 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _6>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x64 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _8>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x80 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _10>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x96 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _12>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x112 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, Int<14>>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x128 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _16>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x144 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, Int<18>>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x160 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, Int<20>>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x176 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, Int<22>>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x192 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _24>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x224 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, Int<28>>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x240 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, Int<30>>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x256 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _32>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
@ -898,6 +922,49 @@ struct MMA_Traits<SM90_64x32x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x48x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<48>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 48, 16>;
using CLayout = GMMA::CLayout_64x48;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x48x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<48>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ALayout_64x16;
using BLayout = GMMA::ABLayout< 48, 16>;
using CLayout = GMMA::CLayout_64x48;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x64x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
{
@ -941,6 +1008,49 @@ struct MMA_Traits<SM90_64x64x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x80x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<80>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 80, 16>;
using CLayout = GMMA::CLayout_64x80;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x80x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<80>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ALayout_64x16;
using BLayout = GMMA::ABLayout< 80, 16>;
using CLayout = GMMA::CLayout_64x80;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x96x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
{
@ -984,6 +1094,49 @@ struct MMA_Traits<SM90_64x96x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x112x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<112>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 112, 16>;
using CLayout = GMMA::CLayout_64x112;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x112x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<112>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ALayout_64x16;
using BLayout = GMMA::ABLayout< 112, 16>;
using CLayout = GMMA::CLayout_64x112;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x128x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
{
@ -1027,6 +1180,135 @@ struct MMA_Traits<SM90_64x128x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x144x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<144>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 144, 16>;
using CLayout = GMMA::CLayout_64x144;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x144x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<144>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ALayout_64x16;
using BLayout = GMMA::ABLayout< 144, 16>;
using CLayout = GMMA::CLayout_64x144;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x160x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<160>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 160, 16>;
using CLayout = GMMA::CLayout_64x160;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x160x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<160>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ALayout_64x16;
using BLayout = GMMA::ABLayout< 160, 16>;
using CLayout = GMMA::CLayout_64x160;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x176x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<176>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 176, 16>;
using CLayout = GMMA::CLayout_64x176;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x176x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<176>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ALayout_64x16;
using BLayout = GMMA::ABLayout< 176, 16>;
using CLayout = GMMA::CLayout_64x176;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x192x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
{
@ -1070,6 +1352,92 @@ struct MMA_Traits<SM90_64x192x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x224x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<224>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 224, 16>;
using CLayout = GMMA::CLayout_64x224;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x224x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<224>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ALayout_64x16;
using BLayout = GMMA::ABLayout< 224, 16>;
using CLayout = GMMA::CLayout_64x224;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x240x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<240>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 240, 16>;
using CLayout = GMMA::CLayout_64x240;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x240x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<240>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ALayout_64x16;
using BLayout = GMMA::ABLayout< 240, 16>;
using CLayout = GMMA::CLayout_64x240;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x256x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
{
@ -1242,6 +1610,49 @@ struct MMA_Traits<SM90_64x32x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x48x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<48>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 48, 16>;
using CLayout = GMMA::CLayout_64x48;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x48x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<48>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ALayout_64x16;
using BLayout = GMMA::ABLayout< 48, 16>;
using CLayout = GMMA::CLayout_64x48;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x64x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
{
@ -1285,6 +1696,49 @@ struct MMA_Traits<SM90_64x64x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x80x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<80>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 80, 16>;
using CLayout = GMMA::CLayout_64x80;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x80x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<80>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ALayout_64x16;
using BLayout = GMMA::ABLayout< 80, 16>;
using CLayout = GMMA::CLayout_64x80;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x96x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
{
@ -1328,6 +1782,49 @@ struct MMA_Traits<SM90_64x96x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x112x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<112>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 112, 16>;
using CLayout = GMMA::CLayout_64x112;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x112x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<112>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ALayout_64x16;
using BLayout = GMMA::ABLayout< 112, 16>;
using CLayout = GMMA::CLayout_64x112;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x128x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
{
@ -1371,6 +1868,135 @@ struct MMA_Traits<SM90_64x128x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x144x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<144>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 144, 16>;
using CLayout = GMMA::CLayout_64x144;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x144x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<144>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ALayout_64x16;
using BLayout = GMMA::ABLayout< 144, 16>;
using CLayout = GMMA::CLayout_64x144;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x160x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<160>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 160, 16>;
using CLayout = GMMA::CLayout_64x160;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x160x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<160>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ALayout_64x16;
using BLayout = GMMA::ABLayout< 160, 16>;
using CLayout = GMMA::CLayout_64x160;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x176x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<176>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 176, 16>;
using CLayout = GMMA::CLayout_64x176;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x176x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<176>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ALayout_64x16;
using BLayout = GMMA::ABLayout< 176, 16>;
using CLayout = GMMA::CLayout_64x176;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x192x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
{
@ -1414,6 +2040,92 @@ struct MMA_Traits<SM90_64x192x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x224x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<224>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 224, 16>;
using CLayout = GMMA::CLayout_64x224;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x224x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<224>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ALayout_64x16;
using BLayout = GMMA::ABLayout< 224, 16>;
using CLayout = GMMA::CLayout_64x224;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x240x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<240>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 240, 16>;
using CLayout = GMMA::CLayout_64x240;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x240x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<240>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ALayout_64x16;
using BLayout = GMMA::ABLayout< 240, 16>;
using CLayout = GMMA::CLayout_64x240;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x256x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
{