449 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			449 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| /***************************************************************************************************
 | |
|  * Copyright (c) 2017-2021, NVIDIA CORPORATION.  All rights reserved.
 | |
|  *
 | |
|  * Redistribution and use in source and binary forms, with or without modification, are permitted
 | |
|  * provided that the following conditions are met:
 | |
|  *     * Redistributions of source code must retain the above copyright notice, this list of
 | |
|  *       conditions and the following disclaimer.
 | |
|  *     * Redistributions in binary form must reproduce the above copyright notice, this list of
 | |
|  *       conditions and the following disclaimer in the documentation and/or other materials
 | |
|  *       provided with the distribution.
 | |
|  *     * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
 | |
|  *       to endorse or promote products derived from this software without specific prior written
 | |
|  *       permission.
 | |
|  *
 | |
|  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
 | |
|  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 | |
|  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
 | |
|  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 | |
|  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
 | |
|  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
 | |
|  * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | |
|  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
|  *
 | |
|  **************************************************************************************************/
 | |
| 
 | |
| #include <iostream>
 | |
| 
 | |
| #include "cutlass/cutlass.h"
 | |
| #include "cutlass/gemm/device/gemm.h"
 | |
| 
 | |
| #include "cutlass/util/command_line.h"
 | |
| #include "cutlass/util/host_tensor.h"
 | |
| #include "cutlass/util/reference/device/gemm.h"
 | |
| #include "cutlass/util/reference/host/tensor_compare.h"
 | |
| #include "cutlass/util/reference/host/tensor_copy.h"
 | |
| #include "cutlass/util/reference/host/tensor_fill.h"
 | |
| #include "cutlass/util/tensor_view_io.h"
 | |
| 
 | |
| #include "helper.h"
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| /// Result structure
 | |
| struct Result {
 | |
| 
 | |
|   double runtime_ms;
 | |
|   double gflops;
 | |
|   cutlass::Status status;
 | |
|   cudaError_t error;
 | |
|   bool passed;
 | |
| 
 | |
|   //
 | |
|   // Methods
 | |
|   //
 | |
| 
 | |
|   Result(
 | |
|     double runtime_ms = 0,
 | |
|     double gflops = 0,
 | |
|     cutlass::Status status = cutlass::Status::kSuccess,
 | |
|     cudaError_t error = cudaSuccess
 | |
|   ):
 | |
|     runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
 | |
| };
 | |
| 
 | |
| ///////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| // Command line options parsing
 | |
