diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 7b52d835..09d92e9a 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -4197,10 +4197,10 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): # layouts for ABC and their alignments layouts_tf32 = [ - [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], - [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]], - [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], - [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], ] math_inst = MathInstruction( @@ -4212,13 +4212,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): min_cc = 90 max_cc = 90 - tile_descriptions = [ - TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + tile_descriptions_medium = [ TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], @@ -4226,19 +4220,38 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] + tile_descriptions_small = [ + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]) + ] + tile_descriptions = tile_descriptions_medium + tile_descriptions_small - data_type_tf32 = { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_accumulator, - "d_type" : math_inst.element_accumulator, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator - } + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + } + ] - schedules = [ + schedules_default = [ [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto], - [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized] + [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized], ] # TMA kernels with TT layout use EpilogueTransposed (NoSmemWarpSpecialized with swapped strides), @@ -4250,32 +4263,27 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): # TMA kernels with TN or NN layout layouts_tf32_tn_nn = [layouts_tf32[0], layouts_tf32[2]] - CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_type_tf32, schedules) + if CudaToolkitVersionSatisfies(cuda_version, 12, 1): + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions_small, data_types, [ + [KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], + [KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized] + ]) + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions_medium, data_types, [ + [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative], + [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized] + ]) + else: + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_types, schedules_default) # TMA kernels with NT layout, only support 64x128x32 tile for now. layouts_tf32_nt = [layouts_tf32[3]] - tile_64x128x32_descriptions = [tile_descriptions[0], tile_descriptions[1], tile_descriptions[2]] - tile_128x128x32_descriptions = [tile_descriptions[3], tile_descriptions[4], tile_descriptions[5]] - CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_64x128x32_descriptions, data_type_tf32, schedules) - CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_128x128x32_descriptions, data_type_tf32, [schedules[1]]) + CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_descriptions_small, data_types, schedules_default) + CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_descriptions_medium, data_types, [ + [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized] + ]) layouts_tf32_tt = [layouts_tf32[1]] - CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_type_tf32, schedules_transposed_epilogue) - - # F32 kernel share same settings with tf32 I/O kernels excluding data type - data_type_f32 = { - "a_type" : DataType.f32, - "b_type" : DataType.f32, - "c_type" : math_inst.element_accumulator, - "d_type" : math_inst.element_accumulator, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32 - } - - CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_type_f32, schedules) - CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_64x128x32_descriptions, data_type_f32, schedules) - CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_128x128x32_descriptions, data_type_f32, [schedules[1]]) - CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_type_f32, schedules_transposed_epilogue) + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_types, schedules_transposed_epilogue) # def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): @@ -4284,7 +4292,7 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): # layouts for ABC and their alignments layouts = [ - [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]], ] math_instructions = [ @@ -4304,20 +4312,23 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): max_cc = 90 for math_inst in math_instructions: - tile_descriptions = [ - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + tile_descriptions_small = [ TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] + tile_descriptions_medium = [ + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + tile_descriptions = tile_descriptions_medium + tile_descriptions_small data_types = [ { @@ -4332,13 +4343,15 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): "a_type" : math_inst.element_a, "b_type" : math_inst.element_b, "c_type" : DataType.s8, - "d_type" : DataType.s8, + "d_type" : math_inst.element_a, "acc_type" : math_inst.element_accumulator, "epi_type" : DataType.f32 } ] for data_type in data_types: + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_type["d_type"]] CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type) # persistent kernels with TMA epilogues @@ -4355,13 +4368,17 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): } ] for data_type in data_types: - # Set alignment d based on Destination format. + # Set output alignment based on destination format first for layout in layouts: layout[2][1] = 128 // DataTypeSize[data_type["d_type"]] + # Pingpong persistent CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, [[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], - [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) - + [KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized]]) + # Cooperative persistent + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_medium, data_type, + [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative], + [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized]]) # def GenerateSM90_TensorOp_1684(manifest, cuda_version):