support unalignment input for conv2d fprop stage=2 Fix for issue #242
This commit is contained in:
parent
c77a524459
commit
7ec3a87f22
@ -66,6 +66,10 @@ template <
|
|||||||
int Stages,
|
int Stages,
|
||||||
typename MathOperatorTag,
|
typename MathOperatorTag,
|
||||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||||
|
/// whether Matrix A is 128b aligned
|
||||||
|
bool AlignedA = true,
|
||||||
|
/// whether Matrix B is 128b aligned
|
||||||
|
bool AlignedB = true,
|
||||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||||
> struct DefaultConv2dFprop;
|
> struct DefaultConv2dFprop;
|
||||||
|
|
||||||
@ -515,6 +519,119 @@ struct DefaultConv2dFprop <
|
|||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and
|
||||||
|
/// multistage pipeline with unaligned data
|
||||||
|
template <
|
||||||
|
typename ElementA,
|
||||||
|
typename LayoutA,
|
||||||
|
typename ElementB,
|
||||||
|
typename LayoutB,
|
||||||
|
typename ElementC,
|
||||||
|
typename LayoutC,
|
||||||
|
typename ElementAccumulator,
|
||||||
|
typename ArchTag,
|
||||||
|
typename ThreadblockShape,
|
||||||
|
typename WarpShape,
|
||||||
|
typename InstructionShape,
|
||||||
|
typename EpilogueOutputOp,
|
||||||
|
typename ThreadblockSwizzle,
|
||||||
|
int Stages,
|
||||||
|
typename MathOperatorTag,
|
||||||
|
bool AlignedA,
|
||||||
|
bool AlignedB
|
||||||
|
>
|
||||||
|
struct DefaultConv2dFprop <
|
||||||
|
ElementA,
|
||||||
|
LayoutA,
|
||||||
|
ElementB,
|
||||||
|
LayoutB,
|
||||||
|
ElementC,
|
||||||
|
LayoutC,
|
||||||
|
ElementAccumulator,
|
||||||
|
arch::OpClassTensorOp,
|
||||||
|
ArchTag,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpShape,
|
||||||
|
InstructionShape,
|
||||||
|
EpilogueOutputOp,
|
||||||
|
ThreadblockSwizzle,
|
||||||
|
Stages,
|
||||||
|
MathOperatorTag,
|
||||||
|
IteratorAlgorithm::kOptimized,
|
||||||
|
AlignedA,
|
||||||
|
AlignedB
|
||||||
|
> {
|
||||||
|
|
||||||
|
// Define the core components from GEMM
|
||||||
|
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||||
|
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||||
|
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||||
|
Stages, MathOperatorTag
|
||||||
|
>;
|
||||||
|
|
||||||
|
// Define iterators over tiles from the A operand
|
||||||
|
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||||
|
using IteratorA =
|
||||||
|
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||||
|
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||||
|
ElementA,
|
||||||
|
LayoutA,
|
||||||
|
ThreadMapA,
|
||||||
|
AlignedA
|
||||||
|
>;
|
||||||
|
|
||||||
|
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||||
|
|
||||||
|
// Define iterators over tiles from the B operand
|
||||||
|
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||||
|
using IteratorB =
|
||||||
|
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||||
|
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||||
|
ElementB,
|
||||||
|
LayoutB,
|
||||||
|
ThreadMapB,
|
||||||
|
AlignedB
|
||||||
|
>;
|
||||||
|
|
||||||
|
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||||
|
|
||||||
|
// Warp-level GEMM components
|
||||||
|
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||||
|
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||||
|
|
||||||
|
// Define the Mma
|
||||||
|
using Mma = threadblock::ImplicitGemmMultistage<
|
||||||
|
ThreadblockShape,
|
||||||
|
IteratorA,
|
||||||
|
SmemIteratorA,
|
||||||
|
arch::CacheOperation::Always,
|
||||||
|
IteratorB,
|
||||||
|
SmemIteratorB,
|
||||||
|
arch::CacheOperation::Global,
|
||||||
|
MmaPolicy,
|
||||||
|
Stages
|
||||||
|
>;
|
||||||
|
|
||||||
|
// Define the epilogue
|
||||||
|
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpMmaTensorOp,
|
||||||
|
1,
|
||||||
|
EpilogueOutputOp,
|
||||||
|
EpilogueOutputOp::kCount
|
||||||
|
>::Epilogue;
|
||||||
|
|
||||||
|
// Define the kernel
|
||||||
|
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||||
|
Mma,
|
||||||
|
Epilogue,
|
||||||
|
ThreadblockSwizzle,
|
||||||
|
conv::Operator::kFprop
|
||||||
|
>;
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and
|
/// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and
|
||||||
/// multistage pipeline.
|
/// multistage pipeline.
|
||||||
template <
|
template <
|
||||||
@ -729,6 +846,120 @@ struct DefaultConv2dFprop <
|
|||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm
|
||||||
|
/// and 2 stage pipeline with disalignment data
|
||||||
|
template <
|
||||||
|
typename ElementA,
|
||||||
|
typename LayoutA,
|
||||||
|
typename ElementB,
|
||||||
|
typename LayoutB,
|
||||||
|
typename ElementC,
|
||||||
|
typename LayoutC,
|
||||||
|
typename ElementAccumulator,
|
||||||
|
typename ArchTag,
|
||||||
|
typename ThreadblockShape,
|
||||||
|
typename WarpShape,
|
||||||
|
typename InstructionShape,
|
||||||
|
typename EpilogueOutputOp,
|
||||||
|
typename ThreadblockSwizzle,
|
||||||
|
typename MathOperatorTag,
|
||||||
|
bool AlignedA,
|
||||||
|
bool AlignedB
|
||||||
|
>
|
||||||
|
struct DefaultConv2dFprop <
|
||||||
|
ElementA,
|
||||||
|
LayoutA,
|
||||||
|
ElementB,
|
||||||
|
LayoutB,
|
||||||
|
ElementC,
|
||||||
|
LayoutC,
|
||||||
|
ElementAccumulator,
|
||||||
|
arch::OpClassTensorOp,
|
||||||
|
ArchTag,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpShape,
|
||||||
|
InstructionShape,
|
||||||
|
EpilogueOutputOp,
|
||||||
|
ThreadblockSwizzle,
|
||||||
|
2,
|
||||||
|
MathOperatorTag,
|
||||||
|
IteratorAlgorithm::kOptimized,
|
||||||
|
AlignedA,
|
||||||
|
AlignedB
|
||||||
|
> {
|
||||||
|
|
||||||
|
// Define the core components from GEMM
|
||||||
|
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||||
|
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||||
|
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||||
|
2, MathOperatorTag>;
|
||||||
|
|
||||||
|
// Define iterators over tiles from the A operand
|
||||||
|
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||||
|
using IteratorA =
|
||||||
|
cutlass::conv::threadblock::TileIterator<
|
||||||
|
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||||
|
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||||
|
ElementA,
|
||||||
|
LayoutA,
|
||||||
|
ThreadMapA,
|
||||||
|
AlignedA
|
||||||
|
>
|
||||||
|
>;
|
||||||
|
|
||||||
|
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||||
|
|
||||||
|
// Define iterators over tiles from the B operand
|
||||||
|
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||||
|
using IteratorB =
|
||||||
|
cutlass::conv::threadblock::TileIterator<
|
||||||
|
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||||
|
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||||
|
ElementB,
|
||||||
|
LayoutB,
|
||||||
|
ThreadMapB,
|
||||||
|
AlignedB
|
||||||
|
>
|
||||||
|
>;
|
||||||
|
|
||||||
|
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||||
|
|
||||||
|
// Warp-level GEMM components
|
||||||
|
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||||
|
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||||
|
|
||||||
|
// Define the Mma
|
||||||
|
using Mma = threadblock::ImplicitGemmPipelined<
|
||||||
|
ThreadblockShape,
|
||||||
|
IteratorA,
|
||||||
|
SmemIteratorA,
|
||||||
|
IteratorB,
|
||||||
|
SmemIteratorB,
|
||||||
|
ElementC,
|
||||||
|
LayoutC,
|
||||||
|
MmaPolicy
|
||||||
|
>;
|
||||||
|
|
||||||
|
// Define the epilogue
|
||||||
|
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||||
|
ArchTag,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpMmaTensorOp,
|
||||||
|
1,
|
||||||
|
EpilogueOutputOp
|
||||||
|
>::Epilogue;
|
||||||
|
|
||||||
|
// Define the kernel
|
||||||
|
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||||
|
Mma,
|
||||||
|
Epilogue,
|
||||||
|
ThreadblockSwizzle,
|
||||||
|
conv::Operator::kFprop
|
||||||
|
>;
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm
|
/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm
|
||||||
/// and 2 stage pipeline.
|
/// and 2 stage pipeline.
|
||||||
template <
|
template <
|
||||||
|
@ -90,6 +90,8 @@ public:
|
|||||||
|
|
||||||
using Params = Conv2dAnalyticParams<Layout>;
|
using Params = Conv2dAnalyticParams<Layout>;
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Params const ¶ms_;
|
Params const ¶ms_;
|
||||||
|
@ -83,6 +83,8 @@ public:
|
|||||||
static int const kConvDim = 2;
|
static int const kConvDim = 2;
|
||||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Parameters structure
|
// Parameters structure
|
||||||
//
|
//
|
||||||
|
@ -111,6 +111,8 @@ public:
|
|||||||
|
|
||||||
using Params = Conv2dAnalyticParams<Layout>;
|
using Params = Conv2dAnalyticParams<Layout>;
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Params const ¶ms_;
|
Params const ¶ms_;
|
||||||
|
@ -100,6 +100,8 @@ public:
|
|||||||
|
|
||||||
using Params = Conv2dDgradOutputGradientIteratorOptimizedParams;
|
using Params = Conv2dDgradOutputGradientIteratorOptimizedParams;
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms_;
|
Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms_;
|
||||||
|
@ -95,6 +95,8 @@ public:
|
|||||||
|
|
||||||
using Params = Conv2dAnalyticParams<Layout>;
|
using Params = Conv2dAnalyticParams<Layout>;
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Params const ¶ms_;
|
Params const ¶ms_;
|
||||||
|
@ -60,7 +60,8 @@ template <
|
|||||||
typename Shape_,
|
typename Shape_,
|
||||||
typename Element_,
|
typename Element_,
|
||||||
typename Layout_,
|
typename Layout_,
|
||||||
typename ThreadMap_
|
typename ThreadMap_,
|
||||||
|
bool Aligned = true
|
||||||
>
|
>
|
||||||
class Conv2dFpropActivationTileAccessIteratorOptimized {
|
class Conv2dFpropActivationTileAccessIteratorOptimized {
|
||||||
public:
|
public:
|
||||||
@ -74,7 +75,8 @@ public:
|
|||||||
using Layout = Layout_;
|
using Layout = Layout_;
|
||||||
using TensorCoord = typename Layout::TensorCoord;
|
using TensorCoord = typename Layout::TensorCoord;
|
||||||
using ThreadMap = ThreadMap_;
|
using ThreadMap = ThreadMap_;
|
||||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
static int const AccessSize = Aligned ? ThreadMap::kElementsPerAccess : 1;
|
||||||
|
using AccessType = AlignedArray<Element, AccessSize>;
|
||||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||||
using Index = typename Layout::Index;
|
using Index = typename Layout::Index;
|
||||||
using LongIndex = typename Layout::LongIndex;
|
using LongIndex = typename Layout::LongIndex;
|
||||||
@ -97,12 +99,15 @@ public:
|
|||||||
|
|
||||||
using Params = Conv2dFpropActivationIteratorOptimizedParams<Layout>;
|
using Params = Conv2dFpropActivationIteratorOptimizedParams<Layout>;
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Conv2dFpropActivationIteratorOptimizedParams<Layout> const ¶ms_;
|
Conv2dFpropActivationIteratorOptimizedParams<Layout> const ¶ms_;
|
||||||
Conv2dProblemSize const &problem_size_;
|
Conv2dProblemSize const &problem_size_;
|
||||||
LongIndex iteration_contiguous_;
|
LongIndex iteration_contiguous_;
|
||||||
LongIndex iteration_strided_;
|
LongIndex iteration_strided_;
|
||||||
|
LongIndex iteration_vector_;
|
||||||
|
|
||||||
// One pointer per access
|
// One pointer per access
|
||||||
char const *pointer_[ThreadMap::Iterations::kStrided];
|
char const *pointer_[ThreadMap::Iterations::kStrided];
|
||||||
@ -112,7 +117,7 @@ private:
|
|||||||
int filter_s_;
|
int filter_s_;
|
||||||
int filter_c_;
|
int filter_c_;
|
||||||
|
|
||||||
Index masks_[ThreadMap::Iterations::kStrided][2];
|
Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2];
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
@ -180,7 +185,11 @@ public:
|
|||||||
int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h;
|
int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h;
|
||||||
|
|
||||||
bool pred = (offset_n[s_idx] < problem_size_.N && h >= 0 && h < problem_size_.H);
|
bool pred = (offset_n[s_idx] < problem_size_.N && h >= 0 && h < problem_size_.H);
|
||||||
masks_[s_idx][0] |= (pred << r);
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||||
|
masks_[s_idx][v_idx][0] |= (pred << r);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -197,12 +206,17 @@ public:
|
|||||||
int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w;
|
int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w;
|
||||||
|
|
||||||
bool pred = (w >= 0 && w < problem_size_.W);
|
bool pred = (w >= 0 && w < problem_size_.W);
|
||||||
masks_[s_idx][1] |= (pred << s);
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||||
|
masks_[s_idx][v_idx][1] |= (pred << s);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (filter_c_ >= problem_size.C) {
|
CUTLASS_PRAGMA_UNROLL
|
||||||
clear_mask();
|
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||||
|
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
set_iteration_index(0);
|
set_iteration_index(0);
|
||||||
@ -250,7 +264,7 @@ private:
|
|||||||
|
|
||||||
/// Clears the predicates
|
/// Clears the predicates
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void clear_mask_(bool clear) {
|
void clear_mask_(bool clear, int index) {
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||||
|
|
||||||
@ -268,10 +282,10 @@ private:
|
|||||||
" mov.u32 %0, m;\n"
|
" mov.u32 %0, m;\n"
|
||||||
"}\n"
|
"}\n"
|
||||||
:
|
:
|
||||||
"=r"(masks_[s][0])
|
"=r"(masks_[s][index][0])
|
||||||
:
|
:
|
||||||
"r"((int)clear),
|
"r"((int)clear),
|
||||||
"r"(masks_[s][0])
|
"r"(masks_[s][index][0])
|
||||||
);
|
);
|
||||||
asm volatile(
|
asm volatile(
|
||||||
"{\n"
|
"{\n"
|
||||||
@ -283,15 +297,15 @@ private:
|
|||||||
" mov.u32 %0, m;\n"
|
" mov.u32 %0, m;\n"
|
||||||
"}\n"
|
"}\n"
|
||||||
:
|
:
|
||||||
"=r"(masks_[s][1])
|
"=r"(masks_[s][index][1])
|
||||||
:
|
:
|
||||||
"r"((int)clear),
|
"r"((int)clear),
|
||||||
"r"(masks_[s][1])
|
"r"(masks_[s][index][1])
|
||||||
);
|
);
|
||||||
#else
|
#else
|
||||||
if (clear) {
|
if (clear) {
|
||||||
masks_[s][0] = 0;
|
masks_[s][index][0] = 0;
|
||||||
masks_[s][1] = 0;
|
masks_[s][index][1] = 0;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@ -302,8 +316,11 @@ public:
|
|||||||
/// Overrides the internal iteration index
|
/// Overrides the internal iteration index
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void set_iteration_index(Index index) {
|
void set_iteration_index(Index index) {
|
||||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
iteration_vector_ = index % kAccessesPerVector;
|
||||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
int residual_access = index / kAccessesPerVector;
|
||||||
|
|
||||||
|
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||||
|
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Adds a pointer offset in units of element
|
/// Adds a pointer offset in units of element
|
||||||
@ -338,7 +355,10 @@ public:
|
|||||||
filter_c_ += params_.filter_c_delta;
|
filter_c_ += params_.filter_c_delta;
|
||||||
}
|
}
|
||||||
|
|
||||||
clear_mask_(filter_c_ >= problem_size_.C);
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||||
|
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Clears the predicates
|
/// Clears the predicates
|
||||||
@ -346,8 +366,11 @@ public:
|
|||||||
void clear_mask() {
|
void clear_mask() {
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||||
masks_[s][0] = Mask(0);
|
CUTLASS_PRAGMA_UNROLL
|
||||||
masks_[s][1] = Mask(0);
|
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||||
|
masks_[s][v][0] = Mask(0);
|
||||||
|
masks_[s][v][1] = Mask(0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -355,21 +378,28 @@ public:
|
|||||||
bool valid() {
|
bool valid() {
|
||||||
|
|
||||||
return
|
return
|
||||||
(masks_[iteration_strided_][0] & (Index(1) << filter_r_)) &&
|
(masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) &&
|
||||||
(masks_[iteration_strided_][1] & (Index(1) << filter_s_));
|
(masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a pointer to the vector starting at the current coordinate
|
/// Returns a pointer to the vector starting at the current coordinate
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
AccessType const *get() const {
|
AccessType const *get() const {
|
||||||
|
|
||||||
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]);
|
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]) + iteration_vector_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Increments to the next memory access
|
/// Increments to the next memory access
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Conv2dFpropActivationTileAccessIteratorOptimized &operator++() {
|
Conv2dFpropActivationTileAccessIteratorOptimized &operator++() {
|
||||||
|
|
||||||
|
++iteration_vector_;
|
||||||
|
if (iteration_vector_ < kAccessesPerVector) {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
iteration_vector_ = 0;
|
||||||
|
|
||||||
++iteration_contiguous_;
|
++iteration_contiguous_;
|
||||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||||
return *this;
|
return *this;
|
||||||
@ -390,7 +420,7 @@ public:
|
|||||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||||
|
|
||||||
// check alignment constraint on iterator's contiguous dimension
|
// check alignment constraint on iterator's contiguous dimension
|
||||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
if (Aligned && problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||||
return Status::kErrorInvalidProblem;
|
return Status::kErrorInvalidProblem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,6 +94,8 @@ public:
|
|||||||
|
|
||||||
using Params = Conv2dAnalyticParams<Layout>;
|
using Params = Conv2dAnalyticParams<Layout>;
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Params const ¶ms_;
|
Params const ¶ms_;
|
||||||
|
@ -60,7 +60,8 @@ template <
|
|||||||
typename Shape_,
|
typename Shape_,
|
||||||
typename Element_,
|
typename Element_,
|
||||||
typename Layout_,
|
typename Layout_,
|
||||||
typename ThreadMap_
|
typename ThreadMap_,
|
||||||
|
bool Aligned = true
|
||||||
>
|
>
|
||||||
class Conv2dFpropFilterTileAccessIteratorOptimized{
|
class Conv2dFpropFilterTileAccessIteratorOptimized{
|
||||||
public:
|
public:
|
||||||
@ -73,7 +74,8 @@ public:
|
|||||||
using Element = Element_;
|
using Element = Element_;
|
||||||
using Layout = Layout_;
|
using Layout = Layout_;
|
||||||
using ThreadMap = ThreadMap_;
|
using ThreadMap = ThreadMap_;
|
||||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
static int const AccessSize = Aligned ? ThreadMap::kElementsPerAccess : 1;
|
||||||
|
using AccessType = AlignedArray<Element, AccessSize>;
|
||||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||||
using TensorCoord = typename Layout::TensorCoord;
|
using TensorCoord = typename Layout::TensorCoord;
|
||||||
using Index = typename Layout::Index;
|
using Index = typename Layout::Index;
|
||||||
@ -89,6 +91,8 @@ public:
|
|||||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||||
"Require Iterations::kContiguous == 1");
|
"Require Iterations::kContiguous == 1");
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Parameters structure
|
// Parameters structure
|
||||||
//
|
//
|
||||||
@ -127,9 +131,10 @@ private:
|
|||||||
Conv2dProblemSize const &problem_size_;
|
Conv2dProblemSize const &problem_size_;
|
||||||
LongIndex iteration_contiguous_;
|
LongIndex iteration_contiguous_;
|
||||||
LongIndex iteration_strided_;
|
LongIndex iteration_strided_;
|
||||||
|
LongIndex iteration_vector_;
|
||||||
char const *pointer_;
|
char const *pointer_;
|
||||||
|
|
||||||
uint32_t predicates_;
|
uint32_t predicates_[kAccessesPerVector];
|
||||||
int filter_rs_;
|
int filter_rs_;
|
||||||
int filter_c_;
|
int filter_c_;
|
||||||
|
|
||||||
@ -154,7 +159,7 @@ public:
|
|||||||
params_(params),
|
params_(params),
|
||||||
problem_size_(problem_size),
|
problem_size_(problem_size),
|
||||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||||
predicates_(0),
|
predicates_{0},
|
||||||
filter_rs_(0),
|
filter_rs_(0),
|
||||||
filter_c_(0) {
|
filter_c_(0) {
|
||||||
|
|
||||||
@ -166,11 +171,14 @@ public:
|
|||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||||
uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0);
|
uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0);
|
||||||
predicates_ |= (pred << s);
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||||
|
predicates_[v_idx] |= (pred << s);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (filter_c_ >= problem_size.C) {
|
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||||
predicates_ = 0u;
|
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
pointer_ += (
|
pointer_ += (
|
||||||
@ -180,11 +188,44 @@ public:
|
|||||||
set_iteration_index(0);
|
set_iteration_index(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Clears the predicates
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
void clear_mask_(bool clear, int index) {
|
||||||
|
// We are using inline PTX assembly here to avoid an CUDA C++ compilation
|
||||||
|
// artifact in which control flow instructions are generated. Instead, our
|
||||||
|
// intent is to predicate the mov instructions.
|
||||||
|
#if defined(__CUDA_ARCH__)
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" .reg .pred p;\n"
|
||||||
|
" .reg .u32 m;"
|
||||||
|
" mov.u32 m, %2;"
|
||||||
|
" setp.ne.b32 p, %1, 0;\n"
|
||||||
|
" @p mov.u32 m, 0;\n"
|
||||||
|
" mov.u32 %0, m;\n"
|
||||||
|
"}\n"
|
||||||
|
:
|
||||||
|
"=r"(predicates_[index])
|
||||||
|
:
|
||||||
|
"r"((int)clear),
|
||||||
|
"r"(predicates_[index])
|
||||||
|
);
|
||||||
|
#else
|
||||||
|
if (clear) {
|
||||||
|
predicates_[index] = 0;
|
||||||
|
predicates_[index] = 0;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
/// Overrides the internal iteration index
|
/// Overrides the internal iteration index
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void set_iteration_index(Index index) {
|
void set_iteration_index(Index index) {
|
||||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
iteration_vector_ = index % kAccessesPerVector;
|
||||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
int residual_access = index / kAccessesPerVector;
|
||||||
|
|
||||||
|
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||||
|
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Adds a pointer offset in units of Element
|
/// Adds a pointer offset in units of Element
|
||||||
@ -207,8 +248,8 @@ public:
|
|||||||
filter_c_ += params_.filter_c_delta;
|
filter_c_ += params_.filter_c_delta;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (filter_c_ >= problem_size_.C) {
|
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||||
predicates_ = 0;
|
clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
pointer_ += next;
|
pointer_ += next;
|
||||||
@ -217,18 +258,25 @@ public:
|
|||||||
/// Returns true if the current coordinate is within the filter tensor W
|
/// Returns true if the current coordinate is within the filter tensor W
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
bool valid() {
|
bool valid() {
|
||||||
return (predicates_ & (1u << iteration_strided_));
|
return (predicates_[iteration_vector_] & (1u << iteration_strided_));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a pointer to the vector starting at the current coordinate
|
/// Returns a pointer to the vector starting at the current coordinate
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
AccessType const *get() const {
|
AccessType const *get() const {
|
||||||
return reinterpret_cast<AccessType const *>(pointer_);
|
return reinterpret_cast<AccessType const *>(pointer_) + iteration_vector_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Increments to the next memory access
|
/// Increments to the next memory access
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Conv2dFpropFilterTileAccessIteratorOptimized &operator++() {
|
Conv2dFpropFilterTileAccessIteratorOptimized &operator++() {
|
||||||
|
++iteration_vector_;
|
||||||
|
if (iteration_vector_ < kAccessesPerVector) {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
iteration_vector_ = 0;
|
||||||
|
|
||||||
++iteration_contiguous_;
|
++iteration_contiguous_;
|
||||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||||
return *this;
|
return *this;
|
||||||
@ -253,7 +301,7 @@ public:
|
|||||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||||
|
|
||||||
// check alignment constraint on iterator's contiguous dimension
|
// check alignment constraint on iterator's contiguous dimension
|
||||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
if (Aligned && problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||||
return Status::kErrorInvalidProblem;
|
return Status::kErrorInvalidProblem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,11 +131,13 @@ public:
|
|||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int v = 0; v < tile_access_iterator_.kAccessesPerVector; ++v) {
|
||||||
cutlass::arch::global_load<
|
cutlass::arch::global_load<
|
||||||
AccessType,
|
AccessType,
|
||||||
sizeof(AccessType)
|
sizeof(AccessType)
|
||||||
>(
|
>(
|
||||||
frag_ptr[c + s * ThreadMap::Iterations::kContiguous],
|
frag_ptr[(c + s * ThreadMap::Iterations::kContiguous) * tile_access_iterator_.kAccessesPerVector + v],
|
||||||
tile_access_iterator_.get() + pointer_offset,
|
tile_access_iterator_.get() + pointer_offset,
|
||||||
tile_access_iterator_.valid()
|
tile_access_iterator_.valid()
|
||||||
);
|
);
|
||||||
@ -144,6 +146,7 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Loads a fragment from memory
|
/// Loads a fragment from memory
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
|
@ -89,6 +89,8 @@ public:
|
|||||||
|
|
||||||
using Params = Conv2dAnalyticParams<Layout>;
|
using Params = Conv2dAnalyticParams<Layout>;
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Params const ¶ms_;
|
Params const ¶ms_;
|
||||||
|
@ -88,6 +88,8 @@ public:
|
|||||||
|
|
||||||
using Params = Conv2dWgradActivationIteratorOptimizedParams;
|
using Params = Conv2dWgradActivationIteratorOptimizedParams;
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Conv2dWgradActivationIteratorOptimizedParams const ¶ms_;
|
Conv2dWgradActivationIteratorOptimizedParams const ¶ms_;
|
||||||
|
@ -89,6 +89,8 @@ public:
|
|||||||
|
|
||||||
using Params = Conv2dAnalyticParams<Layout>;
|
using Params = Conv2dAnalyticParams<Layout>;
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Params const ¶ms_;
|
Params const ¶ms_;
|
||||||
|
@ -88,6 +88,8 @@ public:
|
|||||||
|
|
||||||
using Params = Conv2dWgradOutputGradientIteratorOptimizedParams;
|
using Params = Conv2dWgradOutputGradientIteratorOptimizedParams;
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms_;
|
Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms_;
|
||||||
|
@ -83,6 +83,8 @@ public:
|
|||||||
static_assert(sizeof_bits<Element>::value >= 8,
|
static_assert(sizeof_bits<Element>::value >= 8,
|
||||||
"DGRAD requires elements of size 8b or larger.");
|
"DGRAD requires elements of size 8b or larger.");
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Parameters structure
|
// Parameters structure
|
||||||
//
|
//
|
||||||
|
@ -83,6 +83,8 @@ public:
|
|||||||
static int const kConvDim = 3;
|
static int const kConvDim = 3;
|
||||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Parameters structure
|
// Parameters structure
|
||||||
//
|
//
|
||||||
|
@ -97,6 +97,8 @@ public:
|
|||||||
static_assert(sizeof_bits<Element>::value >= 8,
|
static_assert(sizeof_bits<Element>::value >= 8,
|
||||||
"DGRAD requires elements of size 8b or greater.");
|
"DGRAD requires elements of size 8b or greater.");
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Simpligying assertions
|
// Simpligying assertions
|
||||||
//
|
//
|
||||||
|
@ -101,6 +101,8 @@ public:
|
|||||||
|
|
||||||
using Params = Conv3dDgradOutputGradientIteratorOptimizedParams;
|
using Params = Conv3dDgradOutputGradientIteratorOptimizedParams;
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Params const ¶ms_;
|
Params const ¶ms_;
|
||||||
|
@ -94,6 +94,8 @@ public:
|
|||||||
|
|
||||||
using Params = Conv3dAnalyticParams<Layout>;
|
using Params = Conv3dAnalyticParams<Layout>;
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Params const ¶ms_;
|
Params const ¶ms_;
|
||||||
|
@ -97,6 +97,8 @@ public:
|
|||||||
|
|
||||||
using Params = Conv3dFpropActivationIteratorOptimizedParams<Layout>;
|
using Params = Conv3dFpropActivationIteratorOptimizedParams<Layout>;
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Conv3dFpropActivationIteratorOptimizedParams<Layout> const ¶ms_;
|
Conv3dFpropActivationIteratorOptimizedParams<Layout> const ¶ms_;
|
||||||
|
@ -93,6 +93,8 @@ public:
|
|||||||
|
|
||||||
using Params = Conv3dAnalyticParams<Layout>;
|
using Params = Conv3dAnalyticParams<Layout>;
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Params const ¶ms_;
|
Params const ¶ms_;
|
||||||
|
@ -89,6 +89,8 @@ public:
|
|||||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||||
"Require Iterations::kContiguous == 1");
|
"Require Iterations::kContiguous == 1");
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Parameters structure
|
// Parameters structure
|
||||||
//
|
//
|
||||||
|
@ -82,6 +82,8 @@ public:
|
|||||||
static_assert(sizeof_bits<Element>::value >= 8,
|
static_assert(sizeof_bits<Element>::value >= 8,
|
||||||
"WGRAD requires elements of size 8b or greater.");
|
"WGRAD requires elements of size 8b or greater.");
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Parameters structure
|
// Parameters structure
|
||||||
//
|
//
|
||||||
|
@ -82,6 +82,8 @@ public:
|
|||||||
static_assert(sizeof_bits<Element>::value >= 8,
|
static_assert(sizeof_bits<Element>::value >= 8,
|
||||||
"WGRAD requires elements of size 8b or greater.");
|
"WGRAD requires elements of size 8b or greater.");
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Parameters structure
|
// Parameters structure
|
||||||
//
|
//
|
||||||
|
@ -82,6 +82,8 @@ public:
|
|||||||
static_assert(sizeof_bits<Element>::value >= 8,
|
static_assert(sizeof_bits<Element>::value >= 8,
|
||||||
"WGRAD requires elements of size 8b or greater.");
|
"WGRAD requires elements of size 8b or greater.");
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Parameters structure
|
// Parameters structure
|
||||||
//
|
//
|
||||||
|
@ -82,6 +82,8 @@ public:
|
|||||||
static_assert(sizeof_bits<Element>::value >= 8,
|
static_assert(sizeof_bits<Element>::value >= 8,
|
||||||
"WGRAD requires elements of size 8b or greater.");
|
"WGRAD requires elements of size 8b or greater.");
|
||||||
|
|
||||||
|
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Parameters structure
|
// Parameters structure
|
||||||
//
|
//
|
||||||
|
@ -195,7 +195,8 @@ public:
|
|||||||
IteratorA &iterator_A, IteratorB &iterator_B,
|
IteratorA &iterator_A, IteratorB &iterator_B,
|
||||||
int group_start_A = 0, int group_start_B = 0) {
|
int group_start_A = 0, int group_start_B = 0) {
|
||||||
|
|
||||||
iterator_A.set_iteration_index(group_start_A);
|
iterator_A.set_iteration_index(group_start_A *
|
||||||
|
IteratorA::kAccessesPerVector);
|
||||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||||
|
|
||||||
// Async Copy for operand A
|
// Async Copy for operand A
|
||||||
@ -208,18 +209,22 @@ public:
|
|||||||
this->smem_iterator_A_.get());
|
this->smem_iterator_A_.get());
|
||||||
|
|
||||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||||
IteratorA::ThreadMap::kElementsPerAccess / 8;
|
IteratorA::ThreadMap::kElementsPerAccess /
|
||||||
|
IteratorA::kAccessesPerVector / 8;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||||
dst_ptr, iterator_A.get(), iterator_A.valid());
|
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||||
|
|
||||||
++iterator_A;
|
++iterator_A;
|
||||||
|
}
|
||||||
|
|
||||||
++this->smem_iterator_A_;
|
++this->smem_iterator_A_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
iterator_B.set_iteration_index(group_start_B);
|
iterator_B.set_iteration_index(group_start_B *
|
||||||
|
IteratorB::kAccessesPerVector);
|
||||||
|
|
||||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||||
|
|
||||||
@ -232,12 +237,15 @@ public:
|
|||||||
this->smem_iterator_B_.get());
|
this->smem_iterator_B_.get());
|
||||||
|
|
||||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||||
IteratorB::ThreadMap::kElementsPerAccess / 8;
|
IteratorB::ThreadMap::kElementsPerAccess /
|
||||||
|
IteratorB::kAccessesPerVector / 8;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||||
dst_ptr, iterator_B.get(), iterator_B.valid());
|
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||||
|
|
||||||
++iterator_B;
|
++iterator_B;
|
||||||
|
}
|
||||||
++this->smem_iterator_B_;
|
++this->smem_iterator_B_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -279,14 +287,18 @@ public:
|
|||||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||||
this->smem_iterator_A_.get());
|
this->smem_iterator_A_.get());
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||||
int const kSrcBytes =
|
int const kSrcBytes =
|
||||||
sizeof_bits<typename IteratorA::Element>::value *
|
sizeof_bits<typename IteratorA::Element>::value *
|
||||||
IteratorA::ThreadMap::kElementsPerAccess / 8;
|
IteratorA::ThreadMap::kElementsPerAccess /
|
||||||
|
IteratorA::kAccessesPerVector / 8;
|
||||||
|
|
||||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||||
dst_ptr, iterator_A.get(), iterator_A.valid());
|
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||||
|
|
||||||
++iterator_A;
|
++iterator_A;
|
||||||
|
}
|
||||||
++this->smem_iterator_A_;
|
++this->smem_iterator_A_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -300,14 +312,18 @@ public:
|
|||||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||||
this->smem_iterator_B_.get());
|
this->smem_iterator_B_.get());
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||||
int const kSrcBytes =
|
int const kSrcBytes =
|
||||||
sizeof_bits<typename IteratorB::Element>::value *
|
sizeof_bits<typename IteratorB::Element>::value *
|
||||||
IteratorB::ThreadMap::kElementsPerAccess / 8;
|
IteratorB::ThreadMap::kElementsPerAccess /
|
||||||
|
IteratorB::kAccessesPerVector / 8;
|
||||||
|
|
||||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||||
dst_ptr, iterator_B.get(), iterator_B.valid());
|
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||||
|
|
||||||
++iterator_B;
|
++iterator_B;
|
||||||
|
}
|
||||||
++this->smem_iterator_B_;
|
++this->smem_iterator_B_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user