| struct Options {
 | |
| 
 | |
|   bool help;
 | |
| 
 | |
|   cutlass::gemm::GemmCoord problem_size;
 | |
|   int batch_count;
 | |
|   cutlass::Quaternion<float> alpha;
 | |
|   cutlass::Quaternion<float> beta;
 | |
| 
 | |
|   bool reference_check;
 | |
|   int iterations;
 | |
|   
 | |
|   Options():
 | |
|     help(false),
 | |
|     problem_size({1024, 1024, 1024}),
 | |
|     batch_count(1),
 | |
|     reference_check(true),
 | |
|     iterations(20),
 | |
|     alpha(1),
 | |
|     beta() { }
 | |
| 
 | |
|   bool valid() {
 | |
|     return true;
 | |
|   }
 | |
| 
 | |
|   // Parses the command line
 | |
|   void parse(int argc, char const **args) {
 | |
|     cutlass::CommandLine cmd(argc, args);
 | |
| 
 | |
|     if (cmd.check_cmd_line_flag("help")) {
 | |
|       help = true;
 | |
|     }
 | |
| 
 | |
|     cmd.get_cmd_line_argument("m", problem_size.m());
 | |
|     cmd.get_cmd_line_argument("n", problem_size.n());
 | |
|     cmd.get_cmd_line_argument("k", problem_size.k());
 | |
|     cmd.get_cmd_line_argument("batch", batch_count);
 | |
| 
 | |
|     cmd.get_cmd_line_argument("alpha",   alpha.w());
 | |
|     cmd.get_cmd_line_argument("alpha_i", alpha.x());
 | |
|     cmd.get_cmd_line_argument("alpha_j", alpha.y());
 | |
|     cmd.get_cmd_line_argument("alpha_k", alpha.z());
 | |
| 
 | |
|     cmd.get_cmd_line_argument("beta",   beta.w());
 | |
|     cmd.get_cmd_line_argument("beta_i", beta.x());
 | |
|     cmd.get_cmd_line_argument("beta_j", beta.y());
 | |
|     cmd.get_cmd_line_argument("beta_k", beta.z());
 | |
|     
 | |
|     cmd.get_cmd_line_argument("iterations", iterations);
 | |
| 
 | |
|   }
 | |
| 
 | |
|   /// Prints the usage statement.
 | |
|   std::ostream & print_usage(std::ostream &out) const {
 | |
| 
 | |
|     out << "21_quaternion_gemm example\n\n"
 | |
|       << "  This example uses the CUTLASS Library to execute Quaternion GEMM computations.\n\n"
 | |
|       << "Options:\n\n"
 | |
|       << "  --help                      If specified, displays this usage statement.\n\n"
 | |
|       << "  --m <int>                   GEMM M dimension\n"
 | |
|       << "  --n <int>                   GEMM N dimension\n"
 | |
|       << "  --k <int>                   GEMM K dimension\n"
 | |
|       << "  --batch <int>               Number of GEMM operations executed in one batch\n"
 | |
|       << "  --alpha <f32>               Epilogue scalar alpha (real part)\n"
 | |
|       << "  --alpha_i <f32>             Epilogue scalar alpha_i (imaginary part)\n"
 | |
|       << "  --alpha_j <f32>             Epilogue scalar alpha_j (imaginary part)\n"
 | |
|       << "  --alpha_k <f32>             Epilogue scalar alpha_k (imaginary part)\n"
 | |
|       << "  --beta <f32>                Epilogue scalar beta (real part)\n\n"
 | |
|       << "  --beta_i <f32>              Epilogue scalar beta_i (imaginary part)\n\n"
 | |
|       << "  --beta_j <f32>              Epilogue scalar beta_j (imaginary part)\n\n"
 | |
|       << "  --beta_k <f32>              Epilogue scalar beta_k (imaginary part)\n\n"
 | |
|       << "  --iterations <int>          Number of profiling iterations to perform.\n\n";
 | |
| 
 | |
|     out << "\n\nExamples:\n\n"
 | |
|       << "$ ./examples/21_quaternion_gemm/21_quaternion_gemm  --batch=7 --m=1024 --n=512 --k=1024 \\\n"
 | |
|       << "     --alpha=2 --alpha_i=-2 --beta=0.707 --beta_i=-.707\n\n";
 | |
| 
 | |
|     return out;
 | |
|   }
 | |
| 
 | |
|   /// Compute performance in GFLOP/s
 | |
|   double gflops(double runtime_s) const {
 | |
| 
 | |
|     // Number of real-valued multiply-adds 
 | |
|     int64_t fmas = problem_size.product() * batch_count * 16;
 | |
|     
 | |
|     // Two flops per multiply-add
 | |
|     return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
 | |
|   }
 | |
| };
 | |
| 
 | |
| ///////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| // The code section below describes datatype for input, output matrices and computation between
 | |
| // elements in input matrices.
 | |
| using precision = float;
 | |
| using Element = cutlass::Quaternion<float>;
 | |
| using ElementComputeEpilogue = Element;  // <- data type of epilogue operations
 | |
| using ElementAccumulator = Element;      // <- data type of accumulator
 | |
| using ElementInputA = Element;           // <- data type of elements in input matrix A
 | |
| using ElementInputB = Element;           // <- data type of elements in input matrix B
 | |
| using ElementOutput = Element;           // <- data type of elements in output matrix D
 | |
| 
 | |
| // The code section below describes matrix layout of input and output matrices. Column Major for
 | |
| // Matrix A, Row Major for Matrix B and Row Major for Matrix C
 | |
