From eee0cab26c8eedea447eb3b58b3498eeba2294da Mon Sep 17 00:00:00 2001 From: Ali Hassani <68103095+alihassanijr@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:22:29 -0400 Subject: [PATCH] Stamp out 1x1x1 clusters, 128x256 CTA shape (#1665) Adds 128x256 tile shapes to FP16/BF16 and FP8 generators. Also adds 1x1x1 clusters to all existing FP16/BF16/FP8 generators. NOTE: it is important to set kernel filter (--kernels / CUTLASS_LIBRARY_KERNELS) to a non empty string and skip pruning to get all of the new configurations. If profiling exhaustively, they can be set to `*`. Number of CUTLASS 3.X GEMMs before this commit: 2868 Number of CUTLASS 3.X GEMMs after this commit: 4016 Co-authored-by: Ali Hassani --- python/cutlass_library/generator.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index cbc9c326..8aa18b4b 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -4972,22 +4972,33 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] tile_descriptions_medium = [ TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] tile_descriptions_large = [ TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], 0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], 0, [4, 2, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], + 0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]), + # 128x256x128 + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], + 0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]), ] tile_descriptions = tile_descriptions_medium + tile_descriptions_large @@ -5766,6 +5777,8 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] tile_descriptions_large = [ # 256x128x128 @@ -5773,6 +5786,11 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + # 128x256x128 + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] tile_descriptions = [ # 128x128x128 @@ -5780,6 +5798,8 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] else: