Fix Int8 and TF32 generator (#976)

This commit is contained in:
ANIKET SHIVAM 2023-06-12 09:32:52 -07:00 committed by GitHub
parent 87349d3496
commit 473a67073e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4197,10 +4197,10 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
# layouts for ABC and their alignments # layouts for ABC and their alignments
layouts_tf32 = [ layouts_tf32 = [
[[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]],
[[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]], [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]],
[[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]],
[[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]], [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]],
] ]
math_inst = MathInstruction( math_inst = MathInstruction(
@ -4212,13 +4212,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
min_cc = 90 min_cc = 90
max_cc = 90 max_cc = 90
tile_descriptions = [ tile_descriptions_medium = [
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
@ -4226,19 +4220,38 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
] ]
tile_descriptions_small = [
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1])
]
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
data_type_tf32 = { data_types = [
"a_type" : math_inst.element_a, {
"b_type" : math_inst.element_b, "a_type" : math_inst.element_a,
"c_type" : math_inst.element_accumulator, "b_type" : math_inst.element_b,
"d_type" : math_inst.element_accumulator, "c_type" : math_inst.element_accumulator,
"acc_type" : math_inst.element_accumulator, "d_type" : math_inst.element_accumulator,
"epi_type" : math_inst.element_accumulator "acc_type" : math_inst.element_accumulator,
} "epi_type" : math_inst.element_accumulator
},
{
"a_type" : DataType.f32,
"b_type" : DataType.f32,
"c_type" : math_inst.element_accumulator,
"d_type" : math_inst.element_accumulator,
"acc_type" : math_inst.element_accumulator,
"epi_type" : DataType.f32
}
]
schedules = [ schedules_default = [
[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto], [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized] [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized],
] ]
# TMA kernels with TT layout use EpilogueTransposed (NoSmemWarpSpecialized with swapped strides), # TMA kernels with TT layout use EpilogueTransposed (NoSmemWarpSpecialized with swapped strides),
@ -4250,32 +4263,27 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
# TMA kernels with TN or NN layout # TMA kernels with TN or NN layout
layouts_tf32_tn_nn = [layouts_tf32[0], layouts_tf32[2]] layouts_tf32_tn_nn = [layouts_tf32[0], layouts_tf32[2]]
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_type_tf32, schedules) if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions_small, data_types, [
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized]
])
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions_medium, data_types, [
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized]
])
else:
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_types, schedules_default)
# TMA kernels with NT layout, only support 64x128x32 tile for now. # TMA kernels with NT layout, only support 64x128x32 tile for now.
layouts_tf32_nt = [layouts_tf32[3]] layouts_tf32_nt = [layouts_tf32[3]]
tile_64x128x32_descriptions = [tile_descriptions[0], tile_descriptions[1], tile_descriptions[2]] CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_descriptions_small, data_types, schedules_default)
tile_128x128x32_descriptions = [tile_descriptions[3], tile_descriptions[4], tile_descriptions[5]] CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_descriptions_medium, data_types, [
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_64x128x32_descriptions, data_type_tf32, schedules) [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized]
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_128x128x32_descriptions, data_type_tf32, [schedules[1]]) ])
layouts_tf32_tt = [layouts_tf32[1]] layouts_tf32_tt = [layouts_tf32[1]]
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_type_tf32, schedules_transposed_epilogue) CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_types, schedules_transposed_epilogue)
# F32 kernel share same settings with tf32 I/O kernels excluding data type
data_type_f32 = {
"a_type" : DataType.f32,
"b_type" : DataType.f32,
"c_type" : math_inst.element_accumulator,
"d_type" : math_inst.element_accumulator,
"acc_type" : math_inst.element_accumulator,
"epi_type" : DataType.f32
}
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_type_f32, schedules)
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_64x128x32_descriptions, data_type_f32, schedules)
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_128x128x32_descriptions, data_type_f32, [schedules[1]])
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_type_f32, schedules_transposed_epilogue)
# #
def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
@ -4284,7 +4292,7 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
# layouts for ABC and their alignments # layouts for ABC and their alignments
layouts = [ layouts = [
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]],
] ]
math_instructions = [ math_instructions = [
@ -4304,20 +4312,23 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
max_cc = 90 max_cc = 90
for math_inst in math_instructions: for math_inst in math_instructions:
tile_descriptions = [ tile_descriptions_small = [
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
] ]
tile_descriptions_medium = [
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
data_types = [ data_types = [
{ {
@ -4332,13 +4343,15 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
"a_type" : math_inst.element_a, "a_type" : math_inst.element_a,
"b_type" : math_inst.element_b, "b_type" : math_inst.element_b,
"c_type" : DataType.s8, "c_type" : DataType.s8,
"d_type" : DataType.s8, "d_type" : math_inst.element_a,
"acc_type" : math_inst.element_accumulator, "acc_type" : math_inst.element_accumulator,
"epi_type" : DataType.f32 "epi_type" : DataType.f32
} }
] ]
for data_type in data_types: for data_type in data_types:
for layout in layouts:
layout[2][1] = 128 // DataTypeSize[data_type["d_type"]]
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type) CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type)
# persistent kernels with TMA epilogues # persistent kernels with TMA epilogues
@ -4355,13 +4368,17 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
} }
] ]
for data_type in data_types: for data_type in data_types:
# Set alignment d based on Destination format. # Set output alignment based on destination format first
for layout in layouts: for layout in layouts:
layout[2][1] = 128 // DataTypeSize[data_type["d_type"]] layout[2][1] = 128 // DataTypeSize[data_type["d_type"]]
# Pingpong persistent
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], [[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) [KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized]])
# Cooperative persistent
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_medium, data_type,
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized]])
# #
def GenerateSM90_TensorOp_1684(manifest, cuda_version): def GenerateSM90_TensorOp_1684(manifest, cuda_version):