1127 lines
37 KiB
Plaintext
1127 lines
37 KiB
Plaintext
/***************************************************************************************************
|
|
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
|
|
/*! \file
|
|
\brief GEMM Permute Example.
|
|
|
|
This example computes batched GEMM operations with output results permuted as reshaped tensors.
|
|
|
|
We provide layout plugin as a flexible tool for users to add any customized output tensor permute operation,
|
|
or any other generalized global memory writeout address computation. To add a customized layout, add new class
|
|
in include/cutlass/layout/permute.h
|
|
|
|
In this example, we used Tensor4DPermuteBMM0213 layout to perform Batched GEMM with permute([0, 2, 1, 3]) on BMM
|
|
whole output tensor, and used Tensor5DPermute20314 layout to perform Normal GEMM with permute([2, 0, 3, 1, 4]) on
|
|
output matrix. The address computations are performed in compute(col_init, row_init, stride_init,
|
|
BMM_batch_idx) with {col_permute, row_permute and stride_permute} as new addresses after permute op.
|
|
(check include/cutlass/layout/permute.h)
|
|
|
|
Tips:
|
|
|
|
1) Make sure to set batch_stride_D to zero for BMM permute; Also the BMM GEMM should be in mode
|
|
cutlass::gemm::GemmUniversalMode::kBatched instead of kArray
|
|
|
|
2) When the last dimension is touched in permute op (for example permute([0, 2, 3, 1])), AlignmentC should
|
|
be set to 1. If the last dimension is untouched, one can set AlignmentC to be larger like 8 in our example.
|
|
As a result, permute op without touching the last dimension is recommended to obtain the best performance gain.
|
|
|
|
Examples:
|
|
|
|
# Runs a batched GEMM with 96 batches
|
|
$ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96
|
|
|
|
# Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024)
|
|
$ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 --k=1024 --verbose=true
|
|
|
|
# Execute batched GEMM and profile with NSight
|
|
$ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false
|
|
|
|
*/
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
#include <iostream>
|
|
#include <fstream>
|
|
#include <sstream>
|
|
#include <vector>
|
|
#include <map>
|
|
#include <unordered_map>
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/gemm/gemm.h"
|
|
#include "cutlass/gemm/device/gemm_universal.h"
|
|
|
|
#include "cutlass/util/command_line.h"
|
|
#include "cutlass/util/distribution.h"
|
|
#include "cutlass/util/device_memory.h"
|
|
#include "cutlass/util/tensor_view_io.h"
|
|
#include "cutlass/util/host_tensor.h"
|
|
#include "cutlass/util/reference/host/gemm_complex.h"
|
|
#include "cutlass/util/reference/device/gemm_complex.h"
|
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
|
#include "cutlass/util/reference/host/tensor_copy.h"
|
|
#include "cutlass/util/reference/device/tensor_fill.h"
|
|
#include "cutlass/util/reference/host/tensor_norm.h"
|
|
|
|
#include "cutlass/layout/permute.h"
|
|
|
|
/// Tensor4DPermuteBMM0213 --->
|
|
/// Permute layout function for 4-D permuted tensors for BMM with BMM output tensor (dimension as [B, M, N]) reshaped
|
|
/// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM output tensor.
|
|
const int D1 = 12;
|
|
|
|
/// Tensor5DPermute20314 --->
|
|
/// Permute layout function for 5-D permuted tensors with output matrix (dimension as [M, N]) reshaped
|
|
/// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([2, 0, 3, 1, 4]) on the corresponding output tensor.
|
|
const int T1 = 16;
|
|
const int T2 = 3;
|
|
const int T3 = 8;
|
|
|
|
// Alignment C
|
|
const int AlignmentC = 8;
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Result structure
|
|
struct Result {
|
|
|
|
double runtime_ms;
|
|
double gflops;
|
|
cutlass::Status status;
|
|
cudaError_t error;
|
|
bool passed;
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
Result(
|
|
double runtime_ms = 0,
|
|
double gflops = 0,
|
|
cutlass::Status status = cutlass::Status::kSuccess,
|
|
cudaError_t error = cudaSuccess
|
|
):
|
|
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
|
|
};
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Command line options parsing
|
|
struct Options {
|
|
|
|
bool help;
|
|
bool error;
|
|
bool reference_check;
|
|
|
|
cutlass::gemm::GemmCoord problem_each;
|
|
|
|
int batch_count;
|
|
int iterations;
|
|
int cuda_streams;
|
|
bool verbose;
|
|
float alpha;
|
|
float beta;
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
Options():
|
|
help(false),
|
|
error(false),
|
|
reference_check(true),
|
|
batch_count(-1),
|
|
iterations(20),
|
|
cuda_streams(0),
|
|
verbose(false),
|
|
alpha(1),
|
|
beta()
|
|
{ }
|
|
|
|
// Parses the command line
|
|
void parse(int argc, char const **args) {
|
|
cutlass::CommandLine cmd(argc, args);
|
|
|
|
if (cmd.check_cmd_line_flag("help")) {
|
|
help = true;
|
|
return;
|
|
}
|
|
|
|
cmd.get_cmd_line_argument("alpha", alpha, 1.0f);
|
|
cmd.get_cmd_line_argument("beta", beta, 0.0f);
|
|
cmd.get_cmd_line_argument("iterations", iterations, 20);
|
|
cmd.get_cmd_line_argument("streams", cuda_streams, 0);
|
|
cmd.get_cmd_line_argument("verbose", verbose, false);
|
|
cmd.get_cmd_line_argument("reference-check", reference_check, true);
|
|
|
|
int m, n, k;
|
|
|
|
cmd.get_cmd_line_argument("m", m, 128);
|
|
cmd.get_cmd_line_argument("n", n, 192);
|
|
cmd.get_cmd_line_argument("k", k, 128);
|
|
cmd.get_cmd_line_argument("batch-count", batch_count, 768);
|
|
|
|
cutlass::gemm::GemmCoord problem(m, n, k);
|
|
problem_each = problem;
|
|
|
|
if (batch_count % D1 != 0){
|
|
std::cerr << "\nProblem count error (problem-count = " << batch_count << "). "
|
|
<< "problem-count needs to be divided with no remain by " << D1 << " (D1)."
|
|
<< " (Required by the Batched GEMM permute Tensor4DPermuteBMM0213)\n\n";
|
|
error = true;
|
|
}
|
|
|
|
if (m % (AlignmentC * T1) != 0){
|
|
std::cerr << "\nProblem m size error (m = " << m << "). "
|
|
<< "m needs to be divided with no remain by " << (AlignmentC * T1) << " (AlignmentC * T1)."
|
|
<< " (Required by the normal GEMM permute Tensor5DPermute20314)\n\n";
|
|
error = true;
|
|
}
|
|
|
|
if (n % (AlignmentC * (T2 * T3)) != 0){
|
|
std::cerr << "\nProblem n size error (n = " << n << "). "
|
|
<< "n needs to be divided with no remain by " << (AlignmentC * (T2 * T3)) << " (AlignmentC * T2 * T3)."
|
|
<< " (Required by the normal GEMM permute Tensor5DPermute20314)\n\n";
|
|
error = true;
|
|
}
|
|
}
|
|
|
|
/// Prints the usage statement.
|
|
std::ostream & print_usage(std::ostream &out) const {
|
|
|
|
out << "39_gemm_permute\n\n"
|
|
<< " 1) This example firstly profiles the performance of a batched GEMM kernel with BMM whole output"
|
|
<< " (including output matrices for each batch) as permuted 4D Tensor."
|
|
<< " The BMM tensor output in shape of [B, M, N] is reshaped as [B/D1, D1, M, N] and then permuted with"
|
|
<< " permute([0, 2, 1, 3]) to be in shape of [B/D1, M, D1, N].\n\n"
|
|
<< " 2) This example also profiles the performance of a normal GEMM kernel with output as permuted 5D Tensor."
|
|
<< " The GEMM matrix output in shape of [M, N] is reshaped as [M/T1, T1, T2, T3, N/T2/T3] and then permuted"
|
|
<< " with permute([2, 0, 3, 1, 4]) to be in shape of [T2, M/T1, T3, T1, N/T2/T3].\n\n"
|
|
<< " Note: D1, T1, T2, T3 are compile-time constants defined in gemm_permute.cu\n\n"
|
|
<< "Options:\n\n"
|
|
<< " --help If specified, displays this usage statement.\n\n"
|
|
<< " --batch-count=<int> Sets the number of batches in batched GEMM (batch number for BMM). (default: --batch-count=768)\n"
|
|
<< " --m=<int> Sets the M dimension for both batched GEMM and normal GEMM problems. (default: --m=128)\n"
|
|
<< " --n=<int> Sets the N dimension for both batched GEMM and normal GEMM problems. (default: --n=192)\n"
|
|
<< " --k=<int> Sets the K dimension for both batched GEMM and normal GEMM problems. (default: --k=128)\n"
|
|
<< " --alpha=<f32> Epilogue scalar alpha (real part)\n"
|
|
<< " --beta=<f32> Epilogue scalar beta (real part)\n\n"
|
|
<< " --iterations=<int> Number of profiling iterations to perform.\n"
|
|
<< " --reference-check=<bool> If true, performs reference check.\n"
|
|
<< " --verbose=<bool> If true, prints problem sizes and batching structure.\n";
|
|
|
|
out << "\n\nExamples:\n\n"
|
|
|
|
<< "# Runs a batched GEMM with 96 batches\n"
|
|
<< "$ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96\n\n"
|
|
|
|
<< "# Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024)\n"
|
|
<< "$ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 --k=1024 --verbose=true\n\n"
|
|
|
|
<< "# Execute batched GEMM and profile with NSight\n"
|
|
<< "$ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n";
|
|
|
|
return out;
|
|
}
|
|
|
|
/// Compute performance in GFLOP/s
|
|
double gflops(double runtime_s) const {
|
|
|
|
// Number of real-valued multiply-adds
|
|
int64_t fmas = int64_t();
|
|
|
|
fmas += problem_each.product() * batch_count;
|
|
|
|
// Two flops per multiply-add
|
|
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
|
}
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename GemmBatched, typename GemmPermute>
|
|
class Testbed {
|
|
public:
|
|
|
|
//
|
|
// Type definitions
|
|
//
|
|
|
|
using ElementA = typename GemmBatched::ElementA;
|
|
using ElementB = typename GemmBatched::ElementB;
|
|
using ElementC = typename GemmBatched::ElementC;
|
|
using ElementAccumulator = typename GemmBatched::ElementAccumulator;
|
|
|
|
using EpilogueOutputOp = typename GemmBatched::GemmKernel::Epilogue::OutputOp;
|
|
using ElementCompute = typename EpilogueOutputOp::ElementCompute;
|
|
|
|
using LayoutA = typename GemmBatched::LayoutA;
|
|
using LayoutB = typename GemmBatched::LayoutB;
|
|
using LayoutC = typename GemmBatched::LayoutC;
|
|
|
|
using MatrixCoord = typename LayoutC::TensorCoord;
|
|
|
|
private:
|
|
|
|
//
|
|
// Data members
|
|
//
|
|
|
|
Options & options;
|
|
|
|
/// Initialization
|
|
cutlass::Distribution::Kind init_A;
|
|
cutlass::Distribution::Kind init_B;
|
|
cutlass::Distribution::Kind init_C;
|
|
uint32_t seed;
|
|
|
|
cutlass::DeviceAllocation<ElementA> block_A;
|
|
cutlass::DeviceAllocation<ElementB> block_B;
|
|
cutlass::DeviceAllocation<ElementC> block_C;
|
|
cutlass::DeviceAllocation<ElementC> block_D;
|
|
|
|
public:
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
Testbed(
|
|
Options &options_,
|
|
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
|
uint32_t seed_ = 3090
|
|
):
|
|
options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
|
|
|
/// Verbose BMM info
|
|
void print_BMM_info_() {
|
|
|
|
// Print batched GEMM
|
|
std::cout << "Batched GEMM with permute([0, 2, 1, 3]) on BMM whole output tensor:\n";
|
|
|
|
auto problem = options.problem_each;
|
|
std::cout
|
|
<< problem.m() << "-by-" << problem.n() << "-by-" << problem.k()
|
|
<< ", batch count: " << options.batch_count << "\n";
|
|
|
|
std::cout << "output tensor shape: [" << options.batch_count << ", " << problem.m() << ", "
|
|
<< problem.n() <<"]\n";
|
|
std::cout << "reshaped as: [" << options.batch_count / D1 << ", " << D1 << ", "
|
|
<< problem.m() << ", " << problem.n() <<"]\n";
|
|
std::cout << "finally permuted as: [" << options.batch_count / D1 << ", " << problem.m() << ", "
|
|
<< D1 << ", " << problem.n() <<"]\n";
|
|
|
|
std::cout << "----------------------------------------------------\n";
|
|
|
|
}
|
|
|
|
/// Verbose normal GEMM info
|
|
void print_GEMM_info_() {
|
|
|
|
// Print batched GEMM
|
|
std::cout << "Normal GEMM with permute([2, 0, 3, 1, 4]):\n";
|
|
|
|
auto problem = options.problem_each;
|
|
std::cout
|
|
<< problem.m() << "-by-" << problem.n() << "-by-" << problem.k() << "\n";
|
|
|
|
std::cout << "output tensor shape: [" << problem.m() << ", " << problem.n() <<"]" << std::endl;
|
|
std::cout << "reshaped as: [" << problem.m() / T1 << ", " << T1 << ", "
|
|
<< T2 << ", " << T3 << ", " << problem.n() / T2 / T3 <<"]" << std::endl;
|
|
std::cout << "finally permuted as: [" << T2 << ", " << problem.m() / T1 << ", "
|
|
<< T3 << ", " << T1 << ", " << problem.n() / T2 / T3 <<"]" << std::endl;
|
|
|
|
std::cout << "----------------------------------------------------\n";
|
|
|
|
}
|
|
|
|
private:
|
|
|
|
/// Helper to initialize a tensor view
|
|
template <typename Element>
|
|
void initialize_tensor_(
|
|
Element *ptr,
|
|
size_t capacity,
|
|
cutlass::Distribution::Kind dist_kind,
|
|
uint32_t seed) {
|
|
|
|
if (dist_kind == cutlass::Distribution::Uniform) {
|
|
|
|
Element scope_max, scope_min;
|
|
int bits_input = cutlass::sizeof_bits<Element>::value;
|
|
int bits_output = cutlass::sizeof_bits<typename GemmBatched::ElementC>::value;
|
|
|
|
if (bits_input == 1) {
|
|
scope_max = 2;
|
|
scope_min = 0;
|
|
} else if (bits_input <= 8) {
|
|
scope_max = 2;
|
|
scope_min = -2;
|
|
} else if (bits_output == 16) {
|
|
if (cutlass::sizeof_bits<ElementAccumulator>::value <= 16) {
|
|
scope_max = 5;
|
|
scope_min = -5;
|
|
}
|
|
else {
|
|
scope_max = 8;
|
|
scope_min = -8;
|
|
}
|
|
} else {
|
|
scope_max = 8;
|
|
scope_min = -8;
|
|
}
|
|
|
|
cutlass::reference::device::BlockFillRandomUniform(
|
|
ptr, capacity, seed, scope_max, scope_min, 0);
|
|
}
|
|
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
|
|
|
cutlass::reference::device::BlockFillRandomGaussian(
|
|
ptr, capacity, seed, Element(), Element(0.5f));
|
|
}
|
|
else if (dist_kind == cutlass::Distribution::Sequential) {
|
|
|
|
// Fill with increasing elements
|
|
cutlass::reference::device::BlockFillSequential(
|
|
ptr, capacity, Element(1), Element());
|
|
}
|
|
else {
|
|
|
|
// Fill with all 1s
|
|
cutlass::reference::device::BlockFillSequential(
|
|
ptr, capacity, Element(), Element(1));
|
|
}
|
|
}
|
|
|
|
/// Initializes data structures
|
|
void initialize_(int batch_count) {
|
|
|
|
//
|
|
// Choose random problem sizes
|
|
//
|
|
|
|
// construct a few problems of random sizes
|
|
srand(seed);
|
|
|
|
int64_t total_elements_A = options.problem_each.m() * options.problem_each.k() * batch_count;
|
|
int64_t total_elements_B = options.problem_each.n() * options.problem_each.k() * batch_count;
|
|
int64_t total_elements_C = options.problem_each.m() * options.problem_each.n() * batch_count;
|
|
int64_t total_elements_D = options.problem_each.m() * options.problem_each.n() * batch_count;
|
|
|
|
//
|
|
// Assign space
|
|
//
|
|
|
|
block_A.reset(total_elements_A);
|
|
block_B.reset(total_elements_B);
|
|
block_C.reset(total_elements_C);
|
|
block_D.reset(total_elements_D);
|
|
|
|
//
|
|
// Initialize the problems of the workspace
|
|
//
|
|
|
|
initialize_tensor_(block_A.get(), total_elements_A, init_A, seed * 2021);
|
|
initialize_tensor_(block_B.get(), total_elements_B, init_B, seed * 2022);
|
|
initialize_tensor_(block_C.get(), total_elements_C, init_C, seed * 2023);
|
|
|
|
cutlass::reference::device::BlockFillSequential(
|
|
block_D.get(), total_elements_D, ElementC(), ElementC());
|
|
}
|
|
|
|
/// Verifies the BMM GEMM result
|
|
bool verify_BMM_() {
|
|
|
|
bool passed = true;
|
|
|
|
cutlass::gemm::GemmCoord problem = options.problem_each;
|
|
|
|
LayoutA layout_A(LayoutA::packed({problem.m(), problem.k()}).stride(0));
|
|
LayoutB layout_B(LayoutB::packed({problem.k(), problem.n()}).stride(0));
|
|
LayoutC layout_C(LayoutC::packed({problem.m(), problem.n()}).stride(0));
|
|
LayoutC layout_D(LayoutC::packed({problem.m(), problem.n()}).stride(0));
|
|
|
|
MatrixCoord extent_A{problem.m(), problem.k()};
|
|
MatrixCoord extent_B{problem.k(), problem.n()};
|
|
MatrixCoord extent_C{problem.m(), problem.n()};
|
|
|
|
cutlass::TensorView<ElementA, LayoutA> view_A(block_A.get(), layout_A, extent_A);
|
|
cutlass::TensorView<ElementB, LayoutB> view_B(block_B.get(), layout_B, extent_B);
|
|
cutlass::TensorView<ElementC, LayoutC> view_C(block_C.get(), layout_C, extent_C);
|
|
|
|
cutlass::DeviceAllocation<ElementC> block_Ref(layout_D.capacity(extent_C) * options.batch_count);
|
|
cutlass::TensorView<ElementC, LayoutC> view_Ref_device(block_Ref.get(), layout_D, extent_C);
|
|
|
|
// Reference GEMM
|
|
cutlass::reference::device::GemmComplex<
|
|
ElementA, LayoutA,
|
|
ElementB, LayoutB,
|
|
ElementC, LayoutC,
|
|
ElementCompute, ElementAccumulator
|
|
>(
|
|
problem,
|
|
options.alpha,
|
|
view_A,
|
|
GemmBatched::kTransformA,
|
|
view_B,
|
|
GemmBatched::kTransformB,
|
|
options.beta,
|
|
view_C,
|
|
view_Ref_device,
|
|
ElementAccumulator(0),
|
|
options.batch_count,
|
|
options.problem_each.m() * options.problem_each.k(),
|
|
options.problem_each.n() * options.problem_each.k(),
|
|
options.problem_each.m() * options.problem_each.n(),
|
|
options.problem_each.m() * options.problem_each.n()
|
|
);
|
|
|
|
// Copy to host memory
|
|
std::vector<ElementC> matrix_D(layout_D.capacity(extent_C) * options.batch_count);
|
|
std::vector<ElementC> matrix_Ref(layout_D.capacity(extent_C) * options.batch_count);
|
|
|
|
cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get(), matrix_D.size());
|
|
cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size());
|
|
|
|
// Print out the results and reference in 4D Tensor
|
|
// [options.batch_count, options.problem_each.m() * options.problem_each.n()] -> [D0, D1, D2, D3].
|
|
// After permute Op, -> [D0, D2, D1, D3].
|
|
int D0 = options.batch_count / D1;
|
|
int D2 = options.problem_each.m();
|
|
int D3 = options.problem_each.n();
|
|
|
|
cutlass::TensorView<ElementC, cutlass::layout::TensorNHWC> view_D_Tensor(matrix_D.data(), // if LayoutC = cutlass::layout::ColumnMajor, view_D_Tensor should be constructed differently
|
|
cutlass::layout::TensorNHWC().packed(cutlass::Tensor4DCoord({D0, D2, D1, D3})), cutlass::Tensor4DCoord({D0, D2, D1, D3}));
|
|
|
|
cutlass::TensorView<ElementC, cutlass::layout::TensorNHWC> view_Ref_Tensor(matrix_Ref.data(),
|
|
cutlass::layout::TensorNHWC().packed(cutlass::Tensor4DCoord({D0, D1, D2, D3})), cutlass::Tensor4DCoord({D0, D1, D2, D3}));
|
|
|
|
// Tensor Permute Op on reference tensor
|
|
cutlass::HostTensor<ElementC, cutlass::layout::TensorNHWC> view_Ref_Permute_Tensor(cutlass::Tensor4DCoord({D0, D2, D1, D3}));
|
|
for (int n = 0; n < D0; ++n) {
|
|
for (int h = 0; h < D1; ++h) {
|
|
for (int w = 0; w < D2; ++w) {
|
|
for (int c = 0; c < D3; ++c) {
|
|
view_Ref_Permute_Tensor.at({n, w, h, c}) = view_Ref_Tensor.at({n, h, w, c});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Reference check
|
|
passed = cutlass::reference::host::TensorEquals(view_Ref_Permute_Tensor.host_view(), view_D_Tensor);
|
|
|
|
if (!passed) {
|
|
std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl;
|
|
return passed;
|
|
}
|
|
|
|
std::cout << "Passed verification" << std::endl;
|
|
return passed;
|
|
}
|
|
|
|
bool verify_GEMM_normal_() {
|
|
|
|
bool passed = true;
|
|
|
|
cutlass::gemm::GemmCoord problem = options.problem_each;
|
|
|
|
LayoutA layout_A(LayoutA::packed({problem.m(), problem.k()}).stride(0));
|
|
LayoutB layout_B(LayoutB::packed({problem.k(), problem.n()}).stride(0));
|
|
LayoutC layout_C(LayoutC::packed({problem.m(), problem.n()}).stride(0));
|
|
LayoutC layout_D(LayoutC::packed({problem.m(), problem.n()}).stride(0));
|
|
|
|
MatrixCoord extent_A{problem.m(), problem.k()};
|
|
MatrixCoord extent_B{problem.k(), problem.n()};
|
|
MatrixCoord extent_C{problem.m(), problem.n()};
|
|
|
|
cutlass::TensorView<ElementA, LayoutA> view_A(block_A.get(), layout_A, extent_A);
|
|
cutlass::TensorView<ElementB, LayoutB> view_B(block_B.get(), layout_B, extent_B);
|
|
cutlass::TensorView<ElementC, LayoutC> view_C(block_C.get(), layout_C, extent_C);
|
|
|
|
cutlass::DeviceAllocation<ElementC> block_Ref(layout_D.capacity(extent_C));
|
|
cutlass::TensorView<ElementC, LayoutC> view_Ref_device(block_Ref.get(), layout_D, extent_C);
|
|
|
|
// Reference GEMM
|
|
cutlass::reference::device::GemmComplex<
|
|
ElementA, LayoutA,
|
|
ElementB, LayoutB,
|
|
ElementC, LayoutC,
|
|
ElementCompute, ElementAccumulator
|
|
>(
|
|
problem,
|
|
options.alpha,
|
|
view_A,
|
|
GemmBatched::kTransformA,
|
|
view_B,
|
|
GemmBatched::kTransformB,
|
|
options.beta,
|
|
view_C,
|
|
view_Ref_device,
|
|
ElementAccumulator(0)
|
|
);
|
|
|
|
// Copy to host memory
|
|
std::vector<ElementC> matrix_D(layout_D.capacity(extent_C));
|
|
std::vector<ElementC> matrix_Ref(layout_D.capacity(extent_C));
|
|
|
|
cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get(), matrix_D.size());
|
|
cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size());
|
|
|
|
// Print out the results and reference in 5D Tensor
|
|
// [options.problem_each.m(), options.problem_each.n()] -> [T0, T1, T2, T3, T4].
|
|
// options.problem_each.m() == T0 * T1
|
|
// options.problem_each.n() == T2 * T3 * T4
|
|
// After permute Op, -> [T2, T0, T3, T1, T4].
|
|
int T0 = options.problem_each.m() / T1;
|
|
int T4 = options.problem_each.n() / T2 / T3;
|
|
|
|
cutlass::TensorView<ElementC, cutlass::layout::TensorNDHWC> view_D_Tensor(matrix_D.data(), // if LayoutC = cutlass::layout::ColumnMajor, view_D_Tensor should be constructed differently
|
|
cutlass::layout::TensorNDHWC().packed(cutlass::Tensor5DCoord({T2, T0, T3, T1, T4})), cutlass::Tensor5DCoord({T2, T0, T3, T1, T4}));
|
|
cutlass::TensorView<ElementC, cutlass::layout::TensorNDHWC> view_Ref_Tensor(matrix_Ref.data(),
|
|
cutlass::layout::TensorNDHWC().packed(cutlass::Tensor5DCoord({T0, T1, T2, T3, T4})), cutlass::Tensor5DCoord({T0, T1, T2, T3, T4}));
|
|
|
|
// Tensor Permute Op on reference tensor
|
|
cutlass::HostTensor<ElementC, cutlass::layout::TensorNDHWC> view_Ref_Permute_Tensor(cutlass::Tensor5DCoord({T2, T0, T3, T1, T4}));
|
|
for (int n = 0; n < T0; ++n) {
|
|
for (int d = 0; d < T1; ++d) {
|
|
for (int h = 0; h < T2; ++h) {
|
|
for (int w = 0; w < T3; ++w) {
|
|
for (int c = 0; c < T4; ++c) {
|
|
view_Ref_Permute_Tensor.at({h, n, w, d, c}) = view_Ref_Tensor.at({n, d, h, w, c}); // permute([2,0,3,1,4])
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Reference check
|
|
passed = cutlass::reference::host::TensorEquals(view_Ref_Permute_Tensor.host_view(), view_D_Tensor);
|
|
|
|
if (!passed) {
|
|
std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl;
|
|
return passed;
|
|
}
|
|
|
|
std::cout << "Passed verification" << std::endl;
|
|
return passed;
|
|
}
|
|
|
|
public:
|
|
/// Executes a conventional batched GEMM kernel.
|
|
Result profile_batched_kBatched() {
|
|
|
|
std::cout << "\n====================================================" << std::endl;
|
|
std::cout << "Batched GEMM (CUTLASS):\n"
|
|
<< "====================================================" << std::endl;
|
|
|
|
if (options.verbose) {
|
|
print_BMM_info_();
|
|
}
|
|
|
|
Result result;
|
|
|
|
result.passed = false;
|
|
|
|
// Initialize the problem
|
|
initialize_(options.batch_count);
|
|
|
|
// Configure the GEMM arguments
|
|
typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta);
|
|
|
|
// Please make sure all problem_sizes are the same for kBatched mode
|
|
auto problem = options.problem_each;
|
|
|
|
// For regular BMM
|
|
int64_t batch_stride_C = problem.m() * problem.n();
|
|
// For BMM permute output ---> make sure to set batch_stride_D to zero for BMM permute op
|
|
int64_t batch_stride_D = 0;
|
|
|
|
// Configure GEMM arguments
|
|
typename GemmBatched::Arguments arguments{
|
|
cutlass::gemm::GemmUniversalMode::kBatched,
|
|
options.problem_each,
|
|
options.batch_count,
|
|
epilogue_op,
|
|
(void*)block_A.get(),
|
|
(void*)block_B.get(),
|
|
(void*)block_C.get(),
|
|
(void*)block_D.get(),
|
|
problem.m() * problem.k(),
|
|
problem.n() * problem.k(),
|
|
batch_stride_C,
|
|
batch_stride_D,
|
|
problem.k(),
|
|
problem.n(),
|
|
problem.n(),
|
|
problem.n()
|
|
};
|
|
|
|
// Initialize the GEMM object
|
|
GemmBatched gemm;
|
|
|
|
result.status = gemm.initialize(arguments, nullptr);
|
|
|
|
if (result.status != cutlass::Status::kSuccess) {
|
|
std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl;
|
|
return result;
|
|
}
|
|
|
|
// Run the batched GEMM object
|
|
result.status = gemm.run();
|
|
|
|
if (result.status != cutlass::Status::kSuccess) {
|
|
std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl;
|
|
return result;
|
|
}
|
|
|
|
// Wait for completion
|
|
result.error = cudaDeviceSynchronize();
|
|
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error);
|
|
return result;
|
|
}
|
|
|
|
//
|
|
// Verify correctness
|
|
//
|
|
result.passed = true;
|
|
|
|
if (options.reference_check) {
|
|
result.passed = verify_BMM_();
|
|
}
|
|
|
|
//
|
|
// Warm-up run of the batched GEMM object
|
|
//
|
|
result.status = gemm.run();
|
|
|
|
if (result.status != cutlass::Status::kSuccess) {
|
|
std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl;
|
|
return result;
|
|
}
|
|
|
|
//
|
|
// Construct events
|
|
//
|
|
|
|
cudaEvent_t events[2];
|
|
|
|
for (auto & event : events) {
|
|
result.error = cudaEventCreate(&event);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
// Record an event at the start of a series of GEMM operations
|
|
result.error = cudaEventRecord(events[0]);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return result;
|
|
}
|
|
|
|
//
|
|
// Run profiling loop
|
|
//
|
|
|
|
for (int iter = 0; iter < options.iterations; ++iter) {
|
|
gemm();
|
|
}
|
|
|
|
//
|
|
// Stop profiling loop
|
|
//
|
|
|
|
// Record an event when the GEMM operations have been launched.
|
|
result.error = cudaEventRecord(events[1]);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return result;
|
|
}
|
|
|
|
// Wait for work on the device to complete.
|
|
result.error = cudaEventSynchronize(events[1]);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return result;
|
|
}
|
|
|
|
// Measure elapsed runtime
|
|
float runtime_ms = 0;
|
|
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return result;
|
|
}
|
|
|
|
// Compute average runtime and GFLOPs.
|
|
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
|
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
|
|
|
//
|
|
// Cleanup
|
|
//
|
|
|
|
for (auto event : events) {
|
|
(void)cudaEventDestroy(event);
|
|
}
|
|
|
|
std::cout << " " << 1 << " batched GEMMs launched\n";
|
|
|
|
std::cout << std::endl;
|
|
std::cout << " " << "Batched Runtime: " << result.runtime_ms << " ms\n";
|
|
std::cout << " " << "Batched GFLOPs: " << result.gflops << "\n";
|
|
|
|
return result;
|
|
}
|
|
|
|
Result profile_GEMM_permute() {
|
|
|
|
std::cout << "\n====================================================" << std::endl;
|
|
std::cout << "Normal GEMM (CUTLASS):\n"
|
|
<< "====================================================" << std::endl;
|
|
|
|
if (options.verbose) {
|
|
print_GEMM_info_();
|
|
}
|
|
|
|
Result result;
|
|
|
|
result.passed = false;
|
|
|
|
// Initialize the problem
|
|
initialize_(1);
|
|
|
|
// Configure the GEMM arguments
|
|
typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta);
|
|
|
|
// Please make sure all problem_sizes are the same for kBatched mode
|
|
auto problem = options.problem_each;
|
|
|
|
// Configure GEMM arguments
|
|
typename GemmPermute::Arguments arguments{
|
|
cutlass::gemm::GemmUniversalMode::kGemm,
|
|
options.problem_each,
|
|
1,
|
|
epilogue_op,
|
|
(void*)block_A.get(),
|
|
(void*)block_B.get(),
|
|
(void*)block_C.get(),
|
|
(void*)block_D.get(),
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
problem.k(),
|
|
problem.n(),
|
|
problem.n(),
|
|
problem.n()
|
|
};
|
|
|
|
// Initialize the GEMM object
|
|
GemmPermute gemm_normal;
|
|
|
|
result.status = gemm_normal.initialize(arguments, nullptr);
|
|
|
|
if (result.status != cutlass::Status::kSuccess) {
|
|
std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl;
|
|
return result;
|
|
}
|
|
|
|
// Run the normal GEMM object
|
|
result.status = gemm_normal.run();
|
|
|
|
if (result.status != cutlass::Status::kSuccess) {
|
|
std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl;
|
|
return result;
|
|
}
|
|
|
|
// Wait for completion
|
|
result.error = cudaDeviceSynchronize();
|
|
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error);
|
|
return result;
|
|
}
|
|
|
|
//
|
|
// Verify correctness
|
|
//
|
|
result.passed = true;
|
|
|
|
if (options.reference_check) {
|
|
result.passed = verify_GEMM_normal_();
|
|
}
|
|
|
|
//
|
|
// Warm-up run of the normal GEMM object
|
|
//
|
|
result.status = gemm_normal.run();
|
|
|
|
if (result.status != cutlass::Status::kSuccess) {
|
|
std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl;
|
|
return result;
|
|
}
|
|
|
|
//
|
|
// Construct events
|
|
//
|
|
|
|
cudaEvent_t events[2];
|
|
|
|
for (auto & event : events) {
|
|
result.error = cudaEventCreate(&event);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
// Record an event at the start of a series of GEMM operations
|
|
result.error = cudaEventRecord(events[0]);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return result;
|
|
}
|
|
|
|
//
|
|
// Run profiling loop
|
|
//
|
|
|
|
for (int iter = 0; iter < options.iterations; ++iter) {
|
|
gemm_normal();
|
|
}
|
|
|
|
//
|
|
// Stop profiling loop
|
|
//
|
|
|
|
// Record an event when the GEMM operations have been launched.
|
|
result.error = cudaEventRecord(events[1]);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return result;
|
|
}
|
|
|
|
// Wait for work on the device to complete.
|
|
result.error = cudaEventSynchronize(events[1]);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return result;
|
|
}
|
|
|
|
// Measure elapsed runtime
|
|
float runtime_ms = 0;
|
|
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return result;
|
|
}
|
|
|
|
// Compute average runtime and GFLOPs.
|
|
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
|
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
|
|
|
//
|
|
// Cleanup
|
|
//
|
|
|
|
for (auto event : events) {
|
|
(void)cudaEventDestroy(event);
|
|
}
|
|
|
|
std::cout << std::endl;
|
|
std::cout << " " << "Normal Runtime: " << result.runtime_ms << " ms" << std::endl;
|
|
std::cout << " " << "Normal GFLOPs: " << result.gflops << "\n";
|
|
|
|
return result;
|
|
}
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
int main(int argc, char const **args) {
|
|
|
|
//
|
|
// This example uses mma.sync to directly access Tensor Cores to achieve peak performance.
|
|
//
|
|
|
|
cudaDeviceProp props;
|
|
|
|
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
|
if (error != cudaSuccess) {
|
|
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
|
return -1;
|
|
}
|
|
|
|
if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) {
|
|
|
|
//
|
|
// This example requires an NVIDIA Ampere-architecture GPU.
|
|
//
|
|
|
|
std::cout
|
|
<< "CUTLASS's Grouped GEMM example requires a GPU of NVIDIA's Ampere Architecture or "
|
|
<< "later (compute capability 80 or greater).\n";
|
|
|
|
return 0;
|
|
}
|
|
|
|
//
|
|
// Parse options
|
|
//
|
|
|
|
Options options;
|
|
|
|
options.parse(argc, args);
|
|
|
|
if (options.help) {
|
|
options.print_usage(std::cout) << std::endl;
|
|
return 0;
|
|
}
|
|
|
|
if (options.error) {
|
|
std::cerr << "Aborting execution." << std::endl;
|
|
return -1;
|
|
}
|
|
|
|
//
|
|
// Define the GEMM types
|
|
//
|
|
|
|
using ElementOutput = cutlass::half_t;
|
|
using ElementAccumulator = float;
|
|
|
|
using LayoutA = cutlass::layout::RowMajor;
|
|
using LayoutB = cutlass::layout::RowMajor;
|
|
using LayoutC = cutlass::layout::RowMajor;
|
|
|
|
//
|
|
// Define a conventional batched GEMM type
|
|
//
|
|
|
|
// Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8
|
|
using GemmBatched = cutlass::gemm::device::GemmUniversal<
|
|
cutlass::half_t, LayoutA,
|
|
cutlass::half_t, LayoutB,
|
|
ElementOutput, LayoutC,
|
|
ElementAccumulator,
|
|
cutlass::arch::OpClassTensorOp,
|
|
cutlass::arch::Sm80,
|
|
cutlass::gemm::GemmShape<128, 128, 32>,
|
|
cutlass::gemm::GemmShape<64, 64, 32>,
|
|
cutlass::gemm::GemmShape<16, 8, 16>,
|
|
cutlass::epilogue::thread::LinearCombination<
|
|
ElementOutput,
|
|
AlignmentC, //128 / cutlass::sizeof_bits<ElementOutput>::value,
|
|
ElementAccumulator,
|
|
ElementAccumulator
|
|
>,
|
|
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
|
|
4,
|
|
8, /*alignmentA*/
|
|
8, /*alignmengB*/
|
|
cutlass::arch::OpMultiplyAdd,
|
|
cutlass::ComplexTransform::kNone,
|
|
cutlass::ComplexTransform::kNone,
|
|
false, /*GatherA*/
|
|
false, /*GatherB*/
|
|
false, /*ScatterD*/
|
|
cutlass::layout::Tensor4DPermuteBMM0213<D1> /*PermuteDLayout*/
|
|
>;
|
|
|
|
// Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8
|
|
using GemmPermute = cutlass::gemm::device::GemmUniversal<
|
|
cutlass::half_t, LayoutA,
|
|
cutlass::half_t, LayoutB,
|
|
ElementOutput, LayoutC,
|
|
ElementAccumulator,
|
|
cutlass::arch::OpClassTensorOp,
|
|
cutlass::arch::Sm80,
|
|
cutlass::gemm::GemmShape<128, 128, 32>,
|
|
cutlass::gemm::GemmShape<64, 64, 32>,
|
|
cutlass::gemm::GemmShape<16, 8, 16>,
|
|
cutlass::epilogue::thread::LinearCombination<
|
|
ElementOutput,
|
|
AlignmentC, //128 / cutlass::sizeof_bits<ElementOutput>::value,
|
|
ElementAccumulator,
|
|
ElementAccumulator
|
|
>,
|
|
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
|
|
4,
|
|
8, /*alignmentA*/
|
|
8, /*alignmengB*/
|
|
cutlass::arch::OpMultiplyAdd,
|
|
cutlass::ComplexTransform::kNone,
|
|
cutlass::ComplexTransform::kNone,
|
|
false, /*GatherA*/
|
|
false, /*GatherB*/
|
|
false, /*ScatterD*/
|
|
cutlass::layout::Tensor5DPermute20314<T1, T2, T3> /*PermuteDLayout*/
|
|
>;
|
|
|
|
//
|
|
// Profile it
|
|
//
|
|
|
|
Testbed<GemmBatched, GemmPermute> testbed(options);
|
|
|
|
Result result;
|
|
result = testbed.profile_batched_kBatched();
|
|
if (!result.passed) {
|
|
std::cout << "Profiling batched GEMM has failed.\n";
|
|
std::cout << "\nFailed\n";
|
|
} else {
|
|
std::cout << "\nPassed CUTLASS batched GEMM\n";
|
|
}
|
|
|
|
result = testbed.profile_GEMM_permute();
|
|
if (!result.passed) {
|
|
std::cout << "Profiling normal GEMM has failed.\n";
|
|
std::cout << "\nFailed\n";
|
|
} else {
|
|
std::cout << "\nPassed CUTLASS normal GEMM\n";
|
|
}
|
|
|
|
std::cout << "\n====================================================" << std::endl;
|
|
std::cout << "Finished\n";
|
|
std::cout << "====================================================" << std::endl;
|
|
|
|
return 0;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|