1224 lines
		
	
	
		
			47 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			1224 lines
		
	
	
		
			47 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| /***************************************************************************************************
 | |
|  * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 | |
|  * SPDX-License-Identifier: BSD-3-Clause
 | |
|  *
 | |
|  * Redistribution and use in source and binary forms, with or without
 | |
|  * modification, are permitted provided that the following conditions are met:
 | |
|  *
 | |
|  * 1. Redistributions of source code must retain the above copyright notice, this
 | |
|  * list of conditions and the following disclaimer.
 | |
|  *
 | |
|  * 2. Redistributions in binary form must reproduce the above copyright notice,
 | |
|  * this list of conditions and the following disclaimer in the documentation
 | |
|  * and/or other materials provided with the distribution.
 | |
|  *
 | |
|  * 3. Neither the name of the copyright holder nor the names of its
 | |
|  * contributors may be used to endorse or promote products derived from
 | |
|  * this software without specific prior written permission.
 | |
|  *
 | |
|  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 | |
|  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 | |
|  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 | |
|  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 | |
|  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 | |
|  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 | |
|  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 | |
|  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 | |
|  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | |
|  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
|  *
 | |
|  **************************************************************************************************/
 | |
| 
 | |
| /*! \file
 | |
|     \brief GEMM Permute Example.
 | |
| 
 | |
|     This example computes batched GEMM operations with output results permuted as reshaped tensors.
 | |
| 
 | |
|     We provide layout plugin as a flexible tool for users to add any customized input/output tensor permute operation, 
 | |
|     or any other generalized global memory writeout address computation. To add a customized layout, add new class
 | |
|     in include/cutlass/layout/permute.h
 | |
| 
 | |
|     In this example we use several permute operations (permute([0, 2, 1, 3]))
 | |
| 
 | |
|     In this example, we used Tensor4DPermuteBMM0213 layout to perform Batched GEMM with permute([0, 2, 1, 3]) on BMM
 | |
|     whole output tensor, and used Tensor5DPermute20314 layout to perform Normal GEMM with permute([2, 0, 3, 1, 4]) on
 | |
|     output matrix. The address computations are performed in compute(col_init, row_init, stride_init, 
 | |
|     BMM_batch_idx) with {col_permute, row_permute and stride_permute} as new addresses after permute op.
 | |
|     (check include/cutlass/layout/permute.h)
 | |
| 
 | |
|     Tips:
 | |
|     
 | |
|       1) Make sure to set batch_stride to zero for BMM permute; also the BMM GEMM should be in mode
 | |
|       cutlass::gemm::GemmUniversalMode::kBatched instead of kArray.
 | |
| 
 | |
|       2) When the contiguous dimension is touched in permute op (for example [0, 2, 3, 1] for row-major matrix 
 | |
|       or [1, 0, 2, 3] for column-major), Alignment should be set to 1 for the corresponding matrix. 
 | |
|       If the last dimension is untouched,  one can set Alignment to be larger like 8 in our example.
 | |
|       As a result, permute op without touching the unit stride dimension is recommended to obtain the best performance.
 | |
| 
 | |
|     Examples:
 | |
| 
 | |
|       # Runs a batched GEMM with 96 batches
 | |
|       $ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96
 | |
| 
 | |
|       # Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024)
 | |
|       $ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 --k=1024 --verbose=true
 | |
| 
 | |
|       # Execute batched GEMM and profile with NSight
 | |
|       $ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false
 | |
| 
 | |
| */
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| #include <iostream>
 | |
| #include <fstream>
 | |
| #include <sstream>
 | |
| #include <vector>
 | |
| #include <map>
 | |
| #include <unordered_map>
 | |
| 
 | |
| #include "cutlass/cutlass.h"
 | |
| #include "cutlass/gemm/gemm.h"
 | |
| #include "cutlass/gemm/device/gemm_universal.h"
 | |
| 
 | |
| #include "cutlass/util/command_line.h"
 | |
| #include "cutlass/util/distribution.h"
 | |
| #include "cutlass/util/device_memory.h"
 | |
| #include "cutlass/util/tensor_view_io.h"
 | |
| #include "cutlass/util/host_tensor.h"
 | |
| #include "cutlass/util/reference/host/gemm_complex.h"
 | |
| #include "cutlass/util/reference/device/gemm_complex.h"
 | |
| #include "cutlass/util/reference/host/tensor_compare.h"
 | |
| #include "cutlass/util/reference/device/tensor_compare.h"
 | |
| #include "cutlass/util/reference/host/tensor_copy.h"
 | |
| #include "cutlass/util/reference/device/tensor_fill.h"
 | |
| #include "cutlass/util/reference/host/tensor_norm.h"
 | |
| 
 | |
| #include "cutlass/layout/permute.h"
 | |
| 
 | |
| #include "layouts.h"
 | |
| #include "permute_info.h"
 | |
| 
 | |
| /// Tensor4DPermuteBMM0213 --->
 | |
| /// Permute layout function for 4-D permuted tensors for BMM with BMM tensor (dimension as [B, M, N]) reshaped
 | |
| /// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM tensor.
 | |
| int constexpr D1 = 12;
 | |
| 
 | |
| /// Tensor5DPermute20314 --->
 | |
| /// Permute layout function for 5-D permuted tensors with matrix (dimension as [M, N]) reshaped
 | |
| /// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([2, 0, 3, 1, 4]) on the corresponding tensor.
 | |
| int constexpr T1 = 16; 
 | |
| int constexpr T2 = 3;
 | |
| int constexpr T3 = 8;
 | |
| 
 | |
| /// Tensor4DPermute0213 --->
 | |
| /// Permute layout function for 4-D permuted tensors with matrix (dimension as [M, N]) reshaped
 | |
| /// as [M/S1, S1, S2, N/S2]. Then perform permute([0, 2, 1, 3]) on the corresponding tensor.
 | |
| int constexpr S1 = 8; 
 | |
| int constexpr S2 = 4;
 | |
| 
 | |
| // // // Alignments
 | |
| int constexpr AlignmentA = 8;
 | |
| int constexpr AlignmentB = 8;
 | |
| int constexpr AlignmentC = 8;
 | |
| 
 | |
| /// GEMM element types
 | |
| using ElementInput = cutlass::half_t;
 | |
| using ElementOutput = cutlass::half_t;
 | |
| using ElementAccumulator = float;
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| /// Useful macros
 | |
| 
 | |
| #define CHECK_CUDA_CALL(call, handler) \
 | |
| do { \
 | |
|   cudaError_t __err = (call); \
 | |
|   if (__err != cudaSuccess) { \
 | |
|     std::cerr << #call " failed: " << cudaGetErrorString(__err) << std::endl; \
 | |
|     handler; \
 | |
|   } \
 | |
| } while(0)
 | |
