From dbfced05e73bd80e8c5ded5c60e29556474a8a2f Mon Sep 17 00:00:00 2001 From: Alexander Zinoviev <8257131+alexander-zinoviev@users.noreply.github.com> Date: Wed, 10 Jul 2024 08:00:52 -0700 Subject: [PATCH] Fix typos in convolution tests (#1433) --- ...rad_implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 ++++++++-------- ...rad_implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 ++++++++-------- ...rad_implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 ++++++++-------- ...rop_implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 ++++++++-------- ...rop_implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 ++++++++-------- ...rop_implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 ++++++++-------- ...rad_implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 ++++++++-------- ...rad_implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 ++++++++-------- ...rad_implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 ++++++++-------- 9 files changed, 72 insertions(+), 72 deletions(-) diff --git a/test/unit/conv/device_3x/dgrad/sm90_conv1d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu b/test/unit/conv/device_3x/dgrad/sm90_conv1d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu index 4206bb48..fa98bdee 100644 --- a/test/unit/conv/device_3x/dgrad/sm90_conv1d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu +++ b/test/unit/conv/device_3x/dgrad/sm90_conv1d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu @@ -79,7 +79,7 @@ TEST(SM90_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 6 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -122,7 +122,7 @@ TEST(SM90_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 6 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -165,7 +165,7 @@ TEST(SM90_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 6 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -208,7 +208,7 @@ TEST(SM90_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 6 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -256,7 +256,7 @@ TEST(SM90_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 1 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -299,7 +299,7 @@ TEST(SM90_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 1 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -342,7 +342,7 @@ TEST(SM90_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 1 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -385,7 +385,7 @@ TEST(SM90_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 1 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, diff --git a/test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu b/test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu index 9cb96c35..d8f09fb6 100644 --- a/test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu +++ b/test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu @@ -80,7 +80,7 @@ TEST(SM90_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -124,7 +124,7 @@ TEST(SM90_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -168,7 +168,7 @@ TEST(SM90_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -212,7 +212,7 @@ TEST(SM90_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -260,7 +260,7 @@ TEST(SM90_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -304,7 +304,7 @@ TEST(SM90_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -348,7 +348,7 @@ TEST(SM90_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -392,7 +392,7 @@ TEST(SM90_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, diff --git a/test/unit/conv/device_3x/dgrad/sm90_conv3d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu b/test/unit/conv/device_3x/dgrad/sm90_conv3d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu index 03954253..deead8e5 100644 --- a/test/unit/conv/device_3x/dgrad/sm90_conv3d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu +++ b/test/unit/conv/device_3x/dgrad/sm90_conv3d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu @@ -77,7 +77,7 @@ TEST(SM90_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -122,7 +122,7 @@ TEST(SM90_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -166,7 +166,7 @@ TEST(SM90_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -210,7 +210,7 @@ TEST(SM90_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -258,7 +258,7 @@ TEST(SM90_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -302,7 +302,7 @@ TEST(SM90_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -346,7 +346,7 @@ TEST(SM90_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -390,7 +390,7 @@ TEST(SM90_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kDgrad, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, diff --git a/test/unit/conv/device_3x/fprop/sm90_conv1d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu b/test/unit/conv/device_3x/fprop/sm90_conv1d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu index 694d0d85..924b977e 100644 --- a/test/unit/conv/device_3x/fprop/sm90_conv1d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu +++ b/test/unit/conv/device_3x/fprop/sm90_conv1d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu @@ -76,7 +76,7 @@ TEST(SM90_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 6 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -119,7 +119,7 @@ TEST(SM90_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 6 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -162,7 +162,7 @@ TEST(SM90_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 6 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -205,7 +205,7 @@ TEST(SM90_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 6 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -253,7 +253,7 @@ TEST(SM90_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 1 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -296,7 +296,7 @@ TEST(SM90_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 1 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -339,7 +339,7 @@ TEST(SM90_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 1 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -382,7 +382,7 @@ TEST(SM90_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 1 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, diff --git a/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu b/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu index 7eaca6a3..7768c7f6 100644 --- a/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu +++ b/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu @@ -77,7 +77,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -121,7 +121,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -165,7 +165,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -209,7 +209,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -257,7 +257,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -301,7 +301,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -345,7 +345,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -389,7 +389,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, diff --git a/test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu b/test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu index 5f838a7b..91841f61 100644 --- a/test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu +++ b/test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu @@ -77,7 +77,7 @@ TEST(SM90_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -121,7 +121,7 @@ TEST(SM90_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -165,7 +165,7 @@ TEST(SM90_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -209,7 +209,7 @@ TEST(SM90_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -257,7 +257,7 @@ TEST(SM90_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -301,7 +301,7 @@ TEST(SM90_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -345,7 +345,7 @@ TEST(SM90_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -389,7 +389,7 @@ TEST(SM90_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kFprop, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, diff --git a/test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu b/test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu index df69443d..85057119 100644 --- a/test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu +++ b/test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu @@ -76,7 +76,7 @@ TEST(SM90_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 6 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -119,7 +119,7 @@ TEST(SM90_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 6 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -162,7 +162,7 @@ TEST(SM90_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 6 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -205,7 +205,7 @@ TEST(SM90_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 6 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -253,7 +253,7 @@ TEST(SM90_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 1 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -296,7 +296,7 @@ TEST(SM90_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 1 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -339,7 +339,7 @@ TEST(SM90_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 1 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -382,7 +382,7 @@ TEST(SM90_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f16, 1 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNWC, 8, + ElementAct, cutlass::layout::TensorNWC, 8, ElementFlt, cutlass::layout::TensorNWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, diff --git a/test/unit/conv/device_3x/wgrad/sm90_conv2d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu b/test/unit/conv/device_3x/wgrad/sm90_conv2d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu index 0b257794..c7c8f0ec 100644 --- a/test/unit/conv/device_3x/wgrad/sm90_conv2d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu +++ b/test/unit/conv/device_3x/wgrad/sm90_conv2d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu @@ -77,7 +77,7 @@ TEST(SM90_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -121,7 +121,7 @@ TEST(SM90_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -165,7 +165,7 @@ TEST(SM90_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -209,7 +209,7 @@ TEST(SM90_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -257,7 +257,7 @@ TEST(SM90_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -301,7 +301,7 @@ TEST(SM90_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -345,7 +345,7 @@ TEST(SM90_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -389,7 +389,7 @@ TEST(SM90_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNHWC, 8, + ElementAct, cutlass::layout::TensorNHWC, 8, ElementFlt, cutlass::layout::TensorNHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, diff --git a/test/unit/conv/device_3x/wgrad/sm90_conv3d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu b/test/unit/conv/device_3x/wgrad/sm90_conv3d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu index e9c3f59d..54173452 100644 --- a/test/unit/conv/device_3x/wgrad/sm90_conv3d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu +++ b/test/unit/conv/device_3x/wgrad/sm90_conv3d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu @@ -77,7 +77,7 @@ TEST(SM90_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -121,7 +121,7 @@ TEST(SM90_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -165,7 +165,7 @@ TEST(SM90_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -210,7 +210,7 @@ TEST(SM90_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -258,7 +258,7 @@ TEST(SM90_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -302,7 +302,7 @@ TEST(SM90_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -346,7 +346,7 @@ TEST(SM90_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK, @@ -390,7 +390,7 @@ TEST(SM90_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_ using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::conv::Operator::kWgrad, - ElementAcc, cutlass::layout::TensorNDHWC, 8, + ElementAct, cutlass::layout::TensorNDHWC, 8, ElementFlt, cutlass::layout::TensorNDHWC, 8, ElementAcc, TileShapeMNK, ClusterShapeMNK,