support unalignment input for conv2d fprop stage=2 Fix for issue #242

This commit is contained in:
mengchi.hmc 2021-04-21 14:28:58 +08:00
parent c77a524459
commit 7ec3a87f22
27 changed files with 444 additions and 72 deletions

View File

@ -66,6 +66,10 @@ template <
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
/// whether Matrix A is 128b aligned
bool AlignedA = true,
/// whether Matrix B is 128b aligned
bool AlignedB = true,
conv::StrideSupport StrideSupport = StrideSupport::kStrided
> struct DefaultConv2dFprop;
@ -515,6 +519,119 @@ struct DefaultConv2dFprop <
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and
/// multistage pipeline with unaligned data
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
bool AlignedA,
bool AlignedB
>
struct DefaultConv2dFprop <
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
AlignedA,
AlignedB
> {
// Define the core components from GEMM
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
Stages, MathOperatorTag
>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using IteratorA =
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
LayoutA,
ThreadMapA,
AlignedA
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using IteratorB =
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
LayoutB,
ThreadMapB,
AlignedB
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
// Warp-level GEMM components
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
using MmaPolicy = typename MmaCore::MmaPolicy;
// Define the Mma
using Mma = threadblock::ImplicitGemmMultistage<
ThreadblockShape,
IteratorA,
SmemIteratorA,
arch::CacheOperation::Always,
IteratorB,
SmemIteratorB,
arch::CacheOperation::Global,
MmaPolicy,
Stages
>;
// Define the epilogue
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape,
WarpMmaTensorOp,
1,
EpilogueOutputOp,
EpilogueOutputOp::kCount
>::Epilogue;
// Define the kernel
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
Mma,
Epilogue,
ThreadblockSwizzle,
conv::Operator::kFprop
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and
/// multistage pipeline.
template <
@ -729,6 +846,120 @@ struct DefaultConv2dFprop <
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm
/// and 2 stage pipeline with disalignment data
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag,
bool AlignedA,
bool AlignedB
>
struct DefaultConv2dFprop <
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
2,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
AlignedA,
AlignedB
> {
// Define the core components from GEMM
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
2, MathOperatorTag>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using IteratorA =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
LayoutA,
ThreadMapA,
AlignedA
>
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using IteratorB =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
LayoutB,
ThreadMapB,
AlignedB
>
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
// Warp-level GEMM components
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
using MmaPolicy = typename MmaCore::MmaPolicy;
// Define the Mma
using Mma = threadblock::ImplicitGemmPipelined<
ThreadblockShape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
MmaPolicy
>;
// Define the epilogue
using Epilogue = typename detail::DefaultConvEpilogue<
ArchTag,
ThreadblockShape,
WarpMmaTensorOp,
1,
EpilogueOutputOp
>::Epilogue;
// Define the kernel
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
Mma,
Epilogue,
ThreadblockSwizzle,
conv::Operator::kFprop
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm
/// and 2 stage pipeline.
template <

View File

@ -90,6 +90,8 @@ public:
using Params = Conv2dAnalyticParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;

View File

@ -82,6 +82,8 @@ public:
static StrideSupport const kStrideSupport = StrideSupport_;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Parameters structure

View File

@ -111,6 +111,8 @@ public:
using Params = Conv2dAnalyticParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;

View File

@ -100,6 +100,8 @@ public:
using Params = Conv2dDgradOutputGradientIteratorOptimizedParams;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Conv2dDgradOutputGradientIteratorOptimizedParams const &params_;

View File

@ -95,6 +95,8 @@ public:
using Params = Conv2dAnalyticParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;

View File

@ -60,7 +60,8 @@ template <
typename Shape_,
typename Element_,
typename Layout_,
typename ThreadMap_
typename ThreadMap_,
bool Aligned = true
>
class Conv2dFpropActivationTileAccessIteratorOptimized {
public:
@ -74,7 +75,8 @@ public:
using Layout = Layout_;
using TensorCoord = typename Layout::TensorCoord;
using ThreadMap = ThreadMap_;
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
static int const AccessSize = Aligned ? ThreadMap::kElementsPerAccess : 1;
using AccessType = AlignedArray<Element, AccessSize>;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
@ -97,12 +99,15 @@ public:
using Params = Conv2dFpropActivationIteratorOptimizedParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Conv2dFpropActivationIteratorOptimizedParams<Layout> const &params_;
Conv2dProblemSize const &problem_size_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
// One pointer per access
char const *pointer_[ThreadMap::Iterations::kStrided];
@ -112,7 +117,7 @@ private:
int filter_s_;
int filter_c_;
Index masks_[ThreadMap::Iterations::kStrided][2];
Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2];
public:
@ -180,7 +185,11 @@ public:
int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h;
bool pred = (offset_n[s_idx] < problem_size_.N && h >= 0 && h < problem_size_.H);
masks_[s_idx][0] |= (pred << r);
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
masks_[s_idx][v_idx][0] |= (pred << r);
}
}
}
@ -197,13 +206,18 @@ public:
int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w;
bool pred = (w >= 0 && w < problem_size_.W);
masks_[s_idx][1] |= (pred << s);
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
masks_[s_idx][v_idx][1] |= (pred << s);
}
}
}
if (filter_c_ >= problem_size.C) {
clear_mask();
}
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
}
set_iteration_index(0);
}
@ -250,7 +264,7 @@ private:
/// Clears the predicates
CUTLASS_HOST_DEVICE
void clear_mask_(bool clear) {
void clear_mask_(bool clear, int index) {
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
@ -268,10 +282,10 @@ private:
" mov.u32 %0, m;\n"
"}\n"
:
"=r"(masks_[s][0])
"=r"(masks_[s][index][0])
:
"r"((int)clear),
"r"(masks_[s][0])
"r"(masks_[s][index][0])
);
asm volatile(
"{\n"
@ -283,15 +297,15 @@ private:
" mov.u32 %0, m;\n"
"}\n"
:
"=r"(masks_[s][1])
"=r"(masks_[s][index][1])
:
"r"((int)clear),
"r"(masks_[s][1])
"r"(masks_[s][index][1])
);
#else
if (clear) {
masks_[s][0] = 0;
masks_[s][1] = 0;
masks_[s][index][0] = 0;
masks_[s][index][1] = 0;
}
#endif
}
@ -302,8 +316,11 @@ public:
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of element
@ -338,7 +355,10 @@ public:
filter_c_ += params_.filter_c_delta;
}
clear_mask_(filter_c_ >= problem_size_.C);
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
}
}
/// Clears the predicates
@ -346,8 +366,11 @@ public:
void clear_mask() {
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
masks_[s][0] = Mask(0);
masks_[s][1] = Mask(0);
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kAccessesPerVector; ++v) {
masks_[s][v][0] = Mask(0);
masks_[s][v][1] = Mask(0);
}
}
}
@ -355,21 +378,28 @@ public:
bool valid() {
return
(masks_[iteration_strided_][0] & (Index(1) << filter_r_)) &&
(masks_[iteration_strided_][1] & (Index(1) << filter_s_));
(masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) &&
(masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_));
}
/// Returns a pointer to the vector starting at the current coordinate
CUTLASS_HOST_DEVICE
AccessType const *get() const {
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]);
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]) + iteration_vector_;
}
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
Conv2dFpropActivationTileAccessIteratorOptimized &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
return *this;
@ -390,7 +420,7 @@ public:
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
if (Aligned && problem_size.C % (128/sizeof_bits<Element>::value)) {
return Status::kErrorInvalidProblem;
}

