From b97404837e018d0624daec55b93bf0a02f603404 Mon Sep 17 00:00:00 2001 From: Manish Gupta Date: Wed, 17 May 2023 14:13:23 -0700 Subject: [PATCH] Adding 128x256 tile for 16b input datatype WGMMA gemm (#950) --- tools/library/scripts/generator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 95595af4..f83a7dc3 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -4084,6 +4084,8 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], + 0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]), #TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], # 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - Not compatible with TmaWarpSpecializedCooperative TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], @@ -4092,6 +4094,8 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], + 0, [4, 2, 1], math_inst, min_cc, max_cc, [1,2,1]), #TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], # 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),- Not compatible with TmaWarpSpecializedCooperative TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],