350 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
		
		
			
		
	
	
			350 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
|   | /*************************************************************************************************** | ||
|  |  * Copyright (c) 2017-2018, NVIDIA CORPORATION.  All rights reserved. | ||
|  |  * | ||
|  |  * Redistribution and use in source and binary forms, with or without modification, are permitted | ||
|  |  * provided that the following conditions are met: | ||
|  |  *     * Redistributions of source code must retain the above copyright notice, this list of | ||
|  |  *       conditions and the following disclaimer. | ||
|  |  *     * Redistributions in binary form must reproduce the above copyright notice, this list of | ||
|  |  *       conditions and the following disclaimer in the documentation and/or other materials | ||
|  |  *       provided with the distribution. | ||
|  |  *     * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used | ||
|  |  *       to endorse or promote products derived from this software without specific prior written | ||
|  |  *       permission. | ||
|  |  * | ||
|  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR | ||
|  |  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND | ||
|  |  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE | ||
|  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | ||
|  |  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; | ||
|  |  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, | ||
|  |  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
|  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|  |  * | ||
|  |  **************************************************************************************************/ | ||
|  | 
 | ||
|  | #include <iostream> | ||
|  | #include <vector> | ||
|  | #include "cutlass/cutlass.h" | ||
|  | #include "cutlass/gemm/gemm.h" | ||
|  | #include "cutlass/gemm/sgemm_traits.h" | ||
|  | 
 | ||
|  | /* | ||
|  | This example demonstrates how to use cutlass to compute a batched strided gemm. | ||
|  | In this example, both A and B matrix are non-transpose and column major matrix | ||
|  | batched_C = batched_A x batched_B | ||
|  | As an example, matrix C can be seen as | ||
|  | ----------------------------------------------------------- | ||
|  | (0,0,0) | (0,0,1) | (0,0,2) | (1,0,0) | (1,0,1) | (1,0,2) | | ||
|  | ----------------------------------------------------------- | ||
|  | (0,1,0) | (0,1,1) | (0,1,2) | (1,1,0) | (1,1,1) | (1,1,2) | | ||
|  | ----------------------------------------------------------- | ||
|  | (0,2,0) | (0,2,1) | (0,2,2) | (1,2,0) | (1,2,1) | (1,2,2) | | ||
|  | ----------------------------------------------------------- | ||
|  | (0,3,0) | (0,3,1) | (0,3,2) | (1,3,0) | (1,3,1) | (1,3,2) | | ||
|  | ----------------------------------------------------------- | ||
|  | (0,4,0) | (0,4,1) | (0,4,2) | (1,4,0) | (1,4,1) | (1,4,2) | | ||
|  | ----------------------------------------------------------- | ||
|  | (0,5,0) | (0,5,1) | (0,5,2) | (1,5,0) | (1,5,1) | (1,5,2) | | ||
|  | ----------------------------------------------------------- | ||
|  |            batch 0          |           batch 1 | ||
|  | where we denote each element with (batch_idx, row_idx, column_idx) | ||
|  | In this example, batch size is 2, M is 6 and N is 3 | ||
|  | The stride (batch_stride_C) between the first element of two batches is ldc * n | ||
|  | 
 | ||
|  | matrix A can be seen as | ||
|  | --------------------------------------- | ||
|  | (0,0,0) | (0,0,1) | (1,0,0) | (1,0,1) | | ||
|  | --------------------------------------- | ||
|  | (0,1,0) | (0,1,1) | (1,1,0) | (1,1,1) | | ||
|  | --------------------------------------- | ||
|  | (0,2,0) | (0,2,1) | (1,2,0) | (1,2,1) | | ||
|  | --------------------------------------- | ||
|  | (0,3,0) | (0,3,1) | (1,3,0) | (1,3,1) | | ||
|  | --------------------------------------- | ||
|  | (0,4,0) | (0,4,1) | (1,4,0) | (1,4,1) | | ||
|  | --------------------------------------- | ||
|  | (0,5,0) | (0,5,1) | (1,5,0) | (1,5,1) | | ||
|  | --------------------------------------- | ||
|  |      batch 0      |      batch 1 | ||
|  | , where batch size is 2, M is 6 and K is 2 | ||
|  | The stride (batch_stride_B) between the first element of two batches is lda * k | ||
|  | 
 | ||
|  | matrix B can be seen as | ||
|  | ----------------------------- | ||
|  | (0,0,0) | (0,0,1) | (0,0,2) | | ||
|  | ----------------------------- batch 0 | ||
|  | (0,1,0) | (0,1,1) | (0,1,2) | | ||
|  | ------------------------------------- | ||
|  | (1,0,0) | (1,0,1) | (1,0,2) | | ||
|  | ----------------------------- batch 1 | ||
|  | (1,1,0) | (1,1,1) | (1,1,2) | | ||
|  | ----------------------------- | ||
|  | , where the batch size is 2, N is 3 and K is 2 | ||
|  | The stride (batch_stride_C) between the first element of two batches is k | ||
|  | 
 | ||
|  | 
 | ||
|  | */ | ||
|  | 
 | ||