View File

@ -94,6 +94,8 @@ public:
using Params = Conv2dAnalyticParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;

View File

@ -60,7 +60,8 @@ template <
typename Shape_,
typename Element_,
typename Layout_,
typename ThreadMap_
typename ThreadMap_,
bool Aligned = true
>
class Conv2dFpropFilterTileAccessIteratorOptimized{
public:
@ -73,7 +74,8 @@ public:
using Element = Element_;
using Layout = Layout_;
using ThreadMap = ThreadMap_;
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
static int const AccessSize = Aligned ? ThreadMap::kElementsPerAccess : 1;
using AccessType = AlignedArray<Element, AccessSize>;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Index = typename Layout::Index;
@ -89,6 +91,8 @@ public:
static_assert(ThreadMap::Iterations::kContiguous == 1,
"Require Iterations::kContiguous == 1");
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Parameters structure
//
@ -127,9 +131,10 @@ private:
Conv2dProblemSize const &problem_size_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
char const *pointer_;
uint32_t predicates_;
uint32_t predicates_[kAccessesPerVector];
int filter_rs_;
int filter_c_;
@ -154,7 +159,7 @@ public:
params_(params),
problem_size_(problem_size),
pointer_(reinterpret_cast<char const *>(ptr)),
predicates_(0),
predicates_{0},
filter_rs_(0),
filter_c_(0) {
@ -166,11 +171,14 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0);
predicates_ |= (pred << s);
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
predicates_[v_idx] |= (pred << s);
}
}
if (filter_c_ >= problem_size.C) {
predicates_ = 0u;
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
}
pointer_ += (
@ -180,11 +188,44 @@ public:
set_iteration_index(0);
}
/// Clears the predicates
CUTLASS_HOST_DEVICE
void clear_mask_(bool clear, int index) {
// We are using inline PTX assembly here to avoid an CUDA C++ compilation
// artifact in which control flow instructions are generated. Instead, our
// intent is to predicate the mov instructions.
#if defined(__CUDA_ARCH__)
asm volatile(
"{\n"
" .reg .pred p;\n"
" .reg .u32 m;"
" mov.u32 m, %2;"
" setp.ne.b32 p, %1, 0;\n"
" @p mov.u32 m, 0;\n"
" mov.u32 %0, m;\n"
"}\n"
:
"=r"(predicates_[index])
:
"r"((int)clear),
"r"(predicates_[index])
);
#else
if (clear) {
predicates_[index] = 0;
predicates_[index] = 0;
}
#endif
}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of Element
@ -206,9 +247,9 @@ public:
next = params_.inc_next_c;
filter_c_ += params_.filter_c_delta;
}
if (filter_c_ >= problem_size_.C) {
predicates_ = 0;
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
}
pointer_ += next;
@ -217,18 +258,25 @@ public:
/// Returns true if the current coordinate is within the filter tensor W
CUTLASS_HOST_DEVICE
bool valid() {
return (predicates_ & (1u << iteration_strided_));
return (predicates_[iteration_vector_] & (1u << iteration_strided_));
}
/// Returns a pointer to the vector starting at the current coordinate
CUTLASS_HOST_DEVICE
AccessType const *get() const {
return reinterpret_cast<AccessType const *>(pointer_);
return reinterpret_cast<AccessType const *>(pointer_) + iteration_vector_;
}
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
Conv2dFpropFilterTileAccessIteratorOptimized &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
return *this;
@ -253,7 +301,7 @@ public:
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
if (Aligned && problem_size.C % (128/sizeof_bits<Element>::value)) {
return Status::kErrorInvalidProblem;
}

