Make cutlass::gemm::device::GemmArray usable (#295)
* Fix the build of cutlass/gemm/device/gemm_array.h and add a demo for GemmArray * Add a reference to GemmArray to the docs Co-authored-by: Ivan Komarov <dfyz@yandex-team.ru>
This commit is contained in:
parent
3cfa5db2a2
commit
e96f00586c
@ -28,12 +28,16 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/gemm/device/gemm_array.h"
|
||||
#include "cutlass/gemm/device/gemm_batched.h"
|
||||
|
||||
#pragma warning( disable : 4503)
|
||||
|
||||
/*
|
||||
This example demonstrates how to use cutlass to compute a batched strided gemm.
|
||||
This example demonstrates how to use cutlass to compute a batched gemm in two different ways:
|
||||
1. By specifying pointers to the first matrices of the batch and the stride between the consecutive
|
||||
matrices of the batch (this is called a strided batched gemm).
|
||||
2. By copying pointers to all matrices of the batch to the device memory (this is called an array gemm).
|
||||
In this example, both A and B matrix are non-transpose and column major matrix
|
||||
batched_C = batched_A x batched_B
|
||||
As an example, matrix C can be seen as
|
||||
@ -89,6 +93,45 @@ The stride (batch_stride_C) between the first element of two batches is k
|
||||
|
||||
*/
|
||||
|
||||
cudaError_t cutlass_array_sgemm(
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
float alpha,
|
||||
float const * const *A,
|
||||
int lda,
|
||||
float const * const *B,
|
||||
int ldb,
|
||||
float * const *C,
|
||||
int ldc,
|
||||
float beta,
|
||||
int batch_count) {
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmArray<
|
||||
float, cutlass::layout::ColumnMajor,
|
||||
float, cutlass::layout::ColumnMajor,
|
||||
float, cutlass::layout::ColumnMajor
|
||||
>;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
cutlass::Status status = gemm_op({
|
||||
{m, n, k},
|
||||
A, lda,
|
||||
B, ldb,
|
||||
C, ldc,
|
||||
C, ldc,
|
||||
{alpha, beta},
|
||||
batch_count
|
||||
});
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return cudaErrorUnknown;
|
||||
}
|
||||
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
cudaError_t cutlass_strided_batched_sgemm(
|
||||
int m,
|
||||
int n,
|
||||
@ -188,7 +231,10 @@ cudaError_t strided_batched_gemm_nn_reference(
|
||||
return result;
|
||||
}
|
||||
|
||||
int main() {
|
||||
cudaError_t run_batched_gemm(bool use_array) {
|
||||
|
||||
const char* gemm_desc = use_array ? "array" : "strided batched";
|
||||
std::cout << "Running " << gemm_desc << " gemm" << std::endl;
|
||||
|
||||
// Arbitrary problem size
|
||||
int const m = 520;
|
||||
@ -293,11 +339,69 @@ int main() {
|
||||
}
|
||||
|
||||
// run cutlass
|
||||
result = cutlass_strided_batched_sgemm(
|
||||
m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C,
|
||||
beta, batch_count);
|
||||
if (result != cudaSuccess)
|
||||
return result;
|
||||
if (use_array) {
|
||||
// allocate the host memory for the pointers to the matrices of the batch
|
||||
std::vector<float*> host_ptr_A(batch_count);
|
||||
std::vector<float*> host_ptr_B(batch_count);
|
||||
std::vector<float*> host_ptr_C(batch_count);
|
||||
|
||||
// permute the batch elements to emphasize that GemmArray does not depend on matrices being separated by a fixed stride
|
||||
std::vector<size_t> permutation = {14, 11, 3, 10, 1, 13, 9, 4, 6, 16, 8, 15, 7, 12, 0, 2, 5};
|
||||
for (size_t b_idx = 0; b_idx < batch_count; b_idx++) {
|
||||
host_ptr_A[b_idx] = A + permutation[b_idx] * batch_stride_A;
|
||||
host_ptr_B[b_idx] = B + permutation[b_idx] * batch_stride_B;
|
||||
host_ptr_C[b_idx] = C + permutation[b_idx] * batch_stride_C;
|
||||
}
|
||||
|
||||
// allocate the corresponding device memory
|
||||
float const **ptr_A;
|
||||
float const **ptr_B;
|
||||
float **ptr_C;
|
||||
|
||||
result = cudaMalloc(&ptr_A, batch_count * sizeof(float*));
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMalloc result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMalloc(&ptr_B, batch_count * sizeof(float*));
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMalloc result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMalloc(&ptr_C, batch_count * sizeof(float*));
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMalloc result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// copy the matrix pointers to the device
|
||||
result = cudaMemcpy(ptr_A, host_ptr_A.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMemcpy(ptr_B, host_ptr_B.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMemcpy(ptr_C, host_ptr_C.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
result = cutlass_array_sgemm(m, n, k, alpha, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, beta, batch_count);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
return result;
|
||||
} else {
|
||||
result = cutlass_strided_batched_sgemm(
|
||||
m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C,
|
||||
beta, batch_count);
|
||||
if (result != cudaSuccess)
|
||||
return result;
|
||||
}
|
||||
|
||||
// copy device memory to host
|
||||
result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
@ -314,7 +418,7 @@ int main() {
|
||||
|
||||
// Expect bit-level accuracy for this simple example
|
||||
if (ref_C != result_C) {
|
||||
std::cout << "CUTLASS strided batched gemm does not run correctly" << std::endl;
|
||||
std::cout << "CUTLASS " << gemm_desc << " gemm does not run correctly" << std::endl;
|
||||
return cudaErrorUnknown;
|
||||
}
|
||||
|
||||
@ -335,9 +439,19 @@ int main() {
|
||||
return result;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
std::cout << "Passed." << std::endl;
|
||||
int main() {
|
||||
|
||||
cudaError_t result = cudaSuccess;
|
||||
for (bool use_array : {false, true}) {
|
||||
result = run_batched_gemm(use_array);
|
||||
if (result == cudaSuccess) {
|
||||
std::cout << "Passed." << std::endl;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Exit.
|
||||
|
@ -376,8 +376,8 @@ public:
|
||||
|
||||
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
args.batch_count,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK});
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.batch_count);
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename GemmKernel::Params{
|
||||
|
@ -81,6 +81,7 @@ has semantics similar to cuBLAS.
|
||||
|
||||
The device-wide GEMM API is embodied by the following operators:
|
||||
- [cutlass::gemm::device::Gemm](/include/cutlass/gemm/device/gemm.h) - basic GEMM operation
|
||||
- [cutlass::gemm::device::GemmArray](/include/cutlass/gemm/device/gemm_array.h) - batched GEMM operation in which input matrices are read from arrays of pointers
|
||||
- [cutlass::gemm::device::GemmBatched](/include/cutlass/gemm/device/gemm_batched.h) - batched GEMM operation in which input matrices are separated by a constant stride
|
||||
- [cutlass::gemm::device::GemmSplitKParallel](/include/cutlass/gemm/device/gemm_splitk_parallel.h) - GEMM operation that partitions the GEMM K dimension then launches a separate reduction kernel
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user