Ensure all arch::Mma specializations have ElementC set (#576)
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
This commit is contained in:
parent
5d05808072
commit
25ebf15d02
@ -143,16 +143,17 @@ template <
|
|||||||
/// Layout of B matrix (concept: MatrixLayout)
|
/// Layout of B matrix (concept: MatrixLayout)
|
||||||
typename LayoutB,
|
typename LayoutB,
|
||||||
/// Element type of C matrix
|
/// Element type of C matrix
|
||||||
typename ElementC,
|
typename ElementC_,
|
||||||
/// Layout of C matrix (concept: MatrixLayout)
|
/// Layout of C matrix (concept: MatrixLayout)
|
||||||
typename LayoutC,
|
typename LayoutC,
|
||||||
/// Inner product operator
|
/// Inner product operator
|
||||||
typename Operator_
|
typename Operator_
|
||||||
>
|
>
|
||||||
struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator_> {
|
struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC_, LayoutC, Operator_> {
|
||||||
|
|
||||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||||
using Operator = Operator_;
|
using Operator = Operator_;
|
||||||
|
using ElementC = ElementC_;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
|
@ -62,6 +62,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, float, LayoutA, float, LayoutB, float, L
|
|||||||
|
|
||||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||||
using Operator = OpMultiplyAdd;
|
using Operator = OpMultiplyAdd;
|
||||||
|
using ElementC = float;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -89,6 +90,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, double, LayoutA, double, LayoutB, double
|
|||||||
|
|
||||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||||
using Operator = OpMultiplyAdd;
|
using Operator = OpMultiplyAdd;
|
||||||
|
using ElementC = double;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -117,6 +119,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, int, LayoutA, int, LayoutB, int, LayoutC
|
|||||||
|
|
||||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||||
using Operator = OpMultiplyAdd;
|
using Operator = OpMultiplyAdd;
|
||||||
|
using ElementC = int;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -154,6 +157,7 @@ struct Mma<
|
|||||||
|
|
||||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||||
using Operator = OpMultiplyAddComplex;
|
using Operator = OpMultiplyAddComplex;
|
||||||
|
using ElementC = complex<float>;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -194,6 +198,7 @@ struct Mma<
|
|||||||
|
|
||||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||||
using Operator = OpMultiplyAddComplex;
|
using Operator = OpMultiplyAddComplex;
|
||||||
|
using ElementC = complex<float>;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -232,6 +237,7 @@ struct Mma<
|
|||||||
|
|
||||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||||
using Operator = OpMultiplyAddComplex;
|
using Operator = OpMultiplyAddComplex;
|
||||||
|
using ElementC = complex<float>;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -270,6 +276,7 @@ struct Mma<
|
|||||||
|
|
||||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||||
using Operator = OpMultiplyAddComplex;
|
using Operator = OpMultiplyAddComplex;
|
||||||
|
using ElementC = complex<double>;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -308,6 +315,7 @@ struct Mma<
|
|||||||
|
|
||||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||||
using Operator = OpMultiplyAddComplex;
|
using Operator = OpMultiplyAddComplex;
|
||||||
|
using ElementC = complex<double>;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -344,6 +352,7 @@ struct Mma<
|
|||||||
|
|
||||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||||
using Operator = OpMultiplyAddComplex;
|
using Operator = OpMultiplyAddComplex;
|
||||||
|
using ElementC = complex<double>;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -373,6 +382,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, half_t, LayoutA, half_t, LayoutB, float,
|
|||||||
|
|
||||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||||
using Operator = OpMultiplyAdd;
|
using Operator = OpMultiplyAdd;
|
||||||
|
using ElementC = float;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -401,6 +411,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, Quaternion<float>, LayoutA, Quaternion<f
|
|||||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||||
using Operator = OpMultiplyAdd;
|
using Operator = OpMultiplyAdd;
|
||||||
using Element = Quaternion<float>;
|
using Element = Quaternion<float>;
|
||||||
|
using ElementC = Element;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
|
@ -62,6 +62,7 @@ struct Mma<
|
|||||||
|
|
||||||
using Shape = gemm::GemmShape<2, 1, 1>;
|
using Shape = gemm::GemmShape<2, 1, 1>;
|
||||||
using Operator = OpMultiplyAdd;
|
using Operator = OpMultiplyAdd;
|
||||||
|
using ElementC = half_t;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -107,6 +108,7 @@ struct Mma<
|
|||||||
|
|
||||||
using Shape = gemm::GemmShape<1, 2, 1>;
|
using Shape = gemm::GemmShape<1, 2, 1>;
|
||||||
using Operator = OpMultiplyAdd;
|
using Operator = OpMultiplyAdd;
|
||||||
|
using ElementC = half_t;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -152,6 +154,7 @@ struct Mma <
|
|||||||
|
|
||||||
using Shape = gemm::GemmShape<2, 2, 1>;
|
using Shape = gemm::GemmShape<2, 2, 1>;
|
||||||
using Operator = OpMultiplyAdd;
|
using Operator = OpMultiplyAdd;
|
||||||
|
using ElementC = half_t;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -206,6 +209,7 @@ struct Mma<
|
|||||||
|
|
||||||
using Shape = gemm::GemmShape<2, 2, 1>;
|
using Shape = gemm::GemmShape<2, 2, 1>;
|
||||||
using Operator = OpMultiplyAdd;
|
using Operator = OpMultiplyAdd;
|
||||||
|
using ElementC = half_t;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -246,4 +250,3 @@ struct Mma<
|
|||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,6 +58,7 @@ struct Mma<
|
|||||||
|
|
||||||
using Shape = gemm::GemmShape<1, 1, 4>;
|
using Shape = gemm::GemmShape<1, 1, 4>;
|
||||||
using Operator = OpMultiplyAdd;
|
using Operator = OpMultiplyAdd;
|
||||||
|
using ElementC = int;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -106,6 +107,7 @@ struct Mma<
|
|||||||
|
|
||||||
using Shape = gemm::GemmShape<1, 1, 2>;
|
using Shape = gemm::GemmShape<1, 1, 2>;
|
||||||
using Operator = OpMultiplyAdd;
|
using Operator = OpMultiplyAdd;
|
||||||
|
using ElementC = int;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -138,4 +140,3 @@ struct Mma<
|
|||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user