Add GMMA shape m64n40k16 (#1864)

This commit is contained in:
Tri Dao 2024-10-21 17:41:47 -07:00 committed by GitHub
parent 08101d9d0c
commit 5b50a8faaf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 189 additions and 0 deletions

View File

@ -842,6 +842,11 @@ ss_op_selector()
else if constexpr (Tile_N % 48 == 0) {
return SM90::GMMA::MMA_64x48x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
#endif
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
else if constexpr (Tile_N % 40 == 0) {
return SM90::GMMA::MMA_64x40x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
#endif
else if constexpr (Tile_N % 32 == 0) {
return SM90::GMMA::MMA_64x32x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
@ -920,6 +925,11 @@ ss_op_selector()
else if constexpr (Tile_N % 48 == 0) {
return SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
#endif
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
else if constexpr (Tile_N % 40 == 0) {
return SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
#endif
else if constexpr (Tile_N % 32 == 0) {
return SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};

View File

@ -2595,6 +2595,61 @@ struct MMA_64x32x16_F32F16F16_RS
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
// GMMA 64x40x16 F32+=F16*F16
template <
GMMA::Major tnspA,
GMMA::Major tnspB,
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
>
struct MMA_64x40x16_F32F16F16_SS
{
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = float[20];
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
float & d00, float & d01, float & d02, float & d03,
float & d04, float & d05, float & d06, float & d07,
float & d08, float & d09, float & d10, float & d11,
float & d12, float & d13, float & d14, float & d15,
float & d16, float & d17, float & d18, float & d19,
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
{
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %22, 0;\n"
"wgmma.mma_async.sync.aligned.m64n40k16.f32.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19},"
" %20,"
" %21,"
" p, %23, %24, %25, %26;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
: "l"(desc_a),
"l"(desc_b),
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
#endif
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
// GMMA 64x48x16 F32+=F16*F16
template <
@ -5442,6 +5497,61 @@ struct MMA_64x32x16_F32BF16BF16_RS
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
// GMMA 64x40x16 F32+=BF16*BF16
template <
GMMA::Major tnspA,
GMMA::Major tnspB,
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
>
struct MMA_64x40x16_F32BF16BF16_SS
{
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = float[20];
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
float & d00, float & d01, float & d02, float & d03,
float & d04, float & d05, float & d06, float & d07,
float & d08, float & d09, float & d10, float & d11,
float & d12, float & d13, float & d14, float & d15,
float & d16, float & d17, float & d18, float & d19,
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
{
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %22, 0;\n"
"wgmma.mma_async.sync.aligned.m64n40k16.f32.bf16.bf16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19},"
" %20,"
" %21,"
" p, %23, %24, %25, %26;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
: "l"(desc_a),
"l"(desc_b),
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
#endif
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
// GMMA 64x48x16 F32+=BF16*BF16
template <

View File

@ -450,6 +450,9 @@ using CLayout_64x16 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _2>>,
using CLayout_64x32 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _4>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x40 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _5>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x48 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _6>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
@ -1773,6 +1776,39 @@ struct MMA_Traits<SM90_64x32x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
template <
GMMA::Major tnspA,
GMMA::Major tnspB,
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
>
using SM90_64x40x16_F32F16F16_SS = SM90::GMMA::MMA_64x40x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>;
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x40x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<40>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 40, 16>;
using CLayout = GMMA::CLayout_64x40;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
template <
GMMA::Major tnspA,
GMMA::Major tnspB,
@ -2846,6 +2882,39 @@ struct MMA_Traits<SM90_64x32x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
template <
GMMA::Major tnspA,
GMMA::Major tnspB,
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
>
using SM90_64x40x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>;
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x40x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<40>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 40, 16>;
using CLayout = GMMA::CLayout_64x40;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
template <
GMMA::Major tnspA,
GMMA::Major tnspB,