View File

@ -131,16 +131,19 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
cutlass::arch::global_load<
AccessType,
sizeof(AccessType)
>(
frag_ptr[c + s * ThreadMap::Iterations::kContiguous],
tile_access_iterator_.get() + pointer_offset,
tile_access_iterator_.valid()
);
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < tile_access_iterator_.kAccessesPerVector; ++v) {
cutlass::arch::global_load<
AccessType,
sizeof(AccessType)
>(
frag_ptr[(c + s * ThreadMap::Iterations::kContiguous) * tile_access_iterator_.kAccessesPerVector + v],
tile_access_iterator_.get() + pointer_offset,
tile_access_iterator_.valid()
);
++tile_access_iterator_;
++tile_access_iterator_;
}
}
}
}

View File

@ -89,6 +89,8 @@ public:
using Params = Conv2dAnalyticParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;

View File

@ -88,6 +88,8 @@ public:
using Params = Conv2dWgradActivationIteratorOptimizedParams;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Conv2dWgradActivationIteratorOptimizedParams const &params_;

View File

@ -89,6 +89,8 @@ public:
using Params = Conv2dAnalyticParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;

View File

@ -88,6 +88,8 @@ public:
using Params = Conv2dWgradOutputGradientIteratorOptimizedParams;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Conv2dWgradOutputGradientIteratorOptimizedParams const &params_;

