2018-09-19 07:58:03 +08:00
|
|
|
/***************************************************************************************************
|
2022-04-24 03:02:38 +08:00
|
|
|
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
2018-09-19 07:58:03 +08:00
|
|
|
*
|
2022-04-24 03:02:38 +08:00
|
|
|
* Redistribution and use in source and binary forms, with or without
|
|
|
|
* modification, are permitted provided that the following conditions are met:
|
2018-09-19 07:58:03 +08:00
|
|
|
*
|
2022-04-24 03:02:38 +08:00
|
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
|
|
* list of conditions and the following disclaimer.
|
|
|
|
*
|
|
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
|
|
* and/or other materials provided with the distribution.
|
|
|
|
*
|
|
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
|
|
* contributors may be used to endorse or promote products derived from
|
|
|
|
* this software without specific prior written permission.
|
|
|
|
*
|
|
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
2018-09-19 07:58:03 +08:00
|
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
*
|
|
|
|
**************************************************************************************************/
|
|
|
|
|
|
|
|
#include <iostream>
|
|
|
|
#include <vector>
|
2019-11-20 08:55:34 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
#include "cutlass/cutlass.h"
|
2019-11-20 08:55:34 +08:00
|
|
|
#include "cutlass/layout/matrix.h"
|
2022-02-18 09:01:05 +08:00
|
|
|
#include "cutlass/gemm/device/gemm_array.h"
|
2019-11-20 08:55:34 +08:00
|
|
|
#include "cutlass/gemm/device/gemm_batched.h"
|
2018-09-19 07:58:03 +08:00
|
|
|
|
2019-03-21 01:49:17 +08:00
|
|
|
#pragma warning( disable : 4503)
|
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
/*
|
2022-04-24 03:02:38 +08:00
|
|
|
This example demonstrates how to use cutlass to compute a batched strided gemm in two different ways:
|
2022-02-18 09:01:05 +08:00
|
|
|
1. By specifying pointers to the first matrices of the batch and the stride between the consecutive
|
|
|
|
matrices of the batch (this is called a strided batched gemm).
|
|
|
|
2. By copying pointers to all matrices of the batch to the device memory (this is called an array gemm).
|
2018-09-19 07:58:03 +08:00
|
|
|
In this example, both A and B matrix are non-transpose and column major matrix
|
|
|
|
batched_C = batched_A x batched_B
|
|
|
|
As an example, matrix C can be seen as
|
|
|
|
-----------------------------------------------------------
|
|
|
|
(0,0,0) | (0,0,1) | (0,0,2) | (1,0,0) | (1,0,1) | (1,0,2) |
|
|
|
|
-----------------------------------------------------------
|
|
|
|
(0,1,0) | (0,1,1) | (0,1,2) | (1,1,0) | (1,1,1) | (1,1,2) |
|
|
|
|
-----------------------------------------------------------
|
|
|
|
(0,2,0) | (0,2,1) | (0,2,2) | (1,2,0) | (1,2,1) | (1,2,2) |
|
|
|
|
-----------------------------------------------------------
|
|
|
|
(0,3,0) | (0,3,1) | (0,3,2) | (1,3,0) | (1,3,1) | (1,3,2) |
|
|
|
|
-----------------------------------------------------------
|
|
|
|
(0,4,0) | (0,4,1) | (0,4,2) | (1,4,0) | (1,4,1) | (1,4,2) |
|
|
|
|
-----------------------------------------------------------
|
|
|
|
(0,5,0) | (0,5,1) | (0,5,2) | (1,5,0) | (1,5,1) | (1,5,2) |
|
|
|
|
-----------------------------------------------------------
|
|
|
|
batch 0 | batch 1
|
|
|
|
where we denote each element with (batch_idx, row_idx, column_idx)
|
|
|
|
In this example, batch size is 2, M is 6 and N is 3
|
|
|
|
The stride (batch_stride_C) between the first element of two batches is ldc * n
|
|
|
|
|
|
|
|
matrix A can be seen as
|
|
|
|
---------------------------------------
|
|
|
|
(0,0,0) | (0,0,1) | (1,0,0) | (1,0,1) |
|
|
|
|
---------------------------------------
|
|
|
|
(0,1,0) | (0,1,1) | (1,1,0) | (1,1,1) |
|
|
|
|
---------------------------------------
|
|
|
|
(0,2,0) | (0,2,1) | (1,2,0) | (1,2,1) |
|
|
|
|
---------------------------------------
|
|
|
|
(0,3,0) | (0,3,1) | (1,3,0) | (1,3,1) |
|
|
|
|
---------------------------------------
|
|
|
|
(0,4,0) | (0,4,1) | (1,4,0) | (1,4,1) |
|
|
|
|
---------------------------------------
|
|
|
|
(0,5,0) | (0,5,1) | (1,5,0) | (1,5,1) |
|
|
|
|
---------------------------------------
|
|
|
|
batch 0 | batch 1
|
|
|
|
, where batch size is 2, M is 6 and K is 2
|
|
|
|
The stride (batch_stride_B) between the first element of two batches is lda * k
|
|
|
|
|
|
|
|
matrix B can be seen as
|
|
|
|
-----------------------------
|
|
|
|
(0,0,0) | (0,0,1) | (0,0,2) |
|
|
|
|
----------------------------- batch 0
|
|
|
|
(0,1,0) | (0,1,1) | (0,1,2) |
|
|
|
|
-------------------------------------
|
|
|
|
(1,0,0) | (1,0,1) | (1,0,2) |
|
|
|
|
----------------------------- batch 1
|
|
|
|
(1,1,0) | (1,1,1) | (1,1,2) |
|
|
|
|
-----------------------------
|
|
|
|
, where the batch size is 2, N is 3 and K is 2
|
|
|
|
The stride (batch_stride_C) between the first element of two batches is k
|
|
|
|
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
2022-02-18 09:01:05 +08:00
|
|
|
cudaError_t cutlass_array_sgemm(
|
|
|
|
int m,
|
|
|
|
int n,
|
|
|
|
int k,
|
|
|
|
float alpha,
|
|
|
|
float const * const *A,
|
|
|
|
int lda,
|
|
|
|
float const * const *B,
|
|
|
|
int ldb,
|
|
|
|
float * const *C,
|
|
|
|
int ldc,
|
|
|
|
float beta,
|
|
|
|
int batch_count) {
|
|
|
|
|
|
|
|
using Gemm = cutlass::gemm::device::GemmArray<
|
|
|
|
float, cutlass::layout::ColumnMajor,
|
|
|
|
float, cutlass::layout::ColumnMajor,
|
|
|
|
float, cutlass::layout::ColumnMajor
|
|
|
|
>;
|
|
|
|
|
|
|
|
Gemm gemm_op;
|
|
|
|
|
|
|
|
cutlass::Status status = gemm_op({
|
|
|
|
{m, n, k},
|
|
|
|
A, lda,
|
|
|
|
B, ldb,
|
|
|
|
C, ldc,
|
|
|
|
C, ldc,
|
|
|
|
{alpha, beta},
|
|
|
|
batch_count
|
|
|
|
});
|
|
|
|
|
|
|
|
if (status != cutlass::Status::kSuccess) {
|
|
|
|
return cudaErrorUnknown;
|
|
|
|
}
|
|
|
|
|
|
|
|
return cudaSuccess;
|
|
|
|
}
|
|
|
|
|
2019-11-20 08:55:34 +08:00
|
|
|
cudaError_t cutlass_strided_batched_sgemm(
|
|
|
|
int m,
|
|
|
|
int n,
|
|
|
|
int k,
|
|
|
|
float alpha,
|
|
|
|
float const *A,
|
2018-09-19 07:58:03 +08:00
|
|
|
int lda,
|
|
|
|
long long int batch_stride_A,
|
|
|
|
float const *B,
|
|
|
|
int ldb,
|
|
|
|
long long int batch_stride_B,
|
|
|
|
float *C,
|
|
|
|
int ldc,
|
|
|
|
long long int batch_stride_C,
|
|
|
|
float beta,
|
|
|
|
int batch_count) {
|
2019-11-20 08:55:34 +08:00
|
|
|
|
|
|
|
using Gemm = cutlass::gemm::device::GemmBatched<
|
|
|
|
float, cutlass::layout::ColumnMajor,
|
|
|
|
float, cutlass::layout::ColumnMajor,
|
|
|
|
float, cutlass::layout::ColumnMajor
|
|
|
|
>;
|
|
|
|
|
|
|
|
Gemm gemm_op;
|
|
|
|
|
|
|
|
cutlass::Status status = gemm_op({
|
|
|
|
{m, n, k},
|
|
|
|
{A, lda},
|
|
|
|
batch_stride_A,
|
|
|
|
{B, ldb},
|
|
|
|
batch_stride_B,
|
|
|
|
{C, ldc},
|
|
|
|
batch_stride_C,
|
|
|
|
{C, ldc},
|
|
|
|
batch_stride_C,
|
|
|
|
{alpha, beta},
|
2018-09-19 07:58:03 +08:00
|
|
|
batch_count
|
2019-11-20 08:55:34 +08:00
|
|
|
});
|
|
|
|
|
|
|
|
if (status != cutlass::Status::kSuccess) {
|
|
|
|
return cudaErrorUnknown;
|
2018-09-19 07:58:03 +08:00
|
|
|
}
|
2019-11-20 08:55:34 +08:00
|
|
|
|
|
|
|
return cudaSuccess;
|
2018-09-19 07:58:03 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
template<typename T>
|
2019-11-20 08:55:34 +08:00
|
|
|
cudaError_t strided_batched_gemm_nn_reference(
|
|
|
|
int m,
|
|
|
|
int n,
|
|
|
|
int k,
|
|
|
|
T alpha,
|
|
|
|
std::vector<T> const &A,
|
2018-09-19 07:58:03 +08:00
|
|
|
int lda,
|
|
|
|
long long int batch_stride_A,
|
|
|
|
std::vector<T> const &B,
|
|
|
|
int ldb,
|
|
|
|
long long int batch_stride_B,
|
|
|
|
std::vector<T> &C,
|
|
|
|
int ldc,
|
|
|
|
long long int batch_stride_C,
|
|
|
|
T beta,
|
|
|
|
int batch_count) {
|
|
|
|
/*
|
|
|
|
strided batched gemm NN
|
|
|
|
*/
|
|
|
|
|
|
|
|
cudaError_t result = cudaSuccess;
|
|
|
|
|
|
|
|
if (A.size() < lda * k * batch_count) {
|
|
|
|
std::cout << "the size of A is too small" << std::endl;
|
|
|
|
return cudaErrorInvalidValue;
|
|
|
|
}
|
|
|
|
if (B.size() < ldb * n) {
|
|
|
|
std::cout << "the size of B is too small" << std::endl;
|
|
|
|
return cudaErrorInvalidValue;
|
|
|
|
}
|
|
|
|
if (C.size() < ldc * n * batch_count) {
|
|
|
|
std::cout << "the size of C is too small" << std::endl;
|
|
|
|
return cudaErrorInvalidValue;
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int batch_idx = 0; batch_idx < batch_count; batch_idx++) {
|
|
|
|
for (int n_idx = 0; n_idx < n; n_idx++) {
|
|
|
|
for (int m_idx = 0; m_idx < m; m_idx++) {
|
|
|
|
T accum = beta * C[batch_idx * batch_stride_C + n_idx * ldc + m_idx];
|
|
|
|
for (int k_idx = 0; k_idx < k; k_idx++) {
|
|
|
|
accum += alpha
|
|
|
|
* A[batch_idx * batch_stride_A + k_idx * lda + m_idx]
|
|
|
|
* B[batch_idx * batch_stride_B + n_idx * ldb + k_idx];
|
|
|
|
}
|
|
|
|
C[batch_idx * batch_stride_C + n_idx * ldc + m_idx] = accum;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
2022-04-24 03:02:38 +08:00
|
|
|
|
2022-02-18 09:01:05 +08:00
|
|
|
cudaError_t run_batched_gemm(bool use_array) {
|
|
|
|
|
|
|
|
const char* gemm_desc = use_array ? "array" : "strided batched";
|
|
|
|
std::cout << "Running " << gemm_desc << " gemm" << std::endl;
|
2019-11-20 08:55:34 +08:00
|
|
|
|
|
|
|
// Arbitrary problem size
|
|
|
|
int const m = 520;
|
|
|
|
int const n = 219;
|
|
|
|
int const k = 129;
|
|
|
|
int const batch_count = 17;
|
2018-09-19 07:58:03 +08:00
|
|
|
|
|
|
|
// A, B are non-transpose, column major
|
|
|
|
int const lda = m;
|
|
|
|
int const ldb = k * batch_count;
|
|
|
|
int const ldc = m;
|
|
|
|
|
|
|
|
int const count_A = batch_count * lda * k;
|
|
|
|
int const count_B = ldb * n;
|
|
|
|
int const count_C = batch_count * ldc * n;
|
|
|
|
|
|
|
|
// the memory is batched along K dimension
|
|
|
|
long long int batch_stride_A = static_cast<long long int>(lda) * static_cast<long long int>(k);
|
|
|
|
long long int batch_stride_B = static_cast<long long int>(k);
|
|
|
|
long long int batch_stride_C = static_cast<long long int>(ldc) * static_cast<long long int>(n);
|
|
|
|
|
|
|
|
// alpha and beta
|
|
|
|
float alpha = 1.0f;
|
|
|
|
float beta = 2.0f;
|
|
|
|
|
|
|
|
cudaError_t result = cudaSuccess;
|
|
|
|
|
|
|
|
// allocate the host memory
|
|
|
|
std::vector<float> host_A(count_A);
|
|
|
|
std::vector<float> host_B(count_B);
|
|
|
|
std::vector<float> host_C(count_C);
|
|
|
|
std::vector<float> result_C(count_C);
|
|
|
|
|
|
|
|
// allocate the device memory
|
|
|
|
float *A;
|
|
|
|
float *B;
|
|
|
|
float *C;
|
|
|
|
|
|
|
|
result = cudaMalloc(&A, count_A * sizeof(float));
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "cudaMalloc result = " << result << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
result = cudaMalloc(&B, count_B * sizeof(float));
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "cudaMalloc result = " << result << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
result = cudaMalloc(&C, count_C * sizeof(float));
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "cudaMalloc result = " << result << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
2019-11-20 08:55:34 +08:00
|
|
|
// Limit range to avoid floating-point errors
|
|
|
|
int const kRange = 8;
|
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
// fill A
|
|
|
|
for (int b_idx = 0; b_idx < batch_count; b_idx++) {
|
|
|
|
for (int col_idx = 0; col_idx < k; col_idx++) {
|
|
|
|
for (int row_idx = 0; row_idx < m; row_idx++) {
|
2019-11-20 08:55:34 +08:00
|
|
|
host_A[row_idx + col_idx * lda + b_idx * lda * k] = static_cast<float>((row_idx + col_idx * lda + b_idx * lda * k) % kRange);
|
2018-09-19 07:58:03 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// fill B
|
|
|
|
for (int b_idx = 0; b_idx < batch_count; b_idx++) {
|
|
|
|
for (int col_idx = 0; col_idx < n; col_idx++) {
|
|
|
|
for (int row_idx = 0; row_idx < k; row_idx++) {
|
2019-11-20 08:55:34 +08:00
|
|
|
host_B[row_idx + col_idx * ldb + b_idx * k] = static_cast<float>(((n + k * ldb + batch_count * k) - (row_idx + col_idx * ldb + b_idx * k)) % kRange);
|
2018-09-19 07:58:03 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// fill C
|
|
|
|
for (int b_idx = 0; b_idx < batch_count; b_idx++) {
|
|
|
|
for (int col_idx = 0; col_idx < n; col_idx++) {
|
|
|
|
for (int row_idx = 0; row_idx < m; row_idx++) {
|
|
|
|
host_C[row_idx + col_idx * ldc + b_idx * ldc * n] = 1.f;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// ref memory
|
|
|
|
std::vector<float> ref_A(host_A);
|
|
|
|
std::vector<float> ref_B(host_B);
|
|
|
|
std::vector<float> ref_C(host_C);
|
|
|
|
// copy host memory to device
|
|
|
|
result = cudaMemcpy(A, host_A.data(), count_A * sizeof(float), cudaMemcpyHostToDevice);
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
result = cudaMemcpy(B, host_B.data(), count_B * sizeof(float), cudaMemcpyHostToDevice);
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
result = cudaMemcpy(C, host_C.data(), count_C * sizeof(float), cudaMemcpyHostToDevice);
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
// run cutlass
|
2022-02-18 09:01:05 +08:00
|
|
|
if (use_array) {
|
|
|
|
// allocate the host memory for the pointers to the matrices of the batch
|
|
|
|
std::vector<float*> host_ptr_A(batch_count);
|
|
|
|
std::vector<float*> host_ptr_B(batch_count);
|
|
|
|
std::vector<float*> host_ptr_C(batch_count);
|
|
|
|
|
|
|
|
// permute the batch elements to emphasize that GemmArray does not depend on matrices being separated by a fixed stride
|
|
|
|
std::vector<size_t> permutation = {14, 11, 3, 10, 1, 13, 9, 4, 6, 16, 8, 15, 7, 12, 0, 2, 5};
|
|
|
|
for (size_t b_idx = 0; b_idx < batch_count; b_idx++) {
|
|
|
|
host_ptr_A[b_idx] = A + permutation[b_idx] * batch_stride_A;
|
|
|
|
host_ptr_B[b_idx] = B + permutation[b_idx] * batch_stride_B;
|
|
|
|
host_ptr_C[b_idx] = C + permutation[b_idx] * batch_stride_C;
|
|
|
|
}
|
|
|
|
|
|
|
|
// allocate the corresponding device memory
|
|
|
|
float const **ptr_A;
|
|
|
|
float const **ptr_B;
|
|
|
|
float **ptr_C;
|
|
|
|
|
|
|
|
result = cudaMalloc(&ptr_A, batch_count * sizeof(float*));
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "cudaMalloc result = " << result << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
result = cudaMalloc(&ptr_B, batch_count * sizeof(float*));
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "cudaMalloc result = " << result << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
result = cudaMalloc(&ptr_C, batch_count * sizeof(float*));
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "cudaMalloc result = " << result << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
// copy the matrix pointers to the device
|
|
|
|
result = cudaMemcpy(ptr_A, host_ptr_A.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
result = cudaMemcpy(ptr_B, host_ptr_B.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
result = cudaMemcpy(ptr_C, host_ptr_C.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
result = cutlass_array_sgemm(m, n, k, alpha, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, beta, batch_count);
|
|
|
|
|
|
|
|
if (result != cudaSuccess)
|
|
|
|
return result;
|
|
|
|
} else {
|
|
|
|
result = cutlass_strided_batched_sgemm(
|
|
|
|
m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C,
|
|
|
|
beta, batch_count);
|
|
|
|
if (result != cudaSuccess)
|
|
|
|
return result;
|
|
|
|
}
|
2018-09-19 07:58:03 +08:00
|
|
|
|
|
|
|
// copy device memory to host
|
|
|
|
result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost);
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
//compare with reference code
|
2019-11-20 08:55:34 +08:00
|
|
|
result = strided_batched_gemm_nn_reference(m, n, k, alpha, ref_A, lda, batch_stride_A, ref_B, ldb, batch_stride_B, ref_C, ldc, batch_stride_C,
|
|
|
|
beta, batch_count);
|
2018-09-19 07:58:03 +08:00
|
|
|
if (result != 0)
|
|
|
|
return result;
|
|
|
|
|
2019-11-20 08:55:34 +08:00
|
|
|
// Expect bit-level accuracy for this simple example
|
2018-09-19 07:58:03 +08:00
|
|
|
if (ref_C != result_C) {
|
2022-02-18 09:01:05 +08:00
|
|
|
std::cout << "CUTLASS " << gemm_desc << " gemm does not run correctly" << std::endl;
|
2018-09-19 07:58:03 +08:00
|
|
|
return cudaErrorUnknown;
|
|
|
|
}
|
|
|
|
|
|
|
|
// free memory
|
|
|
|
result = cudaFree(A);
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "cudaFree result = " << result << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
result = cudaFree(B);
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "cudaFree result = " << result << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
result = cudaFree(C);
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "cudaFree result = " << result << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
2022-02-18 09:01:05 +08:00
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
int main() {
|
2018-09-19 07:58:03 +08:00
|
|
|
|
2022-02-18 09:01:05 +08:00
|
|
|
cudaError_t result = cudaSuccess;
|
|
|
|
for (bool use_array : {false, true}) {
|
|
|
|
result = run_batched_gemm(use_array);
|
|
|
|
if (result == cudaSuccess) {
|
|
|
|
std::cout << "Passed." << std::endl;
|
|
|
|
} else {
|
|
|
|
break;
|
|
|
|
}
|
2018-09-19 07:58:03 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Exit.
|
|
|
|
return result == cudaSuccess ? 0 : -1;
|
|
|
|
}
|