5476 cutlass 3x gemm kernels (#1695)
Co-authored-by: dePaul Miller <23461061+depaulmillz@users.noreply.github.com>
This commit is contained in:
parent
e22ba590cd
commit
2049c6c5a2
@ -4960,12 +4960,33 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
|
||||
DataType.bf16, DataType.bf16, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[64, 256, 16],
|
||||
DataType.f16, DataType.f16, DataType.f16,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[64, 256, 16],
|
||||
DataType.f16, DataType.f16, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[64, 256, 16],
|
||||
DataType.bf16, DataType.bf16, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
]
|
||||
|
||||
min_cc = 90
|
||||
max_cc = 90
|
||||
|
||||
for math_inst in math_instructions:
|
||||
tile_descriptions = []
|
||||
tile_descriptions_small = []
|
||||
tile_descriptions_medium = []
|
||||
tile_descriptions_large = []
|
||||
|
||||
if math_inst.instruction_shape[1] == 128:
|
||||
tile_descriptions_small = [
|
||||
# Not compatible with TmaWarpSpecializedCooperative
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
@ -4981,7 +5002,7 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
]
|
||||
tile_descriptions_large = [
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
@ -4990,17 +5011,17 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
|
||||
0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
|
||||
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
|
||||
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
# 128x256x128
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
|
||||
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions = tile_descriptions_medium + tile_descriptions_large
|
||||
else:
|
||||
tile_descriptions = [
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
|
||||
data_type = {
|
||||
"a_type" : math_inst.element_a,
|
||||
@ -5043,7 +5064,7 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
|
||||
# persistent kernels with TMA epilogues
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
# not enough smem for 256x128 f32 out with C allocation
|
||||
if data_type["d_type"] == DataType.f32:
|
||||
if data_type["d_type"] == DataType.f32 and len(tile_descriptions_medium) > 0:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_medium, data_type,
|
||||
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
|
||||
@ -5490,20 +5511,30 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
|
||||
DataType.u8, DataType.u8, DataType.s32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[64, 256, 32],
|
||||
DataType.s8, DataType.s8, DataType.s32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[64, 256, 32],
|
||||
DataType.u8, DataType.u8, DataType.s32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
]
|
||||
|
||||
min_cc = 90
|
||||
max_cc = 90
|
||||
|
||||
for math_inst in math_instructions:
|
||||
# 64x128x128
|
||||
# 64x128x128 or 64x256x128
|
||||
tile_descriptions_small = [
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
]
|
||||
# 128x128x128
|
||||
# 128x128x128 or 128x256x128
|
||||
tile_descriptions_medium = [
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
@ -5670,6 +5701,27 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
||||
DataType.e5m2, DataType.e5m2, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
# inst 64x256x32
|
||||
MathInstruction(
|
||||
[64, 256, 32],
|
||||
DataType.e4m3, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[64, 256, 32],
|
||||
DataType.e4m3, DataType.e5m2, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[64, 256, 32],
|
||||
DataType.e5m2, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
MathInstruction(
|
||||
[64, 256, 32],
|
||||
DataType.e5m2, DataType.e5m2, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
]
|
||||
|
||||
min_cc = 90
|
||||
@ -5788,9 +5840,6 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
# 128x256x128
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions = [
|
||||
# 128x128x128
|
||||
@ -5801,6 +5850,27 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
elif math_inst.instruction_shape[1] == 256:
|
||||
tile_descriptions_small = [
|
||||
# 64x256x128
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions_large = []
|
||||
tile_descriptions = [
|
||||
# 128x256x128
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
|
||||
|
||||
else:
|
||||
assert False, "math inst is not supported"
|
||||
@ -5842,6 +5912,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
||||
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]])
|
||||
|
||||
# Large tiles
|
||||
if len(tile_descriptions_large) > 0:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_large, data_types_large_tile,
|
||||
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
|
||||
|
Loading…
Reference in New Issue
Block a user