Ensure all arch::Mma specializations have ElementC set (#576)
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
This commit is contained in:
parent
5d05808072
commit
25ebf15d02
@ -143,16 +143,17 @@ template <
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB,
|
||||
/// Element type of C matrix
|
||||
typename ElementC,
|
||||
typename ElementC_,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC,
|
||||
/// Inner product operator
|
||||
typename Operator_
|
||||
>
|
||||
struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator_> {
|
||||
struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC_, LayoutC, Operator_> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = Operator_;
|
||||
using ElementC = ElementC_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -218,8 +219,8 @@ struct SparseMma;
|
||||
#include "cutlass/arch/mma_sm50.h"
|
||||
#include "cutlass/arch/mma_sm60.h"
|
||||
#include "cutlass/arch/mma_sm61.h"
|
||||
#include "cutlass/arch/mma_sm70.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/mma_sm70.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/mma_sm80.h"
|
||||
#include "cutlass/arch/mma_sparse_sm80.h"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -62,6 +62,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, float, LayoutA, float, LayoutB, float, L
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = float;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -89,6 +90,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, double, LayoutA, double, LayoutB, double
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = double;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -117,6 +119,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, int, LayoutA, int, LayoutB, int, LayoutC
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = int;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -144,16 +147,17 @@ template <
|
||||
struct Mma<
|
||||
gemm::GemmShape<1, 1, 1>,
|
||||
1,
|
||||
complex<float>,
|
||||
LayoutA,
|
||||
complex<float>,
|
||||
LayoutB,
|
||||
complex<float>,
|
||||
LayoutC,
|
||||
complex<float>,
|
||||
LayoutA,
|
||||
complex<float>,
|
||||
LayoutB,
|
||||
complex<float>,
|
||||
LayoutC,
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
using ElementC = complex<float>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -184,16 +188,17 @@ template <
|
||||
struct Mma<
|
||||
gemm::GemmShape<1, 1, 1>,
|
||||
1,
|
||||
complex<float>,
|
||||
LayoutA,
|
||||
float,
|
||||
LayoutB,
|
||||
complex<float>,
|
||||
LayoutC,
|
||||
complex<float>,
|
||||
LayoutA,
|
||||
float,
|
||||
LayoutB,
|
||||
complex<float>,
|
||||
LayoutC,
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
using ElementC = complex<float>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -222,16 +227,17 @@ template <
|
||||
struct Mma<
|
||||
gemm::GemmShape<1, 1, 1>,
|
||||
1,
|
||||
float,
|
||||
LayoutA,
|
||||
complex<float>,
|
||||
LayoutB,
|
||||
complex<float>,
|
||||
LayoutC,
|
||||
float,
|
||||
LayoutA,
|
||||
complex<float>,
|
||||
LayoutB,
|
||||
complex<float>,
|
||||
LayoutC,
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
using ElementC = complex<float>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -260,16 +266,17 @@ template <
|
||||
struct Mma<
|
||||
gemm::GemmShape<1, 1, 1>,
|
||||
1,
|
||||
complex<double>,
|
||||
LayoutA,
|
||||
complex<double>,
|
||||
LayoutB,
|
||||
complex<double>,
|
||||
LayoutC,
|
||||
complex<double>,
|
||||
LayoutA,
|
||||
complex<double>,
|
||||
LayoutB,
|
||||
complex<double>,
|
||||
LayoutC,
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
using ElementC = complex<double>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -298,16 +305,17 @@ template <
|
||||
struct Mma<
|
||||
gemm::GemmShape<1, 1, 1>,
|
||||
1,
|
||||
complex<double>,
|
||||
LayoutA,
|
||||
double,
|
||||
LayoutB,
|
||||
complex<double>,
|
||||
LayoutC,
|
||||
complex<double>,
|
||||
LayoutA,
|
||||
double,
|
||||
LayoutB,
|
||||
complex<double>,
|
||||
LayoutC,
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
using ElementC = complex<double>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -334,16 +342,17 @@ template <
|
||||
struct Mma<
|
||||
gemm::GemmShape<1, 1, 1>,
|
||||
1,
|
||||
double,
|
||||
LayoutA,
|
||||
complex<double>,
|
||||
LayoutB,
|
||||
complex<double>,
|
||||
LayoutC,
|
||||
double,
|
||||
LayoutA,
|
||||
complex<double>,
|
||||
LayoutB,
|
||||
complex<double>,
|
||||
LayoutC,
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
using ElementC = complex<double>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -373,7 +382,8 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, half_t, LayoutA, half_t, LayoutB, float,
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
|
||||
using ElementC = float;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
Array<float, 1> &d,
|
||||
@ -401,6 +411,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, Quaternion<float>, LayoutA, Quaternion<f
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using Element = Quaternion<float>;
|
||||
using ElementC = Element;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -412,7 +423,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, Quaternion<float>, LayoutA, Quaternion<f
|
||||
multiply_add<Element, Element, Element> op;
|
||||
d[0] = op(a[0], b[0], c[0]);
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -62,6 +62,7 @@ struct Mma<
|
||||
|
||||
using Shape = gemm::GemmShape<2, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = half_t;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -107,6 +108,7 @@ struct Mma<
|
||||
|
||||
using Shape = gemm::GemmShape<1, 2, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = half_t;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -152,6 +154,7 @@ struct Mma <
|
||||
|
||||
using Shape = gemm::GemmShape<2, 2, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = half_t;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -206,7 +209,8 @@ struct Mma<
|
||||
|
||||
using Shape = gemm::GemmShape<2, 2, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
|
||||
using ElementC = half_t;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
Array<half_t, 4> &d,
|
||||
@ -220,12 +224,12 @@ struct Mma<
|
||||
__half2 Alo = __low2half2(reinterpret_cast<__half2 const &>(a));
|
||||
__half2 Ahi = __high2half2(reinterpret_cast<__half2 const &>(a));
|
||||
__half2 const & B = reinterpret_cast<__half2 const &>(b);
|
||||
|
||||
|
||||
__half2 const *C = reinterpret_cast<__half2 const *>(&c);
|
||||
|
||||
__half2 Dlo = __hfma2(Alo, B, C[0]);
|
||||
__half2 Dhi = __hfma2(Ahi, B, C[0]);
|
||||
|
||||
|
||||
Array<half_t, 2> * D = reinterpret_cast<Array<half_t, 2> *>(&d);
|
||||
|
||||
D[0] = reinterpret_cast<Array<half_t, 2> &>(Dlo);
|
||||
@ -246,4 +250,3 @@ struct Mma<
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -55,10 +55,11 @@ struct Mma<
|
||||
int,
|
||||
LayoutC,
|
||||
OpMultiplyAdd> {
|
||||
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 4>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
|
||||
using ElementC = int;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
Array<int, 1> &d,
|
||||
@ -103,9 +104,10 @@ struct Mma<
|
||||
int,
|
||||
LayoutC,
|
||||
OpMultiplyAdd> {
|
||||
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 2>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ElementC = int;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -138,4 +140,3 @@ struct Mma<
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user