support setting load granularity for conv2d fprop
This commit is contained in:
parent
7ec3a87f22
commit
bb35a3ba6f
@ -66,10 +66,10 @@ template <
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
/// whether Matrix A is 128b aligned
|
||||
bool AlignedA = true,
|
||||
/// whether Matrix B is 128b aligned
|
||||
bool AlignedB = true,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
> struct DefaultConv2dFprop;
|
||||
|
||||
@ -94,7 +94,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -113,7 +115,9 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -197,7 +201,9 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
int InterleavedK,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -216,7 +222,9 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -312,7 +320,9 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -331,7 +341,9 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -417,7 +429,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
int InterleavedK,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -436,7 +450,9 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -520,7 +536,7 @@ struct DefaultConv2dFprop <
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and
|
||||
/// multistage pipeline with unaligned data
|
||||
/// multistage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
@ -537,8 +553,8 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
bool AlignedA,
|
||||
bool AlignedB
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -558,8 +574,8 @@ struct DefaultConv2dFprop <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
AlignedA,
|
||||
AlignedB
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -577,7 +593,7 @@ struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ThreadMapA,
|
||||
AlignedA
|
||||
AlignmentA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
@ -590,7 +606,7 @@ struct DefaultConv2dFprop <
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ThreadMapB,
|
||||
AlignedB
|
||||
AlignmentB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@ -607,114 +623,7 @@ struct DefaultConv2dFprop <
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and
|
||||
/// multistage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ThreadMapA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
@ -755,7 +664,9 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
int InterleavedK,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -774,7 +685,9 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -844,10 +757,8 @@ struct DefaultConv2dFprop <
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm
|
||||
/// and 2 stage pipeline with disalignment data
|
||||
/// and 2 stage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
@ -863,8 +774,8 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
bool AlignedA,
|
||||
bool AlignedB
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -884,8 +795,8 @@ struct DefaultConv2dFprop <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
AlignedA,
|
||||
AlignedB
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -903,7 +814,7 @@ struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ThreadMapA,
|
||||
AlignedA
|
||||
AlignmentA
|
||||
>
|
||||
>;
|
||||
|
||||
@ -918,115 +829,7 @@ struct DefaultConv2dFprop <
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ThreadMapB,
|
||||
AlignedB
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmPipelined<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm
|
||||
/// and 2 stage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ThreadMapA
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ThreadMapB
|
||||
AlignmentB
|
||||
>
|
||||
>;
|
||||
|
||||
@ -1083,7 +886,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
int InterleavedK,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -1102,7 +907,9 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -1194,7 +1001,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -1213,7 +1022,9 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -1299,7 +1110,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -1318,7 +1131,9 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -1404,7 +1219,9 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -1423,7 +1240,9 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -1510,7 +1329,9 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -1529,7 +1350,9 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
|
@ -61,7 +61,7 @@ template <
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_,
|
||||
bool Aligned = true
|
||||
int AccessSize = ThreadMap_::kElementsPerAccess
|
||||
>
|
||||
class Conv2dFpropActivationTileAccessIteratorOptimized {
|
||||
public:
|
||||
@ -75,7 +75,6 @@ public:
|
||||
using Layout = Layout_;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using ThreadMap = ThreadMap_;
|
||||
static int const AccessSize = Aligned ? ThreadMap::kElementsPerAccess : 1;
|
||||
using AccessType = AlignedArray<Element, AccessSize>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using Index = typename Layout::Index;
|
||||
@ -216,7 +215,7 @@ public:
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
|
||||
clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx);
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
@ -357,7 +356,7 @@ public:
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
|
||||
clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx);
|
||||
}
|
||||
}
|
||||
|
||||
@ -420,7 +419,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (Aligned && problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.C % AccessSize) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
|
@ -61,7 +61,7 @@ template <
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_,
|
||||
bool Aligned = true
|
||||
int AccessSize = ThreadMap_::kElementsPerAccess
|
||||
>
|
||||
class Conv2dFpropFilterTileAccessIteratorOptimized{
|
||||
public:
|
||||
@ -74,7 +74,6 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = Layout_;
|
||||
using ThreadMap = ThreadMap_;
|
||||
static int const AccessSize = Aligned ? ThreadMap::kElementsPerAccess : 1;
|
||||
using AccessType = AlignedArray<Element, AccessSize>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
@ -178,7 +177,7 @@ public:
|
||||
}
|
||||
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
|
||||
clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx);
|
||||
}
|
||||
|
||||
pointer_ += (
|
||||
@ -249,7 +248,7 @@ public:
|
||||
}
|
||||
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
|
||||
clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx);
|
||||
}
|
||||
|
||||
pointer_ += next;
|
||||
@ -301,7 +300,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (Aligned && problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.C % AccessSize) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user