| using LayoutInputA = cutlass::layout::RowMajor;
 | |
| using LayoutInputB = cutlass::layout::ColumnMajor;
 | |
| using LayoutOutput = cutlass::layout::RowMajor;
 | |
| 
 | |
| // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
 | |
| using MMAOp = cutlass::arch::OpClassSimt;
 | |
| 
 | |
| // This code section describes CUDA SM architecture number
 | |
| using SmArch = cutlass::arch::Sm50;
 | |
| 
 | |
| // This code section describes the tile size a thread block will compute
 | |
| using ShapeMMAThreadBlock =
 | |
|     cutlass::gemm::GemmShape<64, 64, 4>;                   // <- threadblock tile M = 64, N = 64, K = 8
 | |
| // This code section describes tile size a warp will compute
 | |
| using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 16, 4>;  // <- warp tile M = 32, N = 16, K = 8
 | |
| // This code section describes the size of MMA op
 | |
| using ShapeMMAOp = cutlass::gemm::GemmShape<1, 1, 1>;      // <- MMA Op tile M = 1, N = 1, K = 1
 | |
| 
 | |
| // This code section describes how threadblocks are scheduled on GPU
 | |
| using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;  // <- Defaults
 | |
| 
 | |
| // This code section describes the epilogue part of the kernel
 | |
| using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
 | |
|     ElementOutput,                                    // <- data type of output matrix
 | |
|     128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
 | |
|                                                       // memory access. For a byte, it's 16
 | |
|                                                       // elements. This becomes the vector width of
 | |
|                                                       // math instructions in the epilogue too
 | |
|     ElementAccumulator,                               // <- data type of accumulator
 | |
|     ElementComputeEpilogue>;                          // <- data type for alpha/beta in linear combination function
 | |
| 
 | |
| // Number of pipelines you want to use
 | |
| constexpr int NumStages = 2;
 | |
| 
 | |
| using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
 | |
|                                          LayoutInputA,
 | |
|                                          ElementInputB,
 | |
|                                          LayoutInputB,
 | |
|                                          ElementOutput,
 | |
|                                          LayoutOutput,
 | |
|                                          ElementAccumulator,
 | |
|                                          MMAOp,
 | |
|                                          SmArch,
 | |
|                                          ShapeMMAThreadBlock,
 | |
|                                          ShapeMMAWarp,
 | |
|                                          ShapeMMAOp,
 | |
|                                          EpilogueOp,
 | |
|                                          SwizzleThreadBlock,
 | |
|                                          NumStages>;
 | |
| 
 | |
| int run(Options options) {
 | |
| 
 | |
|   // PASS/FAIL status
 | |
|   bool passed = true;
 | |
| 
 | |
|   // Create a tuple of problem size for matrix multiplication
 | |
|   cutlass::gemm::GemmCoord problem_size = options.problem_size;
 | |
| 
 | |
|   // Initialize tensors using CUTLASS helper functions
 | |
|   cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
 | |
|       problem_size.mk());  // <- Create matrix A with dimensions M x K
 | |
|   cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
 | |
|       problem_size.kn());  // <- Create matrix B with dimensions K x N
 | |
|   cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
 | |
|       problem_size.mn());  // <- Create matrix C with dimensions M x N
 | |
|   cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
 | |
|       problem_size.mn());  // <- Create matrix D with dimensions M x N used to store output from
 | |
|                            // CUTLASS kernel
 | |
|   cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
 | |
|       problem_size.mn());  // <- Create matrix D with dimensions M x N used to store output from
 | |
|                            // reference kernel
 | |
| 
 | |
|   // Fill input and output matrices on host using CUTLASS helper functions
 | |
|   cutlass::reference::host::TensorFillRandomUniform(
 | |
|       tensor_a.host_view(),
 | |
|       1,
 | |
|       4,
 | |
|       -4,
 | |
|       0);  // <- Fill matrix A on host with uniform-distribution random data
 | |
| 
 | |
