 e96f00586c
			
		
	
	
		e96f00586c
		
			
		
	
	
	
	
		
			
			* Fix the build of cutlass/gemm/device/gemm_array.h and add a demo for GemmArray * Add a reference to GemmArray to the docs Co-authored-by: Ivan Komarov <dfyz@yandex-team.ru>
		
			
				
	
	
		
			460 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			460 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| /***************************************************************************************************
 | |
|  * Copyright (c) 2017-2021, NVIDIA CORPORATION.  All rights reserved.
 | |
|  *
 | |
|  * Redistribution and use in source and binary forms, with or without modification, are permitted
 | |
|  * provided that the following conditions are met:
 | |
|  *     * Redistributions of source code must retain the above copyright notice, this list of
 | |
|  *       conditions and the following disclaimer.
 | |
|  *     * Redistributions in binary form must reproduce the above copyright notice, this list of
 | |
|  *       conditions and the following disclaimer in the documentation and/or other materials
 | |
|  *       provided with the distribution.
 | |
|  *     * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
 | |
|  *       to endorse or promote products derived from this software without specific prior written
 | |
|  *       permission.
 | |
|  *
 | |
|  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
 | |
|  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 | |
|  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
 | |
|  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 | |
|  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
 | |
|  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
 | |
|  * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | |
|  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
|  *
 | |
|  **************************************************************************************************/
 | |
| 
 | |
| #include <iostream>
 | |
| #include <vector>
 | |
| 
 | |
| #include "cutlass/cutlass.h"
 | |
| #include "cutlass/layout/matrix.h"
 | |
| #include "cutlass/gemm/device/gemm_array.h"
 | |
| #include "cutlass/gemm/device/gemm_batched.h"
 | |
| 
 | |
| #pragma warning( disable : 4503)
 | |
| 
 | |
| /*
 | |
| This example demonstrates how to use cutlass to compute a batched gemm in two different ways:
 | |
|   1. By specifying pointers to the first matrices of the batch and the stride between the consecutive
 | |
|      matrices of the batch (this is called a strided batched gemm).
 | |
|   2. By copying pointers to all matrices of the batch to the device memory (this is called an array gemm).
 | |
| In this example, both A and B matrix are non-transpose and column major matrix
 | |
| batched_C = batched_A x batched_B
 | |
| As an example, matrix C can be seen as
 | |
| -----------------------------------------------------------
 | |
| (0,0,0) | (0,0,1) | (0,0,2) | (1,0,0) | (1,0,1) | (1,0,2) |
 | |
| -----------------------------------------------------------
 | |
| (0,1,0) | (0,1,1) | (0,1,2) | (1,1,0) | (1,1,1) | (1,1,2) |
 | |
| -----------------------------------------------------------
 | |
| (0,2,0) | (0,2,1) | (0,2,2) | (1,2,0) | (1,2,1) | (1,2,2) |
 | |
| -----------------------------------------------------------
 | |
| (0,3,0) | (0,3,1) | (0,3,2) | (1,3,0) | (1,3,1) | (1,3,2) |
 | |
| -----------------------------------------------------------
 | |
| (0,4,0) | (0,4,1) | (0,4,2) | (1,4,0) | (1,4,1) | (1,4,2) |
 | |
| -----------------------------------------------------------
 | |
| (0,5,0) | (0,5,1) | (0,5,2) | (1,5,0) | (1,5,1) | (1,5,2) |
 | |
| -----------------------------------------------------------
 | |
|            batch 0          |           batch 1
 | |
| where we denote each element with (batch_idx, row_idx, column_idx)
 | |
| In this example, batch size is 2, M is 6 and N is 3
 | |
| The stride (batch_stride_C) between the first element of two batches is ldc * n
 | |
| 
 | |
| matrix A can be seen as
 | |
| ---------------------------------------
 | |
| (0,0,0) | (0,0,1) | (1,0,0) | (1,0,1) |
 | |
| ---------------------------------------
 | |
| (0,1,0) | (0,1,1) | (1,1,0) | (1,1,1) |
 | |
| ---------------------------------------
 | |
| (0,2,0) | (0,2,1) | (1,2,0) | (1,2,1) |
 | |
| ---------------------------------------
 | |
| (0,3,0) | (0,3,1) | (1,3,0) | (1,3,1) |
 | |
| ---------------------------------------
 | |
| (0,4,0) | (0,4,1) | (1,4,0) | (1,4,1) |
 | |
| ---------------------------------------
 | |
| (0,5,0) | (0,5,1) | (1,5,0) | (1,5,1) |
 | |
| ---------------------------------------
 | |
|      batch 0      |      batch 1
 | |
| , where batch size is 2, M is 6 and K is 2
 | |
| The stride (batch_stride_B) between the first element of two batches is lda * k
 | |
| 
 | |
| matrix B can be seen as
 | |
| -----------------------------
 | |
| (0,0,0) | (0,0,1) | (0,0,2) |
 | |
| ----------------------------- batch 0
 | |
| (0,1,0) | (0,1,1) | (0,1,2) |
 | |
| -------------------------------------
 | |
| (1,0,0) | (1,0,1) | (1,0,2) |
 | |
| ----------------------------- batch 1
 | |
| (1,1,0) | (1,1,1) | (1,1,2) |
 | |
| -----------------------------
 | |
| , where the batch size is 2, N is 3 and K is 2
 | |
| The stride (batch_stride_C) between the first element of two batches is k
 | |
| 
 | |
| 
 | |
| */
 | |
