From 7ec3a87f22344bf11f8b411c72cd8759583da374 Mon Sep 17 00:00:00 2001 From: "mengchi.hmc" Date: Wed, 21 Apr 2021 14:28:58 +0800 Subject: [PATCH] support unalignment input for conv2d fprop stage=2 Fix for issue #242 --- .../conv/kernel/default_conv2d_fprop.h | 231 ++++++++++++++++++ ...rad_filter_tile_access_iterator_analytic.h | 2 + ...ad_filter_tile_access_iterator_optimized.h | 2 + ...t_gradient_tile_access_iterator_analytic.h | 2 + ..._gradient_tile_access_iterator_optimized.h | 2 + ...activation_tile_access_iterator_analytic.h | 2 + ...ctivation_tile_access_iterator_optimized.h | 78 ++++-- ...rop_filter_tile_access_iterator_analytic.h | 2 + ...op_filter_tile_access_iterator_optimized.h | 78 ++++-- .../conv/threadblock/conv2d_tile_iterator.h | 21 +- ...activation_tile_access_iterator_analytic.h | 2 + ...ctivation_tile_access_iterator_optimized.h | 2 + ...t_gradient_tile_access_iterator_analytic.h | 2 + ..._gradient_tile_access_iterator_optimized.h | 2 + ...rad_filter_tile_access_iterator_analytic.h | 2 + ...ad_filter_tile_access_iterator_optimized.h | 2 + ...t_gradient_tile_access_iterator_analytic.h | 2 + ..._gradient_tile_access_iterator_optimized.h | 2 + ...activation_tile_access_iterator_analytic.h | 2 + ...ctivation_tile_access_iterator_optimized.h | 2 + ...rop_filter_tile_access_iterator_analytic.h | 2 + ...op_filter_tile_access_iterator_optimized.h | 2 + ...activation_tile_access_iterator_analytic.h | 2 + ...ctivation_tile_access_iterator_optimized.h | 2 + ...t_gradient_tile_access_iterator_analytic.h | 2 + ..._gradient_tile_access_iterator_optimized.h | 2 + .../threadblock/implicit_gemm_multistage.h | 64 +++-- 27 files changed, 444 insertions(+), 72 deletions(-) diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop.h b/include/cutlass/conv/kernel/default_conv2d_fprop.h index d22fb7f0..030e5ca5 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop.h @@ -66,6 +66,10 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + /// whether Matrix A is 128b aligned + bool AlignedA = true, + /// whether Matrix B is 128b aligned + bool AlignedB = true, conv::StrideSupport StrideSupport = StrideSupport::kStrided > struct DefaultConv2dFprop; @@ -515,6 +519,119 @@ struct DefaultConv2dFprop < ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and +/// multistage pipeline with unaligned data +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + bool AlignedA, + bool AlignedB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + AlignedA, + AlignedB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag + >; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA, + AlignedA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + AlignedB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and /// multistage pipeline. template < @@ -729,6 +846,120 @@ struct DefaultConv2dFprop < ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm +/// and 2 stage pipeline with disalignment data +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + bool AlignedA, + bool AlignedB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + AlignedA, + AlignedB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA, + AlignedA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + AlignedB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm /// and 2 stage pipeline. template < diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h index 8afb4968..026b2b2f 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h @@ -90,6 +90,8 @@ public: using Params = Conv2dAnalyticParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h index 937216d5..86e3140e 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h @@ -82,6 +82,8 @@ public: static StrideSupport const kStrideSupport = StrideSupport_; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; // // Parameters structure diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h index e33e4ccb..edc42df1 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -111,6 +111,8 @@ public: using Params = Conv2dAnalyticParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h index 078c9e7f..06c3ecf4 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -100,6 +100,8 @@ public: using Params = Conv2dDgradOutputGradientIteratorOptimizedParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h index 51a51504..4943b9b7 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h @@ -95,6 +95,8 @@ public: using Params = Conv2dAnalyticParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h index 573255da..bb720cf7 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h @@ -60,7 +60,8 @@ template < typename Shape_, typename Element_, typename Layout_, - typename ThreadMap_ + typename ThreadMap_, + bool Aligned = true > class Conv2dFpropActivationTileAccessIteratorOptimized { public: @@ -74,7 +75,8 @@ public: using Layout = Layout_; using TensorCoord = typename Layout::TensorCoord; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + static int const AccessSize = Aligned ? ThreadMap::kElementsPerAccess : 1; + using AccessType = AlignedArray; using TensorRef = cutlass::TensorRef; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; @@ -97,12 +99,15 @@ public: using Params = Conv2dFpropActivationIteratorOptimizedParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Conv2dFpropActivationIteratorOptimizedParams const ¶ms_; Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; + LongIndex iteration_vector_; // One pointer per access char const *pointer_[ThreadMap::Iterations::kStrided]; @@ -112,7 +117,7 @@ private: int filter_s_; int filter_c_; - Index masks_[ThreadMap::Iterations::kStrided][2]; + Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; public: @@ -180,7 +185,11 @@ public: int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h; bool pred = (offset_n[s_idx] < problem_size_.N && h >= 0 && h < problem_size_.H); - masks_[s_idx][0] |= (pred << r); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][0] |= (pred << r); + } } } @@ -197,13 +206,18 @@ public: int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w; bool pred = (w >= 0 && w < problem_size_.W); - masks_[s_idx][1] |= (pred << s); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][1] |= (pred << s); + } } } - if (filter_c_ >= problem_size.C) { - clear_mask(); - } + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx); + } set_iteration_index(0); } @@ -250,7 +264,7 @@ private: /// Clears the predicates CUTLASS_HOST_DEVICE - void clear_mask_(bool clear) { + void clear_mask_(bool clear, int index) { CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { @@ -268,10 +282,10 @@ private: " mov.u32 %0, m;\n" "}\n" : - "=r"(masks_[s][0]) + "=r"(masks_[s][index][0]) : "r"((int)clear), - "r"(masks_[s][0]) + "r"(masks_[s][index][0]) ); asm volatile( "{\n" @@ -283,15 +297,15 @@ private: " mov.u32 %0, m;\n" "}\n" : - "=r"(masks_[s][1]) + "=r"(masks_[s][index][1]) : "r"((int)clear), - "r"(masks_[s][1]) + "r"(masks_[s][index][1]) ); #else if (clear) { - masks_[s][0] = 0; - masks_[s][1] = 0; + masks_[s][index][0] = 0; + masks_[s][index][1] = 0; } #endif } @@ -302,8 +316,11 @@ public: /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } /// Adds a pointer offset in units of element @@ -338,7 +355,10 @@ public: filter_c_ += params_.filter_c_delta; } - clear_mask_(filter_c_ >= problem_size_.C); + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx); + } } /// Clears the predicates @@ -346,8 +366,11 @@ public: void clear_mask() { CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - masks_[s][0] = Mask(0); - masks_[s][1] = Mask(0); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + masks_[s][v][0] = Mask(0); + masks_[s][v][1] = Mask(0); + } } } @@ -355,21 +378,28 @@ public: bool valid() { return - (masks_[iteration_strided_][0] & (Index(1) << filter_r_)) && - (masks_[iteration_strided_][1] & (Index(1) << filter_s_)); + (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && + (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); } /// Returns a pointer to the vector starting at the current coordinate CUTLASS_HOST_DEVICE AccessType const *get() const { - return reinterpret_cast(pointer_[iteration_strided_]); + return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; } /// Increments to the next memory access CUTLASS_HOST_DEVICE Conv2dFpropActivationTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + + iteration_vector_ = 0; + ++iteration_contiguous_; if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { return *this; @@ -390,7 +420,7 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (Aligned && problem_size.C % (128/sizeof_bits::value)) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h index b0a89ada..48a51935 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h @@ -94,6 +94,8 @@ public: using Params = Conv2dAnalyticParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h index 2f12e41f..9781e42f 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h @@ -60,7 +60,8 @@ template < typename Shape_, typename Element_, typename Layout_, - typename ThreadMap_ + typename ThreadMap_, + bool Aligned = true > class Conv2dFpropFilterTileAccessIteratorOptimized{ public: @@ -73,7 +74,8 @@ public: using Element = Element_; using Layout = Layout_; using ThreadMap = ThreadMap_; - using AccessType = AlignedArray; + static int const AccessSize = Aligned ? ThreadMap::kElementsPerAccess : 1; + using AccessType = AlignedArray; using TensorRef = cutlass::TensorRef; using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; @@ -89,6 +91,8 @@ public: static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + // // Parameters structure // @@ -127,9 +131,10 @@ private: Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; + LongIndex iteration_vector_; char const *pointer_; - uint32_t predicates_; + uint32_t predicates_[kAccessesPerVector]; int filter_rs_; int filter_c_; @@ -154,7 +159,7 @@ public: params_(params), problem_size_(problem_size), pointer_(reinterpret_cast(ptr)), - predicates_(0), + predicates_{0}, filter_rs_(0), filter_c_(0) { @@ -166,11 +171,14 @@ public: CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0); - predicates_ |= (pred << s); + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + predicates_[v_idx] |= (pred << s); + } } - if (filter_c_ >= problem_size.C) { - predicates_ = 0u; + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx); } pointer_ += ( @@ -180,11 +188,44 @@ public: set_iteration_index(0); } + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask_(bool clear, int index) { + // We are using inline PTX assembly here to avoid an CUDA C++ compilation + // artifact in which control flow instructions are generated. Instead, our + // intent is to predicate the mov instructions. + #if defined(__CUDA_ARCH__) + asm volatile( + "{\n" + " .reg .pred p;\n" + " .reg .u32 m;" + " mov.u32 m, %2;" + " setp.ne.b32 p, %1, 0;\n" + " @p mov.u32 m, 0;\n" + " mov.u32 %0, m;\n" + "}\n" + : + "=r"(predicates_[index]) + : + "r"((int)clear), + "r"(predicates_[index]) + ); + #else + if (clear) { + predicates_[index] = 0; + predicates_[index] = 0; + } + #endif + } + /// Overrides the internal iteration index CUTLASS_HOST_DEVICE void set_iteration_index(Index index) { - iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; - iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; } /// Adds a pointer offset in units of Element @@ -206,9 +247,9 @@ public: next = params_.inc_next_c; filter_c_ += params_.filter_c_delta; } - - if (filter_c_ >= problem_size_.C) { - predicates_ = 0; + + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask_(filter_c_ + v_idx >= problem_size_.C, v_idx); } pointer_ += next; @@ -217,18 +258,25 @@ public: /// Returns true if the current coordinate is within the filter tensor W CUTLASS_HOST_DEVICE bool valid() { - return (predicates_ & (1u << iteration_strided_)); + return (predicates_[iteration_vector_] & (1u << iteration_strided_)); } /// Returns a pointer to the vector starting at the current coordinate CUTLASS_HOST_DEVICE AccessType const *get() const { - return reinterpret_cast(pointer_); + return reinterpret_cast(pointer_) + iteration_vector_; } /// Increments to the next memory access CUTLASS_HOST_DEVICE Conv2dFpropFilterTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + + iteration_vector_ = 0; + ++iteration_contiguous_; if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { return *this; @@ -253,7 +301,7 @@ public: static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (Aligned && problem_size.C % (128/sizeof_bits::value)) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_tile_iterator.h b/include/cutlass/conv/threadblock/conv2d_tile_iterator.h index 61f02d19..68fec78a 100644 --- a/include/cutlass/conv/threadblock/conv2d_tile_iterator.h +++ b/include/cutlass/conv/threadblock/conv2d_tile_iterator.h @@ -131,16 +131,19 @@ public: CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - cutlass::arch::global_load< - AccessType, - sizeof(AccessType) - >( - frag_ptr[c + s * ThreadMap::Iterations::kContiguous], - tile_access_iterator_.get() + pointer_offset, - tile_access_iterator_.valid() - ); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < tile_access_iterator_.kAccessesPerVector; ++v) { + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[(c + s * ThreadMap::Iterations::kContiguous) * tile_access_iterator_.kAccessesPerVector + v], + tile_access_iterator_.get() + pointer_offset, + tile_access_iterator_.valid() + ); - ++tile_access_iterator_; + ++tile_access_iterator_; + } } } } diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h index 1e3a5837..cb79844d 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h @@ -89,6 +89,8 @@ public: using Params = Conv2dAnalyticParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h index 7762d619..aae011b0 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h @@ -88,6 +88,8 @@ public: using Params = Conv2dWgradActivationIteratorOptimizedParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Conv2dWgradActivationIteratorOptimizedParams const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h index 53fc9205..d9e12f87 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -89,6 +89,8 @@ public: using Params = Conv2dAnalyticParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h index f138ef59..f4d7c7d4 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -88,6 +88,8 @@ public: using Params = Conv2dWgradOutputGradientIteratorOptimizedParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h index 01437547..fcbba130 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h @@ -82,6 +82,8 @@ public: static_assert(sizeof_bits::value >= 8, "DGRAD requires elements of size 8b or larger."); + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; // // Parameters structure diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h index ee532ff6..8683d1d5 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h @@ -82,6 +82,8 @@ public: static StrideSupport const kStrideSupport = StrideSupport_; static int const kConvDim = 3; using ConvProblemSize = typename conv::Conv3dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; // // Parameters structure diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h index 1d70ab3d..92782550 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -96,6 +96,8 @@ public: static_assert(sizeof_bits::value >= 8, "DGRAD requires elements of size 8b or greater."); + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; // // Simpligying assertions diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h index 2a62c292..53d1beec 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -101,6 +101,8 @@ public: using Params = Conv3dDgradOutputGradientIteratorOptimizedParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h index 7cadf860..1c148d5b 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h @@ -94,6 +94,8 @@ public: using Params = Conv3dAnalyticParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h index 9246c592..74c559c1 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h @@ -97,6 +97,8 @@ public: using Params = Conv3dFpropActivationIteratorOptimizedParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Conv3dFpropActivationIteratorOptimizedParams const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h index a7f54368..272fc246 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h @@ -93,6 +93,8 @@ public: using Params = Conv3dAnalyticParams; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h index 5d814890..0fd161c5 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h @@ -89,6 +89,8 @@ public: static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + // // Parameters structure // diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h index 396d856a..0e7ab37e 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h @@ -82,6 +82,8 @@ public: static_assert(sizeof_bits::value >= 8, "WGRAD requires elements of size 8b or greater."); + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + // // Parameters structure // diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h index 2835480d..0052fd67 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h @@ -82,6 +82,8 @@ public: static_assert(sizeof_bits::value >= 8, "WGRAD requires elements of size 8b or greater."); + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + // // Parameters structure // diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h index b8af8efa..73e96d4a 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -82,6 +82,8 @@ public: static_assert(sizeof_bits::value >= 8, "WGRAD requires elements of size 8b or greater."); + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + // // Parameters structure // diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h index d3b356e0..0bf96aff 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -82,6 +82,8 @@ public: static_assert(sizeof_bits::value >= 8, "WGRAD requires elements of size 8b or greater."); + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + // // Parameters structure // diff --git a/include/cutlass/conv/threadblock/implicit_gemm_multistage.h b/include/cutlass/conv/threadblock/implicit_gemm_multistage.h index aefdcd6d..cbc35f32 100644 --- a/include/cutlass/conv/threadblock/implicit_gemm_multistage.h +++ b/include/cutlass/conv/threadblock/implicit_gemm_multistage.h @@ -195,7 +195,8 @@ public: IteratorA &iterator_A, IteratorB &iterator_B, int group_start_A = 0, int group_start_B = 0) { - iterator_A.set_iteration_index(group_start_A); + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); this->smem_iterator_A_.set_iteration_index(group_start_A); // Async Copy for operand A @@ -208,18 +209,22 @@ public: this->smem_iterator_A_.get()); int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / 8; + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; - cutlass::arch::cp_async_zfill( - dst_ptr, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + ++iterator_A; + } ++this->smem_iterator_A_; } } - iterator_B.set_iteration_index(group_start_B); + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); this->smem_iterator_B_.set_iteration_index(group_start_B); @@ -232,12 +237,15 @@ public: this->smem_iterator_B_.get()); int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / 8; + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - cutlass::arch::cp_async_zfill( - dst_ptr, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + ++iterator_B; + } ++this->smem_iterator_B_; } } @@ -279,14 +287,18 @@ public: reinterpret_cast( this->smem_iterator_A_.get()); - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / 8; + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; - cutlass::arch::cp_async_zfill( - dst_ptr, iterator_A.get(), iterator_A.valid()); + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); - ++iterator_A; + ++iterator_A; + } ++this->smem_iterator_A_; } @@ -300,14 +312,18 @@ public: reinterpret_cast( this->smem_iterator_B_.get()); - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / 8; + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - cutlass::arch::cp_async_zfill( - dst_ptr, iterator_B.get(), iterator_B.valid()); + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); - ++iterator_B; + ++iterator_B; + } ++this->smem_iterator_B_; }