|  | cudaError_t cutlass_strided_batched_sgemm(float const *A, | ||
|  |   int lda, | ||
|  |   long long int batch_stride_A, | ||
|  |   float const *B, | ||
|  |   int ldb, | ||
|  |   long long int batch_stride_B, | ||
|  |   float *C, | ||
|  |   int ldc, | ||
|  |   long long int batch_stride_C, | ||
|  |   float alpha, | ||
|  |   float beta, | ||
|  |   int m,  | ||
|  |   int n, | ||
|  |   int k, | ||
|  |   int batch_count) { | ||
|  |   // create a cutlass traits | ||
|  |   typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor, | ||
|  |     cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> > | ||
|  |     SgemmTraits; | ||
|  |    | ||
|  |   // create a CUTLASS GEMM object. | ||
|  |   typedef cutlass::gemm::Gemm<SgemmTraits> Gemm; | ||
|  |    | ||
|  |   // Construct and initialize CUTLASS GEMM parameters object. | ||
|  |   typename Gemm::Params params; | ||
|  |    | ||
|  |   int result = params.initialize( | ||
|  |     m,                  // M dimension for each batch | ||
|  |     n,                  // N dimension for each batch | ||
|  |     k,                  // K dimension for each batch | ||
|  |     alpha,              // scalar alpha | ||
|  |     A, | ||
|  |     lda, | ||
|  |     batch_stride_A,     // distance in memory between the first element of neighboring batch | ||
|  |     B, | ||
|  |     ldb, | ||
|  |     batch_stride_B,     // distance in memory between the first element of neighboring batch | ||
|  |     beta,               // scalar beta | ||
|  |     C,                  // source matrix C | ||
|  |     ldc, | ||
|  |     batch_stride_C,     // distance in memory between the first element of neighboring batch | ||
|  |     C,                  // destination matrix C (may be different memory than source C matrix) | ||
|  |     ldc, | ||
|  |     batch_stride_C,    // distance in memory between the first element of neighboring batch | ||
|  |     batch_count | ||
|  |   ); | ||
|  |    | ||
|  |   if (result != 0) { | ||
|  |     std::cerr << "Failed to initialize CUTLASS Gemm::Params object." << std::endl; | ||
|  |     return cudaErrorInvalidValue; | ||
|  |   } | ||
|  |    | ||
|  |   // Launch the CUTLASS GEMM kernel. | ||
|  |   Gemm::launch(params); | ||
|  |   result = cudaDeviceSynchronize(); | ||
|  |   if (result != cudaSuccess) { | ||
|  |     std::cerr << "kernel launch result = " << result << std::endl; | ||
|  |   } | ||
|  |   return cudaGetLastError(); | ||
|  | } | ||
|  | 
 | ||
