2018-09-19 07:58:03 +08:00
|
|
|
/***************************************************************************************************
|
2021-02-26 22:58:26 +08:00
|
|
|
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
2018-09-19 07:58:03 +08:00
|
|
|
*
|
|
|
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
|
|
|
* provided that the following conditions are met:
|
|
|
|
* * Redistributions of source code must retain the above copyright notice, this list of
|
|
|
|
* conditions and the following disclaimer.
|
|
|
|
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
|
|
|
* conditions and the following disclaimer in the documentation and/or other materials
|
|
|
|
* provided with the distribution.
|
|
|
|
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
|
|
|
* to endorse or promote products derived from this software without specific prior written
|
|
|
|
* permission.
|
|
|
|
*
|
|
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
|
|
|
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
|
|
|
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
|
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
|
|
|
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
|
|
|
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
|
|
|
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
*
|
|
|
|
**************************************************************************************************/
|
|
|
|
|
|
|
|
/*
|
|
|
|
This example demonstrates how to call a CUTLASS GEMM kernel and provides a naive reference
|
|
|
|
matrix multiply kernel to verify its correctness.
|
|
|
|
|
|
|
|
The CUTLASS Gemm template is instantiated in the function CutlassSgemmNN. This is kernel computes
|
|
|
|
the general matrix product (GEMM) using single-precision floating-point arithmetic and assumes
|
|
|
|
all matrices have column-major layout.
|
|
|
|
|
|
|
|
The threadblock tile size is chosen as 128x128x8 which offers good performance for large matrices.
|
|
|
|
See the CUTLASS Parallel for All blog post for more exposition on the tunable parameters available
|
|
|
|
in CUTLASS.
|
|
|
|
|
|
|
|
https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/
|
|
|
|
|
|
|
|
Aside from defining and launching the SGEMM kernel, this example does not use any other components
|
|
|
|
or utilities within CUTLASS. Such utilities are demonstrated elsewhere in other examples and are
|
|
|
|
prevalent in the CUTLASS unit tests.
|
2019-11-20 08:55:34 +08:00
|
|
|
|
|
|
|
This example has delibrately been kept similar to the basic_gemm example from cutass-1.3 to
|
|
|
|
highlight the minimum amount of differences needed to transition to cutlass-2.0.
|
|
|
|
|
|
|
|
Cutlass-1.3 sgemm: https://github.com/NVIDIA/cutlass/blob/master/examples/00_basic_gemm/basic_gemm.cu
|
2018-09-19 07:58:03 +08:00
|
|
|
*/
|
|
|
|
|
|
|
|
// Standard Library includes
|
|
|
|
#include <iostream>
|
|
|
|
#include <sstream>
|
|
|
|
#include <vector>
|
|
|
|
|
2019-11-20 08:55:34 +08:00
|
|
|
// Helper methods to check for errors
|
|
|
|
#include "helper.h"
|
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
//
|
|
|
|
// CUTLASS includes needed for single-precision GEMM kernel
|
|
|
|
//
|
|
|
|
|
2019-11-20 08:55:34 +08:00
|
|
|
// Defines cutlass::gemm::device::Gemm, the generic Gemm computation template class.
|
|
|
|
#include "cutlass/gemm/device/gemm.h"
|
2019-03-21 01:49:17 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//
|
|
|
|
// This function defines a CUTLASS GEMM kernel instantiation, constructs its parameters object,
|
|
|
|
// and launches it on the CUDA device.
|
|
|
|
//
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
/// Define a CUTLASS GEMM template and launch a GEMM kernel.
|
|
|
|
cudaError_t CutlassSgemmNN(
|
|
|
|
int M,
|
|
|
|
int N,
|
|
|
|
int K,
|
|
|
|
float alpha,
|
|
|
|
float const *A,
|
|
|
|
int lda,
|
|
|
|
float const *B,
|
|
|
|
int ldb,
|
|
|
|
float beta,
|
|
|
|
float *C,
|
|
|
|
int ldc) {
|
|
|
|
|
|
|
|
// Define type definition for single-precision CUTLASS GEMM with column-major
|
2019-11-20 08:55:34 +08:00
|
|
|
// input matrices and 128x128x8 threadblock tile size (chosen by default).
|
2018-09-19 07:58:03 +08:00
|
|
|
//
|
|
|
|
// To keep the interface manageable, several helpers are defined for plausible compositions
|
|
|
|
// including the following example for single-precision GEMM. Typical values are used as
|
2019-11-20 08:55:34 +08:00
|
|
|
// default template arguments. See `cutlass/gemm/device/default_gemm_configuration.h` for more details.
|
2018-09-19 07:58:03 +08:00
|
|
|
//
|
2019-11-20 08:55:34 +08:00
|
|
|
// To view the full gemm device API interface, see `cutlass/gemm/device/gemm.h`
|
|
|
|
|
|
|
|
using ColumnMajor = cutlass::layout::ColumnMajor;
|
|
|
|
|
|
|
|
using CutlassGemm = cutlass::gemm::device::Gemm<float, // Data-type of A matrix
|
|
|
|
ColumnMajor, // Layout of A matrix
|
|
|
|
float, // Data-type of B matrix
|
|
|
|
ColumnMajor, // Layout of B matrix
|
|
|
|
float, // Data-type of C matrix
|
|
|
|
ColumnMajor>; // Layout of C matrix
|
2018-09-19 07:58:03 +08:00
|
|
|
|
2019-11-20 08:55:34 +08:00
|
|
|
// Define a CUTLASS GEMM type
|
|
|
|
CutlassGemm gemm_operator;
|
2018-09-19 07:58:03 +08:00
|
|
|
|
2019-11-20 08:55:34 +08:00
|
|
|
// Construct the CUTLASS GEMM arguments object.
|
2018-09-19 07:58:03 +08:00
|
|
|
//
|
2019-11-20 08:55:34 +08:00
|
|
|
// One of CUTLASS's design patterns is to define gemm argument objects that are constructible
|
2018-09-19 07:58:03 +08:00
|
|
|
// in host code and passed to kernels by value. These may include pointers, strides, scalars,
|
|
|
|
// and other arguments needed by Gemm and its components.
|
|
|
|
//
|
|
|
|
// The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible
|
|
|
|
// arguments to kernels and (2.) minimized initialization overhead on kernel entry.
|
|
|
|
//
|
2019-11-20 08:55:34 +08:00
|
|
|
CutlassGemm::Arguments args({M , N, K}, // Gemm Problem dimensions
|
|
|
|
{A, lda}, // Tensor-ref for source matrix A
|
|
|
|
{B, ldb}, // Tensor-ref for source matrix B
|
|
|
|
{C, ldc}, // Tensor-ref for source matrix C
|
|
|
|
{C, ldc}, // Tensor-ref for destination matrix D (may be different memory than source C matrix)
|
|
|
|
{alpha, beta}); // Scalars used in the Epilogue
|
2018-09-19 07:58:03 +08:00
|
|
|
|
2019-11-20 08:55:34 +08:00
|
|
|
//
|
2018-09-19 07:58:03 +08:00
|
|
|
// Launch the CUTLASS GEMM kernel.
|
2019-11-20 08:55:34 +08:00
|
|
|
//
|
|
|
|
|
|
|
|
cutlass::Status status = gemm_operator(args);
|
|
|
|
|
|
|
|
//
|
|
|
|
// Return a cudaError_t if the CUTLASS GEMM operator returned an error code.
|
|
|
|
//
|
|
|
|
|
|
|
|
if (status != cutlass::Status::kSuccess) {
|
|
|
|
return cudaErrorUnknown;
|
|
|
|
}
|
2018-09-19 07:58:03 +08:00
|
|
|
|
2019-11-20 08:55:34 +08:00
|
|
|
// Return success, if no errors were encountered.
|
|
|
|
return cudaSuccess;
|
2018-09-19 07:58:03 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//
|
|
|
|
// The source code after this point in the file is generic CUDA using the CUDA Runtime API
|
|
|
|
// and simple CUDA kernels to initialize matrices and compute the general matrix product.
|
|
|
|
//
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
/// Kernel to initialize a matrix with small integers.
|
|
|
|
__global__ void InitializeMatrix_kernel(
|
|
|
|
float *matrix,
|
|
|
|
int ldm,
|
|
|
|
int rows,
|
|
|
|
int columns,
|
|
|
|
int seed = 0) {
|
|
|
|
|
|
|
|
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
|
|
|
int j = threadIdx.y + blockIdx.y * blockDim.y;
|
|
|
|
|
|
|
|
if (i < rows && j < columns) {
|
|
|
|
int offset = i + j * ldm;
|
|
|
|
|
|
|
|
// Generate arbitrary elements.
|
|
|
|
int const k = 16807;
|
|
|
|
int const m = 16;
|
|
|
|
float value = float(((offset + seed) * k % m) - m / 2);
|
|
|
|
|
|
|
|
matrix[offset] = value;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Simple function to initialize a matrix to arbitrary small integers.
|
|
|
|
cudaError_t InitializeMatrix(float *matrix, int ldm, int rows, int columns, int seed = 0) {
|
|
|
|
|
|
|
|
dim3 block(16, 16);
|
|
|
|
dim3 grid(
|
|
|
|
(rows + block.x - 1) / block.x,
|
|
|
|
(columns + block.y - 1) / block.y
|
|
|
|
);
|
|
|
|
|
|
|
|
InitializeMatrix_kernel<<< grid, block >>>(matrix, ldm, rows, columns, seed);
|
|
|
|
|
|
|
|
return cudaGetLastError();
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
/// Allocates device memory for a matrix then fills with arbitrary small integers.
|
|
|
|
cudaError_t AllocateMatrix(float **matrix, int ldm, int rows, int columns, int seed = 0) {
|
|
|
|
cudaError_t result;
|
|
|
|
|
|
|
|
size_t sizeof_matrix = sizeof(float) * ldm * columns;
|
|
|
|
|
|
|
|
// Allocate device memory.
|
|
|
|
result = cudaMalloc(reinterpret_cast<void **>(matrix), sizeof_matrix);
|
|
|
|
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "Failed to allocate matrix: "
|
|
|
|
<< cudaGetErrorString(result) << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Clear the allocation.
|
|
|
|
result = cudaMemset(*matrix, 0, sizeof_matrix);
|
|
|
|
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "Failed to clear matrix device memory: "
|
|
|
|
<< cudaGetErrorString(result) << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Initialize matrix elements to arbitrary small integers.
|
|
|
|
result = InitializeMatrix(*matrix, ldm, rows, columns, seed);
|
|
|
|
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "Failed to initialize matrix: "
|
|
|
|
<< cudaGetErrorString(result) << std::endl;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
/// Naive reference GEMM computation.
|
|
|
|
__global__ void ReferenceGemm_kernel(
|
|
|
|
int M,
|
|
|
|
int N,
|
|
|
|
int K,
|
|
|
|
float alpha,
|
|
|
|
float const *A,
|
|
|
|
int lda,
|
|
|
|
float const *B,
|
|
|
|
int ldb,
|
|
|
|
float beta,
|
|
|
|
float *C,
|
|
|
|
int ldc) {
|
|
|
|
|
|
|
|
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
|
|
|
int j = threadIdx.y + blockIdx.y * blockDim.y;
|
|
|
|
|
|
|
|
if (i < M && j < N) {
|
|
|
|
float accumulator = 0;
|
|
|
|
|
|
|
|
for (int k = 0; k < K; ++k) {
|
|
|
|
accumulator += A[i + k * lda] * B[k + j * ldb];
|
|
|
|
}
|
|
|
|
|
|
|
|
C[i + j * ldc] = alpha * accumulator + beta * C[i + j * ldc];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Reference GEMM computation.
|
|
|
|
cudaError_t ReferenceGemm(
|
|
|
|
int M,
|
|
|
|
int N,
|
|
|
|
int K,
|
|
|
|
float alpha,
|
|
|
|
float const *A,
|
|
|
|
int lda,
|
|
|
|
float const *B,
|
|
|
|
int ldb,
|
|
|
|
float beta,
|
|
|
|
float *C,
|
|
|
|
int ldc) {
|
|
|
|
|
|
|
|
dim3 block(16, 16);
|
|
|
|
dim3 grid(
|
|
|
|
(M + block.x - 1) / block.x,
|
|
|
|
(N + block.y - 1) / block.y
|
|
|
|
);
|
|
|
|
|
|
|
|
ReferenceGemm_kernel<<< grid, block >>>(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
|
|
|
|
|
|
|
|
return cudaGetLastError();
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
/// Allocate several matrices in GPU device memory and call a single-precision
|
|
|
|
/// CUTLASS GEMM kernel.
|
|
|
|
cudaError_t TestCutlassGemm(int M, int N, int K, float alpha, float beta) {
|
|
|
|
cudaError_t result;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Define several matrices to be used as operands to GEMM kernels.
|
|
|
|
//
|
|
|
|
|
|
|
|
// Compute leading dimensions for each matrix.
|
|
|
|
int lda = M;
|
|
|
|
int ldb = K;
|
|
|
|
int ldc = M;
|
|
|
|
|
|
|
|
// Compute size in bytes of the C matrix.
|
|
|
|
size_t sizeof_C = sizeof(float) * ldc * N;
|
|
|
|
|
|
|
|
// Define pointers to matrices in GPU device memory.
|
|
|
|
float *A;
|
|
|
|
float *B;
|
|
|
|
float *C_cutlass;
|
|
|
|
float *C_reference;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Allocate matrices in GPU device memory with arbitrary seeds.
|
|
|
|
//
|
|
|
|
|
|
|
|
result = AllocateMatrix(&A, lda, M, K, 0);
|
|
|
|
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
result = AllocateMatrix(&B, ldb, K, N, 17);
|
|
|
|
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
cudaFree(A);
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
result = AllocateMatrix(&C_cutlass, ldc, M, N, 101);
|
|
|
|
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
cudaFree(A);
|
|
|
|
cudaFree(B);
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
result = AllocateMatrix(&C_reference, ldc, M, N, 101);
|
|
|
|
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
cudaFree(A);
|
|
|
|
cudaFree(B);
|
|
|
|
cudaFree(C_cutlass);
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
result = cudaMemcpy(C_reference, C_cutlass, sizeof_C, cudaMemcpyDeviceToDevice);
|
|
|
|
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "Failed to copy C_cutlass matrix to C_reference: "
|
|
|
|
<< cudaGetErrorString(result) << std::endl;
|
|
|
|
|
|
|
|
cudaFree(C_reference);
|
|
|
|
cudaFree(C_cutlass);
|
|
|
|
cudaFree(B);
|
|
|
|
cudaFree(A);
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
//
|
|
|
|
// Launch CUTLASS GEMM.
|
|
|
|
//
|
|
|
|
|
|
|
|
result = CutlassSgemmNN(M, N, K, alpha, A, lda, B, ldb, beta, C_cutlass, ldc);
|
|
|
|
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "CUTLASS GEMM kernel failed: "
|
|
|
|
<< cudaGetErrorString(result) << std::endl;
|
|
|
|
|
|
|
|
cudaFree(C_reference);
|
|
|
|
cudaFree(C_cutlass);
|
|
|
|
cudaFree(B);
|
|
|
|
cudaFree(A);
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
//
|
|
|
|
// Verify.
|
|
|
|
//
|
|
|
|
|
|
|
|
// Launch reference GEMM
|
|
|
|
result = ReferenceGemm(M, N, K, alpha, A, lda, B, ldb, beta, C_reference, ldc);
|
|
|
|
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "Reference GEMM kernel failed: "
|
|
|
|
<< cudaGetErrorString(result) << std::endl;
|
|
|
|
|
|
|
|
cudaFree(C_reference);
|
|
|
|
cudaFree(C_cutlass);
|
|
|
|
cudaFree(B);
|
|
|
|
cudaFree(A);
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Copy to host and verify equivalence.
|
|
|
|
std::vector<float> host_cutlass(ldc * N, 0);
|
|
|
|
std::vector<float> host_reference(ldc * N, 0);
|
|
|
|
|
|
|
|
result = cudaMemcpy(host_cutlass.data(), C_cutlass, sizeof_C, cudaMemcpyDeviceToHost);
|
|
|
|
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "Failed to copy CUTLASS GEMM results: "
|
|
|
|
<< cudaGetErrorString(result) << std::endl;
|
|
|
|
|
|
|
|
cudaFree(C_reference);
|
|
|
|
cudaFree(C_cutlass);
|
|
|
|
cudaFree(B);
|
|
|
|
cudaFree(A);
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
result = cudaMemcpy(host_reference.data(), C_reference, sizeof_C, cudaMemcpyDeviceToHost);
|
|
|
|
|
|
|
|
if (result != cudaSuccess) {
|
|
|
|
std::cerr << "Failed to copy Reference GEMM results: "
|
|
|
|
<< cudaGetErrorString(result) << std::endl;
|
|
|
|
|
|
|
|
cudaFree(C_reference);
|
|
|
|
cudaFree(C_cutlass);
|
|
|
|
cudaFree(B);
|
|
|
|
cudaFree(A);
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
//
|
|
|
|
// Free device memory allocations.
|
|
|
|
//
|
|
|
|
|
|
|
|
cudaFree(C_reference);
|
|
|
|
cudaFree(C_cutlass);
|
|
|
|
cudaFree(B);
|
|
|
|
cudaFree(A);
|
|
|
|
|
|
|
|
//
|
|
|
|
// Test for bit equivalence of results.
|
|
|
|
//
|
|
|
|
|
|
|
|
if (host_cutlass != host_reference) {
|
|
|
|
std::cerr << "CUTLASS results incorrect." << std::endl;
|
|
|
|
|
|
|
|
return cudaErrorUnknown;
|
|
|
|
}
|
|
|
|
|
|
|
|
return cudaSuccess;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
/// Entry point to basic_gemm example.
|
|
|
|
//
|
|
|
|
// usage:
|
|
|
|
//
|
|
|
|
// 00_basic_gemm <M> <N> <K> <alpha> <beta>
|
|
|
|
//
|
|
|
|
int main(int argc, const char *arg[]) {
|
|
|
|
|
|
|
|
//
|
|
|
|
// Parse the command line to obtain GEMM dimensions and scalar values.
|
|
|
|
//
|
|
|
|
|
|
|
|
// GEMM problem dimensions.
|
|
|
|
int problem[3] = { 128, 128, 128 };
|
|
|
|
|
|
|
|
for (int i = 1; i < argc && i < 4; ++i) {
|
|
|
|
std::stringstream ss(arg[i]);
|
|
|
|
ss >> problem[i - 1];
|
|
|
|
}
|
|
|
|
|
|
|
|
// Scalars used for linear scaling the result of the matrix product.
|
|
|
|
float scalars[2] = { 1, 0 };
|
|
|
|
|
|
|
|
for (int i = 4; i < argc && i < 6; ++i) {
|
|
|
|
std::stringstream ss(arg[i]);
|
|
|
|
ss >> scalars[i - 4];
|
|
|
|
}
|
|
|
|
|
|
|
|
//
|
|
|
|
// Run the CUTLASS GEMM test.
|
|
|
|
//
|
|
|
|
|
|
|
|
cudaError_t result = TestCutlassGemm(
|
|
|
|
problem[0], // GEMM M dimension
|
|
|
|
problem[1], // GEMM N dimension
|
|
|
|
problem[2], // GEMM K dimension
|
|
|
|
scalars[0], // alpha
|
|
|
|
scalars[1] // beta
|
|
|
|
);
|
|
|
|
|
|
|
|
if (result == cudaSuccess) {
|
|
|
|
std::cout << "Passed." << std::endl;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Exit.
|
|
|
|
return result == cudaSuccess ? 0 : -1;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|