Add GMMA shape m64n40k16 (#1864)
This commit is contained in:
parent
08101d9d0c
commit
5b50a8faaf
@ -842,6 +842,11 @@ ss_op_selector()
|
||||
else if constexpr (Tile_N % 48 == 0) {
|
||||
return SM90::GMMA::MMA_64x48x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
#endif
|
||||
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
else if constexpr (Tile_N % 40 == 0) {
|
||||
return SM90::GMMA::MMA_64x40x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
#endif
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90::GMMA::MMA_64x32x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
@ -920,6 +925,11 @@ ss_op_selector()
|
||||
else if constexpr (Tile_N % 48 == 0) {
|
||||
return SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
#endif
|
||||
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
else if constexpr (Tile_N % 40 == 0) {
|
||||
return SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
#endif
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
|
||||
@ -2595,6 +2595,61 @@ struct MMA_64x32x16_F32F16F16_RS
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
// GMMA 64x40x16 F32+=F16*F16
|
||||
template <
|
||||
GMMA::Major tnspA,
|
||||
GMMA::Major tnspB,
|
||||
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
||||
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
||||
>
|
||||
struct MMA_64x40x16_F32F16F16_SS
|
||||
{
|
||||
using DRegisters = void;
|
||||
using ARegisters = uint64_t[1];
|
||||
using BRegisters = uint64_t[1];
|
||||
using CRegisters = float[20];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint64_t const& desc_a,
|
||||
uint64_t const& desc_b,
|
||||
float & d00, float & d01, float & d02, float & d03,
|
||||
float & d04, float & d05, float & d06, float & d07,
|
||||
float & d08, float & d09, float & d10, float & d11,
|
||||
float & d12, float & d13, float & d14, float & d15,
|
||||
float & d16, float & d17, float & d18, float & d19,
|
||||
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %22, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n40k16.f32.f16.f16 "
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19},"
|
||||
" %20,"
|
||||
" %21,"
|
||||
" p, %23, %24, %25, %26;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
||||
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
||||
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
|
||||
: "l"(desc_a),
|
||||
"l"(desc_b),
|
||||
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
// GMMA 64x48x16 F32+=F16*F16
|
||||
template <
|
||||
@ -5442,6 +5497,61 @@ struct MMA_64x32x16_F32BF16BF16_RS
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
// GMMA 64x40x16 F32+=BF16*BF16
|
||||
template <
|
||||
GMMA::Major tnspA,
|
||||
GMMA::Major tnspB,
|
||||
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
||||
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
||||
>
|
||||
struct MMA_64x40x16_F32BF16BF16_SS
|
||||
{
|
||||
using DRegisters = void;
|
||||
using ARegisters = uint64_t[1];
|
||||
using BRegisters = uint64_t[1];
|
||||
using CRegisters = float[20];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint64_t const& desc_a,
|
||||
uint64_t const& desc_b,
|
||||
float & d00, float & d01, float & d02, float & d03,
|
||||
float & d04, float & d05, float & d06, float & d07,
|
||||
float & d08, float & d09, float & d10, float & d11,
|
||||
float & d12, float & d13, float & d14, float & d15,
|
||||
float & d16, float & d17, float & d18, float & d19,
|
||||
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %22, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n40k16.f32.bf16.bf16 "
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19},"
|
||||
" %20,"
|
||||
" %21,"
|
||||
" p, %23, %24, %25, %26;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
||||
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
||||
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
|
||||
: "l"(desc_a),
|
||||
"l"(desc_b),
|
||||
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
// GMMA 64x48x16 F32+=BF16*BF16
|
||||
template <
|
||||
|
||||
@ -450,6 +450,9 @@ using CLayout_64x16 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _2>>,
|
||||
using CLayout_64x32 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _4>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x40 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _5>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x48 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _6>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
@ -1773,6 +1776,39 @@ struct MMA_Traits<SM90_64x32x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
|
||||
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
|
||||
template <
|
||||
GMMA::Major tnspA,
|
||||
GMMA::Major tnspB,
|
||||
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
||||
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
||||
>
|
||||
using SM90_64x40x16_F32F16F16_SS = SM90::GMMA::MMA_64x40x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>;
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x40x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<40>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 40, 16>;
|
||||
using CLayout = GMMA::CLayout_64x40;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
|
||||
template <
|
||||
GMMA::Major tnspA,
|
||||
GMMA::Major tnspB,
|
||||
@ -2846,6 +2882,39 @@ struct MMA_Traits<SM90_64x32x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
|
||||
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
|
||||
template <
|
||||
GMMA::Major tnspA,
|
||||
GMMA::Major tnspB,
|
||||
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
||||
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
||||
>
|
||||
using SM90_64x40x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>;
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x40x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<40>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 40, 16>;
|
||||
using CLayout = GMMA::CLayout_64x40;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
|
||||
template <
|
||||
GMMA::Major tnspA,
|
||||
GMMA::Major tnspB,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user