From a424ca6cf9f39658acc02f7fd45a24fbd5f8e8c5 Mon Sep 17 00:00:00 2001 From: Caleb_Du <59528230+CalebDu@users.noreply.github.com> Date: Fri, 25 Oct 2024 02:38:35 +0800 Subject: [PATCH] fix wrong A/BLayout in MMA_Traits for binary mma and append other MMA_Traits support (#1856) * fix wrong A/BLayout in MMA_Traits and append support for m8n8k128, m16n8k128 mma.and.popc in MMA_Traits instantiation * add "print" template for subbyte_reference --- include/cute/arch/mma_sm80.hpp | 99 ++++++++++++++++++++++++ include/cute/atom/mma_traits_sm80.hpp | 55 ++++++++++++- include/cute/container/array_subbyte.hpp | 5 ++ 3 files changed, 155 insertions(+), 4 deletions(-) diff --git a/include/cute/arch/mma_sm80.hpp b/include/cute/arch/mma_sm80.hpp index 8c684b70..60777f22 100644 --- a/include/cute/arch/mma_sm80.hpp +++ b/include/cute/arch/mma_sm80.hpp @@ -2141,4 +2141,103 @@ struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC //////////////////////////////////////////////////////////////////////////////////////////////////// +// MMA 8x8x128 TN +struct SM80_8x8x128_S32U1U1S32_TN_ANDPOPC +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x128 TN +struct SM80_16x8x128_S32U1U1S32_TN_ANDPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x256 TN +struct SM80_16x8x256_S32U1U1S32_TN_ANDPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cute diff --git a/include/cute/atom/mma_traits_sm80.hpp b/include/cute/atom/mma_traits_sm80.hpp index ab402881..706b10d8 100644 --- a/include/cute/atom/mma_traits_sm80.hpp +++ b/include/cute/atom/mma_traits_sm80.hpp @@ -433,10 +433,57 @@ struct MMA_Traits using Shape_MNK = Shape<_16,_8,_256>; using ThrID = Layout<_32>; - using ALayout = Layout>, - Stride<_64,Stride<_64,_16,_8,_2048>>>; - using BLayout = Layout>, - Stride<_32,Stride< _1,_1024>>>; + using ALayout = Layout,Shape<_32,_2,_2>>, + Stride,Stride<_16,_8,_2048>>>; + using BLayout = Layout,Shape<_32,_2>>, + Stride,Stride< _8,_1024>>>; using CLayout = SM80_16x8_Row; }; + +template <> +struct MMA_Traits + :MMA_Traits {}; + +template<> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = cute::uint1b_t; + using ValTypeB = cute::uint1b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8,_8,_128>; + using ThrID = Layout<_32>; + using ALayout = Layout,_32>, + Stride,_8>>; + using BLayout = Layout,_32>, + Stride,_8>>; + using CLayout = SM80_8x8_Row; +}; + +template <> +struct MMA_Traits + :MMA_Traits {}; + +template<> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = cute::uint1b_t; + using ValTypeB = cute::uint1b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16,_8,_128>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape<_32,_2>>, + Stride,Stride>>>; + using BLayout = Layout,_32>, + Stride,_8>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + :MMA_Traits {}; + } // end namespace cute diff --git a/include/cute/container/array_subbyte.hpp b/include/cute/container/array_subbyte.hpp index 6aa26bc9..747dccf8 100644 --- a/include/cute/container/array_subbyte.hpp +++ b/include/cute/container/array_subbyte.hpp @@ -346,6 +346,11 @@ print(subbyte_iterator const& x) { printf("subptr[%db](%p.%u)", int(sizeof_bits_v), x.ptr_, x.idx_); } +template +CUTE_HOST_DEVICE void +print(subbyte_reference const& x) { + print(x.get()); +} // // array_subbyte // Statically sized array for non-byte-aligned data types