|  | template<typename T>  | ||
|  | cudaError_t strided_batched_gemm_nn_reference(std::vector<T> const &A,  | ||
|  |   int lda, | ||
|  |   long long int batch_stride_A, | ||
|  |   std::vector<T> const &B,  | ||
|  |   int ldb, | ||
|  |   long long int batch_stride_B, | ||
|  |   std::vector<T> &C,  | ||
|  |   int ldc, | ||
|  |   long long int batch_stride_C, | ||
|  |   T alpha, | ||
|  |   T beta, | ||
|  |   int m, | ||
|  |   int n, | ||
|  |   int k, | ||
|  |   int batch_count) { | ||
|  |   /* | ||
|  |   strided batched gemm NN | ||
|  |   */ | ||
|  |    | ||
|  |   cudaError_t result = cudaSuccess; | ||
|  | 
 | ||
|  |   if (A.size() < lda * k * batch_count) { | ||
|  |     std::cout << "the size of A is too small" << std::endl; | ||
|  |     return cudaErrorInvalidValue; | ||
|  |   } | ||
|  |   if (B.size() < ldb * n) { | ||
|  |     std::cout << "the size of B is too small" << std::endl; | ||
|  |     return cudaErrorInvalidValue; | ||
|  |   } | ||
|  |   if (C.size() < ldc * n * batch_count) { | ||
|  |     std::cout << "the size of C is too small" << std::endl; | ||
|  |     return cudaErrorInvalidValue; | ||
|  |   } | ||
|  |    | ||
|  |   for (int batch_idx = 0; batch_idx < batch_count; batch_idx++) { | ||
|  |     for (int n_idx = 0; n_idx < n; n_idx++) { | ||
|  |       for (int m_idx = 0; m_idx < m; m_idx++) { | ||
|  |         T accum = beta * C[batch_idx * batch_stride_C + n_idx * ldc + m_idx]; | ||
|  |         for (int k_idx = 0; k_idx < k; k_idx++) { | ||
|  |           accum += alpha  | ||
|  |             * A[batch_idx * batch_stride_A + k_idx * lda + m_idx] | ||
|  |             * B[batch_idx * batch_stride_B + n_idx * ldb + k_idx]; | ||
|  |         } | ||
|  |         C[batch_idx * batch_stride_C + n_idx * ldc + m_idx] = accum; | ||
|  |       } | ||
|  |     } | ||
|  |   } | ||
|  | 
 | ||
|  |   return result; | ||
|  | } | ||
|  | 
 | ||
|  | int main() { | ||
|  |   int const m = 16; | ||
|  |   int const n = 24; | ||
|  |   int const k = 8; | ||
|  |   int const batch_count = 3; | ||
|  | 
 | ||
|  |   // A, B are non-transpose, column major | ||
|  |   int const lda = m; | ||
|  |   int const ldb = k * batch_count; | ||
|  |   int const ldc = m; | ||
|  | 
 | ||
|  |   int const count_A = batch_count * lda * k; | ||
|  |   int const count_B = ldb * n; | ||
|  |   int const count_C = batch_count * ldc * n; | ||
|  | 
 | ||
|  |   // the memory is batched along K dimension | ||
|  |   long long int batch_stride_A = static_cast<long long int>(lda) * static_cast<long long int>(k); | ||
|  |   long long int batch_stride_B = static_cast<long long int>(k); | ||
|  |   long long int batch_stride_C = static_cast<long long int>(ldc) * static_cast<long long int>(n); | ||
|  | 
 | ||
|  |   // alpha and beta | ||
|  |   float alpha = 1.0f; | ||
|  |   float beta = 2.0f; | ||
|  | 
 | ||
|  |   cudaError_t result = cudaSuccess; | ||
|  | 
 | ||
|  |   // allocate the host memory | ||
|  |   std::vector<float> host_A(count_A); | ||
|  |   std::vector<float> host_B(count_B); | ||
|  |   std::vector<float> host_C(count_C); | ||
|  |   std::vector<float> result_C(count_C); | ||
|  | 
 | ||
|  |   // allocate the device memory | ||
|  |   float *A; | ||
|  |   float *B; | ||
|  |   float *C; | ||
|  | 
 | ||
|  |   result = cudaMalloc(&A, count_A * sizeof(float)); | ||
|  |   if (result != cudaSuccess) { | ||
|  |     std::cerr << "cudaMalloc result = " << result << std::endl; | ||
|  |     return result; | ||
|  |   } | ||
|  |   result = cudaMalloc(&B, count_B * sizeof(float)); | ||
|  |   if (result != cudaSuccess) { | ||
|  |     std::cerr << "cudaMalloc result = " << result << std::endl; | ||
|  |     return result; | ||
|  |   } | ||
|  |   result = cudaMalloc(&C, count_C * sizeof(float)); | ||
|  |   if (result != cudaSuccess) { | ||
|  |     std::cerr << "cudaMalloc result = " << result << std::endl; | ||
|  |     return result; | ||
|  |   } | ||
|  | 
 | ||
|  |   // fill A | ||
|  |   for (int b_idx = 0; b_idx < batch_count; b_idx++) { | ||
|  |     for (int col_idx = 0; col_idx < k; col_idx++) { | ||
|  |       for (int row_idx = 0; row_idx < m; row_idx++) { | ||
|  |         host_A[row_idx + col_idx * lda + b_idx * lda * k] = static_cast<float>(row_idx + col_idx * lda + b_idx * lda * k); | ||
|  |       } | ||
|  |     } | ||
|  |   } | ||
|  |   // fill B | ||
|  |   for (int b_idx = 0; b_idx < batch_count; b_idx++) { | ||
|  |     for (int col_idx = 0; col_idx < n; col_idx++) { | ||
|  |       for (int row_idx = 0; row_idx < k; row_idx++) { | ||
|  |         host_B[row_idx + col_idx * ldb + b_idx * k] = static_cast<float>(n + k * ldb + batch_count * k) - static_cast<float>(row_idx + col_idx * ldb + b_idx * k); | ||
|  |       } | ||
|  |     } | ||
|  |   } | ||
|  |   // fill C | ||
|  |   for (int b_idx = 0; b_idx < batch_count; b_idx++) { | ||
|  |     for (int col_idx = 0; col_idx < n; col_idx++) { | ||
|  |       for (int row_idx = 0; row_idx < m; row_idx++) { | ||
|  |         host_C[row_idx + col_idx * ldc + b_idx * ldc * n] = 1.f; | ||
|  |       } | ||
|  |     } | ||
|  |   } | ||
|  | 
 | ||
|  |   // ref memory | ||
|  |   std::vector<float> ref_A(host_A); | ||
|  |   std::vector<float> ref_B(host_B); | ||
|  |   std::vector<float> ref_C(host_C); | ||
|  |   // copy host memory to device | ||
|  |   result = cudaMemcpy(A, host_A.data(), count_A * sizeof(float), cudaMemcpyHostToDevice); | ||
|  |   if (result != cudaSuccess) { | ||
|  |     std::cerr << "cudaMemcpy result = " << result << std::endl; | ||
|  |     return result; | ||
|  |   } | ||
|  |   result = cudaMemcpy(B, host_B.data(), count_B * sizeof(float), cudaMemcpyHostToDevice); | ||
|  |   if (result != cudaSuccess) { | ||
|  |     std::cerr << "cudaMemcpy result = " << result << std::endl; | ||
|  |     return result; | ||
|  |   } | ||
|  |   result = cudaMemcpy(C, host_C.data(), count_C * sizeof(float), cudaMemcpyHostToDevice); | ||
|  |   if (result != cudaSuccess) { | ||
|  |     std::cerr << "cudaMemcpy result = " << result << std::endl; | ||
|  |     return result; | ||
|  |   } | ||
|  | 
 | ||
|  |   // run cutlass | ||
|  |   result = cutlass_strided_batched_sgemm(A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C, | ||
|  |     alpha, beta, m, n, k, batch_count); | ||
|  |   if (result != cudaSuccess) | ||
|  |     return result; | ||
|  | 
 | ||
|  |   // copy device memory to host | ||
|  |   result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost); | ||
|  |   if (result != cudaSuccess) { | ||
|  |     std::cerr << "cudaMemcpy result = " << result << std::endl; | ||
|  |     return result; | ||
|  |   } | ||
|  | 
 | ||
