Adjust profiler space for SM89 (#1553)

This commit is contained in:
Wenlei Bao 2024-09-19 08:40:30 -07:00 committed by GitHub
parent 2991ce18d3
commit 44dae8b90e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4881,7 +4881,8 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version):
return
layouts = [
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor)
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor)
]
math_instructions = [
@ -4935,43 +4936,49 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version):
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 64], 6, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 64], 6, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 128, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc),
]
data_types = [
@ -4981,6 +4988,12 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version):
DataType.f32,
math_inst.element_accumulator
],
[
math_inst.element_a,
math_inst.element_b,
DataType.bf16,
math_inst.element_accumulator
],
]
operations = []