diff --git a/include/cutlass/block_striped.h b/include/cutlass/block_striped.h index 2ffd59b1..598ffa7b 100644 --- a/include/cutlass/block_striped.h +++ b/include/cutlass/block_striped.h @@ -146,7 +146,6 @@ struct StripedAccessType< template < int BlockThreads, typename ArrayT, - typename T, typename AccessT = StripedAccessType > struct BlockStriped { @@ -156,7 +155,7 @@ struct BlockStriped /// Load CUTLASS_DEVICE - static void load(ArrayT &data, T *ptr, int thread_idx) + static void load(ArrayT &data, ArrayT *ptr, int thread_idx) { AccessT *access_input = reinterpret_cast(ptr); AccessT *access_data = reinterpret_cast(&data); @@ -169,7 +168,7 @@ struct BlockStriped /// Load & Add CUTLASS_DEVICE - static void load_add(ArrayT &data, T *ptr, int thread_idx) + static void load_add(ArrayT &data, ArrayT *ptr, int thread_idx) { AccessT *access_input = reinterpret_cast(ptr); AccessT *access_data = reinterpret_cast(&data); @@ -185,7 +184,7 @@ struct BlockStriped /// Store CUTLASS_DEVICE - static void store(T *ptr, const ArrayT &data, int thread_idx) + static void store(ArrayT *ptr, const ArrayT &data, int thread_idx) { AccessT *access_output = reinterpret_cast(ptr); const AccessT *access_data = reinterpret_cast(&data); @@ -210,19 +209,24 @@ struct BlockStriped template < int BlockThreads, typename ArrayT, - typename T> -struct BlockStripedReduce : BlockStriped + typename ElementT = typename StripedAccessType::Element> +struct BlockStripedReduce : + BlockStriped< + BlockThreads, + ArrayT, + ElementT> { /// Reduce CUTLASS_DEVICE - static void reduce(T *ptr, const ArrayT &data, int thread_idx) + static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx) { - cutlass::red reduce; - const T *access_data = reinterpret_cast(&data); + cutlass::red reduce; + ElementT *access_output = reinterpret_cast(ptr); + const ElementT *access_data = reinterpret_cast(&data); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < BlockStripedReduce::kStripes; ++i) { - reduce(ptr + (BlockThreads * i) + thread_idx, access_data[i]); + reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]); } } }; @@ -234,13 +238,17 @@ struct BlockStripedReduce : BlockStriped template < int BlockThreads, typename ArrayT> -struct BlockStripedReduce : BlockStriped +struct BlockStripedReduce : + BlockStriped< + BlockThreads, + ArrayT, + half2> { static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of half must be even number in length"); /// Reduce CUTLASS_DEVICE - static void reduce(half_t *ptr, const ArrayT &data, int thread_idx) + static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx) { cutlass::red reduce; half2 *access_output = reinterpret_cast(ptr); diff --git a/include/cutlass/epilogue/threadblock/epilogue.h b/include/cutlass/epilogue/threadblock/epilogue.h index 91005832..993de892 100644 --- a/include/cutlass/epilogue/threadblock/epilogue.h +++ b/include/cutlass/epilogue/threadblock/epilogue.h @@ -222,7 +222,7 @@ public: int peer_idx_begin, int peer_idx_end, int reduce_fragment_idx, - ElementAccumulator *element_workspace, + void *element_workspace, OutputOp const &output_op, ///< Output operator OutputTileIterator destination_iterator, ///< Tile iterator for destination OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) diff --git a/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h b/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h index 45b0fd27..f0a20acb 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h +++ b/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h @@ -53,19 +53,20 @@ template < typename Shape, ///< Shape of threadblock tile (concept: GemmShape) int PartitionsK, typename WarpMmaOperator, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) - typename AccumulatorFragmentIterator> ///< Fragment iterator selecting accumulators + typename AccumulatorFragmentIterator> ///< Iterator for enumerating fragments within the per-thread tile of raw accumulators class EpilogueBaseStreamK { protected: - /// The complete warp-level accumulator tile + /// The per-thread tile of raw accumulators using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; /// Number of warps using WarpCount = gemm::GemmShape< Shape::kM / WarpMmaOperator::Shape::kM, - Shape::kN / WarpMmaOperator::Shape::kN, PartitionsK>; + Shape::kN / WarpMmaOperator::Shape::kN, + PartitionsK>; /// Number of threads per block static int const kBlockThreads = 32 * WarpCount::kCount; @@ -76,25 +77,26 @@ protected: /// Fragment type used by the accumulator tile's fragment iterator using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment; - /// Block-striped transfer utility for sharing AccumulatorFragment - using BlockStripedT = BlockStriped; - - /// Number of elements per fragment - static int const kFragmentElements = sizeof(AccumulatorFragment) / sizeof(ElementAccumulator); - public: - /// Number of fragments per accumulator tile + /// Number of AccumulatorTile fragments per thread static int const kAccumulatorFragments = AccumulatorFragmentIterator::Policy::kIterations; - /// Number of workspace accumulation elements shared per output tile - static int const kPeerAccumulators = WarpMmaOperator::Shape::kMN * WarpCount::kCount; - protected: - /// ElementAccumulator stride in the shared workspace between different peer blocks (two: each peer block can share accumulators for up to two tiles) - static const int kPeerStride = kPeerAccumulators * 2; + /// Number of AccumulatorTile fragments per block output tile + static int const kOutputTileFragments = kBlockThreads * kAccumulatorFragments; + /// Block-striped transfer utility for sharing AccumulatorFragment + using BlockStripedT = BlockStriped; + + /// AccumulatorFragment stride in the shared workspace between different peer blocks (each thread block can share accumulators for up to two block output tiles) + static const int kPeerFragmentStride = kOutputTileFragments * 2; + +public: + + /// Workspace bytes per thread block + static size_t const kWorkspaceBytesPerBlock =sizeof(AccumulatorFragment) * kPeerFragmentStride; public: @@ -119,28 +121,29 @@ public: int peer_idx_begin, int peer_idx_end, int reduce_fragment_idx, - ElementAccumulator *element_workspace) + void *workspace_ptr) { plus add_fragments; - int accum_set_offset = - (peer_idx_begin * kPeerStride) + - (reduce_fragment_idx * kBlockThreads * kFragmentElements); + AccumulatorFragment *fragment_workspace = reinterpret_cast(workspace_ptr); + + int fragment_offset = (peer_idx_begin * kPeerFragmentStride) + (reduce_fragment_idx * kBlockThreads); // Load first peer fragment - BlockStripedT::load(accum_fragment, element_workspace + accum_set_offset, this->thread_idx); + BlockStripedT::load(accum_fragment, fragment_workspace + fragment_offset, this->thread_idx); - accum_set_offset += kPeerStride; // Move to next peer - accum_set_offset += kPeerAccumulators; // Move to non-starting accumulator set for peer + fragment_offset += kPeerFragmentStride; // Move to next peer + fragment_offset += kOutputTileFragments; // Move to the set of fragments for this peer's "non-started" output tile - // Reduce additional peer fragments + // Reduce fragments from additional peers #pragma unroll 2 - while (accum_set_offset < peer_idx_end * kPeerStride) + for (; fragment_offset < peer_idx_end * kPeerFragmentStride; fragment_offset += kPeerFragmentStride) { + // Load peer fragment AccumulatorFragment addend_fragment; - BlockStripedT::load(addend_fragment, element_workspace + accum_set_offset, this->thread_idx); - accum_set_offset += kPeerStride; + BlockStripedT::load(addend_fragment, fragment_workspace + fragment_offset, this->thread_idx); + // Add peer fragment accum_fragment = add_fragments(accum_fragment, addend_fragment); } } @@ -150,19 +153,22 @@ public: CUTLASS_DEVICE void share( int peer_idx, - ElementAccumulator *element_workspace, ///< Output pointer for writing this block's accumulator set to - AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile - bool started_tile) + void *workspace_ptr, + AccumulatorTile const &accumulators, + bool started_tile) ///< Whether this thread block computed the first work volume for the current output tile { - int accum_set_offset = peer_idx * kPeerStride; + AccumulatorFragment *fragment_workspace = reinterpret_cast(workspace_ptr); + + int fragment_offset = peer_idx * kPeerFragmentStride; if (!started_tile) { - // Move to non-starting accumulator set - accum_set_offset += kPeerAccumulators; + // Move to the set of fragments for the "non-started" output tile + fragment_offset += kOutputTileFragments; } AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + // Convert raw accumulator tile to fragments and store CUTLASS_PRAGMA_UNROLL for (int iter = 0; iter < kAccumulatorFragments; ++iter) { @@ -172,9 +178,9 @@ public: ++accum_fragment_iterator; // Store accumulator fragment - BlockStripedT::store(element_workspace + accum_set_offset, accum_fragment, this->thread_idx); + BlockStripedT::store(fragment_workspace + fragment_offset, accum_fragment, this->thread_idx); - accum_set_offset += (kFragmentElements * kBlockThreads); + fragment_offset += kBlockThreads; } } diff --git a/include/cutlass/epilogue/threadblock/interleaved_epilogue.h b/include/cutlass/epilogue/threadblock/interleaved_epilogue.h index 0bf9cddf..84d469e9 100644 --- a/include/cutlass/epilogue/threadblock/interleaved_epilogue.h +++ b/include/cutlass/epilogue/threadblock/interleaved_epilogue.h @@ -172,7 +172,7 @@ public: int peer_idx_begin, int peer_idx_end, int reduce_fragment_idx, - ElementAccumulator *element_workspace, + void *element_workspace, OutputOp const &output_op, ///< Output operator OutputTileIterator destination_iterator, ///< Tile iterator for destination OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) diff --git a/include/cutlass/gemm/kernel/gemm_universal_streamk.h b/include/cutlass/gemm/kernel/gemm_universal_streamk.h index e277e4a4..831b5a5f 100644 --- a/include/cutlass/gemm/kernel/gemm_universal_streamk.h +++ b/include/cutlass/gemm/kernel/gemm_universal_streamk.h @@ -77,7 +77,9 @@ public: using LayoutB = typename Mma::IteratorB::Layout; using ElementC = typename Epilogue::OutputTileIterator::Element; using LayoutC = typename Epilogue::OutputTileIterator::Layout; - using ElementAccumulator = typename Mma::ElementC; + + /// The per-thread tile of raw accumulators + using AccumulatorTile = typename Mma::FragmentC; static ComplexTransform const kTransformA = Mma::kTransformA; static ComplexTransform const kTransformB = Mma::kTransformB; @@ -94,21 +96,18 @@ public: static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - /// Number of workspace accumulation elements shared per per block - static int const kPeerAccumulators = Epilogue::kPeerAccumulators; - - /// Number of fragments per (thread) accumulator tile - static int const kAccumulatorFragments = Epilogue::kAccumulatorFragments; - - /// Number of numeric accumulation elements per fragment - static int const kAccumTileElements = sizeof(typename Mma::FragmentC) / sizeof(ElementAccumulator); - /// Warp count (concept: GemmShape) using WarpCount = typename Mma::WarpCount; static int const kThreadCount = 32 * WarpCount::kCount; + /// Workspace bytes per thread block + static size_t const kWorkspaceBytesPerBlock = + __NV_STD_MAX( + kThreadCount * sizeof(AccumulatorTile), + Epilogue::kWorkspaceBytesPerBlock); + /// Block-striped reduction utility - using BlockStripedReduceT = BlockStripedReduce; + using BlockStripedReduceT = BlockStripedReduce; @@ -276,7 +275,7 @@ public: int64_t batch_stride_D; void *barrier_workspace; - ElementAccumulator *partials_workspace; + void *partials_workspace; protected: @@ -305,11 +304,8 @@ public: /// Get the workspace size needed for intermediate partial sums size_t get_partials_workspace_size() const { - // For atomic reduction, each SK-block can share one accumulator tile. For parallel reduction, - // each SK-block can share up to two accumulator tiles. - size_t tile_bytes_accumulators = sizeof(ElementAccumulator) * kPeerAccumulators * 2; int sk_blocks = block_mapping.sk_regions * block_mapping.sk_blocks_per_region; - return cacheline_align_up(tile_bytes_accumulators * sk_blocks); + return cacheline_align_up(kWorkspaceBytesPerBlock * sk_blocks); } @@ -388,7 +384,7 @@ public: if (!workspace) { return Status::kErrorWorkspaceNull; } - partials_workspace = reinterpret_cast(ptr); + partials_workspace = ptr; ptr += partials_workspace_bytes; } @@ -400,7 +396,7 @@ public: if (!workspace) { return Status::kErrorWorkspaceNull; } - barrier_workspace = reinterpret_cast(ptr); + barrier_workspace = ptr; ptr += barrier_workspace_bytes; } @@ -630,14 +626,12 @@ public: return Status::kSuccess; } - /// Determines whether the GEMM problem satisfies this kernel's /// alignment requirements static Status can_implement(Arguments const &args) { return can_implement(args.problem_size); } - protected: // @@ -762,14 +756,16 @@ protected: /// Share accumulators with peers CUTLASS_DEVICE - void share_accumulators(typename Mma::FragmentC const &accumulator_tile, int first_block_idx) + void share_accumulators(AccumulatorTile const &accumulator_tile, int first_block_idx) { - int block_tile_offset = first_block_idx * kPeerAccumulators; + AccumulatorTile *accum_tile_workspace = reinterpret_cast(params.partials_workspace); + + int accum_tile_offset = first_block_idx * kThreadCount; if (block_idx == first_block_idx) { // First peer initializes the workspace partials - BlockStripedReduceT::store(params.partials_workspace + block_tile_offset, accumulator_tile, thread_idx); + BlockStripedReduceT::store(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx); } else { @@ -787,7 +783,7 @@ protected: } // Perform reduction in workspace - BlockStripedReduceT::reduce(params.partials_workspace + block_tile_offset, accumulator_tile, thread_idx); + BlockStripedReduceT::reduce(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx); } // Signal our arrival @@ -797,17 +793,19 @@ protected: /// Acquire accumulators from peers CUTLASS_DEVICE - void acquire_accumulators_atomic( - typename Mma::FragmentC &accumulator_tile, + void acquire_accumulators( + AccumulatorTile &accumulator_tile, int first_block_idx) { + AccumulatorTile *accum_tile_workspace = reinterpret_cast(params.partials_workspace); + // Wait for arrival int num_carry_in = block_idx - first_block_idx; Barrier::wait_eq_reset(params.barrier_workspace, thread_idx, first_block_idx, num_carry_in); // Load and add peer-partials accumulator tile to local accumulator tile - int block_tile_offset = first_block_idx * kPeerAccumulators; - BlockStripedReduceT::load_add(accumulator_tile, params.partials_workspace + block_tile_offset, thread_idx); + int accum_tile_offset = first_block_idx * kThreadCount; + BlockStripedReduceT::load_add(accumulator_tile, accum_tile_workspace + accum_tile_offset, thread_idx); } @@ -815,7 +813,7 @@ protected: CUTLASS_DEVICE void do_epilogue( TileWorkDesc &tile_work, - typename Mma::FragmentC &accumulator_tile) + AccumulatorTile &accumulator_tile) { ElementC *ptr_C = static_cast(params.ptr_C); ElementC *ptr_D = static_cast(params.ptr_D); @@ -867,8 +865,8 @@ protected: int peer_idx_begin, peer_idx_last, reduce_tile_idx, reduce_fragment_idx; // Reduce by sk-tile (every tile contributed to by one or more blocks) - reduce_tile_idx = reduce_idx / kAccumulatorFragments; - reduce_fragment_idx = reduce_idx % kAccumulatorFragments; + reduce_tile_idx = reduce_idx / Epilogue::kAccumulatorFragments; + reduce_fragment_idx = reduce_idx % Epilogue::kAccumulatorFragments; int iter_tile_first = reduce_tile_idx * params.block_mapping.iters_per_tile; int iter_tile_last = iter_tile_first + params.block_mapping.iters_per_tile - 1; @@ -882,7 +880,7 @@ protected: Barrier::wait_eq_reset( params.barrier_workspace, thread_idx, - (reduce_tile_idx * kAccumulatorFragments) + reduce_fragment_idx, + (reduce_tile_idx * Epilogue::kAccumulatorFragments) + reduce_fragment_idx, num_peers); /// The location of this tile (in threadblock-tile coordinates) in the output matrix @@ -946,7 +944,7 @@ protected: typename Mma::IteratorB iterator_B = init_iterator_B(tile_work); // Initialize accumulators - typename Mma::FragmentC accumulator_tile; + AccumulatorTile accumulator_tile; accumulator_tile.clear(); // Perform this tile's range of multiply-accumulate (MAC) iterations @@ -978,7 +976,7 @@ protected: if (!tile_work.tile_started()) { // A "finishing" SK block must first aggregate its accumulator partial sums with those shared by peer threadblocks - acquire_accumulators_atomic(accumulator_tile, first_block_idx); + acquire_accumulators(accumulator_tile, first_block_idx); } do_epilogue(tile_work, accumulator_tile); @@ -997,8 +995,8 @@ protected: Barrier::arrive_range_inc( params.barrier_workspace, thread_idx, - tile_work.tile_idx * kAccumulatorFragments, - kAccumulatorFragments); + tile_work.tile_idx * Epilogue::kAccumulatorFragments, + Epilogue::kAccumulatorFragments); } }