
CUTLASS 2.0 Substantially refactored for - Better performance, particularly for native Turing Tensor Cores - Robust and durable templates spanning the design space - Encapsulated functionality embodying modern C++11 programming techniques - Optimized containers and data types for efficient, generic, portable device code Updates to: - Quick start guide - Documentation - Utilities - CUTLASS Profiler Native Turing Tensor Cores - Efficient GEMM kernels targeting Turing Tensor Cores - Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands Coverage of existing CUTLASS functionality: - GEMM kernels targeting CUDA and Tensor Cores in NVIDIA GPUs - Volta Tensor Cores through native mma.sync and through WMMA API - Optimizations such as parallel reductions, threadblock rasterization, and intra-threadblock reductions - Batched GEMM operations - Complex-valued GEMMs Note: this commit and all that follow require a host compiler supporting C++11 or greater.
346 lines
11 KiB
Plaintext
346 lines
11 KiB
Plaintext
/***************************************************************************************************
|
|
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
|
* provided that the following conditions are met:
|
|
* * Redistributions of source code must retain the above copyright notice, this list of
|
|
* conditions and the following disclaimer.
|
|
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
|
* conditions and the following disclaimer in the documentation and/or other materials
|
|
* provided with the distribution.
|
|
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
|
* to endorse or promote products derived from this software without specific prior written
|
|
* permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
|
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
|
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
|
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
|
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
|
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
|
|
#include <iostream>
|
|
#include <vector>
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/layout/matrix.h"
|
|
#include "cutlass/gemm/device/gemm_batched.h"
|
|
|
|
#pragma warning( disable : 4503)
|
|
|
|
/*
|
|
This example demonstrates how to use cutlass to compute a batched strided gemm.
|
|
In this example, both A and B matrix are non-transpose and column major matrix
|
|
batched_C = batched_A x batched_B
|
|
As an example, matrix C can be seen as
|
|
-----------------------------------------------------------
|
|
(0,0,0) | (0,0,1) | (0,0,2) | (1,0,0) | (1,0,1) | (1,0,2) |
|
|
-----------------------------------------------------------
|
|
(0,1,0) | (0,1,1) | (0,1,2) | (1,1,0) | (1,1,1) | (1,1,2) |
|
|
-----------------------------------------------------------
|
|
(0,2,0) | (0,2,1) | (0,2,2) | (1,2,0) | (1,2,1) | (1,2,2) |
|
|
-----------------------------------------------------------
|
|
(0,3,0) | (0,3,1) | (0,3,2) | (1,3,0) | (1,3,1) | (1,3,2) |
|
|
-----------------------------------------------------------
|
|
(0,4,0) | (0,4,1) | (0,4,2) | (1,4,0) | (1,4,1) | (1,4,2) |
|
|
-----------------------------------------------------------
|
|
(0,5,0) | (0,5,1) | (0,5,2) | (1,5,0) | (1,5,1) | (1,5,2) |
|
|
-----------------------------------------------------------
|
|
batch 0 | batch 1
|
|
where we denote each element with (batch_idx, row_idx, column_idx)
|
|
In this example, batch size is 2, M is 6 and N is 3
|
|
The stride (batch_stride_C) between the first element of two batches is ldc * n
|
|
|
|
matrix A can be seen as
|
|
---------------------------------------
|
|
(0,0,0) | (0,0,1) | (1,0,0) | (1,0,1) |
|
|
---------------------------------------
|
|
(0,1,0) | (0,1,1) | (1,1,0) | (1,1,1) |
|
|
---------------------------------------
|
|
(0,2,0) | (0,2,1) | (1,2,0) | (1,2,1) |
|
|
---------------------------------------
|
|
(0,3,0) | (0,3,1) | (1,3,0) | (1,3,1) |
|
|
---------------------------------------
|
|
(0,4,0) | (0,4,1) | (1,4,0) | (1,4,1) |
|
|
---------------------------------------
|
|
(0,5,0) | (0,5,1) | (1,5,0) | (1,5,1) |
|
|
---------------------------------------
|
|
batch 0 | batch 1
|
|
, where batch size is 2, M is 6 and K is 2
|
|
The stride (batch_stride_B) between the first element of two batches is lda * k
|
|
|
|
matrix B can be seen as
|
|
-----------------------------
|
|
(0,0,0) | (0,0,1) | (0,0,2) |
|
|
----------------------------- batch 0
|
|
(0,1,0) | (0,1,1) | (0,1,2) |
|
|
-------------------------------------
|
|
(1,0,0) | (1,0,1) | (1,0,2) |
|
|
----------------------------- batch 1
|
|
(1,1,0) | (1,1,1) | (1,1,2) |
|
|
-----------------------------
|
|
, where the batch size is 2, N is 3 and K is 2
|
|
The stride (batch_stride_C) between the first element of two batches is k
|
|
|
|
|
|
*/
|
|
|
|
cudaError_t cutlass_strided_batched_sgemm(
|
|
int m,
|
|
int n,
|
|
int k,
|
|
float alpha,
|
|
float const *A,
|
|
int lda,
|
|
long long int batch_stride_A,
|
|
float const *B,
|
|
int ldb,
|
|
long long int batch_stride_B,
|
|
float *C,
|
|
int ldc,
|
|
long long int batch_stride_C,
|
|
float beta,
|
|
int batch_count) {
|
|
|
|
using Gemm = cutlass::gemm::device::GemmBatched<
|
|
float, cutlass::layout::ColumnMajor,
|
|
float, cutlass::layout::ColumnMajor,
|
|
float, cutlass::layout::ColumnMajor
|
|
>;
|
|
|
|
Gemm gemm_op;
|
|
|
|
cutlass::Status status = gemm_op({
|
|
{m, n, k},
|
|
{A, lda},
|
|
batch_stride_A,
|
|
{B, ldb},
|
|
batch_stride_B,
|
|
{C, ldc},
|
|
batch_stride_C,
|
|
{C, ldc},
|
|
batch_stride_C,
|
|
{alpha, beta},
|
|
batch_count
|
|
});
|
|
|
|
if (status != cutlass::Status::kSuccess) {
|
|
return cudaErrorUnknown;
|
|
}
|
|
|
|
return cudaSuccess;
|
|
}
|
|
|
|
template<typename T>
|
|
cudaError_t strided_batched_gemm_nn_reference(
|
|
int m,
|
|
int n,
|
|
int k,
|
|
T alpha,
|
|
std::vector<T> const &A,
|
|
int lda,
|
|
long long int batch_stride_A,
|
|
std::vector<T> const &B,
|
|
int ldb,
|
|
long long int batch_stride_B,
|
|
std::vector<T> &C,
|
|
int ldc,
|
|
long long int batch_stride_C,
|
|
T beta,
|
|
int batch_count) {
|
|
/*
|
|
strided batched gemm NN
|
|
*/
|
|
|
|
cudaError_t result = cudaSuccess;
|
|
|
|
if (A.size() < lda * k * batch_count) {
|
|
std::cout << "the size of A is too small" << std::endl;
|
|
return cudaErrorInvalidValue;
|
|
}
|
|
if (B.size() < ldb * n) {
|
|
std::cout << "the size of B is too small" << std::endl;
|
|
return cudaErrorInvalidValue;
|
|
}
|
|
if (C.size() < ldc * n * batch_count) {
|
|
std::cout << "the size of C is too small" << std::endl;
|
|
return cudaErrorInvalidValue;
|
|
}
|
|
|
|
for (int batch_idx = 0; batch_idx < batch_count; batch_idx++) {
|
|
for (int n_idx = 0; n_idx < n; n_idx++) {
|
|
for (int m_idx = 0; m_idx < m; m_idx++) {
|
|
T accum = beta * C[batch_idx * batch_stride_C + n_idx * ldc + m_idx];
|
|
for (int k_idx = 0; k_idx < k; k_idx++) {
|
|
accum += alpha
|
|
* A[batch_idx * batch_stride_A + k_idx * lda + m_idx]
|
|
* B[batch_idx * batch_stride_B + n_idx * ldb + k_idx];
|
|
}
|
|
C[batch_idx * batch_stride_C + n_idx * ldc + m_idx] = accum;
|
|
}
|
|
}
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
int main() {
|
|
|
|
// Arbitrary problem size
|
|
int const m = 520;
|
|
int const n = 219;
|
|
int const k = 129;
|
|
int const batch_count = 17;
|
|
|
|
// A, B are non-transpose, column major
|
|
int const lda = m;
|
|
int const ldb = k * batch_count;
|
|
int const ldc = m;
|
|
|
|
int const count_A = batch_count * lda * k;
|
|
int const count_B = ldb * n;
|
|
int const count_C = batch_count * ldc * n;
|
|
|
|
// the memory is batched along K dimension
|
|
long long int batch_stride_A = static_cast<long long int>(lda) * static_cast<long long int>(k);
|
|
long long int batch_stride_B = static_cast<long long int>(k);
|
|
long long int batch_stride_C = static_cast<long long int>(ldc) * static_cast<long long int>(n);
|
|
|
|
// alpha and beta
|
|
float alpha = 1.0f;
|
|
float beta = 2.0f;
|
|
|
|
cudaError_t result = cudaSuccess;
|
|
|
|
// allocate the host memory
|
|
std::vector<float> host_A(count_A);
|
|
std::vector<float> host_B(count_B);
|
|
std::vector<float> host_C(count_C);
|
|
std::vector<float> result_C(count_C);
|
|
|
|
// allocate the device memory
|
|
float *A;
|
|
float *B;
|
|
float *C;
|
|
|
|
result = cudaMalloc(&A, count_A * sizeof(float));
|
|
if (result != cudaSuccess) {
|
|
std::cerr << "cudaMalloc result = " << result << std::endl;
|
|
return result;
|
|
}
|
|
result = cudaMalloc(&B, count_B * sizeof(float));
|
|
if (result != cudaSuccess) {
|
|
std::cerr << "cudaMalloc result = " << result << std::endl;
|
|
return result;
|
|
}
|
|
result = cudaMalloc(&C, count_C * sizeof(float));
|
|
if (result != cudaSuccess) {
|
|
std::cerr << "cudaMalloc result = " << result << std::endl;
|
|
return result;
|
|
}
|
|
|
|
// Limit range to avoid floating-point errors
|
|
int const kRange = 8;
|
|
|
|
// fill A
|
|
for (int b_idx = 0; b_idx < batch_count; b_idx++) {
|
|
for (int col_idx = 0; col_idx < k; col_idx++) {
|
|
for (int row_idx = 0; row_idx < m; row_idx++) {
|
|
host_A[row_idx + col_idx * lda + b_idx * lda * k] = static_cast<float>((row_idx + col_idx * lda + b_idx * lda * k) % kRange);
|
|
}
|
|
}
|
|
}
|
|
// fill B
|
|
for (int b_idx = 0; b_idx < batch_count; b_idx++) {
|
|
for (int col_idx = 0; col_idx < n; col_idx++) {
|
|
for (int row_idx = 0; row_idx < k; row_idx++) {
|
|
host_B[row_idx + col_idx * ldb + b_idx * k] = static_cast<float>(((n + k * ldb + batch_count * k) - (row_idx + col_idx * ldb + b_idx * k)) % kRange);
|
|
}
|
|
}
|
|
}
|
|
// fill C
|
|
for (int b_idx = 0; b_idx < batch_count; b_idx++) {
|
|
for (int col_idx = 0; col_idx < n; col_idx++) {
|
|
for (int row_idx = 0; row_idx < m; row_idx++) {
|
|
host_C[row_idx + col_idx * ldc + b_idx * ldc * n] = 1.f;
|
|
}
|
|
}
|
|
}
|
|
|
|
// ref memory
|
|
std::vector<float> ref_A(host_A);
|
|
std::vector<float> ref_B(host_B);
|
|
std::vector<float> ref_C(host_C);
|
|
// copy host memory to device
|
|
result = cudaMemcpy(A, host_A.data(), count_A * sizeof(float), cudaMemcpyHostToDevice);
|
|
if (result != cudaSuccess) {
|
|
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
|
return result;
|
|
}
|
|
result = cudaMemcpy(B, host_B.data(), count_B * sizeof(float), cudaMemcpyHostToDevice);
|
|
if (result != cudaSuccess) {
|
|
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
|
return result;
|
|
}
|
|
result = cudaMemcpy(C, host_C.data(), count_C * sizeof(float), cudaMemcpyHostToDevice);
|
|
if (result != cudaSuccess) {
|
|
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
|
return result;
|
|
}
|
|
|
|
// run cutlass
|
|
result = cutlass_strided_batched_sgemm(
|
|
m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C,
|
|
beta, batch_count);
|
|
if (result != cudaSuccess)
|
|
return result;
|
|
|
|
// copy device memory to host
|
|
result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost);
|
|
if (result != cudaSuccess) {
|
|
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
|
return result;
|
|
}
|
|
|
|
//compare with reference code
|
|
result = strided_batched_gemm_nn_reference(m, n, k, alpha, ref_A, lda, batch_stride_A, ref_B, ldb, batch_stride_B, ref_C, ldc, batch_stride_C,
|
|
beta, batch_count);
|
|
if (result != 0)
|
|
return result;
|
|
|
|
// Expect bit-level accuracy for this simple example
|
|
if (ref_C != result_C) {
|
|
std::cout << "CUTLASS strided batched gemm does not run correctly" << std::endl;
|
|
return cudaErrorUnknown;
|
|
}
|
|
|
|
// free memory
|
|
result = cudaFree(A);
|
|
if (result != cudaSuccess) {
|
|
std::cerr << "cudaFree result = " << result << std::endl;
|
|
return result;
|
|
}
|
|
result = cudaFree(B);
|
|
if (result != cudaSuccess) {
|
|
std::cerr << "cudaFree result = " << result << std::endl;
|
|
return result;
|
|
}
|
|
result = cudaFree(C);
|
|
if (result != cudaSuccess) {
|
|
std::cerr << "cudaFree result = " << result << std::endl;
|
|
return result;
|
|
}
|
|
|
|
|
|
if (result == cudaSuccess) {
|
|
std::cout << "Passed." << std::endl;
|
|
}
|
|
|
|
// Exit.
|
|
return result == cudaSuccess ? 0 : -1;
|
|
}
|