| 
 | |
| cudaError_t cutlass_array_sgemm(
 | |
|   int m,
 | |
|   int n,
 | |
|   int k,
 | |
|   float alpha,
 | |
|   float const * const *A,
 | |
|   int lda,
 | |
|   float const * const *B,
 | |
|   int ldb,
 | |
|   float * const *C,
 | |
|   int ldc,
 | |
|   float beta,
 | |
|   int batch_count) {
 | |
| 
 | |
|   using Gemm = cutlass::gemm::device::GemmArray<
 | |
|     float, cutlass::layout::ColumnMajor,
 | |
|     float, cutlass::layout::ColumnMajor,
 | |
|     float, cutlass::layout::ColumnMajor
 | |
|   >;
 | |
| 
 | |
|   Gemm gemm_op;
 | |
| 
 | |
|   cutlass::Status status = gemm_op({
 | |
|     {m, n, k},
 | |
|     A, lda,
 | |
|     B, ldb,
 | |
|     C, ldc,
 | |
|     C, ldc,
 | |
|     {alpha, beta},
 | |
|     batch_count
 | |
|   });
 | |
| 
 | |
|   if (status != cutlass::Status::kSuccess) {
 | |
|     return cudaErrorUnknown;
 | |
|   }
 | |
| 
 | |
|   return cudaSuccess;
 | |
| }
 | |
| 
 | |
| cudaError_t cutlass_strided_batched_sgemm(
 | |
|   int m, 
 | |
|   int n,
 | |
|   int k,
 | |
|   float alpha,
 | |
|   float const *A,
 | |
|   int lda,
 | |
|   long long int batch_stride_A,
 | |
|   float const *B,
 | |
|   int ldb,
 | |
|   long long int batch_stride_B,
 | |
|   float *C,
 | |
|   int ldc,
 | |
|   long long int batch_stride_C,
 | |
|   float beta,
 | |
|   int batch_count) {
 | |
| 
 | |
|   using Gemm = cutlass::gemm::device::GemmBatched<
 | |
|     float, cutlass::layout::ColumnMajor,
 | |
|     float, cutlass::layout::ColumnMajor,
 | |
|     float, cutlass::layout::ColumnMajor
 | |
|   >;
 | |
| 
 | |
|   Gemm gemm_op;
 | |
| 
 | |
|   cutlass::Status status = gemm_op({
 | |
|     {m, n, k},
 | |
|     {A, lda}, 
 | |
|     batch_stride_A,
 | |
|     {B, ldb}, 
 | |
|     batch_stride_B,
 | |
|     {C, ldc}, 
 | |
|     batch_stride_C,
 | |
|     {C, ldc}, 
 | |
|     batch_stride_C,
 | |
|     {alpha, beta},
 | |
|     batch_count
 | |
|   });
 | |
| 
 | |
|   if (status != cutlass::Status::kSuccess) {
 | |
|     return cudaErrorUnknown;
 | |
|   }
 | |
| 
 | |
|   return cudaSuccess;
 | |
| }
 | |
| 
 | |
| template<typename T> 
 | |
