/*************************************************************************************************** * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this list of * conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright notice, this list of * conditions and the following disclaimer in the documentation and/or other materials * provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used * to endorse or promote products derived from this software without specific prior written * permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Unit tests for threadblock level GEMV */ #include "../../common/cutlass_unit_test.h" #include "cutlass/aligned_buffer.h" #include "cutlass/numeric_types.h" #include "cutlass/gemm/gemm.h" #include "cutlass/layout/matrix.h" #include "cutlass/tensor_ref.h" #include "cutlass/core_io.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/tensor_view_io.h" #include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/gemm.h" #include "cutlass/gemm/threadblock/gemv.h" #include "cutlass/gemm/threadblock/default_gemv_core.h" ///////////////////////////////////////////////////////////////////////////////////////////////// namespace test { namespace gemm { namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// template __global__ void batched_gemv_threadblock_test_kernel( cutlass::gemm::GemmCoord problem_size, LongIndex stride_a, LongIndex stride_b, LongIndex stride_c, RefA ref_A, RefB ref_B, RefC ref_C ) { typename Gemv::IteratorA::TensorCoord threadblock_offset_A(0, 0); typename Gemv::IteratorB::TensorCoord threadblock_offset_B(0, 0); typename Gemv::IteratorB::TensorCoord threadblock_offset_C(0, 0); // Move to the right batches for these threads ref_A.add_pointer_offset(threadIdx.y * stride_a); ref_B.add_pointer_offset(threadIdx.y * stride_b); ref_C.add_pointer_offset(threadIdx.y * stride_c); // Construct iterators to A and B operands typename Gemv::IteratorA::Params params_A(ref_A.layout()); typename Gemv::IteratorA iterator_A(params_A, ref_A.data(), { problem_size.m(), problem_size.k() }, 0, threadblock_offset_A); typename Gemv::IteratorB::Params params_B(ref_B.layout()); typename Gemv::IteratorB iterator_B(params_B, ref_B.data(), { problem_size.k(), problem_size.n() }, threadIdx.x, threadblock_offset_B); Gemv gemv; typename Gemv::FragmentC accum; accum.clear(); // Compute threadblock-scoped matrix multiply-add gemv(problem_size, accum, iterator_A, iterator_B, accum); // IteratorC is PitchLinear<> assumes n() contiguous typename Gemv::IteratorC::Params params_C(ref_C.layout()); typename Gemv::IteratorC iterator_C(params_C, ref_C.data(), { problem_size.m(), problem_size.n() }, threadIdx.x, threadblock_offset_C); iterator_C.store(accum); } ///////////////////////////////////////////////////////////////////////////////////////////////// template void batched_gemv_threadblock_test(cutlass::gemm::GemmCoord problem_size, int num_batch) { using Shape = Shape_; using ElementA = ElementAB_; using LayoutA = LayoutA_; using ElementB = ElementAB_; using LayoutB = LayoutB_; using ElementC = ElementC_; using LayoutC = LayoutC_; using ThreadShape = cutlass::gemm::GemmShape<1, LDG_N, LDG_K>; using Core = typename cutlass::gemm::threadblock::DefaultGemvCore< Shape, ThreadShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC >; if (DEBUG) { num_batch = 1; } using Mma = cutlass::gemm::threadblock::Gemv; // Create host tensors that will be the backing store for the batches // Note that no device memory is initially allocated cutlass::HostTensor matrix_A({problem_size.m(), problem_size.k()}, false); cutlass::HostTensor matrix_B({problem_size.k(), problem_size.n()}, false); cutlass::HostTensor matrix_C_computed({problem_size.m(), problem_size.n()}, false); cutlass::HostTensor matrix_C_reference({problem_size.m(), problem_size.n()}, false); // Reserve memory for the batch of tensors matrix_A.reserve(problem_size.m()*problem_size.k()*num_batch); matrix_B.reserve(problem_size.n()*problem_size.k()*num_batch); matrix_C_computed.reserve(problem_size.m()*problem_size.n()*num_batch); matrix_C_reference.reserve(problem_size.m()*problem_size.n()*num_batch, false); // Fill eatch tensor batch const int seed = 6834; for (int b = 0; b < num_batch; b++) { if(DEBUG) { cutlass::reference::host::BlockFillSequential( matrix_A.host_data_ptr_offset(b*matrix_A.capacity()), matrix_A.capacity()); cutlass::reference::host::BlockFillSequential( matrix_B.host_data_ptr_offset(b*matrix_B.capacity()), matrix_B.capacity()); } else { cutlass::reference::host::TensorFillRandomUniform( matrix_A.host_view(b*matrix_A.capacity()), seed + 1660, 8, -8, 0 ); cutlass::reference::host::TensorFillRandomUniform( matrix_B.host_view(b*matrix_B.capacity()), seed + 1880, 8, -8, 0 ); } cutlass::reference::host::TensorFill(matrix_C_computed.host_view(b*matrix_C_computed.capacity())); cutlass::reference::host::TensorFill(matrix_C_reference.host_view(b*matrix_C_reference.capacity())); } matrix_A.sync_device(); matrix_B.sync_device(); matrix_C_computed.sync_device(); dim3 grid(1, 1); // only 1 CTA is used dim3 block(Shape::kN / LDG_N, num_batch, 1); #if 0 printf("block dim = %d x %d\n", block.x, block.y); #endif // Some sanity checks EXPECT_TRUE( problem_size.n() % LDG_N == 0 ); EXPECT_TRUE( block.x*block.y <= MAX_THREADS_PER_BLOCK ); test::gemm::threadblock::batched_gemv_threadblock_test_kernel<<< grid, block >>>( problem_size, matrix_A.capacity(), matrix_B.capacity(), matrix_C_computed.capacity(), matrix_A.device_ref(), matrix_B.device_ref(), matrix_C_computed.device_ref() ); cudaError_t result = cudaDeviceSynchronize(); EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); matrix_C_computed.sync_host(); // Compute the batched gemms for (int b = 0; b < num_batch; b++) { cutlass::reference::host::Gemm reference_gemm; reference_gemm( problem_size.mnk(), ElementC(1), matrix_A.host_ref(b*matrix_A.capacity()), matrix_B.host_ref(b*matrix_B.capacity()), ElementC(0), matrix_C_reference.host_ref(b*matrix_C_computed.capacity()) ); bool passed = cutlass::reference::host::TensorEquals( matrix_C_computed.host_view(b*matrix_C_computed.capacity()), matrix_C_reference.host_view(b*matrix_C_reference.capacity())); EXPECT_TRUE(passed) //<< "A:\n" << matrix_A.host_view() << "\n" //<< "B:\n" << matrix_B.host_view() << "\n" << "Batch: " << b << "\n" << "Reference:\n" << matrix_C_reference.host_view(b*matrix_C_reference.capacity()) << "\n" << "Computed:\n" << matrix_C_computed.host_view(b*matrix_C_computed.capacity()) << "\n"; } } } // namespace threadblock } // namespace gemm } // namespace test ///////////////////////////////////////////////////////////////////////////////////////////////// // A: ColumnMajor // B: RowMajor // C: ColumnMajor TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_fp32_fp32_2N_2K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 64, 64); const int num_batch = 4; const int LDG_N = 2; const int LDG_K = 2; using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 5x1x128x128_crc_fp32_fp32_4N_4K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 128, 128); const int num_batch = 5; const int LDG_N = 4; const int LDG_K = 4; using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 16x1x17x64_crc_fp32_fp32_1N_4K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 17, 64); const int num_batch = 16; const int LDG_N = 1; const int LDG_K = 4; using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_fp16_fp32_2N_2K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 64, 64); const int num_batch = 4; const int LDG_N = 2; const int LDG_K = 2; using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_fp16_fp32_2N_8K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 64, 64); const int num_batch = 4; const int LDG_N = 2; const int LDG_K = 8; using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 16x1x17x64_crc_fp16_fp32_1N_4K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 17, 64); const int num_batch = 16; const int LDG_N = 1; const int LDG_K = 4; using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_i8_i32_2N_4K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 64, 64); const int num_batch = 4; const int LDG_N = 2; const int LDG_K = 4; using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 16x1x17x64_crc_i8_i32_1N_4K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 17, 64); const int num_batch = 16; const int LDG_N = 1; const int LDG_K = 4; using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } // A: RowMajor // B: ColumnMajor // C: RowMajor TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_fp32_fp32_2N_2K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 64, 64); const int num_batch = 4; const int LDG_N = 2; const int LDG_K = 2; using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 5x1x128x128_rcr_fp32_fp32_4N_4K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 128, 128); const int num_batch = 5; const int LDG_N = 4; const int LDG_K = 4; using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcr_fp32_fp32_1N_4K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 17, 64); const int num_batch = 16; const int LDG_N = 1; const int LDG_K = 4; using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_fp16_fp32_2N_2K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 64, 64); const int num_batch = 4; const int LDG_N = 2; const int LDG_K = 2; using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_fp16_fp32_2N_8K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 64, 64); const int num_batch = 4; const int LDG_N = 2; const int LDG_K = 8; using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcr_fp16_fp32_1N_4K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 17, 64); const int num_batch = 16; const int LDG_N = 1; const int LDG_K = 4; using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_i8_i32_2N_4K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 64, 64); const int num_batch = 4; const int LDG_N = 2; const int LDG_K = 4; using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcr_i8_i32_1N_4K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 17, 64); const int num_batch = 16; const int LDG_N = 1; const int LDG_K = 4; using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } // A: RowMajor // B: ColumnMajor // C: ColumnMajor TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_fp32_fp32_2N_2K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 64, 64); const int num_batch = 4; const int LDG_N = 2; const int LDG_K = 2; using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 5x1x128x128_rcc_fp32_fp32_4N_4K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 128, 128); const int num_batch = 5; const int LDG_N = 4; const int LDG_K = 4; using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcc_fp32_fp32_1N_4K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 17, 64); const int num_batch = 16; const int LDG_N = 1; const int LDG_K = 4; using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_fp16_fp32_2N_2K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 64, 64); const int num_batch = 4; const int LDG_N = 2; const int LDG_K = 2; using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_fp16_fp32_2N_8K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 64, 64); const int num_batch = 4; const int LDG_N = 2; const int LDG_K = 8; using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcc_fp16_fp32_1N_4K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 17, 64); const int num_batch = 16; const int LDG_N = 1; const int LDG_K = 4; using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_i8_i32_2N_4K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 64, 64); const int num_batch = 4; const int LDG_N = 2; const int LDG_K = 4; using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); } TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcc_i8_i32_1N_4K) { using namespace test::gemm::threadblock; cutlass::gemm::GemmCoord problem_size(1, 17, 64); const int num_batch = 16; const int LDG_N = 1; const int LDG_K = 4; using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>; batched_gemv_threadblock_test(problem_size, num_batch); }