5476 cutlass 3x gemm kernels (#1695)

Co-authored-by: dePaul Miller <23461061+depaulmillz@users.noreply.github.com>
This commit is contained in:
dePaul Miller 2024-08-08 10:56:23 -07:00 committed by GitHub
parent e22ba590cd
commit 2049c6c5a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4960,12 +4960,33 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16, DataType.bf16, DataType.f32,
OpcodeClass.TensorOp, OpcodeClass.TensorOp,
MathOperation.multiply_add), MathOperation.multiply_add),
MathInstruction(
[64, 256, 16],
DataType.f16, DataType.f16, DataType.f16,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 256, 16],
DataType.f16, DataType.f16, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 256, 16],
DataType.bf16, DataType.bf16, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
] ]
min_cc = 90 min_cc = 90
max_cc = 90 max_cc = 90
for math_inst in math_instructions: for math_inst in math_instructions:
tile_descriptions = []
tile_descriptions_small = []
tile_descriptions_medium = []
tile_descriptions_large = []
if math_inst.instruction_shape[1] == 128:
tile_descriptions_small = [ tile_descriptions_small = [
# Not compatible with TmaWarpSpecializedCooperative # Not compatible with TmaWarpSpecializedCooperative
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
@ -4981,7 +5002,7 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
] ]
tile_descriptions_large = [ tile_descriptions_large = [
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
@ -4990,17 +5011,17 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]),
# 128x256x128
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]),
] ]
tile_descriptions = tile_descriptions_medium + tile_descriptions_large tile_descriptions = tile_descriptions_medium + tile_descriptions_large
else:
tile_descriptions = [
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
data_type = { data_type = {
"a_type" : math_inst.element_a, "a_type" : math_inst.element_a,
@ -5043,7 +5064,7 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
# persistent kernels with TMA epilogues # persistent kernels with TMA epilogues
if CudaToolkitVersionSatisfies(cuda_version, 12, 1): if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
# not enough smem for 256x128 f32 out with C allocation # not enough smem for 256x128 f32 out with C allocation
if data_type["d_type"] == DataType.f32: if data_type["d_type"] == DataType.f32 and len(tile_descriptions_medium) > 0:
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_medium, data_type, CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_medium, data_type,
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], [[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
@ -5490,20 +5511,30 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
DataType.u8, DataType.u8, DataType.s32, DataType.u8, DataType.u8, DataType.s32,
OpcodeClass.TensorOp, OpcodeClass.TensorOp,
MathOperation.multiply_add), MathOperation.multiply_add),
MathInstruction(
[64, 256, 32],
DataType.s8, DataType.s8, DataType.s32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 256, 32],
DataType.u8, DataType.u8, DataType.s32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
] ]
min_cc = 90 min_cc = 90
max_cc = 90 max_cc = 90
for math_inst in math_instructions: for math_inst in math_instructions:
# 64x128x128 # 64x128x128 or 64x256x128
tile_descriptions_small = [ tile_descriptions_small = [
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
] ]
# 128x128x128 # 128x128x128 or 128x256x128
tile_descriptions_medium = [ tile_descriptions_medium = [
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
@ -5670,6 +5701,27 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
DataType.e5m2, DataType.e5m2, DataType.f32, DataType.e5m2, DataType.e5m2, DataType.f32,
OpcodeClass.TensorOp, OpcodeClass.TensorOp,
MathOperation.multiply_add), MathOperation.multiply_add),
# inst 64x256x32
MathInstruction(
[64, 256, 32],
DataType.e4m3, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 256, 32],
DataType.e4m3, DataType.e5m2, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 256, 32],
DataType.e5m2, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[64, 256, 32],
DataType.e5m2, DataType.e5m2, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
] ]
min_cc = 90 min_cc = 90
@ -5788,9 +5840,6 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
# 128x256x128
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
] ]
tile_descriptions = [ tile_descriptions = [
# 128x128x128 # 128x128x128
@ -5801,6 +5850,27 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
] ]
elif math_inst.instruction_shape[1] == 256:
tile_descriptions_small = [
# 64x256x128
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions_large = []
tile_descriptions = [
# 128x256x128
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
else: else:
assert False, "math inst is not supported" assert False, "math inst is not supported"
@ -5842,6 +5912,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]]) [KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]])
# Large tiles # Large tiles
if len(tile_descriptions_large) > 0:
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_large, data_types_large_tile, CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_large, data_types_large_tile,
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative], [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) [KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]])