support setting load granularity for conv2d fprop
This commit is contained in:
parent
7ec3a87f22
commit
bb35a3ba6f
@ -66,10 +66,10 @@ template <
|
|||||||
int Stages,
|
int Stages,
|
||||||
typename MathOperatorTag,
|
typename MathOperatorTag,
|
||||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||||
/// whether Matrix A is 128b aligned
|
/// Access granularity of A matrix in units of elements
|
||||||
bool AlignedA = true,
|
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||||
/// whether Matrix B is 128b aligned
|
/// Access granularity of B matrix in units of elements
|
||||||
bool AlignedB = true,
|
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value,
|
||||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||||
> struct DefaultConv2dFprop;
|
> struct DefaultConv2dFprop;
|
||||||
|
|
||||||
@ -94,7 +94,9 @@ template <
|
|||||||
typename EpilogueOutputOp,
|
typename EpilogueOutputOp,
|
||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
int Stages,
|
int Stages,
|
||||||
typename MathOperatorTag
|
typename MathOperatorTag,
|
||||||
|
int AlignmentA,
|
||||||
|
int AlignmentB
|
||||||
>
|
>
|
||||||
struct DefaultConv2dFprop <
|
struct DefaultConv2dFprop <
|
||||||
ElementA,
|
ElementA,
|
||||||
@ -113,7 +115,9 @@ struct DefaultConv2dFprop <
|
|||||||
ThreadblockSwizzle,
|
ThreadblockSwizzle,
|
||||||
Stages,
|
Stages,
|
||||||
MathOperatorTag,
|
MathOperatorTag,
|
||||||
IteratorAlgorithm::kAnalytic
|
IteratorAlgorithm::kAnalytic,
|
||||||
|
AlignmentA,
|
||||||
|
AlignmentB
|
||||||
> {
|
> {
|
||||||
|
|
||||||
// Define the core components from GEMM
|
// Define the core components from GEMM
|
||||||
@ -197,7 +201,9 @@ template <
|
|||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
int Stages,
|
int Stages,
|
||||||
typename MathOperatorTag,
|
typename MathOperatorTag,
|
||||||
int InterleavedK
|
int InterleavedK,
|
||||||
|
int AlignmentA,
|
||||||
|
int AlignmentB
|
||||||
>
|
>
|
||||||
struct DefaultConv2dFprop <
|
struct DefaultConv2dFprop <
|
||||||
ElementA,
|
ElementA,
|
||||||
@ -216,7 +222,9 @@ struct DefaultConv2dFprop <
|
|||||||
ThreadblockSwizzle,
|
ThreadblockSwizzle,
|
||||||
Stages,
|
Stages,
|
||||||
MathOperatorTag,
|
MathOperatorTag,
|
||||||
IteratorAlgorithm::kAnalytic
|
IteratorAlgorithm::kAnalytic,
|
||||||
|
AlignmentA,
|
||||||
|
AlignmentB
|
||||||
> {
|
> {
|
||||||
|
|
||||||
// Define the core components from GEMM
|
// Define the core components from GEMM
|
||||||
@ -312,7 +320,9 @@ template <
|
|||||||
typename InstructionShape,
|
typename InstructionShape,
|
||||||
typename EpilogueOutputOp,
|
typename EpilogueOutputOp,
|
||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
typename MathOperatorTag
|
typename MathOperatorTag,
|
||||||
|
int AlignmentA,
|
||||||
|
int AlignmentB
|
||||||
>
|
>
|
||||||
struct DefaultConv2dFprop <
|
struct DefaultConv2dFprop <
|
||||||
ElementA,
|
ElementA,
|
||||||
@ -331,7 +341,9 @@ struct DefaultConv2dFprop <
|
|||||||
ThreadblockSwizzle,
|
ThreadblockSwizzle,
|
||||||
2,
|
2,
|
||||||
MathOperatorTag,
|
MathOperatorTag,
|
||||||
IteratorAlgorithm::kAnalytic
|
IteratorAlgorithm::kAnalytic,
|
||||||
|
AlignmentA,
|
||||||
|
AlignmentB
|
||||||
> {
|
> {
|
||||||
|
|
||||||
// Define the core components from GEMM
|
// Define the core components from GEMM
|
||||||
@ -417,7 +429,9 @@ template <
|
|||||||
typename EpilogueOutputOp,
|
typename EpilogueOutputOp,
|
||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
typename MathOperatorTag,
|
typename MathOperatorTag,
|
||||||
int InterleavedK
|
int InterleavedK,
|
||||||
|
int AlignmentA,
|
||||||
|
int AlignmentB
|
||||||
>
|
>
|
||||||
struct DefaultConv2dFprop <
|
struct DefaultConv2dFprop <
|
||||||
ElementA,
|
ElementA,
|
||||||
@ -436,7 +450,9 @@ struct DefaultConv2dFprop <
|
|||||||
ThreadblockSwizzle,
|
ThreadblockSwizzle,
|
||||||
2,
|
2,
|
||||||
MathOperatorTag,
|
MathOperatorTag,
|
||||||
IteratorAlgorithm::kAnalytic
|
IteratorAlgorithm::kAnalytic,
|
||||||
|
AlignmentA,
|
||||||
|
AlignmentB
|
||||||
> {
|
> {
|
||||||
|
|
||||||
// Define the core components from GEMM
|
// Define the core components from GEMM
|
||||||
@ -520,7 +536,7 @@ struct DefaultConv2dFprop <
|
|||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and
|
/// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and
|
||||||
/// multistage pipeline with unaligned data
|
/// multistage pipeline.
|
||||||
template <
|
template <
|
||||||
typename ElementA,
|
typename ElementA,
|
||||||
typename LayoutA,
|
typename LayoutA,
|
||||||
@ -537,8 +553,8 @@ template <
|
|||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
int Stages,
|
int Stages,
|
||||||
typename MathOperatorTag,
|
typename MathOperatorTag,
|
||||||
bool AlignedA,
|
int AlignmentA,
|
||||||
bool AlignedB
|
int AlignmentB
|
||||||
>
|
>
|
||||||
struct DefaultConv2dFprop <
|
struct DefaultConv2dFprop <
|
||||||
ElementA,
|
ElementA,
|
||||||
@ -558,8 +574,8 @@ struct DefaultConv2dFprop <
|
|||||||
Stages,
|
Stages,
|
||||||
MathOperatorTag,
|
MathOperatorTag,
|
||||||
IteratorAlgorithm::kOptimized,
|
IteratorAlgorithm::kOptimized,
|
||||||
AlignedA,
|
AlignmentA,
|
||||||
AlignedB
|
AlignmentB
|
||||||
> {
|
> {
|
||||||
|
|
||||||
// Define the core components from GEMM
|
// Define the core components from GEMM
|
||||||
@ -577,7 +593,7 @@ struct DefaultConv2dFprop <
|
|||||||
ElementA,
|
ElementA,
|
||||||
LayoutA,
|
LayoutA,
|
||||||
ThreadMapA,
|
ThreadMapA,
|
||||||
AlignedA
|
AlignmentA
|
||||||
>;
|
>;
|
||||||
|
|
||||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||||
@ -590,7 +606,7 @@ struct DefaultConv2dFprop <
|
|||||||
ElementB,
|
ElementB,
|
||||||
LayoutB,
|
LayoutB,
|
||||||
ThreadMapB,
|
ThreadMapB,
|
||||||
AlignedB
|
AlignmentB
|
||||||
>;
|
>;
|
||||||
|
|
||||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||||
@ -607,114 +623,7 @@ struct DefaultConv2dFprop <
|
|||||||
arch::CacheOperation::Always,
|
arch::CacheOperation::Always,
|
||||||
IteratorB,
|
IteratorB,
|
||||||
SmemIteratorB,
|
SmemIteratorB,
|
||||||
arch::CacheOperation::Global,
|
|
||||||
MmaPolicy,
|
|
||||||
Stages
|
|
||||||
>;
|
|
||||||
|
|
||||||
// Define the epilogue
|
|
||||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
|
||||||
ThreadblockShape,
|
|
||||||
WarpMmaTensorOp,
|
|
||||||
1,
|
|
||||||
EpilogueOutputOp,
|
|
||||||
EpilogueOutputOp::kCount
|
|
||||||
>::Epilogue;
|
|
||||||
|
|
||||||
// Define the kernel
|
|
||||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
|
||||||
Mma,
|
|
||||||
Epilogue,
|
|
||||||
ThreadblockSwizzle,
|
|
||||||
conv::Operator::kFprop
|
|
||||||
>;
|
|
||||||
};
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
/// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and
|
|
||||||
/// multistage pipeline.
|
|
||||||
template <
|
|
||||||
typename ElementA,
|
|
||||||
typename LayoutA,
|
|
||||||
typename ElementB,
|
|
||||||
typename LayoutB,
|
|
||||||
typename ElementC,
|
|
||||||
typename LayoutC,
|
|
||||||
typename ElementAccumulator,
|
|
||||||
typename ArchTag,
|
|
||||||
typename ThreadblockShape,
|
|
||||||
typename WarpShape,
|
|
||||||
typename InstructionShape,
|
|
||||||
typename EpilogueOutputOp,
|
|
||||||
typename ThreadblockSwizzle,
|
|
||||||
int Stages,
|
|
||||||
typename MathOperatorTag
|
|
||||||
>
|
|
||||||
struct DefaultConv2dFprop <
|
|
||||||
ElementA,
|
|
||||||
LayoutA,
|
|
||||||
ElementB,
|
|
||||||
LayoutB,
|
|
||||||
ElementC,
|
|
||||||
LayoutC,
|
|
||||||
ElementAccumulator,
|
|
||||||
arch::OpClassTensorOp,
|
|
||||||
ArchTag,
|
|
||||||
ThreadblockShape,
|
|
||||||
WarpShape,
|
|
||||||
InstructionShape,
|
|
||||||
EpilogueOutputOp,
|
|
||||||
ThreadblockSwizzle,
|
|
||||||
Stages,
|
|
||||||
MathOperatorTag,
|
|
||||||
IteratorAlgorithm::kOptimized
|
|
||||||
> {
|
|
||||||
|
|
||||||
// Define the core components from GEMM
|
|
||||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
|
||||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
|
||||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
|
||||||
Stages, MathOperatorTag
|
|
||||||
>;
|
|
||||||
|
|
||||||
// Define iterators over tiles from the A operand
|
|
||||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
|
||||||
using IteratorA =
|
|
||||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
|
||||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
|
||||||
ElementA,
|
|
||||||
LayoutA,
|
|
||||||
ThreadMapA
|
|
||||||
>;
|
|
||||||
|
|
||||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
|
||||||
|
|
||||||
// Define iterators over tiles from the B operand
|
|
||||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
|
||||||
using IteratorB =
|
|
||||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
|
||||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
|
||||||
ElementB,
|
|
||||||
LayoutB,
|
|
||||||
ThreadMapB
|
|
||||||
>;
|
|
||||||
|
|
||||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
|
||||||
|
|
||||||
// Warp-level GEMM components
|
|
||||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
|
||||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
|
||||||
|
|
||||||
// Define the Mma
|
|
||||||
using Mma = threadblock::ImplicitGemmMultistage<
|
|
||||||
ThreadblockShape,
|
|
||||||
IteratorA,
|
|
||||||
SmemIteratorA,
|
|
||||||
arch::CacheOperation::Always,
|
arch::CacheOperation::Always,
|
||||||
IteratorB,
|
|
||||||
SmemIteratorB,
|
|
||||||
arch::CacheOperation::Global,
|
|
||||||
MmaPolicy,
|
MmaPolicy,
|
||||||
Stages
|
Stages
|
||||||
>;
|
>;
|
||||||
@ -755,7 +664,9 @@ template <
|
|||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
int Stages,
|
int Stages,
|
||||||
typename MathOperatorTag,
|
typename MathOperatorTag,
|
||||||
int InterleavedK
|
int InterleavedK,
|
||||||
|
int AlignmentA,
|
||||||
|
int AlignmentB
|
||||||
>
|
>
|
||||||
struct DefaultConv2dFprop <
|
struct DefaultConv2dFprop <
|
||||||
ElementA,
|
ElementA,
|
||||||
@ -774,7 +685,9 @@ struct DefaultConv2dFprop <
|
|||||||
ThreadblockSwizzle,
|
ThreadblockSwizzle,
|
||||||
Stages,
|
Stages,
|
||||||
MathOperatorTag,
|
MathOperatorTag,
|
||||||
IteratorAlgorithm::kOptimized
|
IteratorAlgorithm::kOptimized,
|
||||||
|
AlignmentA,
|
||||||
|
AlignmentB
|
||||||
> {
|
> {
|
||||||
|
|
||||||
// Define the core components from GEMM
|
// Define the core components from GEMM
|
||||||
@ -844,10 +757,8 @@ struct DefaultConv2dFprop <
|
|||||||
>;
|
>;
|
||||||
};
|
};
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm
|
/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm
|
||||||
/// and 2 stage pipeline with disalignment data
|
/// and 2 stage pipeline.
|
||||||
template <
|
template <
|
||||||
typename ElementA,
|
typename ElementA,
|
||||||
typename LayoutA,
|
typename LayoutA,
|
||||||
@ -863,8 +774,8 @@ template <
|
|||||||
typename EpilogueOutputOp,
|
typename EpilogueOutputOp,
|
||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
typename MathOperatorTag,
|
typename MathOperatorTag,
|
||||||
bool AlignedA,
|
int AlignmentA,
|
||||||
bool AlignedB
|
int AlignmentB
|
||||||
>
|
>
|
||||||
struct DefaultConv2dFprop <
|
struct DefaultConv2dFprop <
|
||||||
ElementA,
|
ElementA,
|
||||||
@ -884,8 +795,8 @@ struct DefaultConv2dFprop <
|
|||||||
2,
|
2,
|
||||||
MathOperatorTag,
|
MathOperatorTag,
|
||||||
IteratorAlgorithm::kOptimized,
|
IteratorAlgorithm::kOptimized,
|
||||||
AlignedA,
|
AlignmentA,
|
||||||
AlignedB
|
AlignmentB
|
||||||
> {
|
> {
|
||||||
|
|
||||||
// Define the core components from GEMM
|
// Define the core components from GEMM
|
||||||
@ -903,7 +814,7 @@ struct DefaultConv2dFprop <
|
|||||||
ElementA,
|
ElementA,
|
||||||
LayoutA,
|
LayoutA,
|
||||||
ThreadMapA,
|
ThreadMapA,
|
||||||
AlignedA
|
AlignmentA
|
||||||
>
|
>
|
||||||
>;
|
>;
|
||||||
|
|
||||||
@ -918,115 +829,7 @@ struct DefaultConv2dFprop <
|
|||||||
ElementB,
|
ElementB,
|
||||||
LayoutB,
|
LayoutB,
|
||||||
ThreadMapB,
|
ThreadMapB,
|
||||||
AlignedB
|
AlignmentB
|
||||||
>
|
|
||||||
>;
|
|
||||||
|
|
||||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
|
||||||
|
|
||||||
// Warp-level GEMM components
|
|
||||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
|
||||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
|
||||||
|
|
||||||
// Define the Mma
|
|
||||||
using Mma = threadblock::ImplicitGemmPipelined<
|
|
||||||
ThreadblockShape,
|
|
||||||
IteratorA,
|
|
||||||
SmemIteratorA,
|
|
||||||
IteratorB,
|
|
||||||
SmemIteratorB,
|
|
||||||
ElementC,
|
|
||||||
LayoutC,
|
|
||||||
MmaPolicy
|
|
||||||
>;
|
|
||||||
|
|
||||||
// Define the epilogue
|
|
||||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
|
||||||
ArchTag,
|
|
||||||
ThreadblockShape,
|
|
||||||
WarpMmaTensorOp,
|
|
||||||
1,
|
|
||||||
EpilogueOutputOp
|
|
||||||
>::Epilogue;
|
|
||||||
|
|
||||||
// Define the kernel
|
|
||||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
|
||||||
Mma,
|
|
||||||
Epilogue,
|
|
||||||
ThreadblockSwizzle,
|
|
||||||
conv::Operator::kFprop
|
|
||||||
>;
|
|
||||||
};
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm
|
|
||||||
/// and 2 stage pipeline.
|
|
||||||
template <
|
|
||||||
typename ElementA,
|
|
||||||
typename LayoutA,
|
|
||||||
typename ElementB,
|
|
||||||
typename LayoutB,
|
|
||||||
typename ElementC,
|
|
||||||
typename LayoutC,
|
|
||||||
typename ElementAccumulator,
|
|
||||||
typename ArchTag,
|
|
||||||
typename ThreadblockShape,
|
|
||||||
typename WarpShape,
|
|
||||||
typename InstructionShape,
|
|
||||||
typename EpilogueOutputOp,
|
|
||||||
typename ThreadblockSwizzle,
|
|
||||||
typename MathOperatorTag
|
|
||||||
>
|
|
||||||
struct DefaultConv2dFprop <
|
|
||||||
ElementA,
|
|
||||||
LayoutA,
|
|
||||||
ElementB,
|
|
||||||
LayoutB,
|
|
||||||
ElementC,
|
|
||||||
LayoutC,
|
|
||||||
ElementAccumulator,
|
|
||||||
arch::OpClassTensorOp,
|
|
||||||
ArchTag,
|
|
||||||
ThreadblockShape,
|
|
||||||
WarpShape,
|
|
||||||
InstructionShape,
|
|
||||||
EpilogueOutputOp,
|
|
||||||
ThreadblockSwizzle,
|
|
||||||
2,
|
|
||||||
MathOperatorTag,
|
|
||||||
IteratorAlgorithm::kOptimized
|
|
||||||
> {
|
|
||||||
|
|
||||||
// Define the core components from GEMM
|
|
||||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
|
||||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
|
||||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
|
||||||
2, MathOperatorTag>;
|
|
||||||
|
|
||||||
// Define iterators over tiles from the A operand
|
|
||||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
|
||||||
using IteratorA =
|
|
||||||
cutlass::conv::threadblock::TileIterator<
|
|
||||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
|
||||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
|
||||||
ElementA,
|
|
||||||
LayoutA,
|
|
||||||
ThreadMapA
|
|
||||||
>
|
|
||||||
>;
|
|
||||||
|
|
||||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
|
||||||
|
|
||||||
// Define iterators over tiles from the B operand
|
|
||||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
|
||||||
using IteratorB =
|
|
||||||
cutlass::conv::threadblock::TileIterator<
|
|
||||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
|
||||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
|
||||||
ElementB,
|
|
||||||
LayoutB,
|
|
||||||
ThreadMapB
|
|
||||||
>
|
>
|
||||||
>;
|
>;
|
||||||
|
|
||||||
@ -1083,7 +886,9 @@ template <
|
|||||||
typename EpilogueOutputOp,
|
typename EpilogueOutputOp,
|
||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
typename MathOperatorTag,
|
typename MathOperatorTag,
|
||||||
int InterleavedK
|
int InterleavedK,
|
||||||
|
int AlignmentA,
|
||||||
|
int AlignmentB
|
||||||
>
|
>
|
||||||
struct DefaultConv2dFprop <
|
struct DefaultConv2dFprop <
|
||||||
ElementA,
|
ElementA,
|
||||||
@ -1102,7 +907,9 @@ struct DefaultConv2dFprop <
|
|||||||
ThreadblockSwizzle,
|
ThreadblockSwizzle,
|
||||||
2,
|
2,
|
||||||
MathOperatorTag,
|
MathOperatorTag,
|
||||||
IteratorAlgorithm::kOptimized
|
IteratorAlgorithm::kOptimized,
|
||||||
|
AlignmentA,
|
||||||
|
AlignmentB
|
||||||
> {
|
> {
|
||||||
|
|
||||||
// Define the core components from GEMM
|
// Define the core components from GEMM
|
||||||
@ -1194,7 +1001,9 @@ template <
|
|||||||
typename EpilogueOutputOp,
|
typename EpilogueOutputOp,
|
||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
int Stages,
|
int Stages,
|
||||||
typename MathOperatorTag
|
typename MathOperatorTag,
|
||||||
|
int AlignmentA,
|
||||||
|
int AlignmentB
|
||||||
>
|
>
|
||||||
struct DefaultConv2dFprop <
|
struct DefaultConv2dFprop <
|
||||||
ElementA,
|
ElementA,
|
||||||
@ -1213,7 +1022,9 @@ struct DefaultConv2dFprop <
|
|||||||
ThreadblockSwizzle,
|
ThreadblockSwizzle,
|
||||||
Stages,
|
Stages,
|
||||||
MathOperatorTag,
|
MathOperatorTag,
|
||||||
IteratorAlgorithm::kAnalytic
|
IteratorAlgorithm::kAnalytic,
|
||||||
|
AlignmentA,
|
||||||
|
AlignmentB
|
||||||
> {
|
> {
|
||||||
|
|
||||||
// Define the core components from GEMM
|
// Define the core components from GEMM
|
||||||
@ -1299,7 +1110,9 @@ template <
|
|||||||
typename EpilogueOutputOp,
|
typename EpilogueOutputOp,
|
||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
int Stages,
|
int Stages,
|
||||||
typename MathOperatorTag
|
typename MathOperatorTag,
|
||||||
|
int AlignmentA,
|
||||||
|
int AlignmentB
|
||||||
>
|
>
|
||||||
struct DefaultConv2dFprop <
|
struct DefaultConv2dFprop <
|
||||||
ElementA,
|
ElementA,
|
||||||
@ -1318,7 +1131,9 @@ struct DefaultConv2dFprop <
|
|||||||
ThreadblockSwizzle,
|
ThreadblockSwizzle,
|
||||||
Stages,
|
Stages,
|
||||||
MathOperatorTag,
|
MathOperatorTag,
|
||||||
IteratorAlgorithm::kOptimized
|
IteratorAlgorithm::kOptimized,
|
||||||
|
AlignmentA,
|
||||||
|
AlignmentB
|
||||||
> {
|
> {
|
||||||
|
|
||||||
// Define the core components from GEMM
|
// Define the core components from GEMM
|
||||||
@ -1404,7 +1219,9 @@ template <
|
|||||||
typename InstructionShape,
|
typename InstructionShape,
|
||||||
typename EpilogueOutputOp,
|
typename EpilogueOutputOp,
|
||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
typename MathOperatorTag
|
typename MathOperatorTag,
|
||||||
|
int AlignmentA,
|
||||||
|
int AlignmentB
|
||||||
>
|
>
|
||||||
struct DefaultConv2dFprop <
|
struct DefaultConv2dFprop <
|
||||||
ElementA,
|
ElementA,
|
||||||
@ -1423,7 +1240,9 @@ struct DefaultConv2dFprop <
|
|||||||
ThreadblockSwizzle,
|
ThreadblockSwizzle,
|
||||||
2,
|
2,
|
||||||
MathOperatorTag,
|
MathOperatorTag,
|
||||||
IteratorAlgorithm::kAnalytic
|
IteratorAlgorithm::kAnalytic,
|
||||||
|
AlignmentA,
|
||||||
|
AlignmentB
|
||||||
> {
|
> {
|
||||||
|
|
||||||
// Define the core components from GEMM
|
// Define the core components from GEMM
|
||||||
@ -1510,7 +1329,9 @@ template <
|
|||||||
typename InstructionShape,
|
typename InstructionShape,
|
||||||
typename EpilogueOutputOp,
|
typename EpilogueOutputOp,
|
||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
typename MathOperatorTag
|
typename MathOperatorTag,
|
||||||
|
int AlignmentA,
|
||||||
|
int AlignmentB
|
||||||
>
|
>
|
||||||
struct DefaultConv2dFprop <
|
struct DefaultConv2dFprop <
|
||||||
ElementA,
|
ElementA,
|
||||||
@ -1529,7 +1350,9 @@ struct DefaultConv2dFprop <
|
|||||||
ThreadblockSwizzle,
|
ThreadblockSwizzle,
|
||||||
2,
|
2,
|
||||||
MathOperatorTag,
|
MathOperatorTag,
|
||||||
IteratorAlgorithm::kOptimized
|
IteratorAlgorithm::kOptimized,
|
||||||
|
AlignmentA,
|
||||||
|
AlignmentB
|
||||||
> {
|
> {
|
||||||
|
|
||||||
// Define the core components from GEMM
|
// Define the core components from GEMM
|
||||||
|
@ -61,7 +61,7 @@ template <
|
|||||||
typename Element_,
|
typename Element_,
|
||||||
typename Layout_,
|
typename Layout_,
|
||||||
typename ThreadMap_,
|
typename ThreadMap_,
|
||||||
bool Aligned = true
|
int AccessSize = ThreadMap_::kElementsPerAccess
|
||||||
>
|
>
|
||||||
class Conv2dFpropActivationTileAccessIteratorOptimized {
|
class Conv2dFpropActivationTileAccessIteratorOptimized {
|
||||||
public:
|
public:
|
||||||
@ -75,7 +75,6 @@ public:
|
|||||||
using Layout = Layout_;
|
using Layout = Layout_;
|
||||||
using TensorCoord = typename Layout::TensorCoord;
|
using TensorCoord = typename Layout::TensorCoord;
|
||||||
using ThreadMap = ThreadMap_;
|
using ThreadMap = ThreadMap_;
|
||||||
static int const AccessSize = Aligned ? ThreadMap::kElementsPerAccess : 1;
|
|
||||||
using AccessType = AlignedArray<Element, AccessSize>;
|
using AccessType = AlignedArray<Element, AccessSize>;
|
||||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||||
using Index = typename Layout::Index;
|
using Index = typename Layout::Index;
|
||||||
@ -216,7 +215,7 @@ public:
|
|||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||||
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
|
clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
set_iteration_index(0);
|
set_iteration_index(0);
|
||||||
@ -357,7 +356,7 @@ public:
|
|||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||||
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
|
clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -420,7 +419,7 @@ public:
|
|||||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||||
|
|
||||||
// check alignment constraint on iterator's contiguous dimension
|
// check alignment constraint on iterator's contiguous dimension
|
||||||
if (Aligned && problem_size.C % (128/sizeof_bits<Element>::value)) {
|
if (problem_size.C % AccessSize) {
|
||||||
return Status::kErrorInvalidProblem;
|
return Status::kErrorInvalidProblem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ template <
|
|||||||
typename Element_,
|
typename Element_,
|
||||||
typename Layout_,
|
typename Layout_,
|
||||||
typename ThreadMap_,
|
typename ThreadMap_,
|
||||||
bool Aligned = true
|
int AccessSize = ThreadMap_::kElementsPerAccess
|
||||||
>
|
>
|
||||||
class Conv2dFpropFilterTileAccessIteratorOptimized{
|
class Conv2dFpropFilterTileAccessIteratorOptimized{
|
||||||
public:
|
public:
|
||||||
@ -74,7 +74,6 @@ public:
|
|||||||
using Element = Element_;
|
using Element = Element_;
|
||||||
using Layout = Layout_;
|
using Layout = Layout_;
|
||||||
using ThreadMap = ThreadMap_;
|
using ThreadMap = ThreadMap_;
|
||||||
static int const AccessSize = Aligned ? ThreadMap::kElementsPerAccess : 1;
|
|
||||||
using AccessType = AlignedArray<Element, AccessSize>;
|
using AccessType = AlignedArray<Element, AccessSize>;
|
||||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||||
using TensorCoord = typename Layout::TensorCoord;
|
using TensorCoord = typename Layout::TensorCoord;
|
||||||
@ -178,7 +177,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||||
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
|
clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
pointer_ += (
|
pointer_ += (
|
||||||
@ -249,7 +248,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||||
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
|
clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
pointer_ += next;
|
pointer_ += next;
|
||||||
@ -301,7 +300,7 @@ public:
|
|||||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||||
|
|
||||||
// check alignment constraint on iterator's contiguous dimension
|
// check alignment constraint on iterator's contiguous dimension
|
||||||
if (Aligned && problem_size.C % (128/sizeof_bits<Element>::value)) {
|
if (problem_size.C % AccessSize) {
|
||||||
return Status::kErrorInvalidProblem;
|
return Status::kErrorInvalidProblem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user