Make Shape public from Mma_HFMA2.
Signed-off-by: Peter Han <fujun.han@iluvatar.ai>
This commit is contained in:
parent
0f1056390d
commit
169181f30f
@ -70,15 +70,17 @@ struct Mma_HFMA2;
|
|||||||
// Specialization for NNN //
|
// Specialization for NNN //
|
||||||
/////////////////////////////
|
/////////////////////////////
|
||||||
|
|
||||||
template <typename Shape>
|
template <typename Shape_>
|
||||||
struct Mma_HFMA2 <
|
struct Mma_HFMA2 <
|
||||||
Shape,
|
Shape_,
|
||||||
layout::ColumnMajor,
|
layout::ColumnMajor,
|
||||||
layout::ColumnMajor,
|
layout::ColumnMajor,
|
||||||
layout::ColumnMajor,
|
layout::ColumnMajor,
|
||||||
true
|
true
|
||||||
> {
|
> {
|
||||||
|
|
||||||
|
using Shape = Shape_;
|
||||||
|
|
||||||
static_assert(
|
static_assert(
|
||||||
!(Shape::kM % 2),
|
!(Shape::kM % 2),
|
||||||
"Mma_HFMA2 requires the M dimension to be divisible by 2."
|
"Mma_HFMA2 requires the M dimension to be divisible by 2."
|
||||||
@ -159,15 +161,17 @@ struct Mma_HFMA2 <
|
|||||||
// Specialization for NNT //
|
// Specialization for NNT //
|
||||||
/////////////////////////////
|
/////////////////////////////
|
||||||
|
|
||||||
template <typename Shape>
|
template <typename Shape_>
|
||||||
struct Mma_HFMA2<
|
struct Mma_HFMA2<
|
||||||
Shape,
|
Shape_,
|
||||||
layout::ColumnMajor,
|
layout::ColumnMajor,
|
||||||
layout::ColumnMajor,
|
layout::ColumnMajor,
|
||||||
layout::RowMajor,
|
layout::RowMajor,
|
||||||
true
|
true
|
||||||
> {
|
> {
|
||||||
|
|
||||||
|
using Shape = Shape_;
|
||||||
|
|
||||||
static_assert(
|
static_assert(
|
||||||
!(Shape::kN % 2),
|
!(Shape::kN % 2),
|
||||||
"Mma_HFMA2 requires the N dimension to be divisible by 2."
|
"Mma_HFMA2 requires the N dimension to be divisible by 2."
|
||||||
@ -253,15 +257,17 @@ struct Mma_HFMA2<
|
|||||||
// Specialization for NTN //
|
// Specialization for NTN //
|
||||||
/////////////////////////////
|
/////////////////////////////
|
||||||
|
|
||||||
template <typename Shape>
|
template <typename Shape_>
|
||||||
struct Mma_HFMA2 <
|
struct Mma_HFMA2 <
|
||||||
Shape,
|
Shape_,
|
||||||
layout::ColumnMajor,
|
layout::ColumnMajor,
|
||||||
layout::RowMajor,
|
layout::RowMajor,
|
||||||
layout::ColumnMajor,
|
layout::ColumnMajor,
|
||||||
true
|
true
|
||||||
> {
|
> {
|
||||||
|
|
||||||
|
using Shape = Shape_;
|
||||||
|
|
||||||
static_assert(
|
static_assert(
|
||||||
!(Shape::kM % 2),
|
!(Shape::kM % 2),
|
||||||
"Mma_HFMA2 requires the GEMM M dimension to be divisible by 2."
|
"Mma_HFMA2 requires the GEMM M dimension to be divisible by 2."
|
||||||
@ -342,15 +348,17 @@ struct Mma_HFMA2 <
|
|||||||
// Specialization for NTT //
|
// Specialization for NTT //
|
||||||
/////////////////////////////
|
/////////////////////////////
|
||||||
|
|
||||||
template <typename Shape>
|
template <typename Shape_>
|
||||||
struct Mma_HFMA2<
|
struct Mma_HFMA2<
|
||||||
Shape,
|
Shape_,
|
||||||
layout::ColumnMajor,
|
layout::ColumnMajor,
|
||||||
layout::RowMajor,
|
layout::RowMajor,
|
||||||
layout::RowMajor,
|
layout::RowMajor,
|
||||||
true
|
true
|
||||||
> {
|
> {
|
||||||
|
|
||||||
|
using Shape = Shape_;
|
||||||
|
|
||||||
static_assert(
|
static_assert(
|
||||||
!(Shape::kN % 2),
|
!(Shape::kN % 2),
|
||||||
"Mma_HFMA2 requires the N dimension to be divisible by 2."
|
"Mma_HFMA2 requires the N dimension to be divisible by 2."
|
||||||
@ -431,15 +439,17 @@ struct Mma_HFMA2<
|
|||||||
// Specialization for TNN //
|
// Specialization for TNN //
|
||||||
/////////////////////////////
|
/////////////////////////////
|
||||||
|
|
||||||
template <typename Shape>
|
template <typename Shape_>
|
||||||
struct Mma_HFMA2 <
|
struct Mma_HFMA2 <
|
||||||
Shape,
|
Shape_,
|
||||||
layout::RowMajor,
|
layout::RowMajor,
|
||||||
layout::ColumnMajor,
|
layout::ColumnMajor,
|
||||||
layout::ColumnMajor,
|
layout::ColumnMajor,
|
||||||
true
|
true
|
||||||
> {
|
> {
|
||||||
|
|
||||||
|
using Shape = Shape_;
|
||||||
|
|
||||||
static_assert(
|
static_assert(
|
||||||
!(Shape::kM % 2),
|
!(Shape::kM % 2),
|
||||||
"Mma_HFMA2 requires the M dimension to be divisible by 2."
|
"Mma_HFMA2 requires the M dimension to be divisible by 2."
|
||||||
@ -524,15 +534,17 @@ struct Mma_HFMA2 <
|
|||||||
// Specialization for TNT //
|
// Specialization for TNT //
|
||||||
/////////////////////////////
|
/////////////////////////////
|
||||||
|
|
||||||
template <typename Shape>
|
template <typename Shape_>
|
||||||
struct Mma_HFMA2 <
|
struct Mma_HFMA2 <
|
||||||
Shape,
|
Shape_,
|
||||||
layout::RowMajor,
|
layout::RowMajor,
|
||||||
layout::ColumnMajor,
|
layout::ColumnMajor,
|
||||||
layout::RowMajor,
|
layout::RowMajor,
|
||||||
true
|
true
|
||||||
> {
|
> {
|
||||||
|
|
||||||
|
using Shape = Shape_;
|
||||||
|
|
||||||
static_assert(
|
static_assert(
|
||||||
!(Shape::kN % 2),
|
!(Shape::kN % 2),
|
||||||
"Mma_HFMA2 requires the N dimension to be divisible by 2."
|
"Mma_HFMA2 requires the N dimension to be divisible by 2."
|
||||||
@ -617,15 +629,17 @@ struct Mma_HFMA2 <
|
|||||||
// Specialization for TTN //
|
// Specialization for TTN //
|
||||||
/////////////////////////////
|
/////////////////////////////
|
||||||
|
|
||||||
template <typename Shape>
|
template <typename Shape_>
|
||||||
struct Mma_HFMA2 <
|
struct Mma_HFMA2 <
|
||||||
Shape,
|
Shape_,
|
||||||
layout::RowMajor,
|
layout::RowMajor,
|
||||||
layout::RowMajor,
|
layout::RowMajor,
|
||||||
layout::ColumnMajor,
|
layout::ColumnMajor,
|
||||||
true
|
true
|
||||||
> {
|
> {
|
||||||
|
|
||||||
|
using Shape = Shape_;
|
||||||
|
|
||||||
static_assert(
|
static_assert(
|
||||||
!(Shape::kM % 2),
|
!(Shape::kM % 2),
|
||||||
"Mma_HFMA2 requires the M dimension to be divisible by 2."
|
"Mma_HFMA2 requires the M dimension to be divisible by 2."
|
||||||
@ -711,15 +725,17 @@ struct Mma_HFMA2 <
|
|||||||
// Specialization for TTT //
|
// Specialization for TTT //
|
||||||
/////////////////////////////
|
/////////////////////////////
|
||||||
|
|
||||||
template <typename Shape>
|
template <typename Shape_>
|
||||||
struct Mma_HFMA2<
|
struct Mma_HFMA2<
|
||||||
Shape,
|
Shape_,
|
||||||
layout::RowMajor,
|
layout::RowMajor,
|
||||||
layout::RowMajor,
|
layout::RowMajor,
|
||||||
layout::RowMajor,
|
layout::RowMajor,
|
||||||
true
|
true
|
||||||
> {
|
> {
|
||||||
|
|
||||||
|
using Shape = Shape_;
|
||||||
|
|
||||||
static_assert(
|
static_assert(
|
||||||
!(Shape::kN % 2),
|
!(Shape::kN % 2),
|
||||||
"Mma_HFMA2 requires the N dimension to be divisible by 2."
|
"Mma_HFMA2 requires the N dimension to be divisible by 2."
|
||||||
@ -800,15 +816,17 @@ struct Mma_HFMA2<
|
|||||||
// Specialization for TNT + Inner Product or 1x1x2K + LayoutC = T //
|
// Specialization for TNT + Inner Product or 1x1x2K + LayoutC = T //
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename Shape, typename LayoutA, typename LayoutB>
|
template <typename Shape_, typename LayoutA, typename LayoutB>
|
||||||
struct Mma_HFMA2<
|
struct Mma_HFMA2<
|
||||||
Shape,
|
Shape_,
|
||||||
LayoutA,
|
LayoutA,
|
||||||
LayoutB,
|
LayoutB,
|
||||||
layout::RowMajor,
|
layout::RowMajor,
|
||||||
false
|
false
|
||||||
> {
|
> {
|
||||||
|
|
||||||
|
using Shape = Shape_;
|
||||||
|
|
||||||
static_assert(
|
static_assert(
|
||||||
!(Shape::kK % 2),
|
!(Shape::kK % 2),
|
||||||
"Mma_HFMA2 requires the K dimension to be divisible by 2."
|
"Mma_HFMA2 requires the K dimension to be divisible by 2."
|
||||||
@ -882,15 +900,17 @@ struct Mma_HFMA2<
|
|||||||
// Specialization for TNN + Inner Product or 1x1x2K + LayoutC = N //
|
// Specialization for TNN + Inner Product or 1x1x2K + LayoutC = N //
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename Shape, typename LayoutA, typename LayoutB>
|
template <typename Shape_, typename LayoutA, typename LayoutB>
|
||||||
struct Mma_HFMA2<
|
struct Mma_HFMA2<
|
||||||
Shape,
|
Shape_,
|
||||||
LayoutA,
|
LayoutA,
|
||||||
LayoutB,
|
LayoutB,
|
||||||
layout::ColumnMajor,
|
layout::ColumnMajor,
|
||||||
false
|
false
|
||||||
> {
|
> {
|
||||||
|
|
||||||
|
using Shape = Shape_;
|
||||||
|
|
||||||
static_assert(
|
static_assert(
|
||||||
!(Shape::kK % 2),
|
!(Shape::kK % 2),
|
||||||
"Mma_HFMA2 requires the K dimension to be divisible by 2."
|
"Mma_HFMA2 requires the K dimension to be divisible by 2."
|
||||||
|
Loading…
Reference in New Issue
Block a user