Adding 128x256 tile for 16b input datatype WGMMA gemm (#950)

This commit is contained in:
Manish Gupta 2023-05-17 14:13:23 -07:00 committed by GitHub
parent e2953d47c5
commit b97404837e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4084,6 +4084,8 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]),
#TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
# 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - Not compatible with TmaWarpSpecializedCooperative
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
@ -4092,6 +4094,8 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [1,2,1]),
#TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
# 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),- Not compatible with TmaWarpSpecializedCooperative
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],