fix wrong A/BLayout in MMA_Traits for binary mma and append other MMA_Traits support (#1856)
* fix wrong A/BLayout in MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC> and append support for m8n8k128, m16n8k128 mma.and.popc in MMA_Traits instantiation * add "print" template for subbyte_reference<T>
This commit is contained in:
parent
be692b48b0
commit
a424ca6cf9
@ -2141,4 +2141,103 @@ struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// MMA 8x8x128 TN
|
||||
struct SM80_8x8x128_S32U1U1S32_TN_ANDPOPC
|
||||
{
|
||||
using DRegisters = uint32_t[2];
|
||||
using ARegisters = uint32_t[1];
|
||||
using BRegisters = uint32_t[1];
|
||||
using CRegisters = uint32_t[2];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1,
|
||||
uint32_t const& a0,
|
||||
uint32_t const& b0,
|
||||
uint32_t const& c0, uint32_t const& c1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc "
|
||||
"{%0, %1},"
|
||||
"{%2},"
|
||||
"{%3},"
|
||||
"{%4, %5};\n"
|
||||
: "=r"(d0), "=r"(d1)
|
||||
: "r"(a0),
|
||||
"r"(b0),
|
||||
"r"(c0), "r"(c1));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// MMA 16x8x128 TN
|
||||
struct SM80_16x8x128_S32U1U1S32_TN_ANDPOPC
|
||||
{
|
||||
using DRegisters = uint32_t[4];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[1];
|
||||
using CRegisters = uint32_t[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0,
|
||||
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc "
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5},"
|
||||
"{%6},"
|
||||
"{%7, %8, %9, %10};\n"
|
||||
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0),
|
||||
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// MMA 16x8x256 TN
|
||||
struct SM80_16x8x256_S32U1U1S32_TN_ANDPOPC
|
||||
{
|
||||
using DRegisters = uint32_t[4];
|
||||
using ARegisters = uint32_t[4];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = uint32_t[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
||||
uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc "
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5, %6, %7},"
|
||||
"{%8, %9},"
|
||||
"{%10, %11, %12, %13};\n"
|
||||
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
|
||||
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
||||
"r"(b0), "r"(b1),
|
||||
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cute
|
||||
|
||||
@ -433,10 +433,57 @@ struct MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC>
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_256>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = Layout<Shape <_32,Shape < _8, _4,_2, _2>>,
|
||||
Stride<_64,Stride<_64,_16,_8,_2048>>>;
|
||||
using BLayout = Layout<Shape <_32,Shape <_32, _2>>,
|
||||
Stride<_32,Stride< _1,_1024>>>;
|
||||
using ALayout = Layout<Shape<Shape<_4,_8>,Shape<_32,_2,_2>>,
|
||||
Stride<Stride<_512,_1>,Stride<_16,_8,_2048>>>;
|
||||
using BLayout = Layout<Shape<Shape <_4,_8>,Shape<_32,_2>>,
|
||||
Stride<Stride<_256,_1>,Stride< _8,_1024>>>;
|
||||
using CLayout = SM80_16x8_Row;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_ANDPOPC>
|
||||
:MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC> {};
|
||||
|
||||
template<>
|
||||
struct MMA_Traits<SM80_8x8x128_S32U1U1S32_TN_XORPOPC>
|
||||
{
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = cute::uint1b_t;
|
||||
using ValTypeB = cute::uint1b_t;
|
||||
using ValTypeC = int32_t;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_128>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = Layout<Shape<Shape<_4,_8>,_32>,
|
||||
Stride<Stride<_256,_1>,_8>>;
|
||||
using BLayout = Layout<Shape<Shape<_4,_8>,_32>,
|
||||
Stride<Stride<_256,_1>,_8>>;
|
||||
using CLayout = SM80_8x8_Row;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_8x8x128_S32U1U1S32_TN_ANDPOPC>
|
||||
:MMA_Traits<SM80_8x8x128_S32U1U1S32_TN_XORPOPC> {};
|
||||
|
||||
template<>
|
||||
struct MMA_Traits<SM80_16x8x128_S32U1U1S32_TN_XORPOPC>
|
||||
{
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = cute::uint1b_t;
|
||||
using ValTypeB = cute::uint1b_t;
|
||||
using ValTypeC = int32_t;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_128>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = Layout<Shape<Shape<_4,_8>,Shape<_32,_2>>,
|
||||
Stride<Stride<_512,_1>,Stride<Stride<_16,_8>>>>;
|
||||
using BLayout = Layout<Shape <Shape<_4,_8>,_32>,
|
||||
Stride<Stride<_256,_1>,_8>>;
|
||||
using CLayout = SM80_16x8_Row;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x128_S32U1U1S32_TN_ANDPOPC>
|
||||
:MMA_Traits<SM80_16x8x128_S32U1U1S32_TN_XORPOPC> {};
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
@ -346,6 +346,11 @@ print(subbyte_iterator<T> const& x) {
|
||||
printf("subptr[%db](%p.%u)", int(sizeof_bits_v<T>), x.ptr_, x.idx_);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE void
|
||||
print(subbyte_reference<T> const& x) {
|
||||
print(x.get());
|
||||
}
|
||||
//
|
||||
// array_subbyte
|
||||
// Statically sized array for non-byte-aligned data types
|
||||
|
||||
Loading…
Reference in New Issue
Block a user