[Kernel] Suppress mma.sp warning on CUDA 12.5 and later (#5401)
This commit is contained in:
parent
15985680e2
commit
348616ac4b
@ -20,6 +20,19 @@
|
|||||||
|
|
||||||
namespace marlin_24 {
|
namespace marlin_24 {
|
||||||
|
|
||||||
|
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
|
||||||
|
// is not supported. On later versions of CUDA the version without ordered
|
||||||
|
// metadata results in the following warning:
|
||||||
|
// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction
|
||||||
|
// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially
|
||||||
|
// | reduced performance on some future architectures
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12500
|
||||||
|
#define MMA_SP_INST \
|
||||||
|
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||||
|
#else
|
||||||
|
#define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||||
|
#endif
|
||||||
|
|
||||||
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
|
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
|
||||||
// output/accumulation.
|
// output/accumulation.
|
||||||
__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
|
__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
|
||||||
@ -29,41 +42,38 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
|
|||||||
const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
|
const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
|
||||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||||
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
|
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
|
||||||
|
|
||||||
float* c = reinterpret_cast<float*>(&frag_c);
|
float* c = reinterpret_cast<float*>(&frag_c);
|
||||||
if (psel == 0) {
|
if (psel == 0) {
|
||||||
asm volatile(
|
asm volatile(MMA_SP_INST
|
||||||
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||||
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
||||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
|
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
||||||
"r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
|
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
||||||
"r"(e[0]));
|
asm volatile(MMA_SP_INST
|
||||||
asm volatile(
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||||
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
||||||
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
||||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]),
|
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
||||||
"r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]),
|
|
||||||
"r"(e[0]));
|
|
||||||
} else {
|
} else {
|
||||||
asm volatile(
|
asm volatile(MMA_SP_INST
|
||||||
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||||
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
||||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
|
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
||||||
"r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
|
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
||||||
"r"(e[0]));
|
asm volatile(MMA_SP_INST
|
||||||
asm volatile(
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||||
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
||||||
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
||||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]),
|
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
||||||
"r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]),
|
|
||||||
"r"(e[0]));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user