example 23 gemm operand reduction fusion (#325)

This commit is contained in:
Manish Gupta 2021-09-20 13:34:47 -07:00 committed by GitHub
parent 2e07c4cc2f
commit 538592dea4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 776 additions and 0 deletions

View File

@ -0,0 +1,28 @@
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice, this list of
# conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice, this list of
# conditions and the following disclaimer in the documentation and/or other materials
# provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific prior written
# permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
23_ampere_gemm_operand_reduction_fusion
ampere_gemm_operand_reduction_fusion.cu
)

View File

@ -0,0 +1,747 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
The example demenstrates how to reduce one of the operands of the GEMM along the k-dimension when
computing GEMM. So the output also contains either a Mx1 or 1XN vector. It only works with Ampere
HMMA 16x8x16 FP16 tensor cores, though it is not difficult to apply to other Turing/Ampere tensor
core instructions.
Most of the reduction is done in gemm/warp level, see gemm/warp/mma_with_reduction_tensor_op.h
A few bit of reduction is done in the epilouge before storing the vector, see
epilogue/threadblock/epilogue_gemm_k_reduction.h
*/
#include <iostream>
#include <sstream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h"
#include "cutlass/reduction/device/reduce_split_k.h"
#include "cutlass/reduction/kernel/reduce_split_k.h"
#include "cutlass/reduction/thread/reduction_operators.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/device/convolution.h"
#include "helper.h"
// The code section below describes datatype for input, output tensors and computation between
// elements
using ElementAccumulator = float; // Data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
using ElementOutput = cutlass::half_t; // Data type of elements in output tensor
using LayoutInputA = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::RowMajor;
using LayoutOutput = cutlass::layout::ColumnMajor;
// Layout of the output vector
using LayoutGemmKReduction = cutlass::layout::PitchLinear;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassTensorOp;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm80;
// This code section describes the tile size a thread block will compute
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape
// This code section describes tile size a warp will compute
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape
// This code section describes the size of MMA op
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>;
// Number of pipelines you want to use
constexpr int NumStages = 4;
// Reduce A or B operand along the K dimension
constexpr bool ReduceKForA = true;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue>;
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction<
ElementInputA, LayoutInputA, cutlass::ComplexTransform::kNone, 8,
ElementInputB, LayoutInputB, cutlass::ComplexTransform::kNone, 8,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
ReduceKForA,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd
>::GemmKernel;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Below is the reduction kernel used in the case of parallel spiit-k
using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;;
using ReduceOp = cutlass::reduction::thread::ReduceAdd<
ElementAccumulator,
ElementOutput,
EpilogueOp::kCount
>;
using ReduceGemmSplitKKernel = cutlass::reduction::kernel::ReduceSplitK<
ReduceGemmSplitKShape,
EpilogueOp,
ReduceOp
>;
using ReduceGemmSplitK = cutlass::reduction::device::ReduceSplitK<ReduceGemmSplitKKernel>;
using ReduceVectorSplitKShape = cutlass::MatrixShape<1, 256>;;
// This code section describes the epilogue part of the kernel, we use default value
using DummyEpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue,
cutlass::epilogue::thread::ScaleType::Nothing>;
using ReduceVectorSplitKKernel = cutlass::reduction::kernel::ReduceSplitK<
ReduceVectorSplitKShape,
DummyEpilogueOp,
ReduceOp
>;
using ReduceVectorSplitK = cutlass::reduction::device::ReduceSplitK<ReduceVectorSplitKKernel>;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
cutlass::gemm::GemmCoord problem_size;
int split_k_slices;
bool parallel_split_k;
bool reference_check;
bool measure_performance;
int iterations;
bool save_workspace;
ElementComputeEpilogue alpha;
ElementComputeEpilogue beta;
bool benchmark;
std::string tag;
Options():
help(false),
problem_size(1024, 1024, 1024),
split_k_slices(1),
parallel_split_k(false),
reference_check(true),
measure_performance(false),
iterations(20),
save_workspace(false),
alpha(-1),
beta(-1),
benchmark(false) { }
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
bool valid() {
//
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
// all pointers, strides, and tensor extents must be divisible by 8 elements.
//
int const kAlignment = 8;
if ((problem_size.m() % kAlignment) ||
(problem_size.n() % kAlignment) ||
(problem_size.k() % kAlignment)) {
// misaligned tensors
return false;
}
return true;
}
/// Updates input and filter sizes
void update(
cutlass::gemm::GemmCoord problem_size,
int split_k_slices,
bool parallel_split_k) {
this->problem_size = problem_size;
this->split_k_slices = split_k_slices;
this->parallel_split_k = parallel_split_k;
}
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
if (cmd.check_cmd_line_flag("parallel-split-k")) {
parallel_split_k = true;
}
if (cmd.check_cmd_line_flag("ref-check")) {
reference_check = true;
}
if (cmd.check_cmd_line_flag("perf-check")) {
measure_performance = true;
}
if (cmd.check_cmd_line_flag("save-workspace")) {
save_workspace = true;
}
if (cmd.check_cmd_line_flag("benchmark")) {
benchmark = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("split-k-slices", split_k_slices);
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("tag", tag);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "28_ampere_gemm_bias_fusion example\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m <int> GEMM M\n"
<< " --n <int> GEMM N\n"
<< " --k <int> GEMM K\n"
<< " --split-k-slices <int> Split K Slices\n"
<< " --alpha <float> Epilogue scalar alpha\n"
<< " --beta <float> Epilogue scalar beta\n\n"
<< " --parallel-split-k If set (true), use parallel split K\n"
<< " --ref-check If set (true), reference check on the host is computed\n"
<< " --perf-check If set (true), performance is measured.\n"
<< " --benchmark If set (true), performance benchmarking on several problem sizes.\n"
<< " --iterations <int> Number of profiling iterations to perform.\n"
<< " --save-workspace If set, workspace is written to a text file.\n"
<< " --tag <string> String to replicate across the first column in the results table\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/28_ampere_gemm_bias_fusion_example/ampere_gemm_bias_fusion --m=1024 --n=1024 --k=1024 \n\n";
return out;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Result {
double runtime_ms;
cutlass::Status status;
cutlass::Status reference_check;
cudaError_t error;
Result():
runtime_ms(0),
status(cutlass::Status::kSuccess),
reference_check(cutlass::Status::kInvalid),
error(cudaSuccess) { }
static std::ostream & print_header(std::ostream &out, Options const &options) {
if (!options.tag.empty()) {
out << "Name,";
}
out << "ID,M,N,K,SplitK-Slices,Parallel-SplitK,Runtime";
return out;
}
std::ostream & print(std::ostream &out, int idx, Options const &options) {
if (!options.tag.empty()) {
out << options.tag << ",";
}
out
<< "gemm_" << idx << ","
<< options.problem_size.m() << ","
<< options.problem_size.n() << ","
<< options.problem_size.k() << ","
<< options.split_k_slices << ","
<< options.parallel_split_k << ","
<< runtime_ms ;
return out;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Runs one benchmark
Result profile(Options const &options) {
Result result;
// Initialize tensors using CUTLASS helper functions
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.problem_size.mk());
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.problem_size.kn());
// Create tensor C with dimensions 1x1x1xk which is the bias vector
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.problem_size.mn());
// Create tensor D used to store output from CUTLASS kernel
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.problem_size.mn());
// Create matrix D with dimensions M x N used to store output from reference
// kernel
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.problem_size.mn());
int reduce_vector_length = ReduceKForA ? options.problem_size.m() : options.problem_size.n();
cutlass::HostTensor<ElementOutput, LayoutGemmKReduction> tensor_reduction({reduce_vector_length, 1});
cutlass::HostTensor<ElementOutput, LayoutGemmKReduction> tensor_ref_reduction({reduce_vector_length, 1});
// Fill input and output matrices on host using CUTLASS helper functions
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(4),
ElementInputA(-4),
0); // <- Fill tensor A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(4),
ElementInputB(-4),
0); // <- Fill tensor B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(4),
ElementOutput(-4),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFill(
tensor_d.host_view()); // <- fill matrix D on host with zeros
cutlass::reference::host::TensorFill(
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
cutlass::reference::host::TensorFill(
tensor_reduction.host_view()); // <- fill matrix D on host with zeros
cutlass::reference::host::TensorFill(
tensor_ref_reduction.host_view()); // <- fill matrix D for reference on host with zeros
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
tensor_ref_d.sync_device();
tensor_reduction.sync_device();
// Initialize alpha for dot product computation
ElementComputeEpilogue alpha = ElementComputeEpilogue(options.alpha);
ElementComputeEpilogue beta = ElementComputeEpilogue(options.beta);
cutlass::gemm::GemmUniversalMode mode = options.parallel_split_k ?
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel :
cutlass::gemm::GemmUniversalMode::kGemm;
int batch_count = options.split_k_slices;
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
// instantiated CUTLASS kernel
typename Gemm::Arguments arguments{
mode,
options.problem_size,
batch_count,
{alpha, beta},
tensor_a.device_ref().data(), // <- reference to tensor A on device
tensor_b.device_ref().data(), // <- reference to tensor B on device
tensor_c.device_ref().data(), // <- reference to matrix C on device
tensor_d.device_ref().data(), // <- reference to matrix C on device
tensor_reduction.device_ref().data(), // <- reference to tensor B on device
options.problem_size.m() * options.problem_size.k(),
options.problem_size.n() * options.problem_size.k(),
options.problem_size.m() * options.problem_size.n(),
options.problem_size.m() * options.problem_size.n(),
reduce_vector_length,
tensor_a.layout().stride(0),
tensor_b.layout().stride(0),
tensor_c.layout().stride(0),
tensor_d.layout().stride(0),
tensor_reduction.layout().stride(0)
};
// Instantiate CUTLASS kernel depending on templates
Gemm gemm_op;
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check the problem size is supported or not
result.status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(result.status);
// Initialize CUTLASS kernel with arguments and workspace pointer
result.status = gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(result.status);
// Launch initialized CUTLASS kernel
result.status = gemm_op();
CUTLASS_CHECK(result.status);
if (options.parallel_split_k && batch_count > 1) {
// reduce gemm
int splitk_gemm_stride = options.problem_size.m();
cutlass::layout::RowMajor splitk_gemm_layout(splitk_gemm_stride);
void * workspace_gemm_ptr = workspace.get();
cutlass::TensorRef<ElementOutput, cutlass::layout::RowMajor> workspace_gemm_tensorref(static_cast<ElementOutput *>(workspace_gemm_ptr), splitk_gemm_layout);
cutlass::TensorRef<ElementOutput, cutlass::layout::RowMajor> tensor_d_tensorref(tensor_d.device_ref().data(), splitk_gemm_layout);
cutlass::TensorRef<ElementOutput, cutlass::layout::RowMajor> tensor_c_tensorref(tensor_c.device_ref().data(), splitk_gemm_layout);
typename ReduceGemmSplitK::Arguments reduce_gemm_splitk_arguments{
cutlass::MatrixCoord(options.problem_size.n(), options.problem_size.m()),
batch_count,
size_t(options.problem_size.m() * options.problem_size.n()),
workspace_gemm_tensorref,
tensor_d_tensorref,
tensor_c_tensorref,
{alpha, beta}
};
ReduceGemmSplitK reduce_gemm_splitk_op;
result.status = reduce_gemm_splitk_op.initialize(reduce_gemm_splitk_arguments);
CUTLASS_CHECK(result.status);
result.status = reduce_gemm_splitk_op();
CUTLASS_CHECK(result.status);
// reduce k vector
cutlass::layout::RowMajor splitk_vector_layout(reduce_vector_length);
ElementOutput *workspace_vector_ptr = static_cast<ElementOutput *>(workspace_gemm_ptr) + batch_count * options.problem_size.m() * options.problem_size.n();
cutlass::TensorRef<ElementOutput, cutlass::layout::RowMajor> workspace_vector_tensorref(workspace_vector_ptr, splitk_vector_layout);
cutlass::TensorRef<ElementOutput, cutlass::layout::RowMajor> tensor_reduction_tensorref(tensor_reduction.device_ref().data(), splitk_vector_layout);
cutlass::TensorRef<ElementOutput, cutlass::layout::RowMajor> tensor_nullptr_tensorref(nullptr, splitk_vector_layout);
typename ReduceVectorSplitK::Arguments reduce_vector_splitk_arguments{
cutlass::MatrixCoord(1, reduce_vector_length),
batch_count,
size_t(reduce_vector_length),
workspace_vector_tensorref,
tensor_reduction_tensorref,
tensor_nullptr_tensorref,
{1.0f, 0.0f}
};
ReduceVectorSplitK reduce_vector_splitk_op;
result.status = reduce_vector_splitk_op.initialize(reduce_vector_splitk_arguments);
CUTLASS_CHECK(result.status);
result.status = reduce_vector_splitk_op();
CUTLASS_CHECK(result.status);
}
//
// Create instantiation for device reference conv kernel
//
if (options.reference_check) {
// Launch device reference to compute strictly the product A * B
cutlass::reference::device::Gemm<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator> gemm_device;
gemm_device
(
options.problem_size,
alpha,
tensor_a.device_ref(),
tensor_b.device_ref(),
beta,
tensor_c.device_ref(),
tensor_ref_d.device_ref()
);
// Wait for kernels to finish
cudaDeviceSynchronize();
// Copy output data from CUTLASS and reference kernel to host for comparison
tensor_d.sync_host();
tensor_ref_d.sync_host();
tensor_reduction.sync_host();
// Compute bias + relu in host code
if (ReduceKForA) {
for (int m = 0; m < options.problem_size.m(); ++m) {
for (int k = 0; k < options.problem_size.k(); ++k) {
tensor_ref_reduction.at({m, 0}) +=
tensor_a.at(cutlass::MatrixCoord(m, k));
}
}
} else {
for (int k = 0; k < options.problem_size.k(); ++k) {
for (int n = 0; n < options.problem_size.n(); ++n) {
tensor_ref_reduction.at({n, 0}) +=
tensor_b.at(cutlass::MatrixCoord(k, n));
}
}
}
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool pass = cutlass::reference::host::TensorEquals(tensor_d.host_view(),
tensor_ref_d.host_view());
pass &= cutlass::reference::host::TensorEquals(tensor_ref_reduction.host_view(),
tensor_reduction.host_view());
if (!pass) {
result.reference_check = cutlass::Status::kErrorInternal;
std::cout << "ERROR - results miscompared.\n";
} else {
result.reference_check = cutlass::Status::kSuccess;
std::cout << "Passed.\n";
}
} else {
result.reference_check = cutlass::Status::kInvalid;
}
if (options.save_workspace) {
std::stringstream ss;
ss << "23_ampere_gemm_operand_reduction_fusion"
<< options.problem_size.m() << "x" << options.problem_size.n() << "x" << options.problem_size.k()
<< ".dat";
std::ofstream output_workspace(ss.str());
output_workspace
<< "A = \n" << tensor_a.host_view() << "\n\n"
<< "B = \n" << tensor_b.host_view() << "\n\n";
if (options.reference_check) {
output_workspace << "Reference D = \n" << tensor_ref_d.host_view() << "\n\n";
output_workspace << "Reference reduction vector= \n" << tensor_ref_reduction.host_view() << "\n\n";
}
output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl;
output_workspace << "Computed reduction vector = \n" << tensor_reduction.host_view() << std::endl;
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
}
//
// Performance measurement
//
if (options.measure_performance) {
cudaEvent_t events[2];
for (auto & event : events) {
result.error = cudaEventCreate(&event);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
}
// Record an event at the start of a series of convolution operations.
result.error = cudaEventRecord(events[0]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Launch a sequence of implicit GEMM operations on the device
for (int iteration = 0; iteration < options.iterations; ++iteration) {
result.status = gemm_op();
CUTLASS_CHECK(result.status);
}
// Record an event when the convolutions have been launched.
result.error = cudaEventRecord(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = cudaEventSynchronize(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Print average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
// Cleanup
for (auto event : events) {
(void)cudaEventDestroy(event);
}
}
return result;
}
int main(int argc, char const **args) {
bool notSupported = false;
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
//
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
notSupported = true;
}
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
<< std::endl;
notSupported = true;
}
if (notSupported) {
return 0;
}
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.benchmark) {
// Benchmark several layers
struct Benchmark {
int m, n, k, split_k_slices, parallel_split_k;
} problem_sizes[] = {
{4096, 6144, 4096, 1, false},
};
Result::print_header(std::cout, options) << "\n";
int idx = 1;
for (auto const &problem_size : problem_sizes) {
options.update({problem_size.m, problem_size.n, problem_size.k},
problem_size.split_k_slices, problem_size.parallel_split_k);
Result result = profile(options);
result.print(std::cout, idx, options) << "\n";
++idx;
}
} else {
// Execute one problem size
if (!options.valid()) {
std::cerr << "Invalid problem." << "\n";
return -1;
}
Result result = profile(options);
Result::print_header(std::cout, options) << "\n";
result.print(std::cout, 1, options) << "\n";
}
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -93,6 +93,7 @@ foreach(EXAMPLE
20_simt_canonical
21_quaternion_gemm
22_quaternion_conv
23_ampere_gemm_operand_reduction_fusion
)
add_subdirectory(${EXAMPLE})