/*************************************************************************************************** * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief GEMM Permute Example. This example computes batched GEMM operations with output results permuted as reshaped tensors. We provide layout plugin as a flexible tool for users to add any customized output tensor permute operation, or any other generalized global memory writeout address computation. To add a customized layout, add new class in include/cutlass/layout/permute.h In this example, we used Tensor4DPermuteBMM0213 layout to perform Batched GEMM with permute([0, 2, 1, 3]) on BMM whole output tensor, and used Tensor5DPermute20314 layout to perform Normal GEMM with permute([2, 0, 3, 1, 4]) on output matrix. The address computations are performed in compute(col_init, row_init, stride_init, BMM_batch_idx) with {col_permute, row_permute and stride_permute} as new addresses after permute op. (check include/cutlass/layout/permute.h) Tips: 1) Make sure to set batch_stride_D to zero for BMM permute; Also the BMM GEMM should be in mode cutlass::gemm::GemmUniversalMode::kBatched instead of kArray 2) When the last dimension is touched in permute op (for example permute([0, 2, 3, 1])), AlignmentC should be set to 1. If the last dimension is untouched, one can set AlignmentC to be larger like 8 in our example. As a result, permute op without touching the last dimension is recommended to obtain the best performance gain. Examples: # Runs a batched GEMM with 96 batches $ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 # Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024) $ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 --k=1024 --verbose=true # Execute batched GEMM and profile with NSight $ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false */ ///////////////////////////////////////////////////////////////////////////////////////////////// #include #include #include #include #include #include #include "cutlass/cutlass.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/util/command_line.h" #include "cutlass/util/distribution.h" #include "cutlass/util/device_memory.h" #include "cutlass/util/tensor_view_io.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm_complex.h" #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/layout/permute.h" /// Tensor4DPermuteBMM0213 ---> /// Permute layout function for 4-D permuted tensors for BMM with BMM output tensor (dimension as [B, M, N]) reshaped /// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM output tensor. const int D1 = 12; /// Tensor5DPermute20314 ---> /// Permute layout function for 5-D permuted tensors with output matrix (dimension as [M, N]) reshaped /// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([2, 0, 3, 1, 4]) on the corresponding output tensor. const int T1 = 16; const int T2 = 3; const int T3 = 8; // Alignment C const int AlignmentC = 8; ///////////////////////////////////////////////////////////////////////////////////////////////// /// Result structure struct Result { double runtime_ms; double gflops; cutlass::Status status; cudaError_t error; bool passed; // // Methods // Result( double runtime_ms = 0, double gflops = 0, cutlass::Status status = cutlass::Status::kSuccess, cudaError_t error = cudaSuccess ): runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } }; ///////////////////////////////////////////////////////////////////////////////////////////////// // Command line options parsing struct Options { bool help; bool error; bool reference_check; cutlass::gemm::GemmCoord problem_each; int batch_count; int iterations; int cuda_streams; bool verbose; float alpha; float beta; // // Methods // Options(): help(false), error(false), reference_check(true), batch_count(-1), iterations(20), cuda_streams(0), verbose(false), alpha(1), beta() { } // Parses the command line void parse(int argc, char const **args) { cutlass::CommandLine cmd(argc, args); if (cmd.check_cmd_line_flag("help")) { help = true; return; } cmd.get_cmd_line_argument("alpha", alpha, 1.0f); cmd.get_cmd_line_argument("beta", beta, 0.0f); cmd.get_cmd_line_argument("iterations", iterations, 20); cmd.get_cmd_line_argument("streams", cuda_streams, 0); cmd.get_cmd_line_argument("verbose", verbose, false); cmd.get_cmd_line_argument("reference-check", reference_check, true); int m, n, k; cmd.get_cmd_line_argument("m", m, 128); cmd.get_cmd_line_argument("n", n, 192); cmd.get_cmd_line_argument("k", k, 128); cmd.get_cmd_line_argument("batch-count", batch_count, 768); cutlass::gemm::GemmCoord problem(m, n, k); problem_each = problem; if (batch_count % D1 != 0){ std::cerr << "\nProblem count error (problem-count = " << batch_count << "). " << "problem-count needs to be divided with no remain by " << D1 << " (D1)." << " (Required by the Batched GEMM permute Tensor4DPermuteBMM0213)\n\n"; error = true; } if (m % (AlignmentC * T1) != 0){ std::cerr << "\nProblem m size error (m = " << m << "). " << "m needs to be divided with no remain by " << (AlignmentC * T1) << " (AlignmentC * T1)." << " (Required by the normal GEMM permute Tensor5DPermute20314)\n\n"; error = true; } if (n % (AlignmentC * (T2 * T3)) != 0){ std::cerr << "\nProblem n size error (n = " << n << "). " << "n needs to be divided with no remain by " << (AlignmentC * (T2 * T3)) << " (AlignmentC * T2 * T3)." << " (Required by the normal GEMM permute Tensor5DPermute20314)\n\n"; error = true; } } /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { out << "39_gemm_permute\n\n" << " 1) This example firstly profiles the performance of a batched GEMM kernel with BMM whole output" << " (including output matrices for each batch) as permuted 4D Tensor." << " The BMM tensor output in shape of [B, M, N] is reshaped as [B/D1, D1, M, N] and then permuted with" << " permute([0, 2, 1, 3]) to be in shape of [B/D1, M, D1, N].\n\n" << " 2) This example also profiles the performance of a normal GEMM kernel with output as permuted 5D Tensor." << " The GEMM matrix output in shape of [M, N] is reshaped as [M/T1, T1, T2, T3, N/T2/T3] and then permuted" << " with permute([2, 0, 3, 1, 4]) to be in shape of [T2, M/T1, T3, T1, N/T2/T3].\n\n" << " Note: D1, T1, T2, T3 are compile-time constants defined in gemm_permute.cu\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement.\n\n" << " --batch-count= Sets the number of batches in batched GEMM (batch number for BMM). (default: --batch-count=768)\n" << " --m= Sets the M dimension for both batched GEMM and normal GEMM problems. (default: --m=128)\n" << " --n= Sets the N dimension for both batched GEMM and normal GEMM problems. (default: --n=192)\n" << " --k= Sets the K dimension for both batched GEMM and normal GEMM problems. (default: --k=128)\n" << " --alpha= Epilogue scalar alpha (real part)\n" << " --beta= Epilogue scalar beta (real part)\n\n" << " --iterations= Number of profiling iterations to perform.\n" << " --reference-check= If true, performs reference check.\n" << " --verbose= If true, prints problem sizes and batching structure.\n"; out << "\n\nExamples:\n\n" << "# Runs a batched GEMM with 96 batches\n" << "$ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96\n\n" << "# Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024)\n" << "$ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 --k=1024 --verbose=true\n\n" << "# Execute batched GEMM and profile with NSight\n" << "$ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n"; return out; } /// Compute performance in GFLOP/s double gflops(double runtime_s) const { // Number of real-valued multiply-adds int64_t fmas = int64_t(); fmas += problem_each.product() * batch_count; // Two flops per multiply-add return 2.0 * double(fmas) / double(1.0e9) / runtime_s; } }; /////////////////////////////////////////////////////////////////////////////////////////////////// template class Testbed { public: // // Type definitions // using ElementA = typename GemmBatched::ElementA; using ElementB = typename GemmBatched::ElementB; using ElementC = typename GemmBatched::ElementC; using ElementAccumulator = typename GemmBatched::ElementAccumulator; using EpilogueOutputOp = typename GemmBatched::GemmKernel::Epilogue::OutputOp; using ElementCompute = typename EpilogueOutputOp::ElementCompute; using LayoutA = typename GemmBatched::LayoutA; using LayoutB = typename GemmBatched::LayoutB; using LayoutC = typename GemmBatched::LayoutC; using MatrixCoord = typename LayoutC::TensorCoord; private: // // Data members // Options & options; /// Initialization cutlass::Distribution::Kind init_A; cutlass::Distribution::Kind init_B; cutlass::Distribution::Kind init_C; uint32_t seed; cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; cutlass::DeviceAllocation block_C; cutlass::DeviceAllocation block_D; public: // // Methods // Testbed( Options &options_, cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, uint32_t seed_ = 3090 ): options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } /// Verbose BMM info void print_BMM_info_() { // Print batched GEMM std::cout << "Batched GEMM with permute([0, 2, 1, 3]) on BMM whole output tensor:\n"; auto problem = options.problem_each; std::cout << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() << ", batch count: " << options.batch_count << "\n"; std::cout << "output tensor shape: [" << options.batch_count << ", " << problem.m() << ", " << problem.n() <<"]\n"; std::cout << "reshaped as: [" << options.batch_count / D1 << ", " << D1 << ", " << problem.m() << ", " << problem.n() <<"]\n"; std::cout << "finally permuted as: [" << options.batch_count / D1 << ", " << problem.m() << ", " << D1 << ", " << problem.n() <<"]\n"; std::cout << "----------------------------------------------------\n"; } /// Verbose normal GEMM info void print_GEMM_info_() { // Print batched GEMM std::cout << "Normal GEMM with permute([2, 0, 3, 1, 4]):\n"; auto problem = options.problem_each; std::cout << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() << "\n"; std::cout << "output tensor shape: [" << problem.m() << ", " << problem.n() <<"]" << std::endl; std::cout << "reshaped as: [" << problem.m() / T1 << ", " << T1 << ", " << T2 << ", " << T3 << ", " << problem.n() / T2 / T3 <<"]" << std::endl; std::cout << "finally permuted as: [" << T2 << ", " << problem.m() / T1 << ", " << T3 << ", " << T1 << ", " << problem.n() / T2 / T3 <<"]" << std::endl; std::cout << "----------------------------------------------------\n"; } private: /// Helper to initialize a tensor view template void initialize_tensor_( Element *ptr, size_t capacity, cutlass::Distribution::Kind dist_kind, uint32_t seed) { if (dist_kind == cutlass::Distribution::Uniform) { Element scope_max, scope_min; int bits_input = cutlass::sizeof_bits::value; int bits_output = cutlass::sizeof_bits::value; if (bits_input == 1) { scope_max = 2; scope_min = 0; } else if (bits_input <= 8) { scope_max = 2; scope_min = -2; } else if (bits_output == 16) { if (cutlass::sizeof_bits::value <= 16) { scope_max = 5; scope_min = -5; } else { scope_max = 8; scope_min = -8; } } else { scope_max = 8; scope_min = -8; } cutlass::reference::device::BlockFillRandomUniform( ptr, capacity, seed, scope_max, scope_min, 0); } else if (dist_kind == cutlass::Distribution::Gaussian) { cutlass::reference::device::BlockFillRandomGaussian( ptr, capacity, seed, Element(), Element(0.5f)); } else if (dist_kind == cutlass::Distribution::Sequential) { // Fill with increasing elements cutlass::reference::device::BlockFillSequential( ptr, capacity, Element(1), Element()); } else { // Fill with all 1s cutlass::reference::device::BlockFillSequential( ptr, capacity, Element(), Element(1)); } } /// Initializes data structures void initialize_(int batch_count) { // // Choose random problem sizes // // construct a few problems of random sizes srand(seed); int64_t total_elements_A = options.problem_each.m() * options.problem_each.k() * batch_count; int64_t total_elements_B = options.problem_each.n() * options.problem_each.k() * batch_count; int64_t total_elements_C = options.problem_each.m() * options.problem_each.n() * batch_count; int64_t total_elements_D = options.problem_each.m() * options.problem_each.n() * batch_count; // // Assign space // block_A.reset(total_elements_A); block_B.reset(total_elements_B); block_C.reset(total_elements_C); block_D.reset(total_elements_D); // // Initialize the problems of the workspace // initialize_tensor_(block_A.get(), total_elements_A, init_A, seed * 2021); initialize_tensor_(block_B.get(), total_elements_B, init_B, seed * 2022); initialize_tensor_(block_C.get(), total_elements_C, init_C, seed * 2023); cutlass::reference::device::BlockFillSequential( block_D.get(), total_elements_D, ElementC(), ElementC()); } /// Verifies the BMM GEMM result bool verify_BMM_() { bool passed = true; cutlass::gemm::GemmCoord problem = options.problem_each; LayoutA layout_A(LayoutA::packed({problem.m(), problem.k()}).stride(0)); LayoutB layout_B(LayoutB::packed({problem.k(), problem.n()}).stride(0)); LayoutC layout_C(LayoutC::packed({problem.m(), problem.n()}).stride(0)); LayoutC layout_D(LayoutC::packed({problem.m(), problem.n()}).stride(0)); MatrixCoord extent_A{problem.m(), problem.k()}; MatrixCoord extent_B{problem.k(), problem.n()}; MatrixCoord extent_C{problem.m(), problem.n()}; cutlass::TensorView view_A(block_A.get(), layout_A, extent_A); cutlass::TensorView view_B(block_B.get(), layout_B, extent_B); cutlass::TensorView view_C(block_C.get(), layout_C, extent_C); cutlass::DeviceAllocation block_Ref(layout_D.capacity(extent_C) * options.batch_count); cutlass::TensorView view_Ref_device(block_Ref.get(), layout_D, extent_C); // Reference GEMM cutlass::reference::device::GemmComplex< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ElementCompute, ElementAccumulator >( problem, options.alpha, view_A, GemmBatched::kTransformA, view_B, GemmBatched::kTransformB, options.beta, view_C, view_Ref_device, ElementAccumulator(0), options.batch_count, options.problem_each.m() * options.problem_each.k(), options.problem_each.n() * options.problem_each.k(), options.problem_each.m() * options.problem_each.n(), options.problem_each.m() * options.problem_each.n() ); // Copy to host memory std::vector matrix_D(layout_D.capacity(extent_C) * options.batch_count); std::vector matrix_Ref(layout_D.capacity(extent_C) * options.batch_count); cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get(), matrix_D.size()); cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size()); // Print out the results and reference in 4D Tensor // [options.batch_count, options.problem_each.m() * options.problem_each.n()] -> [D0, D1, D2, D3]. // After permute Op, -> [D0, D2, D1, D3]. int D0 = options.batch_count / D1; int D2 = options.problem_each.m(); int D3 = options.problem_each.n(); cutlass::TensorView view_D_Tensor(matrix_D.data(), // if LayoutC = cutlass::layout::ColumnMajor, view_D_Tensor should be constructed differently cutlass::layout::TensorNHWC().packed(cutlass::Tensor4DCoord({D0, D2, D1, D3})), cutlass::Tensor4DCoord({D0, D2, D1, D3})); cutlass::TensorView view_Ref_Tensor(matrix_Ref.data(), cutlass::layout::TensorNHWC().packed(cutlass::Tensor4DCoord({D0, D1, D2, D3})), cutlass::Tensor4DCoord({D0, D1, D2, D3})); // Tensor Permute Op on reference tensor cutlass::HostTensor view_Ref_Permute_Tensor(cutlass::Tensor4DCoord({D0, D2, D1, D3})); for (int n = 0; n < D0; ++n) { for (int h = 0; h < D1; ++h) { for (int w = 0; w < D2; ++w) { for (int c = 0; c < D3; ++c) { view_Ref_Permute_Tensor.at({n, w, h, c}) = view_Ref_Tensor.at({n, h, w, c}); } } } } // Reference check passed = cutlass::reference::host::TensorEquals(view_Ref_Permute_Tensor.host_view(), view_D_Tensor); if (!passed) { std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl; return passed; } std::cout << "Passed verification" << std::endl; return passed; } bool verify_GEMM_normal_() { bool passed = true; cutlass::gemm::GemmCoord problem = options.problem_each; LayoutA layout_A(LayoutA::packed({problem.m(), problem.k()}).stride(0)); LayoutB layout_B(LayoutB::packed({problem.k(), problem.n()}).stride(0)); LayoutC layout_C(LayoutC::packed({problem.m(), problem.n()}).stride(0)); LayoutC layout_D(LayoutC::packed({problem.m(), problem.n()}).stride(0)); MatrixCoord extent_A{problem.m(), problem.k()}; MatrixCoord extent_B{problem.k(), problem.n()}; MatrixCoord extent_C{problem.m(), problem.n()}; cutlass::TensorView view_A(block_A.get(), layout_A, extent_A); cutlass::TensorView view_B(block_B.get(), layout_B, extent_B); cutlass::TensorView view_C(block_C.get(), layout_C, extent_C); cutlass::DeviceAllocation block_Ref(layout_D.capacity(extent_C)); cutlass::TensorView view_Ref_device(block_Ref.get(), layout_D, extent_C); // Reference GEMM cutlass::reference::device::GemmComplex< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ElementCompute, ElementAccumulator >( problem, options.alpha, view_A, GemmBatched::kTransformA, view_B, GemmBatched::kTransformB, options.beta, view_C, view_Ref_device, ElementAccumulator(0) ); // Copy to host memory std::vector matrix_D(layout_D.capacity(extent_C)); std::vector matrix_Ref(layout_D.capacity(extent_C)); cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get(), matrix_D.size()); cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size()); // Print out the results and reference in 5D Tensor // [options.problem_each.m(), options.problem_each.n()] -> [T0, T1, T2, T3, T4]. // options.problem_each.m() == T0 * T1 // options.problem_each.n() == T2 * T3 * T4 // After permute Op, -> [T2, T0, T3, T1, T4]. int T0 = options.problem_each.m() / T1; int T4 = options.problem_each.n() / T2 / T3; cutlass::TensorView view_D_Tensor(matrix_D.data(), // if LayoutC = cutlass::layout::ColumnMajor, view_D_Tensor should be constructed differently cutlass::layout::TensorNDHWC().packed(cutlass::Tensor5DCoord({T2, T0, T3, T1, T4})), cutlass::Tensor5DCoord({T2, T0, T3, T1, T4})); cutlass::TensorView view_Ref_Tensor(matrix_Ref.data(), cutlass::layout::TensorNDHWC().packed(cutlass::Tensor5DCoord({T0, T1, T2, T3, T4})), cutlass::Tensor5DCoord({T0, T1, T2, T3, T4})); // Tensor Permute Op on reference tensor cutlass::HostTensor view_Ref_Permute_Tensor(cutlass::Tensor5DCoord({T2, T0, T3, T1, T4})); for (int n = 0; n < T0; ++n) { for (int d = 0; d < T1; ++d) { for (int h = 0; h < T2; ++h) { for (int w = 0; w < T3; ++w) { for (int c = 0; c < T4; ++c) { view_Ref_Permute_Tensor.at({h, n, w, d, c}) = view_Ref_Tensor.at({n, d, h, w, c}); // permute([2,0,3,1,4]) } } } } } // Reference check passed = cutlass::reference::host::TensorEquals(view_Ref_Permute_Tensor.host_view(), view_D_Tensor); if (!passed) { std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl; return passed; } std::cout << "Passed verification" << std::endl; return passed; } public: /// Executes a conventional batched GEMM kernel. Result profile_batched_kBatched() { std::cout << "\n====================================================" << std::endl; std::cout << "Batched GEMM (CUTLASS):\n" << "====================================================" << std::endl; if (options.verbose) { print_BMM_info_(); } Result result; result.passed = false; // Initialize the problem initialize_(options.batch_count); // Configure the GEMM arguments typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); // Please make sure all problem_sizes are the same for kBatched mode auto problem = options.problem_each; // For regular BMM int64_t batch_stride_C = problem.m() * problem.n(); // For BMM permute output ---> make sure to set batch_stride_D to zero for BMM permute op int64_t batch_stride_D = 0; // Configure GEMM arguments typename GemmBatched::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kBatched, options.problem_each, options.batch_count, epilogue_op, (void*)block_A.get(), (void*)block_B.get(), (void*)block_C.get(), (void*)block_D.get(), problem.m() * problem.k(), problem.n() * problem.k(), batch_stride_C, batch_stride_D, problem.k(), problem.n(), problem.n(), problem.n() }; // Initialize the GEMM object GemmBatched gemm; result.status = gemm.initialize(arguments, nullptr); if (result.status != cutlass::Status::kSuccess) { std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl; return result; } // Run the batched GEMM object result.status = gemm.run(); if (result.status != cutlass::Status::kSuccess) { std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; return result; } // Wait for completion result.error = cudaDeviceSynchronize(); if (result.error != cudaSuccess) { std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); return result; } // // Verify correctness // result.passed = true; if (options.reference_check) { result.passed = verify_BMM_(); } // // Warm-up run of the batched GEMM object // result.status = gemm.run(); if (result.status != cutlass::Status::kSuccess) { std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; return result; } // // Construct events // cudaEvent_t events[2]; for (auto & event : events) { result.error = cudaEventCreate(&event); if (result.error != cudaSuccess) { std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; return -1; } } // Record an event at the start of a series of GEMM operations result.error = cudaEventRecord(events[0]); if (result.error != cudaSuccess) { std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; return result; } // // Run profiling loop // for (int iter = 0; iter < options.iterations; ++iter) { gemm(); } // // Stop profiling loop // // Record an event when the GEMM operations have been launched. result.error = cudaEventRecord(events[1]); if (result.error != cudaSuccess) { std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; return result; } // Wait for work on the device to complete. result.error = cudaEventSynchronize(events[1]); if (result.error != cudaSuccess) { std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; return result; } // Measure elapsed runtime float runtime_ms = 0; result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); if (result.error != cudaSuccess) { std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; return result; } // Compute average runtime and GFLOPs. result.runtime_ms = double(runtime_ms) / double(options.iterations); result.gflops = options.gflops(result.runtime_ms / 1000.0); // // Cleanup // for (auto event : events) { (void)cudaEventDestroy(event); } std::cout << " " << 1 << " batched GEMMs launched\n"; std::cout << std::endl; std::cout << " " << "Batched Runtime: " << result.runtime_ms << " ms\n"; std::cout << " " << "Batched GFLOPs: " << result.gflops << "\n"; return result; } Result profile_GEMM_permute() { std::cout << "\n====================================================" << std::endl; std::cout << "Normal GEMM (CUTLASS):\n" << "====================================================" << std::endl; if (options.verbose) { print_GEMM_info_(); } Result result; result.passed = false; // Initialize the problem initialize_(1); // Configure the GEMM arguments typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); // Please make sure all problem_sizes are the same for kBatched mode auto problem = options.problem_each; // Configure GEMM arguments typename GemmPermute::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, options.problem_each, 1, epilogue_op, (void*)block_A.get(), (void*)block_B.get(), (void*)block_C.get(), (void*)block_D.get(), 0, 0, 0, 0, problem.k(), problem.n(), problem.n(), problem.n() }; // Initialize the GEMM object GemmPermute gemm_normal; result.status = gemm_normal.initialize(arguments, nullptr); if (result.status != cutlass::Status::kSuccess) { std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl; return result; } // Run the normal GEMM object result.status = gemm_normal.run(); if (result.status != cutlass::Status::kSuccess) { std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; return result; } // Wait for completion result.error = cudaDeviceSynchronize(); if (result.error != cudaSuccess) { std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); return result; } // // Verify correctness // result.passed = true; if (options.reference_check) { result.passed = verify_GEMM_normal_(); } // // Warm-up run of the normal GEMM object // result.status = gemm_normal.run(); if (result.status != cutlass::Status::kSuccess) { std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; return result; } // // Construct events // cudaEvent_t events[2]; for (auto & event : events) { result.error = cudaEventCreate(&event); if (result.error != cudaSuccess) { std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; return -1; } } // Record an event at the start of a series of GEMM operations result.error = cudaEventRecord(events[0]); if (result.error != cudaSuccess) { std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; return result; } // // Run profiling loop // for (int iter = 0; iter < options.iterations; ++iter) { gemm_normal(); } // // Stop profiling loop // // Record an event when the GEMM operations have been launched. result.error = cudaEventRecord(events[1]); if (result.error != cudaSuccess) { std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; return result; } // Wait for work on the device to complete. result.error = cudaEventSynchronize(events[1]); if (result.error != cudaSuccess) { std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; return result; } // Measure elapsed runtime float runtime_ms = 0; result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); if (result.error != cudaSuccess) { std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; return result; } // Compute average runtime and GFLOPs. result.runtime_ms = double(runtime_ms) / double(options.iterations); result.gflops = options.gflops(result.runtime_ms / 1000.0); // // Cleanup // for (auto event : events) { (void)cudaEventDestroy(event); } std::cout << std::endl; std::cout << " " << "Normal Runtime: " << result.runtime_ms << " ms" << std::endl; std::cout << " " << "Normal GFLOPs: " << result.gflops << "\n"; return result; } }; /////////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char const **args) { // // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. // cudaDeviceProp props; cudaError_t error = cudaGetDeviceProperties(&props, 0); if (error != cudaSuccess) { std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; return -1; } if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { // // This example requires an NVIDIA Ampere-architecture GPU. // std::cout << "CUTLASS's Grouped GEMM example requires a GPU of NVIDIA's Ampere Architecture or " << "later (compute capability 80 or greater).\n"; return 0; } // // Parse options // Options options; options.parse(argc, args); if (options.help) { options.print_usage(std::cout) << std::endl; return 0; } if (options.error) { std::cerr << "Aborting execution." << std::endl; return -1; } // // Define the GEMM types // using ElementOutput = cutlass::half_t; using ElementAccumulator = float; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::RowMajor; // // Define a conventional batched GEMM type // // Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8 using GemmBatched = cutlass::gemm::device::GemmUniversal< cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, cutlass::gemm::GemmShape<128, 128, 32>, cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, cutlass::epilogue::thread::LinearCombination< ElementOutput, AlignmentC, //128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, 4, 8, /*alignmentA*/ 8, /*alignmengB*/ cutlass::arch::OpMultiplyAdd, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, false, /*GatherA*/ false, /*GatherB*/ false, /*ScatterD*/ cutlass::layout::Tensor4DPermuteBMM0213 /*PermuteDLayout*/ >; // Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8 using GemmPermute = cutlass::gemm::device::GemmUniversal< cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, cutlass::gemm::GemmShape<128, 128, 32>, cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, cutlass::epilogue::thread::LinearCombination< ElementOutput, AlignmentC, //128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, 4, 8, /*alignmentA*/ 8, /*alignmengB*/ cutlass::arch::OpMultiplyAdd, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, false, /*GatherA*/ false, /*GatherB*/ false, /*ScatterD*/ cutlass::layout::Tensor5DPermute20314 /*PermuteDLayout*/ >; // // Profile it // Testbed testbed(options); Result result; result = testbed.profile_batched_kBatched(); if (!result.passed) { std::cout << "Profiling batched GEMM has failed.\n"; std::cout << "\nFailed\n"; } else { std::cout << "\nPassed CUTLASS batched GEMM\n"; } result = testbed.profile_GEMM_permute(); if (!result.passed) { std::cout << "Profiling normal GEMM has failed.\n"; std::cout << "\nFailed\n"; } else { std::cout << "\nPassed CUTLASS normal GEMM\n"; } std::cout << "\n====================================================" << std::endl; std::cout << "Finished\n"; std::cout << "====================================================" << std::endl; return 0; } /////////////////////////////////////////////////////////////////////////////////////////////////