/*************************************************************************************************** * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this list of * conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright notice, this list of * conditions and the following disclaimer in the documentation and/or other materials * provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used * to endorse or promote products derived from this software without specific prior written * permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #include #include #include "cutlass/cutlass.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/sgemm_traits.h" /* This example demonstrates how to use cutlass to compute a batched strided gemm. In this example, both A and B matrix are non-transpose and column major matrix batched_C = batched_A x batched_B As an example, matrix C can be seen as ----------------------------------------------------------- (0,0,0) | (0,0,1) | (0,0,2) | (1,0,0) | (1,0,1) | (1,0,2) | ----------------------------------------------------------- (0,1,0) | (0,1,1) | (0,1,2) | (1,1,0) | (1,1,1) | (1,1,2) | ----------------------------------------------------------- (0,2,0) | (0,2,1) | (0,2,2) | (1,2,0) | (1,2,1) | (1,2,2) | ----------------------------------------------------------- (0,3,0) | (0,3,1) | (0,3,2) | (1,3,0) | (1,3,1) | (1,3,2) | ----------------------------------------------------------- (0,4,0) | (0,4,1) | (0,4,2) | (1,4,0) | (1,4,1) | (1,4,2) | ----------------------------------------------------------- (0,5,0) | (0,5,1) | (0,5,2) | (1,5,0) | (1,5,1) | (1,5,2) | ----------------------------------------------------------- batch 0 | batch 1 where we denote each element with (batch_idx, row_idx, column_idx) In this example, batch size is 2, M is 6 and N is 3 The stride (batch_stride_C) between the first element of two batches is ldc * n matrix A can be seen as --------------------------------------- (0,0,0) | (0,0,1) | (1,0,0) | (1,0,1) | --------------------------------------- (0,1,0) | (0,1,1) | (1,1,0) | (1,1,1) | --------------------------------------- (0,2,0) | (0,2,1) | (1,2,0) | (1,2,1) | --------------------------------------- (0,3,0) | (0,3,1) | (1,3,0) | (1,3,1) | --------------------------------------- (0,4,0) | (0,4,1) | (1,4,0) | (1,4,1) | --------------------------------------- (0,5,0) | (0,5,1) | (1,5,0) | (1,5,1) | --------------------------------------- batch 0 | batch 1 , where batch size is 2, M is 6 and K is 2 The stride (batch_stride_B) between the first element of two batches is lda * k matrix B can be seen as ----------------------------- (0,0,0) | (0,0,1) | (0,0,2) | ----------------------------- batch 0 (0,1,0) | (0,1,1) | (0,1,2) | ------------------------------------- (1,0,0) | (1,0,1) | (1,0,2) | ----------------------------- batch 1 (1,1,0) | (1,1,1) | (1,1,2) | ----------------------------- , where the batch size is 2, N is 3 and K is 2 The stride (batch_stride_C) between the first element of two batches is k */ cudaError_t cutlass_strided_batched_sgemm(float const *A, int lda, long long int batch_stride_A, float const *B, int ldb, long long int batch_stride_B, float *C, int ldc, long long int batch_stride_C, float alpha, float beta, int m, int n, int k, int batch_count) { // create a cutlass traits typedef cutlass::gemm::SgemmTraits > SgemmTraits; // create a CUTLASS GEMM object. typedef cutlass::gemm::Gemm Gemm; // Construct and initialize CUTLASS GEMM parameters object. typename Gemm::Params params; int result = params.initialize( m, // M dimension for each batch n, // N dimension for each batch k, // K dimension for each batch alpha, // scalar alpha A, lda, batch_stride_A, // distance in memory between the first element of neighboring batch B, ldb, batch_stride_B, // distance in memory between the first element of neighboring batch beta, // scalar beta C, // source matrix C ldc, batch_stride_C, // distance in memory between the first element of neighboring batch C, // destination matrix C (may be different memory than source C matrix) ldc, batch_stride_C, // distance in memory between the first element of neighboring batch batch_count ); if (result != 0) { std::cerr << "Failed to initialize CUTLASS Gemm::Params object." << std::endl; return cudaErrorInvalidValue; } // Launch the CUTLASS GEMM kernel. Gemm::launch(params); result = cudaDeviceSynchronize(); if (result != cudaSuccess) { std::cerr << "kernel launch result = " << result << std::endl; } return cudaGetLastError(); } template cudaError_t strided_batched_gemm_nn_reference(std::vector const &A, int lda, long long int batch_stride_A, std::vector const &B, int ldb, long long int batch_stride_B, std::vector &C, int ldc, long long int batch_stride_C, T alpha, T beta, int m, int n, int k, int batch_count) { /* strided batched gemm NN */ cudaError_t result = cudaSuccess; if (A.size() < lda * k * batch_count) { std::cout << "the size of A is too small" << std::endl; return cudaErrorInvalidValue; } if (B.size() < ldb * n) { std::cout << "the size of B is too small" << std::endl; return cudaErrorInvalidValue; } if (C.size() < ldc * n * batch_count) { std::cout << "the size of C is too small" << std::endl; return cudaErrorInvalidValue; } for (int batch_idx = 0; batch_idx < batch_count; batch_idx++) { for (int n_idx = 0; n_idx < n; n_idx++) { for (int m_idx = 0; m_idx < m; m_idx++) { T accum = beta * C[batch_idx * batch_stride_C + n_idx * ldc + m_idx]; for (int k_idx = 0; k_idx < k; k_idx++) { accum += alpha * A[batch_idx * batch_stride_A + k_idx * lda + m_idx] * B[batch_idx * batch_stride_B + n_idx * ldb + k_idx]; } C[batch_idx * batch_stride_C + n_idx * ldc + m_idx] = accum; } } } return result; } int main() { int const m = 16; int const n = 24; int const k = 8; int const batch_count = 3; // A, B are non-transpose, column major int const lda = m; int const ldb = k * batch_count; int const ldc = m; int const count_A = batch_count * lda * k; int const count_B = ldb * n; int const count_C = batch_count * ldc * n; // the memory is batched along K dimension long long int batch_stride_A = static_cast(lda) * static_cast(k); long long int batch_stride_B = static_cast(k); long long int batch_stride_C = static_cast(ldc) * static_cast(n); // alpha and beta float alpha = 1.0f; float beta = 2.0f; cudaError_t result = cudaSuccess; // allocate the host memory std::vector host_A(count_A); std::vector host_B(count_B); std::vector host_C(count_C); std::vector result_C(count_C); // allocate the device memory float *A; float *B; float *C; result = cudaMalloc(&A, count_A * sizeof(float)); if (result != cudaSuccess) { std::cerr << "cudaMalloc result = " << result << std::endl; return result; } result = cudaMalloc(&B, count_B * sizeof(float)); if (result != cudaSuccess) { std::cerr << "cudaMalloc result = " << result << std::endl; return result; } result = cudaMalloc(&C, count_C * sizeof(float)); if (result != cudaSuccess) { std::cerr << "cudaMalloc result = " << result << std::endl; return result; } // fill A for (int b_idx = 0; b_idx < batch_count; b_idx++) { for (int col_idx = 0; col_idx < k; col_idx++) { for (int row_idx = 0; row_idx < m; row_idx++) { host_A[row_idx + col_idx * lda + b_idx * lda * k] = static_cast(row_idx + col_idx * lda + b_idx * lda * k); } } } // fill B for (int b_idx = 0; b_idx < batch_count; b_idx++) { for (int col_idx = 0; col_idx < n; col_idx++) { for (int row_idx = 0; row_idx < k; row_idx++) { host_B[row_idx + col_idx * ldb + b_idx * k] = static_cast(n + k * ldb + batch_count * k) - static_cast(row_idx + col_idx * ldb + b_idx * k); } } } // fill C for (int b_idx = 0; b_idx < batch_count; b_idx++) { for (int col_idx = 0; col_idx < n; col_idx++) { for (int row_idx = 0; row_idx < m; row_idx++) { host_C[row_idx + col_idx * ldc + b_idx * ldc * n] = 1.f; } } } // ref memory std::vector ref_A(host_A); std::vector ref_B(host_B); std::vector ref_C(host_C); // copy host memory to device result = cudaMemcpy(A, host_A.data(), count_A * sizeof(float), cudaMemcpyHostToDevice); if (result != cudaSuccess) { std::cerr << "cudaMemcpy result = " << result << std::endl; return result; } result = cudaMemcpy(B, host_B.data(), count_B * sizeof(float), cudaMemcpyHostToDevice); if (result != cudaSuccess) { std::cerr << "cudaMemcpy result = " << result << std::endl; return result; } result = cudaMemcpy(C, host_C.data(), count_C * sizeof(float), cudaMemcpyHostToDevice); if (result != cudaSuccess) { std::cerr << "cudaMemcpy result = " << result << std::endl; return result; } // run cutlass result = cutlass_strided_batched_sgemm(A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C, alpha, beta, m, n, k, batch_count); if (result != cudaSuccess) return result; // copy device memory to host result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost); if (result != cudaSuccess) { std::cerr << "cudaMemcpy result = " << result << std::endl; return result; } //compare with reference code result = strided_batched_gemm_nn_reference(ref_A, lda, batch_stride_A, ref_B, ldb, batch_stride_B, ref_C, ldc, batch_stride_C, alpha, beta, m, n, k, batch_count); if (result != 0) return result; if (ref_C != result_C) { std::cout << "CUTLASS strided batched gemm does not run correctly" << std::endl; return cudaErrorUnknown; } // free memory result = cudaFree(A); if (result != cudaSuccess) { std::cerr << "cudaFree result = " << result << std::endl; return result; } result = cudaFree(B); if (result != cudaSuccess) { std::cerr << "cudaFree result = " << result << std::endl; return result; } result = cudaFree(C); if (result != cudaSuccess) { std::cerr << "cudaFree result = " << result << std::endl; return result; } if (result == cudaSuccess) { std::cout << "Passed." << std::endl; } // Exit. return result == cudaSuccess ? 0 : -1; }