diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 681fb828..4d49f52e 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -1712,7 +1712,14 @@ def GenerateSM80_TensorOp_16864_TN(manifest, args): operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + for op in operations: if op.tile_description.threadblock_shape[1] >= 128: op.C.alignment = 8