diff --git a/include/cute/arch/mma_sm90.hpp b/include/cute/arch/mma_sm90.hpp index 6ab29adc..cd96b2d5 100644 --- a/include/cute/arch/mma_sm90.hpp +++ b/include/cute/arch/mma_sm90.hpp @@ -842,6 +842,11 @@ ss_op_selector() else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x16_F32F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F32F16F16_SS{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x16_F32F16F16_SS{}; @@ -920,6 +925,11 @@ ss_op_selector() else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS{}; diff --git a/include/cute/arch/mma_sm90_gmma.hpp b/include/cute/arch/mma_sm90_gmma.hpp index 4dc01463..1213823b 100644 --- a/include/cute/arch/mma_sm90_gmma.hpp +++ b/include/cute/arch/mma_sm90_gmma.hpp @@ -2595,6 +2595,61 @@ struct MMA_64x32x16_F32F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// GMMA 64x40x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x16 F32+=F16*F16 template < @@ -5442,6 +5497,61 @@ struct MMA_64x32x16_F32BF16BF16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// GMMA 64x40x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x16 F32+=BF16*BF16 template < diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index 74f3d646..03016fa9 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -450,6 +450,9 @@ using CLayout_64x16 = Layout,Shape < _2,_2, _2>>, using CLayout_64x32 = Layout,Shape < _2,_2, _4>>, Stride,Stride<_64,_8,_512>>>; +using CLayout_64x40 = Layout,Shape < _2,_2, _5>>, + Stride,Stride<_64,_8,_512>>>; + using CLayout_64x48 = Layout,Shape < _2,_2, _6>>, Stride,Stride<_64,_8,_512>>>; @@ -1773,6 +1776,39 @@ struct MMA_Traits> #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32F16F16_SS = SM90::GMMA::MMA_64x40x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,Int<40>,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -2846,6 +2882,39 @@ struct MMA_Traits> #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,Int<40>,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + template < GMMA::Major tnspA, GMMA::Major tnspB,