Ensure all arch::Mma specializations have ElementC set (#576)

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
This commit is contained in:
dan_the_3rd 2022-07-23 05:53:03 +02:00 committed by GitHub
parent 5d05808072
commit 25ebf15d02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 66 additions and 50 deletions

View File

@ -143,16 +143,17 @@ template <
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB,
/// Element type of C matrix
typename ElementC,
typename ElementC_,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Inner product operator
typename Operator_
>
struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator_> {
struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC_, LayoutC, Operator_> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = Operator_;
using ElementC = ElementC_;
CUTLASS_HOST_DEVICE
void operator()(
@ -218,8 +219,8 @@ struct SparseMma;
#include "cutlass/arch/mma_sm50.h"
#include "cutlass/arch/mma_sm60.h"
#include "cutlass/arch/mma_sm61.h"
#include "cutlass/arch/mma_sm70.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/mma_sm70.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/mma_sm80.h"
#include "cutlass/arch/mma_sparse_sm80.h"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -62,6 +62,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, float, LayoutA, float, LayoutB, float, L
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAdd;
using ElementC = float;
CUTLASS_HOST_DEVICE
void operator()(
@ -89,6 +90,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, double, LayoutA, double, LayoutB, double
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAdd;
using ElementC = double;
CUTLASS_HOST_DEVICE
void operator()(
@ -117,6 +119,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, int, LayoutA, int, LayoutB, int, LayoutC
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAdd;
using ElementC = int;
CUTLASS_HOST_DEVICE
void operator()(
@ -144,16 +147,17 @@ template <
struct Mma<
gemm::GemmShape<1, 1, 1>,
1,
complex<float>,
LayoutA,
complex<float>,
LayoutB,
complex<float>,
LayoutC,
complex<float>,
LayoutA,
complex<float>,
LayoutB,
complex<float>,
LayoutC,
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
using ElementC = complex<float>;
CUTLASS_HOST_DEVICE
void operator()(
@ -184,16 +188,17 @@ template <
struct Mma<
gemm::GemmShape<1, 1, 1>,
1,
complex<float>,
LayoutA,
float,
LayoutB,
complex<float>,
LayoutC,
complex<float>,
LayoutA,
float,
LayoutB,
complex<float>,
LayoutC,
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
using ElementC = complex<float>;
CUTLASS_HOST_DEVICE
void operator()(
@ -222,16 +227,17 @@ template <
struct Mma<
gemm::GemmShape<1, 1, 1>,
1,
float,
LayoutA,
complex<float>,
LayoutB,
complex<float>,
LayoutC,
float,
LayoutA,
complex<float>,
LayoutB,
complex<float>,
LayoutC,
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
using ElementC = complex<float>;
CUTLASS_HOST_DEVICE
void operator()(
@ -260,16 +266,17 @@ template <
struct Mma<
gemm::GemmShape<1, 1, 1>,
1,
complex<double>,
LayoutA,
complex<double>,
LayoutB,
complex<double>,
LayoutC,
complex<double>,
LayoutA,
complex<double>,
LayoutB,
complex<double>,
LayoutC,
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
using ElementC = complex<double>;
CUTLASS_HOST_DEVICE
void operator()(
@ -298,16 +305,17 @@ template <
struct Mma<
gemm::GemmShape<1, 1, 1>,
1,
complex<double>,
LayoutA,
double,
LayoutB,
complex<double>,
LayoutC,
complex<double>,
LayoutA,
double,
LayoutB,
complex<double>,
LayoutC,
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
using ElementC = complex<double>;
CUTLASS_HOST_DEVICE
void operator()(
@ -334,16 +342,17 @@ template <
struct Mma<
gemm::GemmShape<1, 1, 1>,
1,
double,
LayoutA,
complex<double>,
LayoutB,
complex<double>,
LayoutC,
double,
LayoutA,
complex<double>,
LayoutB,
complex<double>,
LayoutC,
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
using ElementC = complex<double>;
CUTLASS_HOST_DEVICE
void operator()(
@ -373,7 +382,8 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, half_t, LayoutA, half_t, LayoutB, float,
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAdd;
using ElementC = float;
CUTLASS_HOST_DEVICE
void operator()(
Array<float, 1> &d,
@ -401,6 +411,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, Quaternion<float>, LayoutA, Quaternion<f
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAdd;
using Element = Quaternion<float>;
using ElementC = Element;
CUTLASS_HOST_DEVICE
void operator()(
@ -412,7 +423,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, Quaternion<float>, LayoutA, Quaternion<f
multiply_add<Element, Element, Element> op;
d[0] = op(a[0], b[0], c[0]);
}
};
}

View File

@ -62,6 +62,7 @@ struct Mma<
using Shape = gemm::GemmShape<2, 1, 1>;
using Operator = OpMultiplyAdd;
using ElementC = half_t;
CUTLASS_HOST_DEVICE
void operator()(
@ -107,6 +108,7 @@ struct Mma<
using Shape = gemm::GemmShape<1, 2, 1>;
using Operator = OpMultiplyAdd;
using ElementC = half_t;
CUTLASS_HOST_DEVICE
void operator()(
@ -152,6 +154,7 @@ struct Mma <
using Shape = gemm::GemmShape<2, 2, 1>;
using Operator = OpMultiplyAdd;
using ElementC = half_t;
CUTLASS_HOST_DEVICE
void operator()(
@ -206,7 +209,8 @@ struct Mma<
using Shape = gemm::GemmShape<2, 2, 1>;
using Operator = OpMultiplyAdd;
using ElementC = half_t;
CUTLASS_HOST_DEVICE
void operator()(
Array<half_t, 4> &d,
@ -220,12 +224,12 @@ struct Mma<
__half2 Alo = __low2half2(reinterpret_cast<__half2 const &>(a));
__half2 Ahi = __high2half2(reinterpret_cast<__half2 const &>(a));
__half2 const & B = reinterpret_cast<__half2 const &>(b);
__half2 const *C = reinterpret_cast<__half2 const *>(&c);
__half2 Dlo = __hfma2(Alo, B, C[0]);
__half2 Dhi = __hfma2(Ahi, B, C[0]);
Array<half_t, 2> * D = reinterpret_cast<Array<half_t, 2> *>(&d);
D[0] = reinterpret_cast<Array<half_t, 2> &>(Dlo);
@ -246,4 +250,3 @@ struct Mma<
}
}

View File

@ -55,10 +55,11 @@ struct Mma<
int,
LayoutC,
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 4>;
using Operator = OpMultiplyAdd;
using ElementC = int;
CUTLASS_HOST_DEVICE
void operator()(
Array<int, 1> &d,
@ -103,9 +104,10 @@ struct Mma<
int,
LayoutC,
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 2>;
using Operator = OpMultiplyAdd;
using ElementC = int;
CUTLASS_HOST_DEVICE
void operator()(
@ -138,4 +140,3 @@ struct Mma<
}
}