Ensure all arch::Mma specializations have ElementC set (#576)

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
This commit is contained in:
dan_the_3rd 2022-07-23 05:53:03 +02:00 committed by GitHub
parent 5d05808072
commit 25ebf15d02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 66 additions and 50 deletions

View File

@ -143,16 +143,17 @@ template <
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB,
/// Element type of C matrix
typename ElementC,
typename ElementC_,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Inner product operator
typename Operator_
>
struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator_> {
struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC_, LayoutC, Operator_> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = Operator_;
using ElementC = ElementC_;
CUTLASS_HOST_DEVICE
void operator()(

View File

@ -62,6 +62,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, float, LayoutA, float, LayoutB, float, L
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAdd;
using ElementC = float;
CUTLASS_HOST_DEVICE
void operator()(
@ -89,6 +90,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, double, LayoutA, double, LayoutB, double
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAdd;
using ElementC = double;
CUTLASS_HOST_DEVICE
void operator()(
@ -117,6 +119,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, int, LayoutA, int, LayoutB, int, LayoutC
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAdd;
using ElementC = int;
CUTLASS_HOST_DEVICE
void operator()(
@ -154,6 +157,7 @@ struct Mma<
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
using ElementC = complex<float>;
CUTLASS_HOST_DEVICE
void operator()(
@ -194,6 +198,7 @@ struct Mma<
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
using ElementC = complex<float>;
CUTLASS_HOST_DEVICE
void operator()(
@ -232,6 +237,7 @@ struct Mma<
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
using ElementC = complex<float>;
CUTLASS_HOST_DEVICE
void operator()(
@ -270,6 +276,7 @@ struct Mma<
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
using ElementC = complex<double>;
CUTLASS_HOST_DEVICE
void operator()(
@ -308,6 +315,7 @@ struct Mma<
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
using ElementC = complex<double>;
CUTLASS_HOST_DEVICE
void operator()(
@ -344,6 +352,7 @@ struct Mma<
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
using ElementC = complex<double>;
CUTLASS_HOST_DEVICE
void operator()(
@ -373,6 +382,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, half_t, LayoutA, half_t, LayoutB, float,
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAdd;
using ElementC = float;
CUTLASS_HOST_DEVICE
void operator()(
@ -401,6 +411,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, Quaternion<float>, LayoutA, Quaternion<f
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAdd;
using Element = Quaternion<float>;
using ElementC = Element;
CUTLASS_HOST_DEVICE
void operator()(

View File

@ -62,6 +62,7 @@ struct Mma<
using Shape = gemm::GemmShape<2, 1, 1>;
using Operator = OpMultiplyAdd;
using ElementC = half_t;
CUTLASS_HOST_DEVICE
void operator()(
@ -107,6 +108,7 @@ struct Mma<
using Shape = gemm::GemmShape<1, 2, 1>;
using Operator = OpMultiplyAdd;
using ElementC = half_t;
CUTLASS_HOST_DEVICE
void operator()(
@ -152,6 +154,7 @@ struct Mma <
using Shape = gemm::GemmShape<2, 2, 1>;
using Operator = OpMultiplyAdd;
using ElementC = half_t;
CUTLASS_HOST_DEVICE
void operator()(
@ -206,6 +209,7 @@ struct Mma<
using Shape = gemm::GemmShape<2, 2, 1>;
using Operator = OpMultiplyAdd;
using ElementC = half_t;
CUTLASS_HOST_DEVICE
void operator()(
@ -246,4 +250,3 @@ struct Mma<
}
}

View File

@ -58,6 +58,7 @@ struct Mma<
using Shape = gemm::GemmShape<1, 1, 4>;
using Operator = OpMultiplyAdd;
using ElementC = int;
CUTLASS_HOST_DEVICE
void operator()(
@ -106,6 +107,7 @@ struct Mma<
using Shape = gemm::GemmShape<1, 1, 2>;
using Operator = OpMultiplyAdd;
using ElementC = int;
CUTLASS_HOST_DEVICE
void operator()(
@ -138,4 +140,3 @@ struct Mma<
}
}