Actually use float accumulation in gemm_f16t_f16t_f16t_wmma_tensor_op… (#407)

* Actually use float accumulation in gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu

As title

* Update gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu

change the missing one

Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
This commit is contained in:
Jongsoo Park 2022-02-16 06:53:21 -08:00 committed by GitHub
parent 1db6971a8d
commit 3cfa5db2a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -50,7 +50,7 @@
TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t,
@ -81,7 +81,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16
TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x16x16) {
// single cta, two warps horizontally
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t,
@ -112,7 +112,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x1
TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x16x16) {
// single cta, two warps vertically
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t,
@ -143,7 +143,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x1
TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) {
// single cta, two warps horizontally two waprs vertically
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t,
@ -174,7 +174,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t,
@ -205,7 +205,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x
TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t,
@ -236,7 +236,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x
TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t,
@ -267,7 +267,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1
TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x16x16) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t,
@ -299,7 +299,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x1
TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t,
@ -333,7 +333,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16
TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x16) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t,
@ -367,7 +367,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x
TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x16) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t,