From 3cfa5db2a24b3a68c8640bd102e5b7054bae947b Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Wed, 16 Feb 2022 06:53:21 -0800 Subject: [PATCH] =?UTF-8?q?Actually=20use=20float=20accumulation=20in=20ge?= =?UTF-8?q?mm=5Ff16t=5Ff16t=5Ff16t=5Fwmma=5Ftensor=5Fop=E2=80=A6=20(#407)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Actually use float accumulation in gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu As title * Update gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu change the missing one Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com> --- ..._f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu index 2a270013..3abeae37 100644 --- a/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu @@ -50,7 +50,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -81,7 +81,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16 TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x16x16) { // single cta, two warps horizontally using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -112,7 +112,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x1 TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x16x16) { // single cta, two warps vertically using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -143,7 +143,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x1 TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { // single cta, two warps horizontally two waprs vertically using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -174,7 +174,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) { using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -205,7 +205,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -236,7 +236,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -267,7 +267,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1 TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x16x16) { using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -299,7 +299,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x1 TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -333,7 +333,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16 TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x16) { using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -367,7 +367,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x16) { using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t,