| 
 | |
| #define CHECK_CUTLASS_CALL(call, handler) \
 | |
| do { \
 | |
|   cutlass::Status __status = (call); \
 | |
|   if (__status != cutlass::Status::kSuccess) { \
 | |
|     std::cerr << #call " failed: " << cutlass::cutlassGetStatusString(__status) << std::endl; \
 | |
|     handler; \
 | |
|   } \
 | |
| } while(0)
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| // Command line options parsing
 | |
| struct Options {
 | |
| 
 | |
|   bool help;
 | |
|   bool error;
 | |
|   bool reference_check;
 | |
| 
 | |
|   cutlass::gemm::GemmCoord problem_each;
 | |
| 
 | |
|   int batch_count;
 | |
|   int iterations;
 | |
|   int cuda_streams;
 | |
|   bool verbose;
 | |
|   float alpha;
 | |
|   float beta;
 | |
| 
 | |
|   //
 | |
|   // Methods
 | |
|   // 
 | |
| 
 | |
|   Options():
 | |
|     help(false),
 | |
|     error(false),
 | |
|     reference_check(true),
 | |
|     batch_count(-1),
 | |
|     iterations(20),
 | |
|     cuda_streams(0),
 | |
|     verbose(false),
 | |
|     alpha(1),
 | |
|     beta()
 | |
|   { }
 | |
| 
 | |
|   // Parses the command line
 | |
|   void parse(int argc, char const **args) {
 | |
|     cutlass::CommandLine cmd(argc, args);
 | |
| 
 | |
|     if (cmd.check_cmd_line_flag("help")) {
 | |
|       help = true;
 | |
|       return;
 | |
|     }
 | |
| 
 | |
|     cmd.get_cmd_line_argument("alpha", alpha, 1.0f);
 | |
|     cmd.get_cmd_line_argument("beta", beta, 0.0f);    
 | |
|     cmd.get_cmd_line_argument("iterations", iterations, 20);
 | |
|     cmd.get_cmd_line_argument("streams", cuda_streams, 0);
 | |
|     cmd.get_cmd_line_argument("verbose", verbose, false);
 | |
|     cmd.get_cmd_line_argument("reference-check", reference_check, true);
 | |
| 
 | |
|     int m, n, k;
 | |
| 
 | |
|     cmd.get_cmd_line_argument("m", m, 384);
 | |
|     cmd.get_cmd_line_argument("n", n, 192);
 | |
|     cmd.get_cmd_line_argument("k", k, 384);
 | |
|     cmd.get_cmd_line_argument("batch-count", batch_count, 96);
 | |
| 
 | |
|     problem_each = cutlass::gemm::GemmCoord(m, n, k);
 | |
|   }
 | |
| 
 | |
|   /// Prints the usage statement.
 | |
|   std::ostream & print_usage(std::ostream &out) const {
 | |
| 
 | |
|     out << 
 | |
|       "39_gemm_permute\n"
 | |
|       "\n"
 | |
|       " This example tests and profiles the performance of normal GEMM and batched GEMM with different"
 | |
|       " combinations of fused permutations of input and output tensors."
 | |
|       "\n"
 | |
|       " Permutations considered in this example:\n"
 | |
|       "\n"
 | |
|       " Normal GEMM:\n"
 | |
|       " 1) Tensor4DPermute0213: matrix of shape [X, Y] is reshaped as [X/S1, S1, S2, Y/S2] and has its dimensions"
 | |
|       " permuted as [0, 2, 1, 3], resulting in shape [X/S1, S2, S1, Y/S2] viewed as matrix of shape [X*S2/S1, Y*S1/S2].\n"
 | |
|       " 2) Tensor5DPermute20314: matrix of shape [X, Y] is reshaped as [X/T1, T1, T2, T3, Y/T2/T3] and has its dimensions"
 | |
|       " permuted as [2, 0, 3, 1, 4], resulting in shape [T2, X/T1, T3, T1, Y/T2/T3] viewed as matrix of shape [X*T2/T1, Y*T1/T2].\n"
 | |
|        "\n"
 | |
|       " Batched GEMM:\n"
 | |
|       " 3) Tensor4DPermuteBMM0213: batched tensor of 3D shape [B, X, Y] is reshaped as 4D shape [B/D1, D1, X, Y]"
 | |
|       " and has its dimensions permuted as [0, 2, 1, 3], resulting in shape [B/D1, X, D1, Y] viewed as"
 | |
|       " a matrix of shape [B/D1, X, Y*D1] for batched GEMM purposes.\n"
 | |
|       "\n"
 | |
|       " Note: S1, S2, D1, D2, T1, T2, T3 are compile-time constants defined in gemm_permute.cu."
 | |
|       " Runtime specification of these values is not supported."
 | |
|       " These values along with alignment requirements place constraints on supported matrix sizes.\n"
 | |
|       "\n"
 | |
|       " Note: X, Y above may refer to M, N or K dimensions of GEMM problem, depending on the tensor considered (A, B or D)."
 | |
|       " For the output tensor D the values correspond directly to dimensions of D, whereas for A and B the original dimensions"
 | |
|       " X', Y' are inferred from the ones supplied to the GEMM, taking into account the permute operation.\n"
 | |
|       "\n"
 | |
|       "Options:\n"
 | |
|       "\n"
 | |
|       "  --help                      If specified, displays this usage statement.\n\n"
 | |
|       "  --batch-count=<int>         Sets the number of batches in batched GEMM (batch number for BMM). (default: --batch-count=768)\n"
 | |
|       "  --m=<int>                   Sets the M dimension for both batched GEMM and normal GEMM problems. (default: --m=128)\n"
 | |
|       "  --n=<int>                   Sets the N dimension for both batched GEMM and normal GEMM problems. (default: --n=192)\n"
 | |
|       "  --k=<int>                   Sets the K dimension for both batched GEMM and normal GEMM problems. (default: --k=384)\n"
 | |
|       "  --alpha=<f32>               Epilogue scalar alpha (real part)\n"
 | |
|       "  --beta=<f32>                Epilogue scalar beta (real part)\n\n"
 | |
|       "  --iterations=<int>          Number of profiling iterations to perform.\n"
 | |
|       "  --reference-check=<bool>    If true, performs reference check.\n"
 | |
|       "  --verbose=<bool>            If true, prints problem sizes and batching structure.\n"
 | |
|       "\n"
 | |
|       "Examples:\n"
 | |
|       "\n"
 | |
|       "# Runs a batched GEMM with 96 batches\n"
 | |
|       "$ ./examples/39_gemm_permute/39_gemm_permute --batch-count=96\n"
 | |
|       "\n"
 | |
|       "# Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024)\n"
 | |
|       "$ ./examples/39_gemm_permute/39_gemm_permute --batch-count=96 --k=1024 --verbose=true\n"
 | |
|       "\n"
 | |
|       "# Execute batched GEMM and profile with NSight\n"
 | |
|       "$ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false\n"
 | |
|       "\n";
 | |
| 
 | |
|     return out;
 | |
|   }
 | |
| 
 | |
|   /// Compute performance in GFLOP/s
 | |
|   double gflops(double runtime_s, bool batched) const {
 | |
| 
 | |
|     // Number of real-valued multiply-adds 
 | |
|     int64_t fmas = int64_t();
 | |
| 
 | |
|     fmas += problem_each.product() * (batched ? batch_count : 1);
 | |
|     
 | |
|     // Two flops per multiply-add
 | |
|     return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
 | |
|   }
 | |
| };
 | |
