Make Shape public from Mma_HFMA2.
Signed-off-by: Peter Han <fujun.han@iluvatar.ai>
This commit is contained in:
parent
0f1056390d
commit
169181f30f
@ -70,15 +70,17 @@ struct Mma_HFMA2;
|
||||
// Specialization for NNN //
|
||||
/////////////////////////////
|
||||
|
||||
template <typename Shape>
|
||||
template <typename Shape_>
|
||||
struct Mma_HFMA2 <
|
||||
Shape,
|
||||
Shape_,
|
||||
layout::ColumnMajor,
|
||||
layout::ColumnMajor,
|
||||
layout::ColumnMajor,
|
||||
true
|
||||
> {
|
||||
|
||||
using Shape = Shape_;
|
||||
|
||||
static_assert(
|
||||
!(Shape::kM % 2),
|
||||
"Mma_HFMA2 requires the M dimension to be divisible by 2."
|
||||
@ -159,15 +161,17 @@ struct Mma_HFMA2 <
|
||||
// Specialization for NNT //
|
||||
/////////////////////////////
|
||||
|
||||
template <typename Shape>
|
||||
template <typename Shape_>
|
||||
struct Mma_HFMA2<
|
||||
Shape,
|
||||
Shape_,
|
||||
layout::ColumnMajor,
|
||||
layout::ColumnMajor,
|
||||
layout::RowMajor,
|
||||
true
|
||||
> {
|
||||
|
||||
using Shape = Shape_;
|
||||
|
||||
static_assert(
|
||||
!(Shape::kN % 2),
|
||||
"Mma_HFMA2 requires the N dimension to be divisible by 2."
|
||||
@ -253,15 +257,17 @@ struct Mma_HFMA2<
|
||||
// Specialization for NTN //
|
||||
/////////////////////////////
|
||||
|
||||
template <typename Shape>
|
||||
template <typename Shape_>
|
||||
struct Mma_HFMA2 <
|
||||
Shape,
|
||||
Shape_,
|
||||
layout::ColumnMajor,
|
||||
layout::RowMajor,
|
||||
layout::ColumnMajor,
|
||||
true
|
||||
> {
|
||||
|
||||
using Shape = Shape_;
|
||||
|
||||
static_assert(
|
||||
!(Shape::kM % 2),
|
||||
"Mma_HFMA2 requires the GEMM M dimension to be divisible by 2."
|
||||
@ -342,15 +348,17 @@ struct Mma_HFMA2 <
|
||||
// Specialization for NTT //
|
||||
/////////////////////////////
|
||||
|
||||
template <typename Shape>
|
||||
template <typename Shape_>
|
||||
struct Mma_HFMA2<
|
||||
Shape,
|
||||
Shape_,
|
||||
layout::ColumnMajor,
|
||||
layout::RowMajor,
|
||||
layout::RowMajor,
|
||||
true
|
||||
> {
|
||||
|
||||
using Shape = Shape_;
|
||||
|
||||
static_assert(
|
||||
!(Shape::kN % 2),
|
||||
"Mma_HFMA2 requires the N dimension to be divisible by 2."
|
||||
@ -431,15 +439,17 @@ struct Mma_HFMA2<
|
||||
// Specialization for TNN //
|
||||
/////////////////////////////
|
||||
|
||||
template <typename Shape>
|
||||
template <typename Shape_>
|
||||
struct Mma_HFMA2 <
|
||||
Shape,
|
||||
Shape_,
|
||||
layout::RowMajor,
|
||||
layout::ColumnMajor,
|
||||
layout::ColumnMajor,
|
||||
true
|
||||
> {
|
||||
|
||||
using Shape = Shape_;
|
||||
|
||||
static_assert(
|
||||
!(Shape::kM % 2),
|
||||
"Mma_HFMA2 requires the M dimension to be divisible by 2."
|
||||
@ -524,15 +534,17 @@ struct Mma_HFMA2 <
|
||||
// Specialization for TNT //
|
||||
/////////////////////////////
|
||||
|
||||
template <typename Shape>
|
||||
template <typename Shape_>
|
||||
struct Mma_HFMA2 <
|
||||
Shape,
|
||||
Shape_,
|
||||
layout::RowMajor,
|
||||
layout::ColumnMajor,
|
||||
layout::RowMajor,
|
||||
true
|
||||
> {
|
||||
|
||||
using Shape = Shape_;
|
||||
|
||||
static_assert(
|
||||
!(Shape::kN % 2),
|
||||
"Mma_HFMA2 requires the N dimension to be divisible by 2."
|
||||
@ -617,15 +629,17 @@ struct Mma_HFMA2 <
|
||||
// Specialization for TTN //
|
||||
/////////////////////////////
|
||||
|
||||
template <typename Shape>
|
||||
template <typename Shape_>
|
||||
struct Mma_HFMA2 <
|
||||
Shape,
|
||||
Shape_,
|
||||
layout::RowMajor,
|
||||
layout::RowMajor,
|
||||
layout::ColumnMajor,
|
||||
true
|
||||
> {
|
||||
|
||||
using Shape = Shape_;
|
||||
|
||||
static_assert(
|
||||
!(Shape::kM % 2),
|
||||
"Mma_HFMA2 requires the M dimension to be divisible by 2."
|
||||
@ -711,15 +725,17 @@ struct Mma_HFMA2 <
|
||||
// Specialization for TTT //
|
||||
/////////////////////////////
|
||||
|
||||
template <typename Shape>
|
||||
template <typename Shape_>
|
||||
struct Mma_HFMA2<
|
||||
Shape,
|
||||
Shape_,
|
||||
layout::RowMajor,
|
||||
layout::RowMajor,
|
||||
layout::RowMajor,
|
||||
true
|
||||
> {
|
||||
|
||||
using Shape = Shape_;
|
||||
|
||||
static_assert(
|
||||
!(Shape::kN % 2),
|
||||
"Mma_HFMA2 requires the N dimension to be divisible by 2."
|
||||
@ -800,15 +816,17 @@ struct Mma_HFMA2<
|
||||
// Specialization for TNT + Inner Product or 1x1x2K + LayoutC = T //
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Shape, typename LayoutA, typename LayoutB>
|
||||
template <typename Shape_, typename LayoutA, typename LayoutB>
|
||||
struct Mma_HFMA2<
|
||||
Shape,
|
||||
Shape_,
|
||||
LayoutA,
|
||||
LayoutB,
|
||||
layout::RowMajor,
|
||||
false
|
||||
> {
|
||||
|
||||
using Shape = Shape_;
|
||||
|
||||
static_assert(
|
||||
!(Shape::kK % 2),
|
||||
"Mma_HFMA2 requires the K dimension to be divisible by 2."
|
||||
@ -882,15 +900,17 @@ struct Mma_HFMA2<
|
||||
// Specialization for TNN + Inner Product or 1x1x2K + LayoutC = N //
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Shape, typename LayoutA, typename LayoutB>
|
||||
template <typename Shape_, typename LayoutA, typename LayoutB>
|
||||
struct Mma_HFMA2<
|
||||
Shape,
|
||||
Shape_,
|
||||
LayoutA,
|
||||
LayoutB,
|
||||
layout::ColumnMajor,
|
||||
false
|
||||
> {
|
||||
|
||||
using Shape = Shape_;
|
||||
|
||||
static_assert(
|
||||
!(Shape::kK % 2),
|
||||
"Mma_HFMA2 requires the K dimension to be divisible by 2."
|
||||
|
Loading…
Reference in New Issue
Block a user