From 25ebf15d02f5a201a44ef6f436939d7f44530eea Mon Sep 17 00:00:00 2001 From: dan_the_3rd <43445237+danthe3rd@users.noreply.github.com> Date: Sat, 23 Jul 2022 05:53:03 +0200 Subject: [PATCH] Ensure all arch::Mma specializations have ElementC set (#576) Co-authored-by: danthe3rd --- include/cutlass/arch/mma.h | 9 ++-- include/cutlass/arch/mma_sm50.h | 87 +++++++++++++++++++-------------- include/cutlass/arch/mma_sm60.h | 11 +++-- include/cutlass/arch/mma_sm61.h | 9 ++-- 4 files changed, 66 insertions(+), 50 deletions(-) diff --git a/include/cutlass/arch/mma.h b/include/cutlass/arch/mma.h index e79a4099..ce3e02f3 100644 --- a/include/cutlass/arch/mma.h +++ b/include/cutlass/arch/mma.h @@ -143,16 +143,17 @@ template < /// Layout of B matrix (concept: MatrixLayout) typename LayoutB, /// Element type of C matrix - typename ElementC, + typename ElementC_, /// Layout of C matrix (concept: MatrixLayout) typename LayoutC, /// Inner product operator typename Operator_ > -struct Mma, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator_> { +struct Mma, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC_, LayoutC, Operator_> { using Shape = gemm::GemmShape<1, 1, 1>; using Operator = Operator_; + using ElementC = ElementC_; CUTLASS_HOST_DEVICE void operator()( @@ -218,8 +219,8 @@ struct SparseMma; #include "cutlass/arch/mma_sm50.h" #include "cutlass/arch/mma_sm60.h" #include "cutlass/arch/mma_sm61.h" -#include "cutlass/arch/mma_sm70.h" -#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm70.h" +#include "cutlass/arch/mma_sm75.h" #include "cutlass/arch/mma_sm80.h" #include "cutlass/arch/mma_sparse_sm80.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/mma_sm50.h b/include/cutlass/arch/mma_sm50.h index 96977c41..f5458fc8 100644 --- a/include/cutlass/arch/mma_sm50.h +++ b/include/cutlass/arch/mma_sm50.h @@ -62,6 +62,7 @@ struct Mma, 1, float, LayoutA, float, LayoutB, float, L using Shape = gemm::GemmShape<1, 1, 1>; using Operator = OpMultiplyAdd; + using ElementC = float; CUTLASS_HOST_DEVICE void operator()( @@ -89,6 +90,7 @@ struct Mma, 1, double, LayoutA, double, LayoutB, double using Shape = gemm::GemmShape<1, 1, 1>; using Operator = OpMultiplyAdd; + using ElementC = double; CUTLASS_HOST_DEVICE void operator()( @@ -117,6 +119,7 @@ struct Mma, 1, int, LayoutA, int, LayoutB, int, LayoutC using Shape = gemm::GemmShape<1, 1, 1>; using Operator = OpMultiplyAdd; + using ElementC = int; CUTLASS_HOST_DEVICE void operator()( @@ -144,16 +147,17 @@ template < struct Mma< gemm::GemmShape<1, 1, 1>, 1, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; using Operator = OpMultiplyAddComplex; + using ElementC = complex; CUTLASS_HOST_DEVICE void operator()( @@ -184,16 +188,17 @@ template < struct Mma< gemm::GemmShape<1, 1, 1>, 1, - complex, - LayoutA, - float, - LayoutB, - complex, - LayoutC, + complex, + LayoutA, + float, + LayoutB, + complex, + LayoutC, OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; using Operator = OpMultiplyAddComplex; + using ElementC = complex; CUTLASS_HOST_DEVICE void operator()( @@ -222,16 +227,17 @@ template < struct Mma< gemm::GemmShape<1, 1, 1>, 1, - float, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, + float, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; using Operator = OpMultiplyAddComplex; + using ElementC = complex; CUTLASS_HOST_DEVICE void operator()( @@ -260,16 +266,17 @@ template < struct Mma< gemm::GemmShape<1, 1, 1>, 1, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; using Operator = OpMultiplyAddComplex; + using ElementC = complex; CUTLASS_HOST_DEVICE void operator()( @@ -298,16 +305,17 @@ template < struct Mma< gemm::GemmShape<1, 1, 1>, 1, - complex, - LayoutA, - double, - LayoutB, - complex, - LayoutC, + complex, + LayoutA, + double, + LayoutB, + complex, + LayoutC, OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; using Operator = OpMultiplyAddComplex; + using ElementC = complex; CUTLASS_HOST_DEVICE void operator()( @@ -334,16 +342,17 @@ template < struct Mma< gemm::GemmShape<1, 1, 1>, 1, - double, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, + double, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; using Operator = OpMultiplyAddComplex; + using ElementC = complex; CUTLASS_HOST_DEVICE void operator()( @@ -373,7 +382,8 @@ struct Mma, 1, half_t, LayoutA, half_t, LayoutB, float, using Shape = gemm::GemmShape<1, 1, 1>; using Operator = OpMultiplyAdd; - + using ElementC = float; + CUTLASS_HOST_DEVICE void operator()( Array &d, @@ -401,6 +411,7 @@ struct Mma, 1, Quaternion, LayoutA, Quaternion; using Operator = OpMultiplyAdd; using Element = Quaternion; + using ElementC = Element; CUTLASS_HOST_DEVICE void operator()( @@ -412,7 +423,7 @@ struct Mma, 1, Quaternion, LayoutA, Quaternion op; d[0] = op(a[0], b[0], c[0]); } - + }; } diff --git a/include/cutlass/arch/mma_sm60.h b/include/cutlass/arch/mma_sm60.h index 2d8f9ab9..6fa8b6f7 100644 --- a/include/cutlass/arch/mma_sm60.h +++ b/include/cutlass/arch/mma_sm60.h @@ -62,6 +62,7 @@ struct Mma< using Shape = gemm::GemmShape<2, 1, 1>; using Operator = OpMultiplyAdd; + using ElementC = half_t; CUTLASS_HOST_DEVICE void operator()( @@ -107,6 +108,7 @@ struct Mma< using Shape = gemm::GemmShape<1, 2, 1>; using Operator = OpMultiplyAdd; + using ElementC = half_t; CUTLASS_HOST_DEVICE void operator()( @@ -152,6 +154,7 @@ struct Mma < using Shape = gemm::GemmShape<2, 2, 1>; using Operator = OpMultiplyAdd; + using ElementC = half_t; CUTLASS_HOST_DEVICE void operator()( @@ -206,7 +209,8 @@ struct Mma< using Shape = gemm::GemmShape<2, 2, 1>; using Operator = OpMultiplyAdd; - + using ElementC = half_t; + CUTLASS_HOST_DEVICE void operator()( Array &d, @@ -220,12 +224,12 @@ struct Mma< __half2 Alo = __low2half2(reinterpret_cast<__half2 const &>(a)); __half2 Ahi = __high2half2(reinterpret_cast<__half2 const &>(a)); __half2 const & B = reinterpret_cast<__half2 const &>(b); - + __half2 const *C = reinterpret_cast<__half2 const *>(&c); __half2 Dlo = __hfma2(Alo, B, C[0]); __half2 Dhi = __hfma2(Ahi, B, C[0]); - + Array * D = reinterpret_cast *>(&d); D[0] = reinterpret_cast &>(Dlo); @@ -246,4 +250,3 @@ struct Mma< } } - diff --git a/include/cutlass/arch/mma_sm61.h b/include/cutlass/arch/mma_sm61.h index 274c0acb..dc90d786 100644 --- a/include/cutlass/arch/mma_sm61.h +++ b/include/cutlass/arch/mma_sm61.h @@ -55,10 +55,11 @@ struct Mma< int, LayoutC, OpMultiplyAdd> { - + using Shape = gemm::GemmShape<1, 1, 4>; using Operator = OpMultiplyAdd; - + using ElementC = int; + CUTLASS_HOST_DEVICE void operator()( Array &d, @@ -103,9 +104,10 @@ struct Mma< int, LayoutC, OpMultiplyAdd> { - + using Shape = gemm::GemmShape<1, 1, 2>; using Operator = OpMultiplyAdd; + using ElementC = int; CUTLASS_HOST_DEVICE void operator()( @@ -138,4 +140,3 @@ struct Mma< } } -