fix wrong A/BLayout in MMA_Traits for binary mma and append other MMA_Traits support (#1856)

* fix wrong A/BLayout in  MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC> and append support for  m8n8k128, m16n8k128  mma.and.popc in MMA_Traits instantiation

* add "print" template for  subbyte_reference<T>
This commit is contained in:
Caleb_Du 2024-10-25 02:38:35 +08:00 committed by GitHub
parent be692b48b0
commit a424ca6cf9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 155 additions and 4 deletions

View File

@ -2141,4 +2141,103 @@ struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 8x8x128 TN
struct SM80_8x8x128_S32U1U1S32_TN_ANDPOPC
{
using DRegisters = uint32_t[2];
using ARegisters = uint32_t[1];
using BRegisters = uint32_t[1];
using CRegisters = uint32_t[2];
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1,
uint32_t const& a0,
uint32_t const& b0,
uint32_t const& c0, uint32_t const& c1)
{
#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED)
asm volatile(
"mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc "
"{%0, %1},"
"{%2},"
"{%3},"
"{%4, %5};\n"
: "=r"(d0), "=r"(d1)
: "r"(a0),
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 16x8x128 TN
struct SM80_16x8x128_S32U1U1S32_TN_ANDPOPC
{
using DRegisters = uint32_t[4];
using ARegisters = uint32_t[2];
using BRegisters = uint32_t[1];
using CRegisters = uint32_t[4];
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
uint32_t const& a0, uint32_t const& a1,
uint32_t const& b0,
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
{
#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc "
"{%0, %1, %2, %3},"
"{%4, %5},"
"{%6},"
"{%7, %8, %9, %10};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1),
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 16x8x256 TN
struct SM80_16x8x256_S32U1U1S32_TN_ANDPOPC
{
using DRegisters = uint32_t[4];
using ARegisters = uint32_t[4];
using BRegisters = uint32_t[2];
using CRegisters = uint32_t[4];
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
uint32_t const& b0, uint32_t const& b1,
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
{
#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cute

View File

@ -433,10 +433,57 @@ struct MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC>
using Shape_MNK = Shape<_16,_8,_256>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape <_32,Shape < _8, _4,_2, _2>>,
Stride<_64,Stride<_64,_16,_8,_2048>>>;
using BLayout = Layout<Shape <_32,Shape <_32, _2>>,
Stride<_32,Stride< _1,_1024>>>;
using ALayout = Layout<Shape<Shape<_4,_8>,Shape<_32,_2,_2>>,
Stride<Stride<_512,_1>,Stride<_16,_8,_2048>>>;
using BLayout = Layout<Shape<Shape <_4,_8>,Shape<_32,_2>>,
Stride<Stride<_256,_1>,Stride< _8,_1024>>>;
using CLayout = SM80_16x8_Row;
};
template <>
struct MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_ANDPOPC>
:MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC> {};
template<>
struct MMA_Traits<SM80_8x8x128_S32U1U1S32_TN_XORPOPC>
{
using ValTypeD = int32_t;
using ValTypeA = cute::uint1b_t;
using ValTypeB = cute::uint1b_t;
using ValTypeC = int32_t;
using Shape_MNK = Shape<_8,_8,_128>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape<Shape<_4,_8>,_32>,
Stride<Stride<_256,_1>,_8>>;
using BLayout = Layout<Shape<Shape<_4,_8>,_32>,
Stride<Stride<_256,_1>,_8>>;
using CLayout = SM80_8x8_Row;
};
template <>
struct MMA_Traits<SM80_8x8x128_S32U1U1S32_TN_ANDPOPC>
:MMA_Traits<SM80_8x8x128_S32U1U1S32_TN_XORPOPC> {};
template<>
struct MMA_Traits<SM80_16x8x128_S32U1U1S32_TN_XORPOPC>
{
using ValTypeD = int32_t;
using ValTypeA = cute::uint1b_t;
using ValTypeB = cute::uint1b_t;
using ValTypeC = int32_t;
using Shape_MNK = Shape<_16,_8,_128>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape<Shape<_4,_8>,Shape<_32,_2>>,
Stride<Stride<_512,_1>,Stride<Stride<_16,_8>>>>;
using BLayout = Layout<Shape <Shape<_4,_8>,_32>,
Stride<Stride<_256,_1>,_8>>;
using CLayout = SM80_16x8_Row;
};
template <>
struct MMA_Traits<SM80_16x8x128_S32U1U1S32_TN_ANDPOPC>
:MMA_Traits<SM80_16x8x128_S32U1U1S32_TN_XORPOPC> {};
} // end namespace cute

View File

@ -346,6 +346,11 @@ print(subbyte_iterator<T> const& x) {
printf("subptr[%db](%p.%u)", int(sizeof_bits_v<T>), x.ptr_, x.idx_);
}
template <class T>
CUTE_HOST_DEVICE void
print(subbyte_reference<T> const& x) {
print(x.get());
}
//
// array_subbyte
// Statically sized array for non-byte-aligned data types