Make Shape public from Mma_HFMA2.

Signed-off-by: Peter Han <fujun.han@iluvatar.ai>
This commit is contained in:
Peter Han 2021-03-04 11:05:16 +08:00
parent 0f1056390d
commit 169181f30f

View File

@ -70,15 +70,17 @@ struct Mma_HFMA2;
// Specialization for NNN // // Specialization for NNN //
///////////////////////////// /////////////////////////////
template <typename Shape> template <typename Shape_>
struct Mma_HFMA2 < struct Mma_HFMA2 <
Shape, Shape_,
layout::ColumnMajor, layout::ColumnMajor,
layout::ColumnMajor, layout::ColumnMajor,
layout::ColumnMajor, layout::ColumnMajor,
true true
> { > {
using Shape = Shape_;
static_assert( static_assert(
!(Shape::kM % 2), !(Shape::kM % 2),
"Mma_HFMA2 requires the M dimension to be divisible by 2." "Mma_HFMA2 requires the M dimension to be divisible by 2."
@ -159,15 +161,17 @@ struct Mma_HFMA2 <
// Specialization for NNT // // Specialization for NNT //
///////////////////////////// /////////////////////////////
template <typename Shape> template <typename Shape_>
struct Mma_HFMA2< struct Mma_HFMA2<
Shape, Shape_,
layout::ColumnMajor, layout::ColumnMajor,
layout::ColumnMajor, layout::ColumnMajor,
layout::RowMajor, layout::RowMajor,
true true
> { > {
using Shape = Shape_;
static_assert( static_assert(
!(Shape::kN % 2), !(Shape::kN % 2),
"Mma_HFMA2 requires the N dimension to be divisible by 2." "Mma_HFMA2 requires the N dimension to be divisible by 2."
@ -253,15 +257,17 @@ struct Mma_HFMA2<
// Specialization for NTN // // Specialization for NTN //
///////////////////////////// /////////////////////////////
template <typename Shape> template <typename Shape_>
struct Mma_HFMA2 < struct Mma_HFMA2 <
Shape, Shape_,
layout::ColumnMajor, layout::ColumnMajor,
layout::RowMajor, layout::RowMajor,
layout::ColumnMajor, layout::ColumnMajor,
true true
> { > {
using Shape = Shape_;
static_assert( static_assert(
!(Shape::kM % 2), !(Shape::kM % 2),
"Mma_HFMA2 requires the GEMM M dimension to be divisible by 2." "Mma_HFMA2 requires the GEMM M dimension to be divisible by 2."
@ -342,15 +348,17 @@ struct Mma_HFMA2 <
// Specialization for NTT // // Specialization for NTT //
///////////////////////////// /////////////////////////////
template <typename Shape> template <typename Shape_>
struct Mma_HFMA2< struct Mma_HFMA2<
Shape, Shape_,
layout::ColumnMajor, layout::ColumnMajor,
layout::RowMajor, layout::RowMajor,
layout::RowMajor, layout::RowMajor,
true true
> { > {
using Shape = Shape_;
static_assert( static_assert(
!(Shape::kN % 2), !(Shape::kN % 2),
"Mma_HFMA2 requires the N dimension to be divisible by 2." "Mma_HFMA2 requires the N dimension to be divisible by 2."
@ -431,15 +439,17 @@ struct Mma_HFMA2<
// Specialization for TNN // // Specialization for TNN //
///////////////////////////// /////////////////////////////
template <typename Shape> template <typename Shape_>
struct Mma_HFMA2 < struct Mma_HFMA2 <
Shape, Shape_,
layout::RowMajor, layout::RowMajor,
layout::ColumnMajor, layout::ColumnMajor,
layout::ColumnMajor, layout::ColumnMajor,
true true
> { > {
using Shape = Shape_;
static_assert( static_assert(
!(Shape::kM % 2), !(Shape::kM % 2),
"Mma_HFMA2 requires the M dimension to be divisible by 2." "Mma_HFMA2 requires the M dimension to be divisible by 2."
@ -524,15 +534,17 @@ struct Mma_HFMA2 <
// Specialization for TNT // // Specialization for TNT //
///////////////////////////// /////////////////////////////
template <typename Shape> template <typename Shape_>
struct Mma_HFMA2 < struct Mma_HFMA2 <
Shape, Shape_,
layout::RowMajor, layout::RowMajor,
layout::ColumnMajor, layout::ColumnMajor,
layout::RowMajor, layout::RowMajor,
true true
> { > {
using Shape = Shape_;
static_assert( static_assert(
!(Shape::kN % 2), !(Shape::kN % 2),
"Mma_HFMA2 requires the N dimension to be divisible by 2." "Mma_HFMA2 requires the N dimension to be divisible by 2."
@ -617,15 +629,17 @@ struct Mma_HFMA2 <
// Specialization for TTN // // Specialization for TTN //
///////////////////////////// /////////////////////////////
template <typename Shape> template <typename Shape_>
struct Mma_HFMA2 < struct Mma_HFMA2 <
Shape, Shape_,
layout::RowMajor, layout::RowMajor,
layout::RowMajor, layout::RowMajor,
layout::ColumnMajor, layout::ColumnMajor,
true true
> { > {
using Shape = Shape_;
static_assert( static_assert(
!(Shape::kM % 2), !(Shape::kM % 2),
"Mma_HFMA2 requires the M dimension to be divisible by 2." "Mma_HFMA2 requires the M dimension to be divisible by 2."
@ -711,15 +725,17 @@ struct Mma_HFMA2 <
// Specialization for TTT // // Specialization for TTT //
///////////////////////////// /////////////////////////////
template <typename Shape> template <typename Shape_>
struct Mma_HFMA2< struct Mma_HFMA2<
Shape, Shape_,
layout::RowMajor, layout::RowMajor,
layout::RowMajor, layout::RowMajor,
layout::RowMajor, layout::RowMajor,
true true
> { > {
using Shape = Shape_;
static_assert( static_assert(
!(Shape::kN % 2), !(Shape::kN % 2),
"Mma_HFMA2 requires the N dimension to be divisible by 2." "Mma_HFMA2 requires the N dimension to be divisible by 2."
@ -800,15 +816,17 @@ struct Mma_HFMA2<
// Specialization for TNT + Inner Product or 1x1x2K + LayoutC = T // // Specialization for TNT + Inner Product or 1x1x2K + LayoutC = T //
///////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////
template <typename Shape, typename LayoutA, typename LayoutB> template <typename Shape_, typename LayoutA, typename LayoutB>
struct Mma_HFMA2< struct Mma_HFMA2<
Shape, Shape_,
LayoutA, LayoutA,
LayoutB, LayoutB,
layout::RowMajor, layout::RowMajor,
false false
> { > {
using Shape = Shape_;
static_assert( static_assert(
!(Shape::kK % 2), !(Shape::kK % 2),
"Mma_HFMA2 requires the K dimension to be divisible by 2." "Mma_HFMA2 requires the K dimension to be divisible by 2."
@ -882,15 +900,17 @@ struct Mma_HFMA2<
// Specialization for TNN + Inner Product or 1x1x2K + LayoutC = N // // Specialization for TNN + Inner Product or 1x1x2K + LayoutC = N //
///////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////
template <typename Shape, typename LayoutA, typename LayoutB> template <typename Shape_, typename LayoutA, typename LayoutB>
struct Mma_HFMA2< struct Mma_HFMA2<
Shape, Shape_,
LayoutA, LayoutA,
LayoutB, LayoutB,
layout::ColumnMajor, layout::ColumnMajor,
false false
> { > {
using Shape = Shape_;
static_assert( static_assert(
!(Shape::kK % 2), !(Shape::kK % 2),
"Mma_HFMA2 requires the K dimension to be divisible by 2." "Mma_HFMA2 requires the K dimension to be divisible by 2."