Ensure all arch::Mma specializations have ElementC set (#576)
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
This commit is contained in:
parent
5d05808072
commit
25ebf15d02
@ -143,16 +143,17 @@ template <
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB,
|
||||
/// Element type of C matrix
|
||||
typename ElementC,
|
||||
typename ElementC_,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC,
|
||||
/// Inner product operator
|
||||
typename Operator_
|
||||
>
|
||||
struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator_> {
|
||||
struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC_, LayoutC, Operator_> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = Operator_;
|
||||
using ElementC = ElementC_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
|
@ -62,6 +62,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, float, LayoutA, float, LayoutB, float, L
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = float;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -89,6 +90,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, double, LayoutA, double, LayoutB, double
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = double;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -117,6 +119,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, int, LayoutA, int, LayoutB, int, LayoutC
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = int;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -154,6 +157,7 @@ struct Mma<
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
using ElementC = complex<float>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -194,6 +198,7 @@ struct Mma<
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
using ElementC = complex<float>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -232,6 +237,7 @@ struct Mma<
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
using ElementC = complex<float>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -270,6 +276,7 @@ struct Mma<
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
using ElementC = complex<double>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -308,6 +315,7 @@ struct Mma<
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
using ElementC = complex<double>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -344,6 +352,7 @@ struct Mma<
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
using ElementC = complex<double>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -373,6 +382,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, half_t, LayoutA, half_t, LayoutB, float,
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = float;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -401,6 +411,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, Quaternion<float>, LayoutA, Quaternion<f
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using Element = Quaternion<float>;
|
||||
using ElementC = Element;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
|
@ -62,6 +62,7 @@ struct Mma<
|
||||
|
||||
using Shape = gemm::GemmShape<2, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = half_t;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -107,6 +108,7 @@ struct Mma<
|
||||
|
||||
using Shape = gemm::GemmShape<1, 2, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = half_t;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -152,6 +154,7 @@ struct Mma <
|
||||
|
||||
using Shape = gemm::GemmShape<2, 2, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = half_t;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -206,6 +209,7 @@ struct Mma<
|
||||
|
||||
using Shape = gemm::GemmShape<2, 2, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = half_t;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -246,4 +250,3 @@ struct Mma<
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -58,6 +58,7 @@ struct Mma<
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 4>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = int;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -106,6 +107,7 @@ struct Mma<
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 2>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = int;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -138,4 +140,3 @@ struct Mma<
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user