cutlass/examples/05_batched_gemm/batched_gemm.cu

346 lines
11 KiB
Plaintext
Raw Normal View History

2018-09-19 07:58:03 +08:00
/***************************************************************************************************
2021-02-26 22:58:26 +08:00
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
2018-09-19 07:58:03 +08:00
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2018-09-19 07:58:03 +08:00
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
#include <vector>
2018-09-19 07:58:03 +08:00
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/gemm/device/gemm_batched.h"
2018-09-19 07:58:03 +08:00
#pragma warning( disable : 4503)
2018-09-19 07:58:03 +08:00
/*
This example demonstrates how to use cutlass to compute a batched strided gemm.
In this example, both A and B matrix are non-transpose and column major matrix
batched_C = batched_A x batched_B
As an example, matrix C can be seen as
-----------------------------------------------------------
(0,0,0) | (0,0,1) | (0,0,2) | (1,0,0) | (1,0,1) | (1,0,2) |
-----------------------------------------------------------
(0,1,0) | (0,1,1) | (0,1,2) | (1,1,0) | (1,1,1) | (1,1,2) |
-----------------------------------------------------------
(0,2,0) | (0,2,1) | (0,2,2) | (1,2,0) | (1,2,1) | (1,2,2) |
-----------------------------------------------------------
(0,3,0) | (0,3,1) | (0,3,2) | (1,3,0) | (1,3,1) | (1,3,2) |
-----------------------------------------------------------
(0,4,0) | (0,4,1) | (0,4,2) | (1,4,0) | (1,4,1) | (1,4,2) |
-----------------------------------------------------------
(0,5,0) | (0,5,1) | (0,5,2) | (1,5,0) | (1,5,1) | (1,5,2) |
-----------------------------------------------------------
batch 0 | batch 1
where we denote each element with (batch_idx, row_idx, column_idx)
In this example, batch size is 2, M is 6 and N is 3
The stride (batch_stride_C) between the first element of two batches is ldc * n
matrix A can be seen as
---------------------------------------
(0,0,0) | (0,0,1) | (1,0,0) | (1,0,1) |
---------------------------------------
(0,1,0) | (0,1,1) | (1,1,0) | (1,1,1) |
---------------------------------------
(0,2,0) | (0,2,1) | (1,2,0) | (1,2,1) |
---------------------------------------
(0,3,0) | (0,3,1) | (1,3,0) | (1,3,1) |
---------------------------------------
(0,4,0) | (0,4,1) | (1,4,0) | (1,4,1) |
---------------------------------------
(0,5,0) | (0,5,1) | (1,5,0) | (1,5,1) |
---------------------------------------
batch 0 | batch 1
, where batch size is 2, M is 6 and K is 2
The stride (batch_stride_B) between the first element of two batches is lda * k
matrix B can be seen as
-----------------------------
(0,0,0) | (0,0,1) | (0,0,2) |
----------------------------- batch 0
(0,1,0) | (0,1,1) | (0,1,2) |
-------------------------------------
(1,0,0) | (1,0,1) | (1,0,2) |
----------------------------- batch 1
(1,1,0) | (1,1,1) | (1,1,2) |
-----------------------------
, where the batch size is 2, N is 3 and K is 2
The stride (batch_stride_C) between the first element of two batches is k
*/
cudaError_t cutlass_strided_batched_sgemm(
int m,
int n,
int k,
float alpha,
float const *A,
2018-09-19 07:58:03 +08:00
int lda,
long long int batch_stride_A,
float const *B,
int ldb,
long long int batch_stride_B,
float *C,
int ldc,
long long int batch_stride_C,
float beta,
int batch_count) {
using Gemm = cutlass::gemm::device::GemmBatched<
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::ColumnMajor
>;
Gemm gemm_op;
cutlass::Status status = gemm_op({
{m, n, k},
{A, lda},
batch_stride_A,
{B, ldb},
batch_stride_B,
{C, ldc},
batch_stride_C,
{C, ldc},
batch_stride_C,
{alpha, beta},
2018-09-19 07:58:03 +08:00
batch_count
});
if (status != cutlass::Status::kSuccess) {
return cudaErrorUnknown;
2018-09-19 07:58:03 +08:00
}
return cudaSuccess;
2018-09-19 07:58:03 +08:00
}
template<typename T>
cudaError_t strided_batched_gemm_nn_reference(
int m,
int n,
int k,
T alpha,
std::vector<T> const &A,
2018-09-19 07:58:03 +08:00
int lda,
long long int batch_stride_A,
std::vector<T> const &B,
int ldb,
long long int batch_stride_B,
std::vector<T> &C,
int ldc,
long long int batch_stride_C,
T beta,
int batch_count) {
/*
strided batched gemm NN
*/
cudaError_t result = cudaSuccess;
if (A.size() < lda * k * batch_count) {
std::cout << "the size of A is too small" << std::endl;
return cudaErrorInvalidValue;
}
if (B.size() < ldb * n) {
std::cout << "the size of B is too small" << std::endl;
return cudaErrorInvalidValue;
}
if (C.size() < ldc * n * batch_count) {
std::cout << "the size of C is too small" << std::endl;
return cudaErrorInvalidValue;
}
for (int batch_idx = 0; batch_idx < batch_count; batch_idx++) {
for (int n_idx = 0; n_idx < n; n_idx++) {
for (int m_idx = 0; m_idx < m; m_idx++) {
T accum = beta * C[batch_idx * batch_stride_C + n_idx * ldc + m_idx];
for (int k_idx = 0; k_idx < k; k_idx++) {
accum += alpha
* A[batch_idx * batch_stride_A + k_idx * lda + m_idx]
* B[batch_idx * batch_stride_B + n_idx * ldb + k_idx];
}
C[batch_idx * batch_stride_C + n_idx * ldc + m_idx] = accum;
}
}
}
return result;
}
int main() {
// Arbitrary problem size
int const m = 520;
int const n = 219;
int const k = 129;
int const batch_count = 17;
2018-09-19 07:58:03 +08:00
// A, B are non-transpose, column major
int const lda = m;
int const ldb = k * batch_count;
int const ldc = m;
int const count_A = batch_count * lda * k;
int const count_B = ldb * n;
int const count_C = batch_count * ldc * n;
// the memory is batched along K dimension
long long int batch_stride_A = static_cast<long long int>(lda) * static_cast<long long int>(k);
long long int batch_stride_B = static_cast<long long int>(k);
long long int batch_stride_C = static_cast<long long int>(ldc) * static_cast<long long int>(n);
// alpha and beta
float alpha = 1.0f;
float beta = 2.0f;
cudaError_t result = cudaSuccess;
// allocate the host memory
std::vector<float> host_A(count_A);
std::vector<float> host_B(count_B);
std::vector<float> host_C(count_C);
std::vector<float> result_C(count_C);
// allocate the device memory
float *A;
float *B;
float *C;
result = cudaMalloc(&A, count_A * sizeof(float));
if (result != cudaSuccess) {
std::cerr << "cudaMalloc result = " << result << std::endl;
return result;
}
result = cudaMalloc(&B, count_B * sizeof(float));
if (result != cudaSuccess) {
std::cerr << "cudaMalloc result = " << result << std::endl;
return result;
}
result = cudaMalloc(&C, count_C * sizeof(float));
if (result != cudaSuccess) {
std::cerr << "cudaMalloc result = " << result << std::endl;
return result;
}
// Limit range to avoid floating-point errors
int const kRange = 8;
2018-09-19 07:58:03 +08:00
// fill A
for (int b_idx = 0; b_idx < batch_count; b_idx++) {
for (int col_idx = 0; col_idx < k; col_idx++) {
for (int row_idx = 0; row_idx < m; row_idx++) {
host_A[row_idx + col_idx * lda + b_idx * lda * k] = static_cast<float>((row_idx + col_idx * lda + b_idx * lda * k) % kRange);
2018-09-19 07:58:03 +08:00
}
}
}
// fill B
for (int b_idx = 0; b_idx < batch_count; b_idx++) {
for (int col_idx = 0; col_idx < n; col_idx++) {
for (int row_idx = 0; row_idx < k; row_idx++) {
host_B[row_idx + col_idx * ldb + b_idx * k] = static_cast<float>(((n + k * ldb + batch_count * k) - (row_idx + col_idx * ldb + b_idx * k)) % kRange);
2018-09-19 07:58:03 +08:00
}
}
}
// fill C
for (int b_idx = 0; b_idx < batch_count; b_idx++) {
for (int col_idx = 0; col_idx < n; col_idx++) {
for (int row_idx = 0; row_idx < m; row_idx++) {
host_C[row_idx + col_idx * ldc + b_idx * ldc * n] = 1.f;
}
}
}
// ref memory
std::vector<float> ref_A(host_A);
std::vector<float> ref_B(host_B);
std::vector<float> ref_C(host_C);
// copy host memory to device
result = cudaMemcpy(A, host_A.data(), count_A * sizeof(float), cudaMemcpyHostToDevice);
if (result != cudaSuccess) {
std::cerr << "cudaMemcpy result = " << result << std::endl;
return result;
}
result = cudaMemcpy(B, host_B.data(), count_B * sizeof(float), cudaMemcpyHostToDevice);
if (result != cudaSuccess) {
std::cerr << "cudaMemcpy result = " << result << std::endl;
return result;
}
result = cudaMemcpy(C, host_C.data(), count_C * sizeof(float), cudaMemcpyHostToDevice);
if (result != cudaSuccess) {
std::cerr << "cudaMemcpy result = " << result << std::endl;
return result;
}
// run cutlass
result = cutlass_strided_batched_sgemm(
m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C,
beta, batch_count);
2018-09-19 07:58:03 +08:00
if (result != cudaSuccess)
return result;
// copy device memory to host
result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost);
if (result != cudaSuccess) {
std::cerr << "cudaMemcpy result = " << result << std::endl;
return result;
}
//compare with reference code
result = strided_batched_gemm_nn_reference(m, n, k, alpha, ref_A, lda, batch_stride_A, ref_B, ldb, batch_stride_B, ref_C, ldc, batch_stride_C,
beta, batch_count);
2018-09-19 07:58:03 +08:00
if (result != 0)
return result;
// Expect bit-level accuracy for this simple example
2018-09-19 07:58:03 +08:00
if (ref_C != result_C) {
std::cout << "CUTLASS strided batched gemm does not run correctly" << std::endl;
return cudaErrorUnknown;
}
// free memory
result = cudaFree(A);
if (result != cudaSuccess) {
std::cerr << "cudaFree result = " << result << std::endl;
return result;
}
result = cudaFree(B);
if (result != cudaSuccess) {
std::cerr << "cudaFree result = " << result << std::endl;
return result;
}
result = cudaFree(C);
if (result != cudaSuccess) {
std::cerr << "cudaFree result = " << result << std::endl;
return result;
}
if (result == cudaSuccess) {
std::cout << "Passed." << std::endl;
}
// Exit.
return result == cudaSuccess ? 0 : -1;
}