From f78994bb400ccc29ad3bf7d06a41641859e14557 Mon Sep 17 00:00:00 2001 From: Haicheng Wu <57973641+hwu36@users.noreply.github.com> Date: Sat, 25 Dec 2021 07:29:54 -0500 Subject: [PATCH] add the missing pieces (#392) Co-authored-by: Haicheng Wu --- tools/library/scripts/gemm_operation.py | 1 + tools/library/scripts/generator.py | 3 +++ tools/library/scripts/library.py | 4 ++++ 3 files changed, 8 insertions(+) diff --git a/tools/library/scripts/gemm_operation.py b/tools/library/scripts/gemm_operation.py index fe7462a3..2914d30a 100644 --- a/tools/library/scripts/gemm_operation.py +++ b/tools/library/scripts/gemm_operation.py @@ -41,6 +41,7 @@ class GemmOperation: complex_operators = [ MathOperation.multiply_add_complex, MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 ] return self.tile_description.math_instruction.math_operation in complex_operators diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 4b2c7805..bc5b599e 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -2641,6 +2641,9 @@ def GenerateSM80(manifest, args): GenerateSM80_TensorOp_1688_fast_math(manifest, args) GenerateSM80_SparseTensorOp_16816_fast_math(manifest, args) GenerateSM80_TensorOp_1688_complex(manifest, args) + # 3xTF32 + GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, args) + GenerateSM80_TensorOp_1688_fast_fp32_math_complex(manifest, args) GenerateSM80_TensorOp_884(manifest, args) GenerateSM80_TensorOp_884_complex(manifest, args) GenerateSM80_TensorOp_884_complex_gaussian(manifest, args) diff --git a/tools/library/scripts/library.py b/tools/library/scripts/library.py index 21ef62bf..de0a1e3a 100644 --- a/tools/library/scripts/library.py +++ b/tools/library/scripts/library.py @@ -240,6 +240,8 @@ class MathOperation(enum.Enum): xor_popc = enum_auto() multiply_add_fast_bf16 = enum_auto() multiply_add_fast_f16 = enum_auto() + multiply_add_fast_f32 = enum_auto() + multiply_add_complex_fast_f32 = enum_auto() multiply_add_complex = enum_auto() multiply_add_complex_gaussian = enum_auto() @@ -250,6 +252,8 @@ MathOperationTag = { MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16', + MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32', + MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32', MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex', MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex', }