|  |   //compare with reference code | ||
|  |   result = strided_batched_gemm_nn_reference(ref_A, lda, batch_stride_A, ref_B, ldb, batch_stride_B, ref_C, ldc, batch_stride_C, | ||
|  |     alpha, beta, m, n, k, batch_count); | ||
|  |   if (result != 0) | ||
|  |     return result; | ||
|  | 
 | ||
|  |   if (ref_C != result_C) { | ||
|  |     std::cout << "CUTLASS strided batched gemm does not run correctly" << std::endl; | ||
|  |     return cudaErrorUnknown; | ||
|  |   } | ||
|  | 
 | ||
|  |   // free memory | ||
|  |   result = cudaFree(A); | ||
|  |   if (result != cudaSuccess) { | ||
|  |     std::cerr << "cudaFree result = " << result << std::endl; | ||
|  |     return result; | ||
|  |   } | ||
|  |   result = cudaFree(B); | ||
|  |   if (result != cudaSuccess) { | ||
|  |     std::cerr << "cudaFree result = " << result << std::endl; | ||
|  |     return result; | ||
|  |   } | ||
|  |   result = cudaFree(C); | ||
|  |   if (result != cudaSuccess) { | ||
|  |     std::cerr << "cudaFree result = " << result << std::endl; | ||
|  |     return result; | ||
|  |   } | ||
|  | 
 | ||
|  | 
 | ||
|  |   if (result == cudaSuccess) { | ||
|  |     std::cout << "Passed." << std::endl; | ||
|  |   } | ||
|  | 
 | ||
|  |   // Exit. | ||
|  |   return result == cudaSuccess ? 0 : -1; | ||
|  | } |