Fix Int8 and TF32 generator (#976)

This commit is contained in:
ANIKET SHIVAM 2023-06-12 09:32:52 -07:00 committed by GitHub
parent 87349d3496
commit 473a67073e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4197,10 +4197,10 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
# layouts for ABC and their alignments
layouts_tf32 = [
[[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]],
[[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]],
[[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]],
[[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]],
[[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]],
[[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]],
[[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]],
[[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]],
]
math_inst = MathInstruction(
@ -4212,13 +4212,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
min_cc = 90
max_cc = 90
tile_descriptions = [
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
tile_descriptions_medium = [
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
@ -4226,19 +4220,38 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions_small = [
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1])
]
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
data_type_tf32 = {
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : math_inst.element_accumulator,
"d_type" : math_inst.element_accumulator,
"acc_type" : math_inst.element_accumulator,
"epi_type" : math_inst.element_accumulator
}
data_types = [
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : math_inst.element_accumulator,
"d_type" : math_inst.element_accumulator,
"acc_type" : math_inst.element_accumulator,
"epi_type" : math_inst.element_accumulator
},
{
"a_type" : DataType.f32,
"b_type" : DataType.f32,
"c_type" : math_inst.element_accumulator,
"d_type" : math_inst.element_accumulator,
"acc_type" : math_inst.element_accumulator,
"epi_type" : DataType.f32
}
]
schedules = [
schedules_default = [
[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized]
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized],
]
# TMA kernels with TT layout use EpilogueTransposed (NoSmemWarpSpecialized with swapped strides),
@ -4250,32 +4263,27 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
# TMA kernels with TN or NN layout
layouts_tf32_tn_nn = [layouts_tf32[0], layouts_tf32[2]]
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_type_tf32, schedules)
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions_small, data_types, [
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized]
])
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions_medium, data_types, [
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized]
])
else:
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_types, schedules_default)
# TMA kernels with NT layout, only support 64x128x32 tile for now.
layouts_tf32_nt = [layouts_tf32[3]]
tile_64x128x32_descriptions = [tile_descriptions[0], tile_descriptions[1], tile_descriptions[2]]
tile_128x128x32_descriptions = [tile_descriptions[3], tile_descriptions[4], tile_descriptions[5]]
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_64x128x32_descriptions, data_type_tf32, schedules)
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_128x128x32_descriptions, data_type_tf32, [schedules[1]])
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_descriptions_small, data_types, schedules_default)
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_descriptions_medium, data_types, [
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized]
])
layouts_tf32_tt = [layouts_tf32[1]]
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_type_tf32, schedules_transposed_epilogue)
# F32 kernel share same settings with tf32 I/O kernels excluding data type
data_type_f32 = {
"a_type" : DataType.f32,
"b_type" : DataType.f32,
"c_type" : math_inst.element_accumulator,
"d_type" : math_inst.element_accumulator,
"acc_type" : math_inst.element_accumulator,
"epi_type" : DataType.f32
}
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_type_f32, schedules)
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_64x128x32_descriptions, data_type_f32, schedules)
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_128x128x32_descriptions, data_type_f32, [schedules[1]])
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_type_f32, schedules_transposed_epilogue)
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_types, schedules_transposed_epilogue)
#
def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
@ -4284,7 +4292,7 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
# layouts for ABC and their alignments
layouts = [
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]],
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]],
]
math_instructions = [
@ -4304,20 +4312,23 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
max_cc = 90
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
tile_descriptions_small = [
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions_medium = [
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
data_types = [
{
@ -4332,13 +4343,15 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.s8,
"d_type" : DataType.s8,
"d_type" : math_inst.element_a,
"acc_type" : math_inst.element_accumulator,
"epi_type" : DataType.f32
}
]
for data_type in data_types:
for layout in layouts:
layout[2][1] = 128 // DataTypeSize[data_type["d_type"]]
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type)
# persistent kernels with TMA epilogues
@ -4355,13 +4368,17 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
}
]
for data_type in data_types:
# Set alignment d based on Destination format.
# Set output alignment based on destination format first
for layout in layouts:
layout[2][1] = 128 // DataTypeSize[data_type["d_type"]]
# Pingpong persistent
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized]])
# Cooperative persistent
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_medium, data_type,
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized]])
#
def GenerateSM90_TensorOp_1684(manifest, cuda_version):