support unalignment input for conv2d fprop stage=2 Fix for issue #242
This commit is contained in:
parent
c77a524459
commit
7ec3a87f22
@ -66,6 +66,10 @@ template <
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
/// whether Matrix A is 128b aligned
|
||||
bool AlignedA = true,
|
||||
/// whether Matrix B is 128b aligned
|
||||
bool AlignedB = true,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
> struct DefaultConv2dFprop;
|
||||
|
||||
@ -515,6 +519,119 @@ struct DefaultConv2dFprop <
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and
|
||||
/// multistage pipeline with unaligned data
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
bool AlignedA,
|
||||
bool AlignedB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
AlignedA,
|
||||
AlignedB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ThreadMapA,
|
||||
AlignedA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ThreadMapB,
|
||||
AlignedB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and
|
||||
/// multistage pipeline.
|
||||
template <
|
||||
@ -729,6 +846,120 @@ struct DefaultConv2dFprop <
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm
|
||||
/// and 2 stage pipeline with disalignment data
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
bool AlignedA,
|
||||
bool AlignedB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
AlignedA,
|
||||
AlignedB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ThreadMapA,
|
||||
AlignedA
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ThreadMapB,
|
||||
AlignedB
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmPipelined<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm
|
||||
/// and 2 stage pipeline.
|
||||
template <
|
||||
|
@ -90,6 +90,8 @@ public:
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
|
@ -82,6 +82,8 @@ public:
|
||||
static StrideSupport const kStrideSupport = StrideSupport_;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
|
@ -111,6 +111,8 @@ public:
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
|
@ -100,6 +100,8 @@ public:
|
||||
|
||||
using Params = Conv2dDgradOutputGradientIteratorOptimizedParams;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms_;
|
||||
|
@ -95,6 +95,8 @@ public:
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
|
@ -60,7 +60,8 @@ template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
bool Aligned = true
|
||||
>
|
||||
class Conv2dFpropActivationTileAccessIteratorOptimized {
|
||||
public:
|
||||
@ -74,7 +75,8 @@ public:
|
||||
using Layout = Layout_;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
static int const AccessSize = Aligned ? ThreadMap::kElementsPerAccess : 1;
|
||||
using AccessType = AlignedArray<Element, AccessSize>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
@ -97,12 +99,15 @@ public:
|
||||
|
||||
using Params = Conv2dFpropActivationIteratorOptimizedParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Conv2dFpropActivationIteratorOptimizedParams<Layout> const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
|
||||
// One pointer per access
|
||||
char const *pointer_[ThreadMap::Iterations::kStrided];
|
||||
@ -112,7 +117,7 @@ private:
|
||||
int filter_s_;
|
||||
int filter_c_;
|
||||
|
||||
Index masks_[ThreadMap::Iterations::kStrided][2];
|
||||
Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2];
|
||||
|
||||
public:
|
||||
|
||||
@ -180,7 +185,11 @@ public:
|
||||
int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h;
|
||||
|
||||
bool pred = (offset_n[s_idx] < problem_size_.N && h >= 0 && h < problem_size_.H);
|
||||
masks_[s_idx][0] |= (pred << r);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
masks_[s_idx][v_idx][0] |= (pred << r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -197,13 +206,18 @@ public:
|
||||
int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w;
|
||||
|
||||
bool pred = (w >= 0 && w < problem_size_.W);
|
||||
masks_[s_idx][1] |= (pred << s);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
masks_[s_idx][v_idx][1] |= (pred << s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (filter_c_ >= problem_size.C) {
|
||||
clear_mask();
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
@ -250,7 +264,7 @@ private:
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask_(bool clear) {
|
||||
void clear_mask_(bool clear, int index) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
|
||||
@ -268,10 +282,10 @@ private:
|
||||
" mov.u32 %0, m;\n"
|
||||
"}\n"
|
||||
:
|
||||
"=r"(masks_[s][0])
|
||||
"=r"(masks_[s][index][0])
|
||||
:
|
||||
"r"((int)clear),
|
||||
"r"(masks_[s][0])
|
||||
"r"(masks_[s][index][0])
|
||||
);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
@ -283,15 +297,15 @@ private:
|
||||
" mov.u32 %0, m;\n"
|
||||
"}\n"
|
||||
:
|
||||
"=r"(masks_[s][1])
|
||||
"=r"(masks_[s][index][1])
|
||||
:
|
||||
"r"((int)clear),
|
||||
"r"(masks_[s][1])
|
||||
"r"(masks_[s][index][1])
|
||||
);
|
||||
#else
|
||||
if (clear) {
|
||||
masks_[s][0] = 0;
|
||||
masks_[s][1] = 0;
|
||||
masks_[s][index][0] = 0;
|
||||
masks_[s][index][1] = 0;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@ -302,8 +316,11 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of element
|
||||
@ -338,7 +355,10 @@ public:
|
||||
filter_c_ += params_.filter_c_delta;
|
||||
}
|
||||
|
||||
clear_mask_(filter_c_ >= problem_size_.C);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
@ -346,8 +366,11 @@ public:
|
||||
void clear_mask() {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
masks_[s][0] = Mask(0);
|
||||
masks_[s][1] = Mask(0);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
masks_[s][v][0] = Mask(0);
|
||||
masks_[s][v][1] = Mask(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -355,21 +378,28 @@ public:
|
||||
bool valid() {
|
||||
|
||||
return
|
||||
(masks_[iteration_strided_][0] & (Index(1) << filter_r_)) &&
|
||||
(masks_[iteration_strided_][1] & (Index(1) << filter_s_));
|
||||
(masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) &&
|
||||
(masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]);
|
||||
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropActivationTileAccessIteratorOptimized &operator++() {
|
||||
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@ -390,7 +420,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (Aligned && problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
|
@ -94,6 +94,8 @@ public:
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
|
@ -60,7 +60,8 @@ template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
bool Aligned = true
|
||||
>
|
||||
class Conv2dFpropFilterTileAccessIteratorOptimized{
|
||||
public:
|
||||
@ -73,7 +74,8 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = Layout_;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
static int const AccessSize = Aligned ? ThreadMap::kElementsPerAccess : 1;
|
||||
using AccessType = AlignedArray<Element, AccessSize>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@ -89,6 +91,8 @@ public:
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
@ -127,9 +131,10 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
uint32_t predicates_;
|
||||
uint32_t predicates_[kAccessesPerVector];
|
||||
int filter_rs_;
|
||||
int filter_c_;
|
||||
|
||||
@ -154,7 +159,7 @@ public:
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
predicates_(0),
|
||||
predicates_{0},
|
||||
filter_rs_(0),
|
||||
filter_c_(0) {
|
||||
|
||||
@ -166,11 +171,14 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0);
|
||||
predicates_ |= (pred << s);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
predicates_[v_idx] |= (pred << s);
|
||||
}
|
||||
}
|
||||
|
||||
if (filter_c_ >= problem_size.C) {
|
||||
predicates_ = 0u;
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
|
||||
}
|
||||
|
||||
pointer_ += (
|
||||
@ -180,11 +188,44 @@ public:
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask_(bool clear, int index) {
|
||||
// We are using inline PTX assembly here to avoid an CUDA C++ compilation
|
||||
// artifact in which control flow instructions are generated. Instead, our
|
||||
// intent is to predicate the mov instructions.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" .reg .u32 m;"
|
||||
" mov.u32 m, %2;"
|
||||
" setp.ne.b32 p, %1, 0;\n"
|
||||
" @p mov.u32 m, 0;\n"
|
||||
" mov.u32 %0, m;\n"
|
||||
"}\n"
|
||||
:
|
||||
"=r"(predicates_[index])
|
||||
:
|
||||
"r"((int)clear),
|
||||
"r"(predicates_[index])
|
||||
);
|
||||
#else
|
||||
if (clear) {
|
||||
predicates_[index] = 0;
|
||||
predicates_[index] = 0;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@ -206,9 +247,9 @@ public:
|
||||
next = params_.inc_next_c;
|
||||
filter_c_ += params_.filter_c_delta;
|
||||
}
|
||||
|
||||
if (filter_c_ >= problem_size_.C) {
|
||||
predicates_ = 0;
|
||||
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
|
||||
}
|
||||
|
||||
pointer_ += next;
|
||||
@ -217,18 +258,25 @@ public:
|
||||
/// Returns true if the current coordinate is within the filter tensor W
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() {
|
||||
return (predicates_ & (1u << iteration_strided_));
|
||||
return (predicates_[iteration_vector_] & (1u << iteration_strided_));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
return reinterpret_cast<AccessType const *>(pointer_);
|
||||
return reinterpret_cast<AccessType const *>(pointer_) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropFilterTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@ -253,7 +301,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (Aligned && problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
|
@ -131,16 +131,19 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
|
||||
cutlass::arch::global_load<
|
||||
AccessType,
|
||||
sizeof(AccessType)
|
||||
>(
|
||||
frag_ptr[c + s * ThreadMap::Iterations::kContiguous],
|
||||
tile_access_iterator_.get() + pointer_offset,
|
||||
tile_access_iterator_.valid()
|
||||
);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < tile_access_iterator_.kAccessesPerVector; ++v) {
|
||||
cutlass::arch::global_load<
|
||||
AccessType,
|
||||
sizeof(AccessType)
|
||||
>(
|
||||
frag_ptr[(c + s * ThreadMap::Iterations::kContiguous) * tile_access_iterator_.kAccessesPerVector + v],
|
||||
tile_access_iterator_.get() + pointer_offset,
|
||||
tile_access_iterator_.valid()
|
||||
);
|
||||
|
||||
++tile_access_iterator_;
|
||||
++tile_access_iterator_;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -89,6 +89,8 @@ public:
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
|
@ -88,6 +88,8 @@ public:
|
||||
|
||||
using Params = Conv2dWgradActivationIteratorOptimizedParams;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Conv2dWgradActivationIteratorOptimizedParams const ¶ms_;
|
||||
|
@ -89,6 +89,8 @@ public:
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
|
@ -88,6 +88,8 @@ public:
|
||||
|
||||
using Params = Conv2dWgradOutputGradientIteratorOptimizedParams;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms_;
|
||||
|
@ -82,6 +82,8 @@ public:
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or larger.");
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
|
@ -82,6 +82,8 @@ public:
|
||||
static StrideSupport const kStrideSupport = StrideSupport_;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
|
@ -96,6 +96,8 @@ public:
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or greater.");
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Simpligying assertions
|
||||
|
@ -101,6 +101,8 @@ public:
|
||||
|
||||
using Params = Conv3dDgradOutputGradientIteratorOptimizedParams;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
|
@ -94,6 +94,8 @@ public:
|
||||
|
||||
using Params = Conv3dAnalyticParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
|
@ -97,6 +97,8 @@ public:
|
||||
|
||||
using Params = Conv3dFpropActivationIteratorOptimizedParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Conv3dFpropActivationIteratorOptimizedParams<Layout> const ¶ms_;
|
||||
|
@ -93,6 +93,8 @@ public:
|
||||
|
||||
using Params = Conv3dAnalyticParams<Layout>;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
|
@ -89,6 +89,8 @@ public:
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
@ -82,6 +82,8 @@ public:
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
@ -82,6 +82,8 @@ public:
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
@ -82,6 +82,8 @@ public:
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
@ -82,6 +82,8 @@ public:
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
@ -195,7 +195,8 @@ public:
|
||||
IteratorA &iterator_A, IteratorB &iterator_B,
|
||||
int group_start_A = 0, int group_start_B = 0) {
|
||||
|
||||
iterator_A.set_iteration_index(group_start_A);
|
||||
iterator_A.set_iteration_index(group_start_A *
|
||||
IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
@ -208,18 +209,22 @@ public:
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess / 8;
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(group_start_B);
|
||||
iterator_B.set_iteration_index(group_start_B *
|
||||
IteratorB::kAccessesPerVector);
|
||||
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
|
||||
@ -232,12 +237,15 @@ public:
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess / 8;
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
++iterator_B;
|
||||
}
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
@ -279,14 +287,18 @@ public:
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess / 8;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr, iterator_A.get(), iterator_A.valid());
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
++iterator_A;
|
||||
}
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
|
||||
@ -300,14 +312,18 @@ public:
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess / 8;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr, iterator_B.get(), iterator_B.valid());
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
++iterator_B;
|
||||
}
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user