Make cutlass::gemm::device::GemmArray usable (#295)

* Fix the build of cutlass/gemm/device/gemm_array.h and add a demo for GemmArray

* Add a reference to GemmArray to the docs

Co-authored-by: Ivan Komarov <dfyz@yandex-team.ru>
This commit is contained in:
Ivan Komarov 2022-02-18 04:01:05 +03:00 committed by GitHub
parent 3cfa5db2a2
commit e96f00586c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 127 additions and 12 deletions

View File

@ -28,12 +28,16 @@
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h" #include "cutlass/layout/matrix.h"
#include "cutlass/gemm/device/gemm_array.h"
#include "cutlass/gemm/device/gemm_batched.h" #include "cutlass/gemm/device/gemm_batched.h"
#pragma warning( disable : 4503) #pragma warning( disable : 4503)
/* /*
This example demonstrates how to use cutlass to compute a batched strided gemm. This example demonstrates how to use cutlass to compute a batched gemm in two different ways:
1. By specifying pointers to the first matrices of the batch and the stride between the consecutive
matrices of the batch (this is called a strided batched gemm).
2. By copying pointers to all matrices of the batch to the device memory (this is called an array gemm).
In this example, both A and B matrix are non-transpose and column major matrix In this example, both A and B matrix are non-transpose and column major matrix
batched_C = batched_A x batched_B batched_C = batched_A x batched_B
As an example, matrix C can be seen as As an example, matrix C can be seen as
@ -89,6 +93,45 @@ The stride (batch_stride_C) between the first element of two batches is k
*/ */
cudaError_t cutlass_array_sgemm(
int m,
int n,
int k,
float alpha,
float const * const *A,
int lda,
float const * const *B,
int ldb,
float * const *C,
int ldc,
float beta,
int batch_count) {
using Gemm = cutlass::gemm::device::GemmArray<
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::ColumnMajor
>;
Gemm gemm_op;
cutlass::Status status = gemm_op({
{m, n, k},
A, lda,
B, ldb,
C, ldc,
C, ldc,
{alpha, beta},
batch_count
});
if (status != cutlass::Status::kSuccess) {
return cudaErrorUnknown;
}
return cudaSuccess;
}
cudaError_t cutlass_strided_batched_sgemm( cudaError_t cutlass_strided_batched_sgemm(
int m, int m,
int n, int n,
@ -188,7 +231,10 @@ cudaError_t strided_batched_gemm_nn_reference(
return result; return result;
} }
int main() { cudaError_t run_batched_gemm(bool use_array) {
const char* gemm_desc = use_array ? "array" : "strided batched";
std::cout << "Running " << gemm_desc << " gemm" << std::endl;
// Arbitrary problem size // Arbitrary problem size
int const m = 520; int const m = 520;
@ -293,11 +339,69 @@ int main() {
} }
// run cutlass // run cutlass
result = cutlass_strided_batched_sgemm( if (use_array) {
m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C, // allocate the host memory for the pointers to the matrices of the batch
beta, batch_count); std::vector<float*> host_ptr_A(batch_count);
if (result != cudaSuccess) std::vector<float*> host_ptr_B(batch_count);
return result; std::vector<float*> host_ptr_C(batch_count);
// permute the batch elements to emphasize that GemmArray does not depend on matrices being separated by a fixed stride
std::vector<size_t> permutation = {14, 11, 3, 10, 1, 13, 9, 4, 6, 16, 8, 15, 7, 12, 0, 2, 5};
for (size_t b_idx = 0; b_idx < batch_count; b_idx++) {
host_ptr_A[b_idx] = A + permutation[b_idx] * batch_stride_A;
host_ptr_B[b_idx] = B + permutation[b_idx] * batch_stride_B;
host_ptr_C[b_idx] = C + permutation[b_idx] * batch_stride_C;
}
// allocate the corresponding device memory
float const **ptr_A;
float const **ptr_B;
float **ptr_C;
result = cudaMalloc(&ptr_A, batch_count * sizeof(float*));
if (result != cudaSuccess) {
std::cerr << "cudaMalloc result = " << result << std::endl;
return result;
}
result = cudaMalloc(&ptr_B, batch_count * sizeof(float*));
if (result != cudaSuccess) {
std::cerr << "cudaMalloc result = " << result << std::endl;
return result;
}
result = cudaMalloc(&ptr_C, batch_count * sizeof(float*));
if (result != cudaSuccess) {
std::cerr << "cudaMalloc result = " << result << std::endl;
return result;
}
// copy the matrix pointers to the device
result = cudaMemcpy(ptr_A, host_ptr_A.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
if (result != cudaSuccess) {
std::cerr << "cudaMemcpy result = " << result << std::endl;
return result;
}
result = cudaMemcpy(ptr_B, host_ptr_B.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
if (result != cudaSuccess) {
std::cerr << "cudaMemcpy result = " << result << std::endl;
return result;
}
result = cudaMemcpy(ptr_C, host_ptr_C.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
if (result != cudaSuccess) {
std::cerr << "cudaMemcpy result = " << result << std::endl;
return result;
}
result = cutlass_array_sgemm(m, n, k, alpha, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, beta, batch_count);
if (result != cudaSuccess)
return result;
} else {
result = cutlass_strided_batched_sgemm(
m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C,
beta, batch_count);
if (result != cudaSuccess)
return result;
}
// copy device memory to host // copy device memory to host
result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost); result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost);
@ -314,7 +418,7 @@ int main() {
// Expect bit-level accuracy for this simple example // Expect bit-level accuracy for this simple example
if (ref_C != result_C) { if (ref_C != result_C) {
std::cout << "CUTLASS strided batched gemm does not run correctly" << std::endl; std::cout << "CUTLASS " << gemm_desc << " gemm does not run correctly" << std::endl;
return cudaErrorUnknown; return cudaErrorUnknown;
} }
@ -335,9 +439,19 @@ int main() {
return result; return result;
} }
return result;
}
if (result == cudaSuccess) { int main() {
std::cout << "Passed." << std::endl;
cudaError_t result = cudaSuccess;
for (bool use_array : {false, true}) {
result = run_batched_gemm(use_array);
if (result == cudaSuccess) {
std::cout << "Passed." << std::endl;
} else {
break;
}
} }
// Exit. // Exit.

View File

@ -376,8 +376,8 @@ public:
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size, args.problem_size,
args.batch_count, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}); args.batch_count);
// Initialize the Params structure // Initialize the Params structure
params_ = typename GemmKernel::Params{ params_ = typename GemmKernel::Params{

View File

@ -81,6 +81,7 @@ has semantics similar to cuBLAS.
The device-wide GEMM API is embodied by the following operators: The device-wide GEMM API is embodied by the following operators:
- [cutlass::gemm::device::Gemm](/include/cutlass/gemm/device/gemm.h) - basic GEMM operation - [cutlass::gemm::device::Gemm](/include/cutlass/gemm/device/gemm.h) - basic GEMM operation
- [cutlass::gemm::device::GemmArray](/include/cutlass/gemm/device/gemm_array.h) - batched GEMM operation in which input matrices are read from arrays of pointers
- [cutlass::gemm::device::GemmBatched](/include/cutlass/gemm/device/gemm_batched.h) - batched GEMM operation in which input matrices are separated by a constant stride - [cutlass::gemm::device::GemmBatched](/include/cutlass/gemm/device/gemm_batched.h) - batched GEMM operation in which input matrices are separated by a constant stride
- [cutlass::gemm::device::GemmSplitKParallel](/include/cutlass/gemm/device/gemm_splitk_parallel.h) - GEMM operation that partitions the GEMM K dimension then launches a separate reduction kernel - [cutlass::gemm::device::GemmSplitKParallel](/include/cutlass/gemm/device/gemm_splitk_parallel.h) - GEMM operation that partitions the GEMM K dimension then launches a separate reduction kernel