|   cutlass::reference::host::TensorFillRandomUniform(
 | |
|       tensor_b.host_view(),
 | |
|       1,
 | |
|       4,
 | |
|       -4,
 | |
|       0);  // <- Fill matrix B on host with uniform-distribution random data
 | |
| 
 | |
|   cutlass::reference::host::TensorFillRandomUniform(
 | |
|       tensor_c.host_view(),
 | |
|       1,
 | |
|       4,
 | |
|       -4,
 | |
|       0);  // <- Fill matrix C on host with uniform-distribution random data
 | |
| 
 | |
|   cutlass::reference::host::TensorFill(
 | |
|       tensor_d.host_view());  // <- fill matrix D on host with zeros
 | |
|   cutlass::reference::host::TensorFill(
 | |
|       tensor_ref_d.host_view());  // <- fill matrix D for reference on host with zeros
 | |
| 
 | |
|   // Copy data from host to GPU
 | |
|   tensor_a.sync_device();
 | |
|   tensor_b.sync_device();
 | |
|   tensor_c.sync_device();
 | |
|   tensor_d.sync_device();
 | |
|   tensor_ref_d.sync_device();
 | |
| 
 | |
|   // Initialize alpha and beta for dot product computation
 | |
|   ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
 | |
|   ElementComputeEpilogue beta = ElementComputeEpilogue(0);
 | |
| 
 | |
|   // Split K dimension into 1 partitions
 | |
|   int split_k_slices = 1;
 | |
| 
 | |
|   // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
 | |
|   // instantiated CUTLASS kernel
 | |
|   typename Gemm::Arguments arguments{problem_size,  // <- problem size of matrix multiplication
 | |
|                                      tensor_a.device_ref(),  // <- reference to matrix A on device
 | |
|                                      tensor_b.device_ref(),  // <- reference to matrix B on device
 | |
|                                      tensor_c.device_ref(),  // <- reference to matrix C on device
 | |
|                                      tensor_d.device_ref(),  // <- reference to matrix D on device
 | |
|                                      {alpha, beta},          // <- tuple of alpha and beta
 | |
|                                      split_k_slices};        // <- k-dimension split factor
 | |
| 
 | |
|   // Using the arguments, query for extra workspace required for matrix multiplication computation
 | |
|   size_t workspace_size = Gemm::get_workspace_size(arguments);
 | |
| 
 | |
|   // Allocate workspace memory
 | |
|   cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
 | |
| 
 | |
|   // Instantiate CUTLASS kernel depending on templates
 | |
|   Gemm gemm_op;
 | |
| 
 | |
|   // Check the problem size is supported or not 
 | |
|   cutlass::Status status = gemm_op.can_implement(arguments);
 | |
|   CUTLASS_CHECK(status);
 | |
| 
 | |
|   // Initialize CUTLASS kernel with arguments and workspace pointer
 | |
|   status = gemm_op.initialize(arguments, workspace.get());
 | |
|   CUTLASS_CHECK(status);
 | |
|   
 | |
|   // Result structure
 | |
|   Result result;
 | |
| 
 | |
|   //
 | |
|   // Construct events
 | |
|   //
 | |
| 
 | |
|   cudaEvent_t events[2];
 | |
| 
 | |
