Make Shape public from Mma_HFMA2.

Signed-off-by: Peter Han <fujun.han@iluvatar.ai>
This commit is contained in:
Peter Han 2021-03-04 11:05:16 +08:00
parent 0f1056390d
commit 169181f30f

View File

@ -70,15 +70,17 @@ struct Mma_HFMA2;
// Specialization for NNN //
/////////////////////////////
template <typename Shape>
template <typename Shape_>
struct Mma_HFMA2 <
Shape,
Shape_,
layout::ColumnMajor,
layout::ColumnMajor,
layout::ColumnMajor,
true
> {
using Shape = Shape_;
static_assert(
!(Shape::kM % 2),
"Mma_HFMA2 requires the M dimension to be divisible by 2."
@ -159,15 +161,17 @@ struct Mma_HFMA2 <
// Specialization for NNT //
/////////////////////////////
template <typename Shape>
template <typename Shape_>
struct Mma_HFMA2<
Shape,
Shape_,
layout::ColumnMajor,
layout::ColumnMajor,
layout::RowMajor,
true
> {
using Shape = Shape_;
static_assert(
!(Shape::kN % 2),
"Mma_HFMA2 requires the N dimension to be divisible by 2."
@ -253,15 +257,17 @@ struct Mma_HFMA2<
// Specialization for NTN //
/////////////////////////////
template <typename Shape>
template <typename Shape_>
struct Mma_HFMA2 <
Shape,
Shape_,
layout::ColumnMajor,
layout::RowMajor,
layout::ColumnMajor,
true
> {
using Shape = Shape_;
static_assert(
!(Shape::kM % 2),
"Mma_HFMA2 requires the GEMM M dimension to be divisible by 2."
@ -342,15 +348,17 @@ struct Mma_HFMA2 <
// Specialization for NTT //
/////////////////////////////
template <typename Shape>
template <typename Shape_>
struct Mma_HFMA2<
Shape,
Shape_,
layout::ColumnMajor,
layout::RowMajor,
layout::RowMajor,
true
> {
using Shape = Shape_;
static_assert(
!(Shape::kN % 2),
"Mma_HFMA2 requires the N dimension to be divisible by 2."
@ -431,15 +439,17 @@ struct Mma_HFMA2<
// Specialization for TNN //
/////////////////////////////
template <typename Shape>
template <typename Shape_>
struct Mma_HFMA2 <
Shape,
Shape_,
layout::RowMajor,
layout::ColumnMajor,
layout::ColumnMajor,
true
> {
using Shape = Shape_;
static_assert(
!(Shape::kM % 2),
"Mma_HFMA2 requires the M dimension to be divisible by 2."
@ -524,15 +534,17 @@ struct Mma_HFMA2 <
// Specialization for TNT //
/////////////////////////////
template <typename Shape>
template <typename Shape_>
struct Mma_HFMA2 <
Shape,
Shape_,
layout::RowMajor,
layout::ColumnMajor,
layout::RowMajor,
true
> {
using Shape = Shape_;
static_assert(
!(Shape::kN % 2),
"Mma_HFMA2 requires the N dimension to be divisible by 2."
@ -617,15 +629,17 @@ struct Mma_HFMA2 <
// Specialization for TTN //
/////////////////////////////
template <typename Shape>
template <typename Shape_>
struct Mma_HFMA2 <
Shape,
Shape_,
layout::RowMajor,
layout::RowMajor,
layout::ColumnMajor,
true
> {
using Shape = Shape_;
static_assert(
!(Shape::kM % 2),
"Mma_HFMA2 requires the M dimension to be divisible by 2."
@ -711,15 +725,17 @@ struct Mma_HFMA2 <
// Specialization for TTT //
/////////////////////////////
template <typename Shape>
template <typename Shape_>
struct Mma_HFMA2<
Shape,
Shape_,
layout::RowMajor,
layout::RowMajor,
layout::RowMajor,
true
> {
using Shape = Shape_;
static_assert(
!(Shape::kN % 2),
"Mma_HFMA2 requires the N dimension to be divisible by 2."
@ -800,15 +816,17 @@ struct Mma_HFMA2<
// Specialization for TNT + Inner Product or 1x1x2K + LayoutC = T //
/////////////////////////////////////////////////////////////////////
template <typename Shape, typename LayoutA, typename LayoutB>
template <typename Shape_, typename LayoutA, typename LayoutB>
struct Mma_HFMA2<
Shape,
Shape_,
LayoutA,
LayoutB,
layout::RowMajor,
false
> {
using Shape = Shape_;
static_assert(
!(Shape::kK % 2),
"Mma_HFMA2 requires the K dimension to be divisible by 2."
@ -882,15 +900,17 @@ struct Mma_HFMA2<
// Specialization for TNN + Inner Product or 1x1x2K + LayoutC = N //
/////////////////////////////////////////////////////////////////////
template <typename Shape, typename LayoutA, typename LayoutB>
template <typename Shape_, typename LayoutA, typename LayoutB>
struct Mma_HFMA2<
Shape,
Shape_,
LayoutA,
LayoutB,
layout::ColumnMajor,
false
> {
using Shape = Shape_;
static_assert(
!(Shape::kK % 2),
"Mma_HFMA2 requires the K dimension to be divisible by 2."