diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 8aa18b4b..1f2eb86e 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -4960,47 +4960,68 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): DataType.bf16, DataType.bf16, DataType.f32, OpcodeClass.TensorOp, MathOperation.multiply_add), + MathInstruction( + [64, 256, 16], + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 256, 16], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 256, 16], + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), ] min_cc = 90 max_cc = 90 for math_inst in math_instructions: - tile_descriptions_small = [ - # Not compatible with TmaWarpSpecializedCooperative - TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), - ] - tile_descriptions_medium = [ - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), - ] - tile_descriptions_large = [ - TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], - 0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], - 0, [4, 2, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], - 0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]), - # 128x256x128 - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], - 0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]), - ] - tile_descriptions = tile_descriptions_medium + tile_descriptions_large + tile_descriptions = [] + tile_descriptions_small = [] + tile_descriptions_medium = [] + tile_descriptions_large = [] + + if math_inst.instruction_shape[1] == 128: + tile_descriptions_small = [ + # Not compatible with TmaWarpSpecializedCooperative + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + tile_descriptions_medium = [ + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + ] + tile_descriptions_large = [ + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + tile_descriptions = tile_descriptions_medium + tile_descriptions_large + else: + tile_descriptions = [ + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 2, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] data_type = { "a_type" : math_inst.element_a, @@ -5043,7 +5064,7 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): # persistent kernels with TMA epilogues if CudaToolkitVersionSatisfies(cuda_version, 12, 1): # not enough smem for 256x128 f32 out with C allocation - if data_type["d_type"] == DataType.f32: + if data_type["d_type"] == DataType.f32 and len(tile_descriptions_medium) > 0: CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_medium, data_type, [[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) @@ -5490,20 +5511,30 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): DataType.u8, DataType.u8, DataType.s32, OpcodeClass.TensorOp, MathOperation.multiply_add), + MathInstruction( + [64, 256, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 256, 32], + DataType.u8, DataType.u8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), ] min_cc = 90 max_cc = 90 for math_inst in math_instructions: - # 64x128x128 + # 64x128x128 or 64x256x128 tile_descriptions_small = [ TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), ] - # 128x128x128 + # 128x128x128 or 128x256x128 tile_descriptions_medium = [ TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), @@ -5670,6 +5701,27 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): DataType.e5m2, DataType.e5m2, DataType.f32, OpcodeClass.TensorOp, MathOperation.multiply_add), + # inst 64x256x32 + MathInstruction( + [64, 256, 32], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 256, 32], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 256, 32], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 256, 32], + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), ] min_cc = 90 @@ -5788,9 +5840,6 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), - # 128x256x128 - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] tile_descriptions = [ # 128x128x128 @@ -5801,6 +5850,27 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] + elif math_inst.instruction_shape[1] == 256: + tile_descriptions_small = [ + # 64x256x128 + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + tile_descriptions_large = [] + tile_descriptions = [ + # 128x256x128 + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + else: assert False, "math inst is not supported" @@ -5842,9 +5912,10 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): [KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]]) # Large tiles - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_large, data_types_large_tile, - [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative], - [KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) + if len(tile_descriptions_large) > 0: + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_large, data_types_large_tile, + [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative], + [KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) # Add stream-K variants (with and without TMA epilogues) CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK])