CUDA 9 lacks host-side conversions from float=>half. Instead, we must reinterpret_cast<> from cutlass::half_t => half.

This commit is contained in:
akerr 2018-09-29 15:04:20 -07:00
parent 6877595a5e
commit cfe4b933ef

View File

@ -144,18 +144,18 @@ cudaError_t Cutlass_FP16_SgemmNN(
typename Gemm::Params params;
int result = params.initialize(
M, // GEMM M dimension
N, // GEMM N dimension
K, // GEMM K dimension
half(float(alpha)), // scalar alpha - This is a legal conversion from cutlass::half_t to CUDA's half.
A, // matrix A operand
M, // GEMM M dimension
N, // GEMM N dimension
K, // GEMM K dimension
reinterpret_cast<half const &>(alpha), // scalar alpha - This is a legal conversion from cutlass::half_t to CUDA's half.
A, // matrix A operand
lda,
B, // matrix B operand
B, // matrix B operand
ldb,
half(float(beta)), // scalar beta - This is a legal conversion from cutlass::half_t to CUDA's half.
C, // source matrix C
reinterpret_cast<half const &>(beta), // scalar beta - This is a legal conversion from cutlass::half_t to CUDA's half.
C, // source matrix C
ldc,
C, // destination matrix C (may be different memory than source C matrix)
C, // destination matrix C (may be different memory than source C matrix)
ldc
);