| 
 | |
| ///////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| namespace { // (anonymous)
 | |
| 
 | |
| /// Dimension-generic permutation loop
 | |
| template<int I, typename Element, typename Layout, typename PermuteOp, typename Coord>
 | |
| void permute_host_impl(
 | |
|     cutlass::TensorView<Element const, Layout> const & input,
 | |
|     cutlass::TensorView<Element, Layout> const & output,
 | |
|     PermuteOp && permute,
 | |
|     Coord & coord
 | |
| ) {
 | |
|   static_assert(Layout::kRank == Coord::kRank, "Incompatible Layout and Coord types");
 | |
|   if constexpr (I == Coord::kRank) {
 | |
|     output.at(permute(coord)) = input.at(coord);
 | |
|   }
 | |
|   else {
 | |
|     for (coord[I] = 0; coord[I] < input.extent(I); ++coord[I]) {
 | |
|       permute_host_impl<I+1>(input, output, std::forward<PermuteOp>(permute), coord);
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| } // namespace (anonymous)
 | |
| 
 | |
| /// Perform a reference (host-based) permutation of an input tensor
 | |
| template<typename PermuteLayout, typename Element, typename Layout>
 | |
| void permute_host(
 | |
|     cutlass::TensorView<Element const, Layout> const &input,
 | |
|     cutlass::TensorView<Element, Layout> const &output,
 | |
|     int batch_count) {
 | |
|   Layout layout = input.layout();
 | |
|   cutlass::MatrixCoord extent = input.extent();
 | |
| 
 | |
|   std::size_t num_elems = layout.capacity(extent) * batch_count;
 | |
|   std::vector<Element> h_input(num_elems);
 | |
|   cutlass::device_memory::copy_to_host(h_input.data(), input.data(), num_elems);
 | |
| 
 | |
|   std::vector<Element> h_output(num_elems);
 | |
| 
 | |
|   using Info = PermuteInfo<PermuteLayout>;
 | |
|   using TensorLayout = typename Info::Layout;
 | |
| 
 | |
|   auto shape_orig = Info::original_shape(extent, batch_count);
 | |
|   auto shape_perm = Info::permute(shape_orig);
 | |
| 
 | |
|   cutlass::TensorView<Element const, TensorLayout> view_input(h_input.data(), TensorLayout::packed(shape_orig), shape_orig); 
 | |
|   cutlass::TensorView<Element, TensorLayout> view_output(h_output.data(), TensorLayout::packed(shape_perm), shape_perm);
 | |
| 
 | |
|   decltype(shape_orig) coord;
 | |
|   permute_host_impl<0>(view_input, view_output, Info::permute, coord);
 | |
| 
 | |
|   cutlass::device_memory::copy_to_device(output.data(), h_output.data(), num_elems);
 | |
| }
 | |
| 
 | |
| template<typename Layout>
 | |
| struct LayoutInfo;
 | |
| 
 | |
| template<>
 | |
| struct LayoutInfo<cutlass::layout::RowMajor> {
 | |
|   static std::string name() { return "RowMajor"; }
 | |
| };
 | |
| 
 | |
| template<>
 | |
| struct LayoutInfo<cutlass::layout::ColumnMajor> {
 | |
|   static std::string name() { return "ColumnMajor"; }
 | |
| };
 | |
| 
 | |
| ///////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| template <typename ElementA, typename ElementB, typename ElementC>
 | |
| class Testbed {
 | |
| private:
 | |
| 
 | |
|   //
 | |
|   // Data members
 | |
|   //
 | |
| 
 | |
|   Options & options;
 | |
| 
 | |
|   /// Initialization
 | |
|   cutlass::Distribution::Kind init_A;
 | |
|   cutlass::Distribution::Kind init_B;
 | |
|   cutlass::Distribution::Kind init_C;
 | |
|   uint32_t seed;
 | |
| 
 | |
|   cutlass::DeviceAllocation<ElementA> block_A;
 | |
|   cutlass::DeviceAllocation<ElementB> block_B;
 | |
|   cutlass::DeviceAllocation<ElementC> block_C;
 | |
|   cutlass::DeviceAllocation<ElementC> block_D;
 | |
| 
 | |
| public:
 | |
| 
 | |
|   //
 | |
|   // Methods
 | |
|   //
 | |
| 
 | |
|   Testbed(
 | |
|     Options &options_,
 | |
|     cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
 | |
|     cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
 | |
|     cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
 | |
|     uint32_t seed_ = 3090
 | |
|   ):
 | |
|     options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
 | |
| 
 | |
| private:
 | |
| 
 | |
|   /// Print permutation info for one tensor
 | |
|   template<typename PermuteLayout>
 | |
|   void print_tensor_info(
 | |
|       std::ostream & os,
 | |
|       std::string const &tensor_name,
 | |
|       int row_dim,
 | |
|       int col_dim) {
 | |
| 
 | |
|     cutlass::MatrixCoord extent(options.problem_each.at(row_dim), options.problem_each.at(col_dim));
 | |
|     using Info = PermuteInfo<PermuteLayout>;
 | |
| 
 | |
|     os << "tensor " << tensor_name << ": " << Info::desc() << "\n";
 | |
|     os << "    extent: [" << extent.row() << ", " << extent.column() << "]";
 | |
|     if (Info::kBatched) {
 | |
|       os << ", batch count: " << options.batch_count;
 | |
|     }
 | |
|     os << "\n";
 | |
|     if (!cutlass::layout::is_trivial_permute<PermuteLayout>) {
 | |
|       auto shape_orig = Info::original_shape(extent, options.batch_count);
 | |
|       auto shape_perm = Info::permute(shape_orig);
 | |
|       os << "  original: [" << shape_orig << "]\n";
 | |
|       os << "  permuted: [" << shape_perm << "]\n";
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   /// Check shape compatibility for one tensor
 | |
|   template<typename Layout, typename PermuteLayout, int Alignment>
 | |
|   bool check_tensor_shape(
 | |
|       std::string const &tensor_name,
 | |
|       int row_dim,
 | |
|       int col_dim) {
 | |
| 
 | |
|     cutlass::MatrixCoord extent(options.problem_each.at(row_dim), options.problem_each.at(col_dim));
 | |
| 
 | |
|     using Info = PermuteInfo<PermuteLayout>;
 | |
| 
 | |
|     auto rowAlign = cutlass::platform::is_same<Layout, cutlass::layout::ColumnMajor>::value ? Alignment : 1;
 | |
|     auto colAlign = cutlass::platform::is_same<Layout, cutlass::layout::RowMajor>::value ? Alignment : 1;
 | |
| 
 | |
|     auto rowFactor = Info::kRowFactor * rowAlign;
 | |
|     auto colFactor = Info::kColumnFactor * colAlign;
 | |
| 
 | |
|     // Assumes row-major layout
 | |
|     bool const valid_row = extent.row() % rowFactor == 0;
 | |
|     if (!valid_row) {
 | |
|       std::cerr << "Invalid tensor " << tensor_name << " row size = " << extent.row() << ", "
 | |
|                    "must be divisible by " << rowFactor << ", "
 | |
|                    "required by " << Info::name() << 
 | |
|                    (rowAlign > 1 ? (" and alignment of " + std::to_string(rowAlign)) : "") << std::endl;
 | |
|     }
 | |
| 
 | |
|     bool const valid_col = extent.column() % colFactor == 0;
 | |
|     if (!valid_col) {
 | |
|       std::cerr << "Invalid tensor " << tensor_name << " column size = " << extent.column() << ", "
 | |
|                    "must be divisible by " << colFactor << ", "
 | |
|                    "required by " << Info::name() << 
 | |
|                    (colAlign > 1 ? (" and alignment of " + std::to_string(colAlign)) : "") << std::endl;
 | |
|     }
 | |
| 
 | |
|     bool const valid_bsz = options.batch_count % Info::kBatchFactor == 0;
 | |
|     if (!valid_bsz) {
 | |
|       std::cerr << "Invalid batch count = " << options.batch_count << ", "
 | |
|                    "must be divisible by " << Info::kBatchFactor << ", "
 | |
|                    "required by " << Info::name() << std::endl;
 | |
|     }
 | |
| 
 | |
|     return valid_row && valid_col && valid_bsz;
 | |
|   }
 | |
| 
 | |
|   /// Helper to initialize a tensor view
 | |
|   template <typename Element>
 | |
|   void initialize_tensor_(
 | |
|       Element *ptr,
 | |
|       size_t capacity, 
 | |
|       cutlass::Distribution::Kind dist_kind,
 | |
|       uint32_t seed) {
 | |
| 
 | |
|     if (dist_kind == cutlass::Distribution::Uniform) {
 | |
| 
 | |
|       Element scope_max, scope_min;
 | |
|       int bits_input = cutlass::sizeof_bits<Element>::value;
 | |
|       int bits_output = cutlass::sizeof_bits<ElementC>::value;
 | |
| 
 | |
|       if (bits_input == 1) {
 | |
|         scope_max = 2;
 | |
|         scope_min = 0;
 | |
|       } else if (bits_input <= 8) {
 | |
|         scope_max = 2;
 | |
|         scope_min = -2;
 | |
|       } else if (bits_output == 16) {
 | |
|         if (cutlass::sizeof_bits<ElementAccumulator>::value <= 16) {
 | |
|           scope_max = 5;
 | |
|           scope_min = -5;
 | |
|         }
 | |
|         else {
 | |
|           scope_max = 8;
 | |
|           scope_min = -8;
 | |
|         }
 | |
|       } else {
 | |
|         scope_max = 8;
 | |
|         scope_min = -8;
 | |
|       }
 | |
| 
 | |
|       cutlass::reference::device::BlockFillRandomUniform(
 | |
|         ptr, capacity, seed, scope_max, scope_min, 0);
 | |
|     } 
 | |
|     else if (dist_kind == cutlass::Distribution::Gaussian) {
 | |
| 
 | |
|       cutlass::reference::device::BlockFillRandomGaussian(
 | |
|         ptr, capacity, seed, Element(), Element(0.5f));
 | |
|     }
 | |
|     else if (dist_kind == cutlass::Distribution::Sequential) {
 | |
| 
 | |
|       // Fill with increasing elements
 | |
|       cutlass::reference::device::BlockFillSequential(
 | |
|         ptr, capacity, Element(1), Element());
 | |
|     } 
 | |
|     else {
 | |
| 
 | |
|       // Fill with all 1s
 | |
|       cutlass::reference::device::BlockFillSequential(
 | |
|         ptr, capacity, Element(), Element(1));
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   /// Initializes data structures
 | |
|   void initialize(int batch_count) {
 | |
| 
 | |
|     srand(seed);
 | |
| 
 | |
|     int64_t total_elements_A = options.problem_each.m() * options.problem_each.k() * batch_count;
 | |
|     int64_t total_elements_B = options.problem_each.n() * options.problem_each.k() * batch_count;
 | |
|     int64_t total_elements_C = options.problem_each.m() * options.problem_each.n() * batch_count;
 | |
|     int64_t total_elements_D = options.problem_each.m() * options.problem_each.n() * batch_count;
 | |
| 
 | |
|     // Allocate space
 | |
|     block_A.reset(total_elements_A);
 | |
|     block_B.reset(total_elements_B);
 | |
|     block_C.reset(total_elements_C);
 | |
|     block_D.reset(total_elements_D);
 | |
| 
 | |
|     // Initialize input tensors
 | |
|     initialize_tensor_(block_A.get(), total_elements_A, init_A, seed * 2021);
 | |
|     initialize_tensor_(block_B.get(), total_elements_B, init_B, seed * 2022);
 | |
|     initialize_tensor_(block_C.get(), total_elements_C, init_C, seed * 2023);
 | |
| 
 | |
|     cutlass::reference::device::BlockFillSequential(
 | |
|       block_D.get(), total_elements_D, ElementC(), ElementC());
 | |
|   }
 | |
| 
 | |
| 
 | |
|   /// Check device GEMM results against a reference implementation with separate host-based permutation
 | |
|   template<typename Gemm>
 | |
|   bool validate(Gemm const &gemm) {
 | |
| 
 | |
|     bool constexpr kBatched = PermuteInfo<typename Gemm::PermuteALayout>::kBatched 
 | |
|                            || PermuteInfo<typename Gemm::PermuteBLayout>::kBatched 
 | |
|                            || PermuteInfo<typename Gemm::PermuteDLayout>::kBatched;
 | |
|                       
 | |
|     int const batch_count = kBatched ? options.batch_count : 1;
 | |
| 
 | |
|     cutlass::gemm::GemmCoord problem = options.problem_each;
 | |
| 
 | |
|     cutlass::MatrixCoord extent_A{problem.m(), problem.k()};
 | |
|     cutlass::MatrixCoord extent_B{problem.k(), problem.n()};
 | |
|     cutlass::MatrixCoord extent_C{problem.m(), problem.n()};
 | |
| 
 | |
|     using LayoutA = typename Gemm::LayoutA;
 | |
|     using LayoutB = typename Gemm::LayoutB;
 | |
|     using LayoutC = typename Gemm::LayoutC;
 | |
| 
 | |
|     LayoutA layout_A(LayoutA::packed(extent_A));
 | |
|     LayoutB layout_B(LayoutB::packed(extent_B));
 | |
|     LayoutC layout_C(LayoutC::packed(extent_C));
 | |
| 
 | |
|     auto size_A = layout_A.capacity(extent_A) * batch_count;
 | |
|     auto size_B = layout_B.capacity(extent_B) * batch_count;
 | |
|     auto size_C = layout_C.capacity(extent_C) * batch_count;
 | |
|     
 | |
|     cutlass::TensorView<ElementA, LayoutA> view_A(block_A.get(), layout_A, extent_A);
 | |
|     cutlass::TensorView<ElementB, LayoutB> view_B(block_B.get(), layout_B, extent_B);
 | |
|     cutlass::TensorView<ElementC, LayoutC> view_C(block_C.get(), layout_C, extent_C);
 | |
|     cutlass::TensorView<ElementC, LayoutC> view_D(block_D.get(), layout_C, extent_C);
 | |
| 
 | |
|     cutlass::DeviceAllocation<ElementA> block_A_perm(size_A);
 | |
|     cutlass::DeviceAllocation<ElementA> block_B_perm(size_B);
 | |
| 
 | |
|     cutlass::TensorView<ElementA, LayoutA> view_A_perm(block_A_perm.get(), layout_A, extent_A);
 | |
|     cutlass::TensorView<ElementB, LayoutB> view_B_perm(block_B_perm.get(), layout_B, extent_B);
 | |
| 
 | |
|     permute_host<typename Gemm::PermuteALayout>(view_A.const_view(), view_A_perm, batch_count);
 | |
|     permute_host<typename Gemm::PermuteBLayout>(view_B.const_view(), view_B_perm, batch_count);
 | |
| 
 | |
|     cutlass::DeviceAllocation<ElementC>    block_D_ref(size_C);
 | |
|     cutlass::TensorView<ElementC, LayoutC> view_D_ref(block_D_ref.get(), layout_C, extent_C);
 | |
| 
 | |
|     using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp;
 | |
| 
 | |
|     // Reference GEMM
 | |
|     cutlass::reference::device::GemmComplex<
 | |
|         ElementA, LayoutA,
 | |
|         ElementB, LayoutB,
 | |
|         ElementC, LayoutC, 
 | |
|         typename EpilogueOutputOp::ElementCompute,
 | |
|         typename Gemm::ElementAccumulator
 | |
|     >(
 | |
|       problem,
 | |
|       options.alpha, 
 | |
|       view_A_perm,
 | |
|       Gemm::kTransformA,
 | |
|       view_B_perm,
 | |
|       Gemm::kTransformB,
 | |
|       options.beta, 
 | |
|       view_C, 
 | |
|       view_D_ref, 
 | |
|       ElementAccumulator(0),
 | |
|       batch_count,
 | |
|       options.problem_each.m() * options.problem_each.k(),
 | |
|       options.problem_each.n() * options.problem_each.k(),
 | |
|       options.problem_each.m() * options.problem_each.n(),
 | |
|       options.problem_each.m() * options.problem_each.n()
 | |
|     );
 | |
| 
 | |
|     cutlass::DeviceAllocation<ElementC>    block_D_perm(size_C);
 | |
|     cutlass::TensorView<ElementC, LayoutC> view_D_perm(block_D_perm.get(), layout_C, extent_C);
 | |
|     permute_host<typename Gemm::PermuteDLayout>(view_D_ref.const_view(), view_D_perm, batch_count);
 | |
| 
 | |
|     // Reference check
 | |
|     return cutlass::reference::device::BlockCompareEqual(view_D_perm.data(), view_D.data(), size_C);
 | |
| }
 | |
| 
 | |
| public:
 | |
| 
 | |
|   template<typename Gemm>
 | |
|   bool profile_GEMM_permute() {
 | |
| 
 | |
|     using LayoutA = typename Gemm::LayoutA;
 | |
|     using LayoutB = typename Gemm::LayoutB;
 | |
|     using LayoutC = typename Gemm::LayoutC;
 | |
| 
 | |
|     using PermuteALayout = typename Gemm::PermuteALayout;
 | |
|     using PermuteBLayout = typename Gemm::PermuteBLayout;
 | |
|     using PermuteDLayout = typename Gemm::PermuteDLayout;
 | |
| 
 | |
|     bool constexpr kBatched = PermuteInfo<PermuteALayout>::kBatched 
 | |
|                            || PermuteInfo<PermuteBLayout>::kBatched 
 | |
|                            || PermuteInfo<PermuteDLayout>::kBatched;
 | |
| 
 | |
|     std::cout << "\n"
 | |
|                  "====================================================\n"
 | |
|                  << (kBatched ? "Batched" : "Normal") << " GEMM:"
 | |
|                  << "\n  A=" << LayoutInfo<LayoutA>::name() << "," << PermuteInfo<PermuteALayout>::name()
 | |
|                  << "\n  B=" << LayoutInfo<LayoutB>::name() << "," << PermuteInfo<PermuteBLayout>::name()
 | |
|                  << "\n  D=" << LayoutInfo<LayoutC>::name() << "," << PermuteInfo<PermuteDLayout>::name()
 | |
|                  << "\n"
 | |
|                  "====================================================\n";
 | |
| 
 | |
|     if (options.verbose) {
 | |
|       print_tensor_info<PermuteALayout>(std::cout, "A", 0, 2);
 | |
|       print_tensor_info<PermuteBLayout>(std::cout, "B", 2, 1);
 | |
|       print_tensor_info<PermuteDLayout>(std::cout, "D", 0, 1);
 | |
|     }
 | |
|     std::cout << std::endl;
 | |
| 
 | |
|     bool valid = true;
 | |
|     valid &= check_tensor_shape<LayoutA, PermuteALayout, Gemm::kAlignmentA>("A", 0, 2);
 | |
|     valid &= check_tensor_shape<LayoutB, PermuteBLayout, Gemm::kAlignmentB>("B", 2, 1);
 | |
|     valid &= check_tensor_shape<LayoutC, PermuteDLayout, Gemm::kAlignmentC>("D", 0, 1);
 | |
|     if (!valid)
 | |
|     {
 | |
|       std::cout << "Skipped test" << std::endl;
 | |
|       return true;
 | |
|     }
 | |
| 
 | |
|     int const batch_count = kBatched ? options.batch_count : 1;
 | |
| 
 | |
|     // Initialize the problem
 | |
|     initialize(batch_count);
 | |
| 
 | |
|     // Configure the GEMM arguments
 | |
|     using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp;
 | |
|     typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta);
 | |
| 
 | |
|     // Please make sure all problem_sizes are the same for kBatched mode
 | |
|     auto problem = options.problem_each;
 | |
| 
 | |
|     cutlass::MatrixCoord extent_A{problem.m(), problem.k()};
 | |
|     cutlass::MatrixCoord extent_B{problem.k(), problem.n()};
 | |
|     cutlass::MatrixCoord extent_C{problem.m(), problem.n()};
 | |
| 
 | |
|     LayoutA layout_A(LayoutA::packed(extent_A));
 | |
|     LayoutB layout_B(LayoutB::packed(extent_B));
 | |
|     LayoutC layout_C(LayoutC::packed(extent_C));
 | |
| 
 | |
|     // Configure GEMM arguments
 | |
|     typename Gemm::Arguments arguments{
 | |
|       kBatched ? cutlass::gemm::GemmUniversalMode::kBatched : cutlass::gemm::GemmUniversalMode::kGemm,
 | |
|       problem,
 | |
|       batch_count,
 | |
|       epilogue_op,
 | |
|       (void*)block_A.get(),
 | |
|       (void*)block_B.get(),
 | |
|       (void*)block_C.get(),
 | |
|       (void*)block_D.get(),
 | |
|       // For any non-trivial permute the batch stride must be set to 0
 | |
|       cutlass::layout::is_trivial_permute<PermuteALayout> ? layout_A.capacity(extent_A) : 0,
 | |
|       cutlass::layout::is_trivial_permute<PermuteBLayout> ? layout_B.capacity(extent_B) : 0,
 | |
|       layout_C.capacity(extent_C),
 | |
|       cutlass::layout::is_trivial_permute<PermuteDLayout> ? layout_C.capacity(extent_C) : 0,
 | |
|       layout_A.stride(0),
 | |
|       layout_B.stride(0),
 | |
|       layout_C.stride(0),
 | |
|       layout_C.stride(0),
 | |
|     };
 | |
| 
 | |
|     // Initialize the GEMM object
 | |
|     Gemm gemm_normal;
 | |
| 
 | |
|     CHECK_CUTLASS_CALL(gemm_normal.initialize(arguments, nullptr), return false);
 | |
| 
 | |
|     // Run the normal GEMM object
 | |
|     CHECK_CUTLASS_CALL(gemm_normal.run(), return false);
 | |
| 
 | |
|     // Wait for completion
 | |
|     CHECK_CUDA_CALL(cudaDeviceSynchronize(), return false);
 | |
| 
 | |
|     //
 | |
|     // Verify correctness
 | |
|     //
 | |
|     if (options.reference_check) {
 | |
|       if (validate(gemm_normal)) {
 | |
|         std::cout << "\nPassed verification\n" << std::endl;
 | |
|       }
 | |
|       else {
 | |
|         std::cerr << "\n*** Error - problem failed the QA check ***\n" << std::endl;
 | |
|         return false;
 | |
|       }
 | |
|     }
 | |
| 
 | |
|     // Warm-up run of the normal GEMM object
 | |
|     CHECK_CUTLASS_CALL(gemm_normal.run(), return false);
 | |
| 
 | |
|     // Construct events
 | |
|     cudaEvent_t events[2];
 | |
|     for (auto & event : events) {
 | |
|       CHECK_CUDA_CALL(cudaEventCreate(&event), return false);
 | |
|     }
 | |
| 
 | |
|     // Record an event at the start of a series of GEMM operations
 | |
|     CHECK_CUDA_CALL(cudaEventRecord(events[0]), return false);
 | |
| 
 | |
|     // Run profiling loop
 | |
|     for (int iter = 0; iter < options.iterations; ++iter) {
 | |
|       gemm_normal();
 | |
|     }
 | |
| 
 | |
|     // Record an event when the GEMM operations have been launched.
 | |
|     CHECK_CUDA_CALL(cudaEventRecord(events[1]), return false);
 | |
| 
 | |
|     // Wait for work on the device to complete.
 | |
|     CHECK_CUDA_CALL(cudaEventSynchronize(events[1]), return false);
 | |
| 
 | |
|     // Measure elapsed runtime
 | |
|     float runtime_total_ms = 0;
 | |
|     CHECK_CUDA_CALL(cudaEventElapsedTime(&runtime_total_ms, events[0], events[1]), return false);
 | |
| 
 | |
|     // Compute average runtime and GFLOPs.
 | |
|     double runtime_avg_ms = double(runtime_total_ms) / double(options.iterations);
 | |
|     double gflops = options.gflops(runtime_avg_ms / 1000.0, kBatched);
 | |
| 
 | |
|     // Cleanup
 | |
|     for (auto event : events) {
 | |
|       CHECK_CUDA_CALL(cudaEventDestroy(event), return false);
 | |
|     }
 | |
| 
 | |
|     std::cout << "    Runtime: " << runtime_avg_ms << " ms\n"
 | |
|                  "     GFLOPs: " << gflops << std::endl;
 | |
| 
 | |
|     return true;
 | |
|   }
 | |
| };
 | |
| 
 | |
| /// Shorthand alist for GEMM instantiations
 | |
| template<typename LayoutA, typename PermuteALayout,
 | |
|          typename LayoutB, typename PermuteBLayout,
 | |
|          typename LayoutC, typename PermuteDLayout>
 | |
| using GemmPermute = cutlass::gemm::device::GemmUniversal<
 | |
|   ElementInput, LayoutA,
 | |
|   ElementInput, LayoutB,
 | |
|   ElementOutput, LayoutC,
 | |
|   ElementAccumulator,
 | |
|   cutlass::arch::OpClassTensorOp,
 | |
|   cutlass::arch::Sm80,
 | |
|   cutlass::gemm::GemmShape<128, 128, 32>,
 | |
|   cutlass::gemm::GemmShape<64, 64, 32>,
 | |
|   cutlass::gemm::GemmShape<16, 8, 16>,
 | |
|   cutlass::epilogue::thread::LinearCombination<
 | |
|     ElementOutput, 
 | |
|     AlignmentC, //128 / cutlass::sizeof_bits<ElementOutput>::value,
 | |
|     ElementAccumulator, 
 | |
|     ElementAccumulator
 | |
|   >,
 | |
|   cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
 | |
|   4,              /*kStages*/
 | |
|   AlignmentA,     /*AlignmentA*/
 | |
|   AlignmentB,     /*AlignmentB*/
 | |
|   cutlass::arch::OpMultiplyAdd,
 | |
|   cutlass::ComplexTransform::kNone,
 | |
|   cutlass::ComplexTransform::kNone,
 | |
|   false,  /*GatherA*/
 | |
|   false,  /*GatherB*/
 | |
|   false,  /*ScatterD*/
 | |
|   PermuteDLayout,  /*PermuteDLayout*/
 | |
|   typename cutlass::layout::InversePermute<PermuteALayout>::type,  /*PermuteALayout*/
 | |
|   typename cutlass::layout::InversePermute<PermuteBLayout>::type   /*PermuteBLayout*/
 | |
| >;
 | |
| 
 | |
| ///////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| int main(int argc, char const **args) {
 | |
| 
 | |
|   //
 | |
|   // This example uses mma.sync to directly access Tensor Cores to achieve peak performance.
 | |
|   //
 | |
| 
 | |
|   cudaDeviceProp props;
 | |
| 
 | |
|   CHECK_CUDA_CALL(cudaGetDeviceProperties(&props, 0), return EXIT_FAILURE);
 | |
| 
 | |
|   if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) {
 | |
|   
 | |
|     //
 | |
|     // This example requires an NVIDIA Ampere-architecture GPU.
 | |
|     //
 | |
| 
 | |
|     std::cout << "CUTLASS's GEMM+Permute example requires a GPU of NVIDIA's Ampere Architecture "
 | |
|                  "or later (compute capability 80 or greater).\n";
 | |
| 
 | |
|     return EXIT_SUCCESS;
 | |
|   }
 | |
| 
 | |
|   //
 | |
|   // Parse options
 | |
|   //
 | |
| 
 | |
|   Options options;
 | |
|   
 | |
|   options.parse(argc, args);
 | |
| 
 | |
|   if (options.help) {
 | |
|     options.print_usage(std::cout) << std::endl;
 | |
|     return EXIT_SUCCESS;
 | |
|   }
 | |
| 
 | |
|   if (options.error) {
 | |
|     std::cerr << "Aborting execution." << std::endl;
 | |
|     return EXIT_FAILURE;
 | |
|   }
 | |
| 
 | |
|   //
 | |
|   // Define GEMM types to test
 | |
|   //
 | |
| 
 | |
|   //
 | |
|   // TTT (Row-major) GEMMs
 | |
|   //
 | |
| 
 | |
|   using TTTGemmNormalPermuteNone = GemmPermute<
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using TTTGemmNormalPermuteA = GemmPermute<
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using TTTGemmNormalPermuteAD = GemmPermute<
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>
 | |
|   >;
 | |
| 
 | |
|   using TTTGemmNormalPermuteB = GemmPermute<
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using TTTGemmNormalPermuteBD = GemmPermute<
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>
 | |
|   >;
 | |
| 
 | |
|   using TTTGemmNormalPermuteD = GemmPermute<
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>
 | |
|   >;
 | |
| 
 | |
|   using TTTGemmNormalPermuteAB = GemmPermute<
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using TTTGemmNormalPermuteABD = GemmPermute<
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>
 | |
|   >;
 | |
| 
 | |
|   //
 | |
|   // NNN (Col-major) GEMMs
 | |
|   //
 | |
| 
 | |
|   using NNNGemmNormalPermuteNone = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using NNNGemmNormalPermuteA = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using NNNGemmNormalPermuteAD = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>
 | |
|   >;
 | |
| 
 | |
|   using NNNGemmNormalPermuteB = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor<S1, S2>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using NNNGemmNormalPermuteBD = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor<S1, S2>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>
 | |
|   >;
 | |
| 
 | |
|   using NNNGemmNormalPermuteD = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>
 | |
|   >;
 | |
| 
 | |
|   using NNNGemmNormalPermuteAB = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor<S1, S2>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using NNNGemmNormalPermuteABD = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor<S1, S2>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>
 | |
|   >;
 | |
| 
 | |
|   //
 | |
|   // NNT (Col-major inputs, row-major output) GEMMs
 | |
|   //
 | |
| 
 | |
|   using NNTGemmNormalPermuteNone = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using NNTGemmNormalPermuteA = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using NNTGemmNormalPermuteAD = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>
 | |
|   >;
 | |
| 
 | |
|   using NNTGemmNormalPermuteB = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor<S1, S2>,
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using NNTGemmNormalPermuteBD = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor<S1, S2>,
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>
 | |
|   >;
 | |
| 
 | |
|   using NNTGemmNormalPermuteD = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>
 | |
|   >;
 | |
| 
 | |
|   using NNTGemmNormalPermuteAB = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor<S1, S2>,
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using NNTGemmNormalPermuteABD = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor<S1, S2>,
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>
 | |
|   >;
 | |
| 
 | |
|   //
 | |
|   // TTN (Row-major inputs, col-major output) GEMMs
 | |
|   //
 | |
| 
 | |
|   using TTNGemmNormalPermuteNone = GemmPermute<
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using TTNGemmNormalPermuteA = GemmPermute<
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using TTNGemmNormalPermuteAD = GemmPermute<
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>
 | |
|   >;
 | |
| 
 | |
|   using TTNGemmNormalPermuteB = GemmPermute<
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using TTNGemmNormalPermuteBD = GemmPermute<
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>
 | |
|   >;
 | |
| 
 | |
|   using TTNGemmNormalPermuteD = GemmPermute<
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>
 | |
|   >;
 | |
| 
 | |
|   using TTNGemmNormalPermuteAB = GemmPermute<
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using TTNGemmNormalPermuteABD = GemmPermute<
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::RowMajor,    cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>
 | |
|   >;
 | |
| 
 | |
|   //
 | |
|   // TTT (Row-major) BMMs
 | |
|   //
 | |
| 
 | |
|   using TTTGemmBatchedPermuteA = GemmPermute<
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using TTTGemmBatchedPermuteAD = GemmPermute<
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>
 | |
|   >;
 | |
| 
 | |
|   using TTTGemmBatchedPermuteB = GemmPermute<
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using TTTGemmBatchedPermuteBD = GemmPermute<
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>
 | |
|   >;
 | |
| 
 | |
|   using TTTGemmBatchedPermuteD = GemmPermute<
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>
 | |
|   >;
 | |
| 
 | |
|   using TTTGemmBatchedPermuteAB = GemmPermute<
 | |
|     cutlass::layout::RowMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>
 | |
|   >;
 | |
| 
 | |
|   using TTTGemmBatchedPermuteABD = GemmPermute<
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>,
 | |
|     cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>
 | |
|   >;
 | |
| 
 | |
|   //
 | |
|   // NNN (Col-major) BMMs
 | |
|   //
 | |
| 
 | |
|   using NNNGemmBatchedPermuteA = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using NNNGemmBatchedPermuteAD = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>
 | |
|   >;
 | |
| 
 | |
|   using NNNGemmBatchedPermuteB = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using NNNGemmBatchedPermuteBD = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>
 | |
|   >;
 | |
| 
 | |
|   using NNNGemmBatchedPermuteD = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>
 | |
|   >;
 | |
| 
 | |
|   using NNNGemmBatchedPermuteAB = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
 | |
|   >;
 | |
| 
 | |
|   using NNNGemmBatchedPermuteABD = GemmPermute<
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>,
 | |
|     cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>
 | |
|   >;
 | |
| 
 | |
|   //
 | |
|   // Profile it
 | |
|   //
 | |
| 
 | |
|   Testbed<ElementInput, ElementInput, ElementOutput> testbed(options);
 | |
| 
 | |
|   bool result = true;
 | |
| 
 | |
|   result &= testbed.profile_GEMM_permute<TTTGemmNormalPermuteNone>();
 | |
|   result &= testbed.profile_GEMM_permute<TTTGemmNormalPermuteA>();
 | |
|   result &= testbed.profile_GEMM_permute<TTTGemmNormalPermuteAD>();
 | |
|   result &= testbed.profile_GEMM_permute<TTTGemmNormalPermuteB>();
 | |
|   result &= testbed.profile_GEMM_permute<TTTGemmNormalPermuteBD>();
 | |
|   result &= testbed.profile_GEMM_permute<TTTGemmNormalPermuteD>();
 | |
|   result &= testbed.profile_GEMM_permute<TTTGemmNormalPermuteAB>();
 | |
|   result &= testbed.profile_GEMM_permute<TTTGemmNormalPermuteABD>();
 | |
| 
 | |
|   result &= testbed.profile_GEMM_permute<NNNGemmNormalPermuteNone>();
 | |
|   result &= testbed.profile_GEMM_permute<NNNGemmNormalPermuteA>();
 | |
|   result &= testbed.profile_GEMM_permute<NNNGemmNormalPermuteAD>();
 | |
|   result &= testbed.profile_GEMM_permute<NNNGemmNormalPermuteB>();
 | |
|   result &= testbed.profile_GEMM_permute<NNNGemmNormalPermuteBD>();
 | |
|   result &= testbed.profile_GEMM_permute<NNNGemmNormalPermuteD>();
 | |
|   result &= testbed.profile_GEMM_permute<NNNGemmNormalPermuteAB>();
 | |
|   result &= testbed.profile_GEMM_permute<NNNGemmNormalPermuteABD>();
 | |
| 
 | |
|   result &= testbed.profile_GEMM_permute<NNTGemmNormalPermuteNone>();
 | |
|   result &= testbed.profile_GEMM_permute<NNTGemmNormalPermuteA>();
 | |
|   result &= testbed.profile_GEMM_permute<NNTGemmNormalPermuteAD>();
 | |
|   result &= testbed.profile_GEMM_permute<NNTGemmNormalPermuteB>();
 | |
|   result &= testbed.profile_GEMM_permute<NNTGemmNormalPermuteBD>();
 | |
|   result &= testbed.profile_GEMM_permute<NNTGemmNormalPermuteD>();
 | |
|   result &= testbed.profile_GEMM_permute<NNTGemmNormalPermuteAB>();
 | |
|   result &= testbed.profile_GEMM_permute<NNTGemmNormalPermuteABD>();
 | |
| 
 | |
|   result &= testbed.profile_GEMM_permute<TTNGemmNormalPermuteNone>();
 | |
|   result &= testbed.profile_GEMM_permute<TTNGemmNormalPermuteA>();
 | |
|   result &= testbed.profile_GEMM_permute<TTNGemmNormalPermuteAD>();
 | |
|   result &= testbed.profile_GEMM_permute<TTNGemmNormalPermuteB>();
 | |
|   result &= testbed.profile_GEMM_permute<TTNGemmNormalPermuteBD>();
 | |
|   result &= testbed.profile_GEMM_permute<TTNGemmNormalPermuteD>();
 | |
|   result &= testbed.profile_GEMM_permute<TTNGemmNormalPermuteAB>();
 | |
|   result &= testbed.profile_GEMM_permute<TTNGemmNormalPermuteABD>();
 | |
| 
 | |
|   result &= testbed.profile_GEMM_permute<TTTGemmBatchedPermuteA>();
 | |
|   result &= testbed.profile_GEMM_permute<TTTGemmBatchedPermuteAD>();
 | |
|   result &= testbed.profile_GEMM_permute<TTTGemmBatchedPermuteB>();
 | |
|   result &= testbed.profile_GEMM_permute<TTTGemmBatchedPermuteBD>();
 | |
|   result &= testbed.profile_GEMM_permute<TTTGemmBatchedPermuteD>();
 | |
|   result &= testbed.profile_GEMM_permute<TTTGemmBatchedPermuteAB>();
 | |
|   result &= testbed.profile_GEMM_permute<TTTGemmBatchedPermuteABD>();
 | |
| 
 | |
|   result &= testbed.profile_GEMM_permute<NNNGemmBatchedPermuteA>();
 | |
|   result &= testbed.profile_GEMM_permute<NNNGemmBatchedPermuteAD>();
 | |
|   result &= testbed.profile_GEMM_permute<NNNGemmBatchedPermuteB>();
 | |
|   result &= testbed.profile_GEMM_permute<NNNGemmBatchedPermuteBD>();
 | |
|   result &= testbed.profile_GEMM_permute<NNNGemmBatchedPermuteD>();
 | |
|   result &= testbed.profile_GEMM_permute<NNNGemmBatchedPermuteAB>();
 | |
|   result &= testbed.profile_GEMM_permute<NNNGemmBatchedPermuteABD>();
 | |
| 
 | |
|   std::cout << "\n"
 | |
|                "====================================================\n"
 | |
|                "Finished (" << (result ? "PASS" : "FAIL") << ")\n"
 | |
|                "====================================================" << std::endl;
 | |
| 
 | |
|   return result ? EXIT_SUCCESS : EXIT_FAILURE;
 | |
| }
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
