diff --git a/examples/05_batched_gemm/batched_gemm.cu b/examples/05_batched_gemm/batched_gemm.cu index 1d41932c..e39d405b 100644 --- a/examples/05_batched_gemm/batched_gemm.cu +++ b/examples/05_batched_gemm/batched_gemm.cu @@ -28,12 +28,16 @@ #include "cutlass/cutlass.h" #include "cutlass/layout/matrix.h" +#include "cutlass/gemm/device/gemm_array.h" #include "cutlass/gemm/device/gemm_batched.h" #pragma warning( disable : 4503) /* -This example demonstrates how to use cutlass to compute a batched strided gemm. +This example demonstrates how to use cutlass to compute a batched gemm in two different ways: + 1. By specifying pointers to the first matrices of the batch and the stride between the consecutive + matrices of the batch (this is called a strided batched gemm). + 2. By copying pointers to all matrices of the batch to the device memory (this is called an array gemm). In this example, both A and B matrix are non-transpose and column major matrix batched_C = batched_A x batched_B As an example, matrix C can be seen as @@ -89,6 +93,45 @@ The stride (batch_stride_C) between the first element of two batches is k */ +cudaError_t cutlass_array_sgemm( + int m, + int n, + int k, + float alpha, + float const * const *A, + int lda, + float const * const *B, + int ldb, + float * const *C, + int ldc, + float beta, + int batch_count) { + + using Gemm = cutlass::gemm::device::GemmArray< + float, cutlass::layout::ColumnMajor, + float, cutlass::layout::ColumnMajor, + float, cutlass::layout::ColumnMajor + >; + + Gemm gemm_op; + + cutlass::Status status = gemm_op({ + {m, n, k}, + A, lda, + B, ldb, + C, ldc, + C, ldc, + {alpha, beta}, + batch_count + }); + + if (status != cutlass::Status::kSuccess) { + return cudaErrorUnknown; + } + + return cudaSuccess; +} + cudaError_t cutlass_strided_batched_sgemm( int m, int n, @@ -188,7 +231,10 @@ cudaError_t strided_batched_gemm_nn_reference( return result; } -int main() { +cudaError_t run_batched_gemm(bool use_array) { + + const char* gemm_desc = use_array ? "array" : "strided batched"; + std::cout << "Running " << gemm_desc << " gemm" << std::endl; // Arbitrary problem size int const m = 520; @@ -293,11 +339,69 @@ int main() { } // run cutlass - result = cutlass_strided_batched_sgemm( - m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C, - beta, batch_count); - if (result != cudaSuccess) - return result; + if (use_array) { + // allocate the host memory for the pointers to the matrices of the batch + std::vector host_ptr_A(batch_count); + std::vector host_ptr_B(batch_count); + std::vector host_ptr_C(batch_count); + + // permute the batch elements to emphasize that GemmArray does not depend on matrices being separated by a fixed stride + std::vector permutation = {14, 11, 3, 10, 1, 13, 9, 4, 6, 16, 8, 15, 7, 12, 0, 2, 5}; + for (size_t b_idx = 0; b_idx < batch_count; b_idx++) { + host_ptr_A[b_idx] = A + permutation[b_idx] * batch_stride_A; + host_ptr_B[b_idx] = B + permutation[b_idx] * batch_stride_B; + host_ptr_C[b_idx] = C + permutation[b_idx] * batch_stride_C; + } + + // allocate the corresponding device memory + float const **ptr_A; + float const **ptr_B; + float **ptr_C; + + result = cudaMalloc(&ptr_A, batch_count * sizeof(float*)); + if (result != cudaSuccess) { + std::cerr << "cudaMalloc result = " << result << std::endl; + return result; + } + result = cudaMalloc(&ptr_B, batch_count * sizeof(float*)); + if (result != cudaSuccess) { + std::cerr << "cudaMalloc result = " << result << std::endl; + return result; + } + result = cudaMalloc(&ptr_C, batch_count * sizeof(float*)); + if (result != cudaSuccess) { + std::cerr << "cudaMalloc result = " << result << std::endl; + return result; + } + + // copy the matrix pointers to the device + result = cudaMemcpy(ptr_A, host_ptr_A.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice); + if (result != cudaSuccess) { + std::cerr << "cudaMemcpy result = " << result << std::endl; + return result; + } + result = cudaMemcpy(ptr_B, host_ptr_B.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice); + if (result != cudaSuccess) { + std::cerr << "cudaMemcpy result = " << result << std::endl; + return result; + } + result = cudaMemcpy(ptr_C, host_ptr_C.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice); + if (result != cudaSuccess) { + std::cerr << "cudaMemcpy result = " << result << std::endl; + return result; + } + + result = cutlass_array_sgemm(m, n, k, alpha, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, beta, batch_count); + + if (result != cudaSuccess) + return result; + } else { + result = cutlass_strided_batched_sgemm( + m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C, + beta, batch_count); + if (result != cudaSuccess) + return result; + } // copy device memory to host result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost); @@ -314,7 +418,7 @@ int main() { // Expect bit-level accuracy for this simple example if (ref_C != result_C) { - std::cout << "CUTLASS strided batched gemm does not run correctly" << std::endl; + std::cout << "CUTLASS " << gemm_desc << " gemm does not run correctly" << std::endl; return cudaErrorUnknown; } @@ -335,9 +439,19 @@ int main() { return result; } + return result; +} - if (result == cudaSuccess) { - std::cout << "Passed." << std::endl; +int main() { + + cudaError_t result = cudaSuccess; + for (bool use_array : {false, true}) { + result = run_batched_gemm(use_array); + if (result == cudaSuccess) { + std::cout << "Passed." << std::endl; + } else { + break; + } } // Exit. diff --git a/include/cutlass/gemm/device/gemm_array.h b/include/cutlass/gemm/device/gemm_array.h index e85f4591..fa7e8e59 100644 --- a/include/cutlass/gemm/device/gemm_array.h +++ b/include/cutlass/gemm/device/gemm_array.h @@ -376,8 +376,8 @@ public: cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( args.problem_size, - args.batch_count, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}); + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); // Initialize the Params structure params_ = typename GemmKernel::Params{ diff --git a/media/docs/gemm_api.md b/media/docs/gemm_api.md index e5b54bf6..2bc9ffac 100644 --- a/media/docs/gemm_api.md +++ b/media/docs/gemm_api.md @@ -81,6 +81,7 @@ has semantics similar to cuBLAS. The device-wide GEMM API is embodied by the following operators: - [cutlass::gemm::device::Gemm](/include/cutlass/gemm/device/gemm.h) - basic GEMM operation +- [cutlass::gemm::device::GemmArray](/include/cutlass/gemm/device/gemm_array.h) - batched GEMM operation in which input matrices are read from arrays of pointers - [cutlass::gemm::device::GemmBatched](/include/cutlass/gemm/device/gemm_batched.h) - batched GEMM operation in which input matrices are separated by a constant stride - [cutlass::gemm::device::GemmSplitKParallel](/include/cutlass/gemm/device/gemm_splitk_parallel.h) - GEMM operation that partitions the GEMM K dimension then launches a separate reduction kernel