diff --git a/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu b/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu index 2a270013..3abeae37 100644 --- a/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu +++ b/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu @@ -50,7 +50,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -81,7 +81,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16 TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x16x16) { // single cta, two warps horizontally using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -112,7 +112,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x1 TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x16x16) { // single cta, two warps vertically using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -143,7 +143,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x1 TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { // single cta, two warps horizontally two waprs vertically using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -174,7 +174,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) { using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -205,7 +205,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -236,7 +236,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -267,7 +267,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1 TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x16x16) { using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -299,7 +299,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x1 TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -333,7 +333,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16 TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x16) { using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, @@ -367,7 +367,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x16) { using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; + using ElementAccumulator = float; using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t,