Fix Int8 and TF32 generator (#976)
This commit is contained in:
parent
87349d3496
commit
473a67073e
@ -4197,10 +4197,10 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
|||||||
|
|
||||||
# layouts for ABC and their alignments
|
# layouts for ABC and their alignments
|
||||||
layouts_tf32 = [
|
layouts_tf32 = [
|
||||||
[[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]],
|
[[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]],
|
||||||
[[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]],
|
[[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]],
|
||||||
[[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]],
|
[[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]],
|
||||||
[[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]],
|
[[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]],
|
||||||
]
|
]
|
||||||
|
|
||||||
math_inst = MathInstruction(
|
math_inst = MathInstruction(
|
||||||
@ -4212,13 +4212,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
|||||||
min_cc = 90
|
min_cc = 90
|
||||||
max_cc = 90
|
max_cc = 90
|
||||||
|
|
||||||
tile_descriptions = [
|
tile_descriptions_medium = [
|
||||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
|
||||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
|
||||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
@ -4226,19 +4220,38 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
|||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||||
]
|
]
|
||||||
|
tile_descriptions_small = [
|
||||||
|
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||||
|
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||||
|
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1])
|
||||||
|
]
|
||||||
|
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
|
||||||
|
|
||||||
data_type_tf32 = {
|
data_types = [
|
||||||
"a_type" : math_inst.element_a,
|
{
|
||||||
"b_type" : math_inst.element_b,
|
"a_type" : math_inst.element_a,
|
||||||
"c_type" : math_inst.element_accumulator,
|
"b_type" : math_inst.element_b,
|
||||||
"d_type" : math_inst.element_accumulator,
|
"c_type" : math_inst.element_accumulator,
|
||||||
"acc_type" : math_inst.element_accumulator,
|
"d_type" : math_inst.element_accumulator,
|
||||||
"epi_type" : math_inst.element_accumulator
|
"acc_type" : math_inst.element_accumulator,
|
||||||
}
|
"epi_type" : math_inst.element_accumulator
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"a_type" : DataType.f32,
|
||||||
|
"b_type" : DataType.f32,
|
||||||
|
"c_type" : math_inst.element_accumulator,
|
||||||
|
"d_type" : math_inst.element_accumulator,
|
||||||
|
"acc_type" : math_inst.element_accumulator,
|
||||||
|
"epi_type" : DataType.f32
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
schedules = [
|
schedules_default = [
|
||||||
[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
|
[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
|
||||||
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized]
|
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||||
]
|
]
|
||||||
|
|
||||||
# TMA kernels with TT layout use EpilogueTransposed (NoSmemWarpSpecialized with swapped strides),
|
# TMA kernels with TT layout use EpilogueTransposed (NoSmemWarpSpecialized with swapped strides),
|
||||||
@ -4250,32 +4263,27 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
|||||||
|
|
||||||
# TMA kernels with TN or NN layout
|
# TMA kernels with TN or NN layout
|
||||||
layouts_tf32_tn_nn = [layouts_tf32[0], layouts_tf32[2]]
|
layouts_tf32_tn_nn = [layouts_tf32[0], layouts_tf32[2]]
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_type_tf32, schedules)
|
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||||
|
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions_small, data_types, [
|
||||||
|
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
|
||||||
|
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||||
|
])
|
||||||
|
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions_medium, data_types, [
|
||||||
|
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
||||||
|
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_types, schedules_default)
|
||||||
|
|
||||||
# TMA kernels with NT layout, only support 64x128x32 tile for now.
|
# TMA kernels with NT layout, only support 64x128x32 tile for now.
|
||||||
layouts_tf32_nt = [layouts_tf32[3]]
|
layouts_tf32_nt = [layouts_tf32[3]]
|
||||||
tile_64x128x32_descriptions = [tile_descriptions[0], tile_descriptions[1], tile_descriptions[2]]
|
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_descriptions_small, data_types, schedules_default)
|
||||||
tile_128x128x32_descriptions = [tile_descriptions[3], tile_descriptions[4], tile_descriptions[5]]
|
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_descriptions_medium, data_types, [
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_64x128x32_descriptions, data_type_tf32, schedules)
|
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_128x128x32_descriptions, data_type_tf32, [schedules[1]])
|
])
|
||||||
|
|
||||||
layouts_tf32_tt = [layouts_tf32[1]]
|
layouts_tf32_tt = [layouts_tf32[1]]
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_type_tf32, schedules_transposed_epilogue)
|
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_types, schedules_transposed_epilogue)
|
||||||
|
|
||||||
# F32 kernel share same settings with tf32 I/O kernels excluding data type
|
|
||||||
data_type_f32 = {
|
|
||||||
"a_type" : DataType.f32,
|
|
||||||
"b_type" : DataType.f32,
|
|
||||||
"c_type" : math_inst.element_accumulator,
|
|
||||||
"d_type" : math_inst.element_accumulator,
|
|
||||||
"acc_type" : math_inst.element_accumulator,
|
|
||||||
"epi_type" : DataType.f32
|
|
||||||
}
|
|
||||||
|
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_type_f32, schedules)
|
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_64x128x32_descriptions, data_type_f32, schedules)
|
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_128x128x32_descriptions, data_type_f32, [schedules[1]])
|
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_type_f32, schedules_transposed_epilogue)
|
|
||||||
|
|
||||||
#
|
#
|
||||||
def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
|
def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
|
||||||
@ -4284,7 +4292,7 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
|
|||||||
|
|
||||||
# layouts for ABC and their alignments
|
# layouts for ABC and their alignments
|
||||||
layouts = [
|
layouts = [
|
||||||
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]],
|
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]],
|
||||||
]
|
]
|
||||||
|
|
||||||
math_instructions = [
|
math_instructions = [
|
||||||
@ -4304,20 +4312,23 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
|
|||||||
max_cc = 90
|
max_cc = 90
|
||||||
|
|
||||||
for math_inst in math_instructions:
|
for math_inst in math_instructions:
|
||||||
tile_descriptions = [
|
tile_descriptions_small = [
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
|
||||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
|
||||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
|
||||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||||
]
|
]
|
||||||
|
tile_descriptions_medium = [
|
||||||
|
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||||
|
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||||
|
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||||
|
]
|
||||||
|
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
|
||||||
|
|
||||||
data_types = [
|
data_types = [
|
||||||
{
|
{
|
||||||
@ -4332,13 +4343,15 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
|
|||||||
"a_type" : math_inst.element_a,
|
"a_type" : math_inst.element_a,
|
||||||
"b_type" : math_inst.element_b,
|
"b_type" : math_inst.element_b,
|
||||||
"c_type" : DataType.s8,
|
"c_type" : DataType.s8,
|
||||||
"d_type" : DataType.s8,
|
"d_type" : math_inst.element_a,
|
||||||
"acc_type" : math_inst.element_accumulator,
|
"acc_type" : math_inst.element_accumulator,
|
||||||
"epi_type" : DataType.f32
|
"epi_type" : DataType.f32
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
for data_type in data_types:
|
for data_type in data_types:
|
||||||
|
for layout in layouts:
|
||||||
|
layout[2][1] = 128 // DataTypeSize[data_type["d_type"]]
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type)
|
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type)
|
||||||
|
|
||||||
# persistent kernels with TMA epilogues
|
# persistent kernels with TMA epilogues
|
||||||
@ -4355,13 +4368,17 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
for data_type in data_types:
|
for data_type in data_types:
|
||||||
# Set alignment d based on Destination format.
|
# Set output alignment based on destination format first
|
||||||
for layout in layouts:
|
for layout in layouts:
|
||||||
layout[2][1] = 128 // DataTypeSize[data_type["d_type"]]
|
layout[2][1] = 128 // DataTypeSize[data_type["d_type"]]
|
||||||
|
# Pingpong persistent
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||||
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
|
[[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
|
||||||
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
|
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized]])
|
||||||
|
# Cooperative persistent
|
||||||
|
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_medium, data_type,
|
||||||
|
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
||||||
|
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized]])
|
||||||
|
|
||||||
#
|
#
|
||||||
def GenerateSM90_TensorOp_1684(manifest, cuda_version):
|
def GenerateSM90_TensorOp_1684(manifest, cuda_version):
|
||||||
|
Loading…
Reference in New Issue
Block a user