View File

@ -82,6 +82,8 @@ public:
static_assert(sizeof_bits<Element>::value >= 8,
"DGRAD requires elements of size 8b or larger.");
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Parameters structure

View File

@ -82,6 +82,8 @@ public:
static StrideSupport const kStrideSupport = StrideSupport_;
static int const kConvDim = 3;
using ConvProblemSize = typename conv::Conv3dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Parameters structure

View File

@ -96,6 +96,8 @@ public:
static_assert(sizeof_bits<Element>::value >= 8,
"DGRAD requires elements of size 8b or greater.");
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Simpligying assertions

View File

@ -101,6 +101,8 @@ public:
using Params = Conv3dDgradOutputGradientIteratorOptimizedParams;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;

View File

@ -94,6 +94,8 @@ public:
using Params = Conv3dAnalyticParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;

View File

@ -97,6 +97,8 @@ public:
using Params = Conv3dFpropActivationIteratorOptimizedParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Conv3dFpropActivationIteratorOptimizedParams<Layout> const &params_;

View File

@ -93,6 +93,8 @@ public:
using Params = Conv3dAnalyticParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;

View File

@ -89,6 +89,8 @@ public:
static_assert(ThreadMap::Iterations::kContiguous == 1,
"Require Iterations::kContiguous == 1");
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Parameters structure
//

View File

@ -82,6 +82,8 @@ public:
static_assert(sizeof_bits<Element>::value >= 8,
"WGRAD requires elements of size 8b or greater.");
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Parameters structure
//

View File

@ -82,6 +82,8 @@ public:
static_assert(sizeof_bits<Element>::value >= 8,
"WGRAD requires elements of size 8b or greater.");
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Parameters structure
//

View File

@ -82,6 +82,8 @@ public:
static_assert(sizeof_bits<Element>::value >= 8,
"WGRAD requires elements of size 8b or greater.");
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Parameters structure
//

View File

@ -82,6 +82,8 @@ public:
static_assert(sizeof_bits<Element>::value >= 8,
"WGRAD requires elements of size 8b or greater.");
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Parameters structure
//

View File

@ -195,7 +195,8 @@ public:
IteratorA &iterator_A, IteratorB &iterator_B,
int group_start_A = 0, int group_start_B = 0) {
iterator_A.set_iteration_index(group_start_A);
iterator_A.set_iteration_index(group_start_A *
IteratorA::kAccessesPerVector);
this->smem_iterator_A_.set_iteration_index(group_start_A);
// Async Copy for operand A
@ -208,18 +209,22 @@ public:
this->smem_iterator_A_.get());
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess / 8;
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr, iterator_A.get(), iterator_A.valid());
++iterator_A;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
}
iterator_B.set_iteration_index(group_start_B);
iterator_B.set_iteration_index(group_start_B *
IteratorB::kAccessesPerVector);
this->smem_iterator_B_.set_iteration_index(group_start_B);
@ -232,12 +237,15 @@ public:
this->smem_iterator_B_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess / 8;
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr, iterator_B.get(), iterator_B.valid());
++iterator_B;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
++iterator_B;
}
++this->smem_iterator_B_;
}
}
@ -279,14 +287,18 @@ public:
reinterpret_cast<typename IteratorA::AccessType *>(
this->smem_iterator_A_.get());
int const kSrcBytes =
sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr, iterator_A.get(), iterator_A.valid());
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
++iterator_A;
}
++this->smem_iterator_A_;
}
@ -300,14 +312,18 @@ public:
reinterpret_cast<typename IteratorB::AccessType *>(
this->smem_iterator_B_.get());
int const kSrcBytes =
sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr, iterator_B.get(), iterator_B.valid());
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
++iterator_B;
++iterator_B;
}
++this->smem_iterator_B_;
}