diff --git a/include/cutlass/gemm/thread/mma_sm60.h b/include/cutlass/gemm/thread/mma_sm60.h index 562c682e..839e07a7 100644 --- a/include/cutlass/gemm/thread/mma_sm60.h +++ b/include/cutlass/gemm/thread/mma_sm60.h @@ -70,15 +70,17 @@ struct Mma_HFMA2; // Specialization for NNN // ///////////////////////////// -template +template struct Mma_HFMA2 < - Shape, + Shape_, layout::ColumnMajor, layout::ColumnMajor, layout::ColumnMajor, true > { + using Shape = Shape_; + static_assert( !(Shape::kM % 2), "Mma_HFMA2 requires the M dimension to be divisible by 2." @@ -159,15 +161,17 @@ struct Mma_HFMA2 < // Specialization for NNT // ///////////////////////////// -template +template struct Mma_HFMA2< - Shape, + Shape_, layout::ColumnMajor, layout::ColumnMajor, layout::RowMajor, true > { + using Shape = Shape_; + static_assert( !(Shape::kN % 2), "Mma_HFMA2 requires the N dimension to be divisible by 2." @@ -253,15 +257,17 @@ struct Mma_HFMA2< // Specialization for NTN // ///////////////////////////// -template +template struct Mma_HFMA2 < - Shape, + Shape_, layout::ColumnMajor, layout::RowMajor, layout::ColumnMajor, true > { + using Shape = Shape_; + static_assert( !(Shape::kM % 2), "Mma_HFMA2 requires the GEMM M dimension to be divisible by 2." @@ -342,15 +348,17 @@ struct Mma_HFMA2 < // Specialization for NTT // ///////////////////////////// -template +template struct Mma_HFMA2< - Shape, + Shape_, layout::ColumnMajor, layout::RowMajor, layout::RowMajor, true > { + using Shape = Shape_; + static_assert( !(Shape::kN % 2), "Mma_HFMA2 requires the N dimension to be divisible by 2." @@ -431,15 +439,17 @@ struct Mma_HFMA2< // Specialization for TNN // ///////////////////////////// -template +template struct Mma_HFMA2 < - Shape, + Shape_, layout::RowMajor, layout::ColumnMajor, layout::ColumnMajor, true > { + using Shape = Shape_; + static_assert( !(Shape::kM % 2), "Mma_HFMA2 requires the M dimension to be divisible by 2." @@ -524,15 +534,17 @@ struct Mma_HFMA2 < // Specialization for TNT // ///////////////////////////// -template +template struct Mma_HFMA2 < - Shape, + Shape_, layout::RowMajor, layout::ColumnMajor, layout::RowMajor, true > { + using Shape = Shape_; + static_assert( !(Shape::kN % 2), "Mma_HFMA2 requires the N dimension to be divisible by 2." @@ -617,15 +629,17 @@ struct Mma_HFMA2 < // Specialization for TTN // ///////////////////////////// -template +template struct Mma_HFMA2 < - Shape, + Shape_, layout::RowMajor, layout::RowMajor, layout::ColumnMajor, true > { + using Shape = Shape_; + static_assert( !(Shape::kM % 2), "Mma_HFMA2 requires the M dimension to be divisible by 2." @@ -711,15 +725,17 @@ struct Mma_HFMA2 < // Specialization for TTT // ///////////////////////////// -template +template struct Mma_HFMA2< - Shape, + Shape_, layout::RowMajor, layout::RowMajor, layout::RowMajor, true > { + using Shape = Shape_; + static_assert( !(Shape::kN % 2), "Mma_HFMA2 requires the N dimension to be divisible by 2." @@ -800,15 +816,17 @@ struct Mma_HFMA2< // Specialization for TNT + Inner Product or 1x1x2K + LayoutC = T // ///////////////////////////////////////////////////////////////////// -template +template struct Mma_HFMA2< - Shape, + Shape_, LayoutA, LayoutB, layout::RowMajor, false > { + using Shape = Shape_; + static_assert( !(Shape::kK % 2), "Mma_HFMA2 requires the K dimension to be divisible by 2." @@ -882,15 +900,17 @@ struct Mma_HFMA2< // Specialization for TNN + Inner Product or 1x1x2K + LayoutC = N // ///////////////////////////////////////////////////////////////////// -template +template struct Mma_HFMA2< - Shape, + Shape_, LayoutA, LayoutB, layout::ColumnMajor, false > { + using Shape = Shape_; + static_assert( !(Shape::kK % 2), "Mma_HFMA2 requires the K dimension to be divisible by 2."