diff --git a/include/cutlass/gemm/thread/mma_sm60.h b/include/cutlass/gemm/thread/mma_sm60.h index 562c682e..839e07a7 100644 --- a/include/cutlass/gemm/thread/mma_sm60.h +++ b/include/cutlass/gemm/thread/mma_sm60.h @@ -70,15 +70,17 @@ struct Mma_HFMA2; // Specialization for NNN // ///////////////////////////// -template +template struct Mma_HFMA2 < - Shape, + Shape_, layout::ColumnMajor, layout::ColumnMajor, layout::ColumnMajor, true > { + using Shape = Shape_; + static_assert( !(Shape::kM % 2), "Mma_HFMA2 requires the M dimension to be divisible by 2." @@ -159,15 +161,17 @@ struct Mma_HFMA2 < // Specialization for NNT // ///////////////////////////// -template +template struct Mma_HFMA2< - Shape, + Shape_, layout::ColumnMajor, layout::ColumnMajor, layout::RowMajor, true > { + using Shape = Shape_; + static_assert( !(Shape::kN % 2), "Mma_HFMA2 requires the N dimension to be divisible by 2." @@ -253,15 +257,17 @@ struct Mma_HFMA2< // Specialization for NTN // ///////////////////////////// -template +template struct Mma_HFMA2 < - Shape, + Shape_, layout::ColumnMajor, layout::RowMajor, layout::ColumnMajor, true > { + using Shape = Shape_; + static_assert( !(Shape::kM % 2), "Mma_HFMA2 requires the GEMM M dimension to be divisible by 2." @@ -342,15 +348,17 @@ struct Mma_HFMA2 < // Specialization for NTT // ///////////////////////////// -template +template struct Mma_HFMA2< - Shape, + Shape_, layout::ColumnMajor, layout::RowMajor, layout::RowMajor, true > { + using Shape = Shape_; + static_assert( !(Shape::kN % 2), "Mma_HFMA2 requires the N dimension to be divisible by 2." @@ -431,15 +439,17 @@ struct Mma_HFMA2< // Specialization for TNN // ///////////////////////////// -template +template struct Mma_HFMA2 < - Shape, + Shape_, layout::RowMajor, layout::ColumnMajor, layout::ColumnMajor, true > { + using Shape = Shape_; + static_assert( !(Shape::kM % 2), "Mma_HFMA2 requires the M dimension to be divisible by 2." @@ -524,15 +534,17 @@ struct Mma_HFMA2 < // Specialization for TNT // ///////////////////////////// -template +template struct Mma_HFMA2 < - Shape, + Shape_, layout::RowMajor, layout::ColumnMajor, layout::RowMajor, true > { + using Shape = Shape_; + static_assert( !(Shape::kN % 2), "Mma_HFMA2 requires the N dimension to be divisible by 2." @@ -617,15 +629,17 @@ struct Mma_HFMA2 < // Specialization for TTN // ///////////////////////////// -template +template struct Mma_HFMA2 < - Shape, + Shape_, layout::RowMajor, layout::RowMajor, layout::ColumnMajor, true > { + using Shape = Shape_; + static_assert( !(Shape::kM % 2), "Mma_HFMA2 requires the M dimension to be divisible by 2." @@ -711,15 +725,17 @@ struct Mma_HFMA2 < // Specialization for TTT // ///////////////////////////// -template +template struct Mma_HFMA2< - Shape, + Shape_, layout::RowMajor, layout::RowMajor, layout::RowMajor, true > { + using Shape = Shape_; + static_assert( !(Shape::kN % 2), "Mma_HFMA2 requires the N dimension to be divisible by 2." @@ -800,15 +816,17 @@ struct Mma_HFMA2< // Specialization for TNT + Inner Product or 1x1x2K + LayoutC = T // ///////////////////////////////////////////////////////////////////// -template +template struct Mma_HFMA2< - Shape, + Shape_, LayoutA, LayoutB, layout::RowMajor, false > { + using Shape = Shape_; + static_assert( !(Shape::kK % 2), "Mma_HFMA2 requires the K dimension to be divisible by 2." @@ -882,15 +900,17 @@ struct Mma_HFMA2< // Specialization for TNN + Inner Product or 1x1x2K + LayoutC = N // ///////////////////////////////////////////////////////////////////// -template +template struct Mma_HFMA2< - Shape, + Shape_, LayoutA, LayoutB, layout::ColumnMajor, false > { + using Shape = Shape_; + static_assert( !(Shape::kK % 2), "Mma_HFMA2 requires the K dimension to be divisible by 2." diff --git a/test/unit/conv/device/CMakeLists.txt b/test/unit/conv/device/CMakeLists.txt index 15786256..e60c232a 100644 --- a/test/unit/conv/device/CMakeLists.txt +++ b/test/unit/conv/device/CMakeLists.txt @@ -101,6 +101,9 @@ cutlass_test_unit_add_executable( conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu + + # F16 + conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu ) if (CUTLASS_NVCC_MAX_ARCH GREATER_EQUAL 80) diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu new file mode 100644 index 00000000..cd555e46 --- /dev/null +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu @@ -0,0 +1,128 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + + +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv2d_testbed.h" + + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM60_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_simt_f16, + 128x128_8x2_64x64x8) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementA, + cutlass::layout::TensorNHWC, + ElementB, + cutlass::layout::TensorNHWC, + ElementC, + cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm60, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); + +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM60_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_simt_f16, + 128x128_8x2_64x64x8) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementA, + cutlass::layout::TensorNHWC, + ElementB, + cutlass::layout::TensorNHWC, + ElementC, + cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm60, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); + +}