|   for (auto & event : events) {
 | |
|     result.error = cudaEventCreate(&event);
 | |
|     if (result.error != cudaSuccess) {
 | |
|       std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
 | |
|       return -1;
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   // Record an event at the start of a series of GEMMs
 | |
|   result.error = cudaEventRecord(events[0]);
 | |
|   if (result.error != cudaSuccess) {
 | |
|     std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
 | |
|     return -1;
 | |
|   }
 | |
| 
 | |
|   //
 | |
|   // Run profiling loop
 | |
|   //
 | |
| 
 | |
|   for (int iter = 0; iter < options.iterations; ++iter) {
 | |
| 
 | |
|     // Launch initialized CUTLASS kernel
 | |
|     status = gemm_op();
 | |
|     CUTLASS_CHECK(status);
 | |
| 
 | |
|   }
 | |
| 
 | |
|   //
 | |
|   // Stop profiling loop
 | |
|   //
 | |
| 
 | |
|   // Record an event when the GEMMs are complete
 | |
|   result.error = cudaEventRecord(events[1]);
 | |
|   if (result.error != cudaSuccess) {
 | |
|     std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
 | |
|     return -1;
 | |
|   }
 | |
| 
 | |
|   // Wait for work on the device to complete.
 | |
|   result.error = cudaEventSynchronize(events[1]);
 | |
|   if (result.error != cudaSuccess) {
 | |
|     std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
 | |
|     return -1;
 | |
|   }
 | |
| 
 | |
|   // Measure elapsed runtime
 | |
|   float runtime_ms = 0;
 | |
|   result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
 | |
|   if (result.error != cudaSuccess) {
 | |
|     std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
 | |
|     return -1;
 | |
|   }
 | |
| 
 | |
|   // Compute average runtime and GFLOPs.
 | |
|   result.runtime_ms = double(runtime_ms) / double(options.iterations);
 | |
|   result.gflops = options.gflops(result.runtime_ms / 1000.0);
 | |
| 
 | |
|   // Cleanup
 | |
|   for (auto event : events) {
 | |
|     (void)cudaEventDestroy(event);
 | |
|   }
 | |
| 
 | |
|   if (options.reference_check) {
 | |
| 
 | |
|     // Create instantiation for device reference gemm kernel
 | |
|     cutlass::reference::device::Gemm<ElementInputA,
 | |
|                                      LayoutInputA,
 | |
|                                      ElementInputB,
 | |
|                                      LayoutInputB,
 | |
|                                      ElementOutput,
 | |
|                                      LayoutOutput,
 | |
|                                      ElementComputeEpilogue,
 | |
|                                      ElementComputeEpilogue> gemm_device;
 | |
| 
 | |
|     // Launch device reference gemm kernel
 | |
|     gemm_device(problem_size,
 | |
|                 alpha,
 | |
|                 tensor_a.device_ref(),
 | |
|                 tensor_b.device_ref(),
 | |
|                 beta,
 | |
|                 tensor_c.device_ref(),
 | |
|                 tensor_ref_d.device_ref());
 | |
| 
 | |
|     // Wait for kernels to finish
 | |
|     cudaDeviceSynchronize();
 | |
| 
 | |
|     // Copy output data from CUTLASS and reference kernel to host for comparison
 | |
|     tensor_d.sync_host();
 | |
|     tensor_ref_d.sync_host();
 | |
| 
 | |
|     // Check if output from CUTLASS kernel and reference kernel are equal or not
 | |
|     passed &= cutlass::reference::host::TensorEquals(
 | |
|       tensor_d.host_view(),
 | |
|       tensor_ref_d.host_view());
 | |
| 
 | |
|   }
 | |
| 
 | |
|   if (passed) {
 | |
|     std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl;
 | |
|     std::cout << " GFLOPs: " << result.gflops << std::endl;
 | |
|   }
 | |
| 
 | |
|   std::cout << (passed ? "Passed" : "Failed") << std::endl;
 | |
|   return (passed ? 0  : -1);
 | |
| }
 | |
| 
 | |
| int main(int argc, char const** argv) {
 | |
| 
 | |
|   Options options;
 | |
|   options.parse(argc, argv);
 | |
| 
 | |
|   if (options.help) {
 | |
|     options.print_usage(std::cout) << std::endl;
 | |
|     return 0;
 | |
|   }
 | |
| 
 | |
|   printf("%d x %d x %d Single Precision Quaternion Matrix Multiply\n", \
 | |
|     options.problem_size.m(), options.problem_size.n(), options.problem_size.k());
 | |
| 
 | |
|   if (!options.valid()) {
 | |
|     std::cerr << "Invalid problem." << std::endl;
 | |
|     return -1;
 | |
|   }
 | |
| 
 | |
|   return run(options);
 | |
| }
 | |
| 
 | 