| cudaError_t strided_batched_gemm_nn_reference(
 | |
|   int m,
 | |
|   int n,
 | |
|   int k,
 | |
|   T alpha,
 | |
|   std::vector<T> const &A, 
 | |
|   int lda,
 | |
|   long long int batch_stride_A,
 | |
|   std::vector<T> const &B, 
 | |
|   int ldb,
 | |
|   long long int batch_stride_B,
 | |
|   std::vector<T> &C, 
 | |
|   int ldc,
 | |
|   long long int batch_stride_C,
 | |
|   T beta,
 | |
|   int batch_count) {
 | |
|   /*
 | |
|   strided batched gemm NN
 | |
|   */
 | |
|   
 | |
|   cudaError_t result = cudaSuccess;
 | |
| 
 | |
|   if (A.size() < lda * k * batch_count) {
 | |
|     std::cout << "the size of A is too small" << std::endl;
 | |
|     return cudaErrorInvalidValue;
 | |
|   }
 | |
|   if (B.size() < ldb * n) {
 | |
|     std::cout << "the size of B is too small" << std::endl;
 | |
|     return cudaErrorInvalidValue;
 | |
|   }
 | |
|   if (C.size() < ldc * n * batch_count) {
 | |
|     std::cout << "the size of C is too small" << std::endl;
 | |
|     return cudaErrorInvalidValue;
 | |
|   }
 | |
|   
 | |
|   for (int batch_idx = 0; batch_idx < batch_count; batch_idx++) {
 | |
|     for (int n_idx = 0; n_idx < n; n_idx++) {
 | |
|       for (int m_idx = 0; m_idx < m; m_idx++) {
 | |
|         T accum = beta * C[batch_idx * batch_stride_C + n_idx * ldc + m_idx];
 | |
|         for (int k_idx = 0; k_idx < k; k_idx++) {
 | |
|           accum += alpha 
 | |
|             * A[batch_idx * batch_stride_A + k_idx * lda + m_idx]
 | |
|             * B[batch_idx * batch_stride_B + n_idx * ldb + k_idx];
 | |
|         }
 | |
|         C[batch_idx * batch_stride_C + n_idx * ldc + m_idx] = accum;
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   return result;
 | |
| }
 | |
| 
 | |
| cudaError_t run_batched_gemm(bool use_array) {
 | |
| 
 | |
|   const char* gemm_desc = use_array ? "array" : "strided batched";
 | |
|   std::cout << "Running " << gemm_desc << " gemm" << std::endl;
 | |
| 
 | |
|   // Arbitrary problem size
 | |
|   int const m = 520;
 | |
|   int const n = 219;
 | |
|   int const k = 129;
 | |
|   int const batch_count = 17;
 | |
| 
 | |
|   // A, B are non-transpose, column major
 | |
|   int const lda = m;
 | |
|   int const ldb = k * batch_count;
 | |
|   int const ldc = m;
 | |
| 
 | |
|   int const count_A = batch_count * lda * k;
 | |
|   int const count_B = ldb * n;
 | |
|   int const count_C = batch_count * ldc * n;
 | |
| 
 | |
|   // the memory is batched along K dimension
 | |
|   long long int batch_stride_A = static_cast<long long int>(lda) * static_cast<long long int>(k);
 | |
|   long long int batch_stride_B = static_cast<long long int>(k);
 | |
|   long long int batch_stride_C = static_cast<long long int>(ldc) * static_cast<long long int>(n);
 | |
| 
 | |
|   // alpha and beta
 | |
|   float alpha = 1.0f;
 | |
|   float beta = 2.0f;
 | |
| 
 | |
|   cudaError_t result = cudaSuccess;
 | |
| 
 | |
|   // allocate the host memory
 | |
|   std::vector<float> host_A(count_A);
 | |
|   std::vector<float> host_B(count_B);
 | |
|   std::vector<float> host_C(count_C);
 | |
|   std::vector<float> result_C(count_C);
 | |
| 
 | |
|   // allocate the device memory
 | |
|   float *A;
 | |
|   float *B;
 | |
|   float *C;
 | |
| 
 | |
|   result = cudaMalloc(&A, count_A * sizeof(float));
 | |
|   if (result != cudaSuccess) {
 | |
|     std::cerr << "cudaMalloc result = " << result << std::endl;
 | |
|     return result;
 | |
|   }
 | |
|   result = cudaMalloc(&B, count_B * sizeof(float));
 | |
|   if (result != cudaSuccess) {
 | |
|     std::cerr << "cudaMalloc result = " << result << std::endl;
 | |
|     return result;
 | |
|   }
 | |
|   result = cudaMalloc(&C, count_C * sizeof(float));
 | |
|   if (result != cudaSuccess) {
 | |
|     std::cerr << "cudaMalloc result = " << result << std::endl;
 | |
|     return result;
 | |
|   }
 | |
| 
 | |
|   // Limit range to avoid floating-point errors
 | |
|   int const kRange = 8;
 | |
| 
 | |
|   // fill A
 | |
|   for (int b_idx = 0; b_idx < batch_count; b_idx++) {
 | |
|     for (int col_idx = 0; col_idx < k; col_idx++) {
 | |
|       for (int row_idx = 0; row_idx < m; row_idx++) {
 | |
|         host_A[row_idx + col_idx * lda + b_idx * lda * k] = static_cast<float>((row_idx + col_idx * lda + b_idx * lda * k) % kRange);
 | |
|       }
 | |
|     }
 | |
|   }
 | |
|   // fill B
 | |
|   for (int b_idx = 0; b_idx < batch_count; b_idx++) {
 | |
|     for (int col_idx = 0; col_idx < n; col_idx++) {
 | |
|       for (int row_idx = 0; row_idx < k; row_idx++) {
 | |
|         host_B[row_idx + col_idx * ldb + b_idx * k] = static_cast<float>(((n + k * ldb + batch_count * k) - (row_idx + col_idx * ldb + b_idx * k)) % kRange);
 | |
|       }
 | |
|     }
 | |
|   }
 | |
|   // fill C
 | |
|   for (int b_idx = 0; b_idx < batch_count; b_idx++) {
 | |
|     for (int col_idx = 0; col_idx < n; col_idx++) {
 | |
|       for (int row_idx = 0; row_idx < m; row_idx++) {
 | |
|         host_C[row_idx + col_idx * ldc + b_idx * ldc * n] = 1.f;
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   // ref memory
 | |
|   std::vector<float> ref_A(host_A);
 | |
|   std::vector<float> ref_B(host_B);
 | |
|   std::vector<float> ref_C(host_C);
 | |
|   // copy host memory to device
 | |
|   result = cudaMemcpy(A, host_A.data(), count_A * sizeof(float), cudaMemcpyHostToDevice);
 | |
|   if (result != cudaSuccess) {
 | |
|     std::cerr << "cudaMemcpy result = " << result << std::endl;
 | |
|     return result;
 | |
|   }
 | |
|   result = cudaMemcpy(B, host_B.data(), count_B * sizeof(float), cudaMemcpyHostToDevice);
 | |
|   if (result != cudaSuccess) {
 | |
|     std::cerr << "cudaMemcpy result = " << result << std::endl;
 | |
|     return result;
 | |
|   }
 | |
|   result = cudaMemcpy(C, host_C.data(), count_C * sizeof(float), cudaMemcpyHostToDevice);
 | |
|   if (result != cudaSuccess) {
 | |
|     std::cerr << "cudaMemcpy result = " << result << std::endl;
 | |
|     return result;
 | |
|   }
 | |
| 
 | |
|   // run cutlass
 | |
|   if (use_array) {
 | |
|     // allocate the host memory for the pointers to the matrices of the batch
 | |
|     std::vector<float*> host_ptr_A(batch_count);
 | |
|     std::vector<float*> host_ptr_B(batch_count);
 | |
|     std::vector<float*> host_ptr_C(batch_count);
 | |
| 
 | |
|     // permute the batch elements to emphasize that GemmArray does not depend on matrices being separated by a fixed stride
 | |
|     std::vector<size_t> permutation = {14, 11, 3, 10, 1, 13, 9, 4, 6, 16, 8, 15, 7, 12, 0, 2, 5};
 | |
|     for (size_t b_idx = 0; b_idx < batch_count; b_idx++) {
 | |
|       host_ptr_A[b_idx] = A + permutation[b_idx] * batch_stride_A;
 | |
|       host_ptr_B[b_idx] = B + permutation[b_idx] * batch_stride_B;
 | |
|       host_ptr_C[b_idx] = C + permutation[b_idx] * batch_stride_C;
 | |
|     }
 | |
| 
 | |
|     // allocate the corresponding device memory
 | |
|     float const **ptr_A;
 | |
|     float const **ptr_B;
 | |
|     float **ptr_C;
 | |
| 
 | |
|     result = cudaMalloc(&ptr_A, batch_count * sizeof(float*));
 | |
|     if (result != cudaSuccess) {
 | |
|       std::cerr << "cudaMalloc result = " << result << std::endl;
 | |
|       return result;
 | |
|     }
 | |
|     result = cudaMalloc(&ptr_B, batch_count * sizeof(float*));
 | |
|     if (result != cudaSuccess) {
 | |
|       std::cerr << "cudaMalloc result = " << result << std::endl;
 | |
|       return result;
 | |
|     }
 | |
|     result = cudaMalloc(&ptr_C, batch_count * sizeof(float*));
 | |
|     if (result != cudaSuccess) {
 | |
|       std::cerr << "cudaMalloc result = " << result << std::endl;
 | |
|       return result;
 | |
|     }
 | |
| 
 | |
|     // copy the matrix pointers to the device
 | |
|     result = cudaMemcpy(ptr_A, host_ptr_A.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
 | |
|     if (result != cudaSuccess) {
 | |
|       std::cerr << "cudaMemcpy result = " << result << std::endl;
 | |
|       return result;
 | |
|     }
 | |
|     result = cudaMemcpy(ptr_B, host_ptr_B.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
 | |
|     if (result != cudaSuccess) {
 | |
|       std::cerr << "cudaMemcpy result = " << result << std::endl;
 | |
|       return result;
 | |
|     }
 | |
|     result = cudaMemcpy(ptr_C, host_ptr_C.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
 | |
|     if (result != cudaSuccess) {
 | |
|       std::cerr << "cudaMemcpy result = " << result << std::endl;
 | |
|       return result;
 | |
|     }
 | |
| 
 | |
|     result = cutlass_array_sgemm(m, n, k, alpha, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, beta, batch_count);
 | |
| 
 | |
|     if (result != cudaSuccess)
 | |
|       return result;
 | |
|   } else {
 | |
|     result = cutlass_strided_batched_sgemm(
 | |
|       m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C,
 | |
|       beta, batch_count);
 | |
|     if (result != cudaSuccess)
 | |
|       return result;
 | |
|   }
 | |
| 
 | |
|   // copy device memory to host
 | |
|   result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost);
 | |
|   if (result != cudaSuccess) {
 | |
|     std::cerr << "cudaMemcpy result = " << result << std::endl;
 | |
|     return result;
 | |
|   }
 | |
| 
 | |
|   //compare with reference code
 | |
|   result = strided_batched_gemm_nn_reference(m, n, k, alpha, ref_A, lda, batch_stride_A, ref_B, ldb, batch_stride_B, ref_C, ldc, batch_stride_C,
 | |
|     beta, batch_count);
 | |
|   if (result != 0)
 | |
|     return result;
 | |
| 
 | |
|   // Expect bit-level accuracy for this simple example
 | |
|   if (ref_C != result_C) {
 | |
|     std::cout << "CUTLASS " << gemm_desc << " gemm does not run correctly" << std::endl;
 | |
|     return cudaErrorUnknown;
 | |
|   }
 | |
| 
 | |
|   // free memory
 | |
|   result = cudaFree(A);
 | |
|   if (result != cudaSuccess) {
 | |
|     std::cerr << "cudaFree result = " << result << std::endl;
 | |
|     return result;
 | |
|   }
 | |
|   result = cudaFree(B);
 | |
|   if (result != cudaSuccess) {
 | |
|     std::cerr << "cudaFree result = " << result << std::endl;
 | |
|     return result;
 | |
|   }
 | |
|   result = cudaFree(C);
 | |
|   if (result != cudaSuccess) {
 | |
|     std::cerr << "cudaFree result = " << result << std::endl;
 | |
|     return result;
 | |
|   }
 | |
| 
 | |
|   return result;
 | |
| }
 | |
| 
 | |
| int main() {
 | |
| 
 | |
|   cudaError_t result = cudaSuccess;
 | |
|   for (bool use_array : {false, true}) {
 | |
|     result = run_batched_gemm(use_array);
 | |
|     if (result == cudaSuccess) {
 | |
|       std::cout << "Passed." << std::endl;
 | |
|     } else {
 | |
|       break;
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   // Exit.
 | |
|   return result == cudaSuccess ? 0 : -1;
 | |
| }
 |