Fix Int8 and TF32 generator (#976)
This commit is contained in:
parent
87349d3496
commit
473a67073e
@ -4197,10 +4197,10 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
||||
|
||||
# layouts for ABC and their alignments
|
||||
layouts_tf32 = [
|
||||
[[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]],
|
||||
[[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]],
|
||||
[[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]],
|
||||
[[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]],
|
||||
[[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]],
|
||||
[[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]],
|
||||
[[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]],
|
||||
[[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]],
|
||||
]
|
||||
|
||||
math_inst = MathInstruction(
|
||||
@ -4212,13 +4212,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
||||
min_cc = 90
|
||||
max_cc = 90
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
tile_descriptions_medium = [
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
@ -4226,19 +4220,38 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions_small = [
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1])
|
||||
]
|
||||
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
|
||||
|
||||
data_type_tf32 = {
|
||||
data_types = [
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : math_inst.element_accumulator,
|
||||
"d_type" : math_inst.element_accumulator,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : math_inst.element_accumulator
|
||||
},
|
||||
{
|
||||
"a_type" : DataType.f32,
|
||||
"b_type" : DataType.f32,
|
||||
"c_type" : math_inst.element_accumulator,
|
||||
"d_type" : math_inst.element_accumulator,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : DataType.f32
|
||||
}
|
||||
]
|
||||
|
||||
schedules = [
|
||||
schedules_default = [
|
||||
[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
|
||||
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||
]
|
||||
|
||||
# TMA kernels with TT layout use EpilogueTransposed (NoSmemWarpSpecialized with swapped strides),
|
||||
@ -4250,32 +4263,27 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
||||
|
||||
# TMA kernels with TN or NN layout
|
||||
layouts_tf32_tn_nn = [layouts_tf32[0], layouts_tf32[2]]
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_type_tf32, schedules)
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions_small, data_types, [
|
||||
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||
])
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions_medium, data_types, [
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||
])
|
||||
else:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_types, schedules_default)
|
||||
|
||||
# TMA kernels with NT layout, only support 64x128x32 tile for now.
|
||||
layouts_tf32_nt = [layouts_tf32[3]]
|
||||
tile_64x128x32_descriptions = [tile_descriptions[0], tile_descriptions[1], tile_descriptions[2]]
|
||||
tile_128x128x32_descriptions = [tile_descriptions[3], tile_descriptions[4], tile_descriptions[5]]
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_64x128x32_descriptions, data_type_tf32, schedules)
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_128x128x32_descriptions, data_type_tf32, [schedules[1]])
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_descriptions_small, data_types, schedules_default)
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_descriptions_medium, data_types, [
|
||||
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||
])
|
||||
|
||||
layouts_tf32_tt = [layouts_tf32[1]]
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_type_tf32, schedules_transposed_epilogue)
|
||||
|
||||
# F32 kernel share same settings with tf32 I/O kernels excluding data type
|
||||
data_type_f32 = {
|
||||
"a_type" : DataType.f32,
|
||||
"b_type" : DataType.f32,
|
||||
"c_type" : math_inst.element_accumulator,
|
||||
"d_type" : math_inst.element_accumulator,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : DataType.f32
|
||||
}
|
||||
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_type_f32, schedules)
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_64x128x32_descriptions, data_type_f32, schedules)
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_128x128x32_descriptions, data_type_f32, [schedules[1]])
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_type_f32, schedules_transposed_epilogue)
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_types, schedules_transposed_epilogue)
|
||||
|
||||
#
|
||||
def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
|
||||
@ -4284,7 +4292,7 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
|
||||
|
||||
# layouts for ABC and their alignments
|
||||
layouts = [
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]],
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]],
|
||||
]
|
||||
|
||||
math_instructions = [
|
||||
@ -4304,20 +4312,23 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
|
||||
max_cc = 90
|
||||
|
||||
for math_inst in math_instructions:
|
||||
tile_descriptions = [
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
tile_descriptions_small = [
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions_medium = [
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
|
||||
|
||||
data_types = [
|
||||
{
|
||||
@ -4332,13 +4343,15 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.s8,
|
||||
"d_type" : DataType.s8,
|
||||
"d_type" : math_inst.element_a,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : DataType.f32
|
||||
}
|
||||
]
|
||||
|
||||
for data_type in data_types:
|
||||
for layout in layouts:
|
||||
layout[2][1] = 128 // DataTypeSize[data_type["d_type"]]
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type)
|
||||
|
||||
# persistent kernels with TMA epilogues
|
||||
@ -4355,13 +4368,17 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
|
||||
}
|
||||
]
|
||||
for data_type in data_types:
|
||||
# Set alignment d based on Destination format.
|
||||
# Set output alignment based on destination format first
|
||||
for layout in layouts:
|
||||
layout[2][1] = 128 // DataTypeSize[data_type["d_type"]]
|
||||
# Pingpong persistent
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
|
||||
|
||||
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized]])
|
||||
# Cooperative persistent
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_medium, data_type,
|
||||
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized]])
|
||||
|
||||
#
|
||||
def GenerateSM90_TensorOp_1684(manifest, cuda_version):
|
||||
|
Loading…
Reference in New Issue
Block a user