diff --git a/include/cutlass/reduction/device/tensor_reduce.h b/include/cutlass/reduction/device/tensor_reduce.h new file mode 100644 index 00000000..c67b205e --- /dev/null +++ b/include/cutlass/reduction/device/tensor_reduce.h @@ -0,0 +1,258 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over one or more ranks of an affine tensor +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/reduction/device/tensor_reduce_affine_strided.h" +#include "cutlass/reduction/device/tensor_reduce_affine_contiguous.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tensor reduction operator on specific CUTLASS layouts over exactly one index +template < + typename ElementOutput_, + typename ElementSource_, + typename Layout_, + typename ReductionOp_, + int VectorLength_ = 1, + typename ElementCompute_ = ElementOutput_ +> +struct TensorReduction { + + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using Layout = Layout_; + using ReductionOp = ReductionOp_; + static int const kVectorLength = VectorLength_; + using ElementCompute = ElementCompute_; + + using TensorCoord = typename Layout::TensorCoord; + + /// Reduction operator + using ReductionDeviceStridedOperator = TensorReductionAffineStrided< + 4, 3, ElementOutput, ElementSource, ReductionOp, kVectorLength, ElementCompute + >; + + using ReductionDeviceContiguousOperator = TensorReductionAffineContiguous< + 4, 3, ElementOutput, ElementSource, ReductionOp, kVectorLength, ElementCompute + >; + + // + // Data members + // + + ReductionDeviceStridedOperator reduction_strided; + ReductionDeviceContiguousOperator reduction_contiguous; + int reduction_index; + + // + // Methods + // + + /// + TensorReduction( + TensorCoord extent, + int reduction_index_ + ): + reduction_index(reduction_index_) { + + Coord<4> extent_affine; + + switch (reduction_index) { + case 0: + extent_affine[0] = extent[1]; + extent_affine[1] = extent[2]; + extent_affine[2] = extent[0]; + extent_affine[3] = extent[3]; + break; + case 1: + extent_affine[0] = extent[0]; + extent_affine[1] = extent[2]; + extent_affine[2] = extent[1]; + extent_affine[3] = extent[3]; + break; + case 2: + extent_affine[0] = extent[0]; + extent_affine[1] = extent[1]; + extent_affine[2] = extent[2]; + extent_affine[3] = extent[3]; + break; + case 3: + extent_affine[0] = extent[0]; + extent_affine[1] = extent[1]; + extent_affine[2] = extent[2]; + extent_affine[3] = extent[3]; + break; + default: break; + } + + if (reduction_index == 3) { + reduction_contiguous = ReductionDeviceContiguousOperator(extent_affine); + } + else { + reduction_strided = ReductionDeviceStridedOperator(extent_affine); + } + } + + /// Simple check to verify the object is initialized correctly + bool good() const { + if (reduction_index == 3) { + return reduction_contiguous.good(); + } + return reduction_strided.good(); + } + + /// Size of one workspace + int64_t workspace_stride() const { + if (reduction_index == 3) { + return reduction_contiguous.workspace_stride(); + } + else { + return reduction_strided.workspace_stride(); + } + } + + /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs + int64_t workspace_size() const { + if (reduction_index == 3) { + return reduction_contiguous.workspace_size(); + } + else { + return reduction_strided.workspace_size(); + } + } + + /// Helper to use overloaded function call operator + Status reduce( + TensorRef dst_ref, + TensorRef src_ref, + void *device_workspace_ptr = nullptr, + ElementCompute reduction_identity = ElementCompute(), + ReductionOp reduction_op = ReductionOp(), + cudaStream_t stream = nullptr) { + + int64_t src_stride[3]; + int64_t dst_stride[2]; + + switch (reduction_index) { + case 0: + src_stride[0] = src_ref.stride()[1]; + src_stride[1] = src_ref.stride()[0]; + src_stride[2] = src_ref.stride()[2]; + dst_stride[0] = dst_ref.stride()[1]; + dst_stride[1] = dst_ref.stride()[0]; + break; + case 1: + src_stride[0] = src_ref.stride()[2]; + src_stride[1] = src_ref.stride()[0]; + src_stride[2] = src_ref.stride()[1]; + dst_stride[0] = dst_ref.stride()[2]; + dst_stride[1] = dst_ref.stride()[0]; + break; + case 2: + src_stride[0] = src_ref.stride()[2]; + src_stride[1] = src_ref.stride()[1]; + src_stride[2] = src_ref.stride()[0]; + dst_stride[0] = dst_ref.stride()[2]; + dst_stride[1] = dst_ref.stride()[1]; + break; + case 3: + src_stride[0] = src_ref.stride()[2]; + src_stride[1] = src_ref.stride()[1]; + src_stride[2] = src_ref.stride()[0]; + + dst_stride[0] = dst_ref.stride()[2]; + dst_stride[1] = dst_ref.stride()[1]; + dst_stride[2] = dst_ref.stride()[0]; + + default: break; + } + + if (reduction_index == 3) { + return reduction_contiguous( + dst_ref.data(), + dst_stride, + src_ref.data(), + src_stride, + device_workspace_ptr, + reduction_identity, + reduction_op, + stream); + } + else { + return reduction_strided( + dst_ref.data(), + dst_stride, + src_ref.data(), + src_stride, + device_workspace_ptr, + reduction_identity, + reduction_op, + stream); + } + } + + Status operator()( + TensorRef dst_ref, + TensorRef src_ref, + void *device_workspace_ptr = nullptr, + ElementCompute reduction_identity = ElementCompute(), + ReductionOp reduction_op = ReductionOp(), + cudaStream_t stream = nullptr) { + + return reduce( + dst_ref, + src_ref, + device_workspace_ptr, + reduction_identity, + reduction_op, + stream); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reduction +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h b/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h new file mode 100644 index 00000000..3b7ee419 --- /dev/null +++ b/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h @@ -0,0 +1,367 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over one or more ranks of an affine tensor +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tensor reduction operator on layouts which are affine +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (e.g. ND => 2) + typename ElementOutput_, + typename ElementSource_, + typename ReductionOp_, + int VectorLength = 1, + typename ElementCompute_ = ElementOutput_, + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +struct TensorReductionAffineContiguous { + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using ReductionOp = ReductionOp_; + using ElementCompute = ElementCompute_; + + // + // Data members + // + + /// Internal status field + Status status; + + /// Extent of tensor in source layout + Coord extent; + + /// Number of points in the outer index space + int64_t outer_count; + + /// Number of elements in the inner index space + int64_t inner_count; + + /// Number of workspaces needed + int workspace_count; + + /// CUDA Grid shape (.x => contiguous, .y => outer, .z => inner) + dim3 grid_shape; + + /// CUDA Threadblock shape (.x => contiguous, .y => outer, .z => inner) + dim3 threadblock_shape; + + /// CUDA grid shape for the final reduction step if needed + dim3 grid_final; + + /// CUDA threadblock shape for the final reduction step if needed + dim3 threadblock_final; + +private: + // + // Methods + // + + /// Helper to reshape 'count' such that it is less than 2 x 'ext' + static int reshape_pow2(int ext, int count) { + if (ext > count) { + return 1; + } + int x = 1; + for (; count >= ext * 2; ) { + count >>= 1; + x <<= 1; + } + return x; + } + +public: + + /// Default ctor + TensorReductionAffineContiguous(): + status(Status::kErrorInvalidProblem), + extent(), + outer_count(0), + inner_count(0), + workspace_count(0), + grid_shape(0, 0, 0), + threadblock_shape(0, 0, 0) { } + + /// Constructor + TensorReductionAffineContiguous( + Coord extent_, + int target_threadblock_count = 128 + ): + status(Status::kSuccess), + extent(extent_), + outer_count(0), + inner_count(0), + workspace_count(0) { + + // + // Plan the parallel mapping strategy. + // + + outer_count = 1; + inner_count = 1; + + // Compute number of elements in strided ranks + for (int p = 0; p < kReducedRank; ++p) { + outer_count *= extent[p]; + } + + for (int p = 0; p < kInnerRank; ++p) { + inner_count *= extent[kReducedRank + p]; + } + + int cta_count_x = 1; + int cta_count_y = 1; + int cta_count_z = 1; + + int cta_threads_x = kThreads; + int cta_threads_y = 1; + int cta_threads_z = 1; + + // Determine CTA shape + int64_t inner_vector_count = inner_count / kVectorLength; + + // Priority 1. Assign threadblocks to outer indices if possible + if (outer_count > target_threadblock_count) { + cta_count_x = 1; + cta_count_y = target_threadblock_count; + cta_count_z = 1; + } + else { + + cta_count_y = int(outer_count); + int remaining_ctas = target_threadblock_count / cta_count_y; + + // Priority 2. Assign inner dimensions to one CTA + if (inner_vector_count > cta_threads_x) { + int64_t cta_z_bound = inner_vector_count / cta_threads_x; + if (cta_z_bound > remaining_ctas) { + cta_count_z = remaining_ctas; + } + else { + cta_count_z = int(cta_z_bound); + } + } + else { + cta_threads_x = reshape_pow2(int(inner_vector_count), cta_threads_x); + cta_count_z = 1; + } + } + + grid_shape = dim3(cta_count_x, cta_count_y, cta_count_z); + threadblock_shape = dim3(cta_threads_x, cta_threads_y, cta_threads_z); + + workspace_count = (cta_count_z > 1 ? cta_count_z : 0); + + // Determine shape of final reduction kernel if needed + if (workspace_count) { + + int final_threads = kThreads; + int final_ctas = 1; + + if (outer_count > kThreads) { + final_ctas = int(outer_count + kThreads - 1) / kThreads; + } + else { + final_threads = int(outer_count); + } + + grid_final = dim3(final_ctas, 1, 1); + threadblock_final = dim3(final_threads, 1, 1); + } + else { + grid_final = dim3(0, 0, 0); + threadblock_final = dim3(0, 0, 0); + } + } + + /// Simple check to verify the object is initialized correctly + bool good() const { + return status == Status::kSuccess; + } + + /// Size (in bytes) of workspace elements which are densely packed together + int64_t workspace_stride() const { + + // Error condition + if (!good()) { + return 0; + } + + return outer_count * sizeof_bits::value / 8; + } + + /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs + int64_t workspace_size() const { + + // Error condition + if (!good()) { + return 0; + } + + // No reduction across CTAs + if (grid_shape.z == 1) { + return 0; + } + + return workspace_stride() * grid_shape.z; + } + + /// Performs a reduction + Status reduce( + ElementOutput *dst_ptr, ///< Pointer to destination tensor + int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) + ElementSource const *src_ptr, ///< Pointer to source tensor + int64_t src_stride[], ///< Stride vector (of length kRank - 1) + void *device_workspace_ptr = nullptr, ///< Device workspace + ElementCompute reduction_identity = ElementCompute(), ///< Reduction identity element + ReductionOp reduction_op = ReductionOp(), ///< Reduction operator + cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched + + // Initial status check + if (!good()) { + return status; + } + + // Guard against null workspace + if (workspace_count > 1 && device_workspace_ptr == nullptr) { + return Status::kErrorWorkspaceNull; + } + + // Define reduction kernel + using ReductionKernel = kernel::TensorReductionAffineContiguous< + kRank, + kReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + kVectorLength, + ElementCompute, + kThreads>; + + using FinalReductionKernel = kernel::TensorReductionAffineContiguousFinal< + kRank, + kReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + kVectorLength, + ElementCompute, + kThreads>; + + using Params = typename ReductionKernel::Params; + + // Construct the parameters + Params params( + extent, + dst_ptr, + dst_stride, + src_ptr, + src_stride, + static_cast(device_workspace_ptr), + workspace_stride(), + workspace_count, + reduction_op, + reduction_identity); + + // Shared memory size + int shared_mem_bytes = sizeof(typename ReductionKernel::SharedStorage); + + // Launch the kernel + Kernel<<< grid_shape, threadblock_shape, shared_mem_bytes, stream >>>(params); + + // Check error condition + if (cudaPeekAtLastError() == cudaSuccess) { + status = Status::kSuccess; + } + else { + status = Status::kErrorInternal; + } + + // Final reduction kernel + if (workspace_count) { + Kernel<<< grid_final, threadblock_final, 0, stream >>>(params); + } + + // Check error condition + if (cudaPeekAtLastError() == cudaSuccess) { + status = Status::kSuccess; + } + else { + status = Status::kErrorInternal; + } + + return status; + } + + /// Helper to use overloaded function call operator + Status operator()( + ElementOutput *dst_ptr, ///< Pointer to destination tensor + int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) + ElementSource const *src_ptr, ///< Pointer to source tensor + int64_t src_stride[], ///< Stride vector (of length kRank - 1) + void *device_workspace_ptr = nullptr, ///< Pointer to device workspace + ElementCompute reduction_identity = ElementCompute(), ///< Reduction identity element + ReductionOp reduction_op = ReductionOp(), ///< Reduction operator + cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched + + return reduce(dst_ptr, dst_stride, src_ptr, src_stride, device_workspace_ptr, reduction_identity, reduction_op, stream); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reduction +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/reduction/device/tensor_reduce_affine_strided.h b/include/cutlass/reduction/device/tensor_reduce_affine_strided.h new file mode 100644 index 00000000..9368d92a --- /dev/null +++ b/include/cutlass/reduction/device/tensor_reduce_affine_strided.h @@ -0,0 +1,355 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over one or more ranks of an affine tensor +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/reduction/kernel/tensor_reduce_affine_strided.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tensor reduction operator on layouts which are affine +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput_, + typename ElementSource_, + typename ReductionOp_, + int VectorLength = 1, + typename ElementCompute_ = ElementOutput_, + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +struct TensorReductionAffineStrided { + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using ReductionOp = ReductionOp_; + using ElementCompute = ElementCompute_; + + // + // Data members + // + + /// Internal status field + Status status; + + /// Extent of tensor in source layout + Coord extent; + + /// Number of points in the outer index space + int64_t outer_count; + + /// Number of elements in the inner index space + int64_t inner_count; + + /// Number of workspaces needed + int workspace_count; + + /// CUDA Grid shape (.x => contiguous, .y => outer, .z => inner) + dim3 grid_shape; + + /// CUDA Threadblock shape (.x => contiguous, .y => outer, .z => inner) + dim3 threadblock_shape; + + /// CUDA grid shape for the final reduction step if needed + dim3 grid_final; + + /// CUDA threadblock shape for the final reduction step if needed + dim3 threadblock_final; + +private: + // + // Methods + // + + /// Helper to reshape 'count' such that it is less than 2 x 'ext' + static int reshape_pow2(int ext, int count) { + if (ext > count) { + return 1; + } + int x = 1; + for (; count >= ext * 2; ) { + count >>= 1; + x <<= 1; + } + return x; + } + +public: + + /// Default ctor + TensorReductionAffineStrided(): + status(Status::kErrorInvalidProblem), + extent(), + outer_count(0), + inner_count(0), + workspace_count(0), + grid_shape(0, 0, 0), + threadblock_shape(0, 0, 0) { } + + /// Constructor + TensorReductionAffineStrided( + Coord extent_, + int target_threadblock_count = 128 + ): + status(Status::kSuccess), + extent(extent_), + outer_count(0), + inner_count(0), + workspace_count(0) { + + // + // Plan the parallel mapping strategy. + // + + outer_count = 1; + inner_count = 1; + + // Compute number of elements in strided ranks + for (int p = 0; p < kReducedRank - 1; ++p) { + outer_count *= extent[p]; + } + + for (int p = 0; p < kInnerRank; ++p) { + inner_count *= extent[kReducedRank + p - 1]; + } + + // Compute plan for the reduction + int extent_c = extent[kRank - 1]; + int vectors_c = (extent_c -1 + kVectorLength) / kVectorLength; + + // Determine CTA shape + int cta_width = kThreads * kVectorLength; + int cta_ways = reshape_pow2(extent_c, cta_width); + int cta_threads_x = kThreads / cta_ways; + + threadblock_shape = dim3(cta_threads_x, 1, std::min(cta_ways, 64)); + + // This leads to an error. + if (threadblock_shape.z > 1) { + if (threadblock_shape.y != 1) { + status = Status::kErrorInternal; + return; + } + } + + // Determine grid shape + int cta_count_x = (vectors_c + cta_threads_x - 1) / cta_threads_x; + int cta_count_y = std::max(1, target_threadblock_count / cta_count_x); + + // Limit the number of CTAs assigned to outer dimension + if (int64_t(cta_count_y * threadblock_shape.y) > outer_count) { + cta_count_y = int(outer_count + threadblock_shape.y - 1) / threadblock_shape.y; + } + + // Limit the number of CTAs assigned to inner dimension + int cta_count_z = std::max(1, target_threadblock_count / cta_count_y); + if (int64_t(cta_count_z * threadblock_shape.z) > inner_count) { + cta_count_z = int(inner_count + threadblock_shape.z - 1) / threadblock_shape.z; + } + + grid_shape = dim3(cta_count_x, cta_count_y, cta_count_z); + workspace_count = (cta_count_z > 1 ? cta_count_z : 0); + + // Determine shape of final reduction kernel if needed + grid_final = dim3(cta_count_x, int(outer_count)); + threadblock_final = dim3(cta_threads_x, 1, 1); + } + + /// Simple check to verify the object is initialized correctly + bool good() const { + return status == Status::kSuccess; + } + + /// Size of one CTA's workspace + int64_t workspace_stride() const { + + // Error condition + if (!good()) { + return 0; + } + + int vector_size_bytes = kVectorLength * sizeof_bits::value / 8; + + return extent[kRank - 1] * vector_size_bytes; + } + + /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs + int64_t workspace_size() const { + + // Error condition + if (!good()) { + return 0; + } + + // No reduction across CTAs + if (grid_shape.z == 1) { + return 0; + } + + return workspace_stride() * outer_count * grid_shape.z; + } + + /// Performs a reduction + Status reduce( + ElementOutput *dst_ptr, ///< Pointer to destination tensor + int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) + ElementSource const *src_ptr, ///< Pointer to source tensor + int64_t src_stride[], ///< Stride vector (of length kRank - 1) + void *device_workspace_ptr = nullptr, ///< Device workspace + ElementCompute reduction_identity = ElementCompute(), ///< Reduciton identity + ReductionOp reduction_op = ReductionOp(), ///< Reduction operator + cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched + + // Initial status check + if (!good()) { + return status; + } + + // Guard against null workspace + if (workspace_count > 1 && device_workspace_ptr == nullptr) { + return Status::kErrorWorkspaceNull; + } + + // Define reduction kernel + using ReductionKernel = kernel::TensorReductionAffineStrided< + kRank, + kReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + kVectorLength, + ElementCompute, + kThreads>; + + using FinalReductionKernel = kernel::TensorReductionAffineStridedFinal< + kRank, + kReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + kVectorLength, + ElementCompute, + kThreads>; + + using Params = typename ReductionKernel::Params; + + // Construct the parameters + Params params( + extent, + dst_ptr, + dst_stride, + src_ptr, + src_stride, + static_cast(device_workspace_ptr), + workspace_stride(), + workspace_count, + reduction_op, + reduction_identity); + + // Shared memory size + int shared_mem_bytes = sizeof(typename ReductionKernel::SharedStorage); + + // Launch the kernel + Kernel<<< grid_shape, threadblock_shape, shared_mem_bytes, stream >>>(params); + + // Check error condition + if (cudaPeekAtLastError() == cudaSuccess) { + status = Status::kSuccess; + } + else { + status = Status::kErrorInternal; + } + + // Final reduction kernel + if (workspace_count) { + + Kernel<<< grid_final, threadblock_final, 0, stream >>>(params); + + // Check error condition + if (cudaPeekAtLastError() == cudaSuccess) { + status = Status::kSuccess; + } + else { + status = Status::kErrorInternal; + } + } + + return status; + } + + /// Helper to use overloaded function call operator + Status operator()( + ElementOutput *dst_ptr, ///< Pointer to destination tensor + int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) + ElementSource const *src_ptr, ///< Pointer to source tensor + int64_t src_stride[], ///< Stride vector (of length kRank - 1) + void *device_workspace_ptr = nullptr, ///< Pointer to device workspace + ElementCompute reduction_identity = ElementCompute(), ///< Reduciton identity + ReductionOp reduction_op = ReductionOp(), ///< Reduction operator + cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched + + return reduce( + dst_ptr, + dst_stride, + src_ptr, + src_stride, + device_workspace_ptr, + reduction_identity, + reduction_op, + stream); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reduction +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h b/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h new file mode 100644 index 00000000..25f9bbef --- /dev/null +++ b/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h @@ -0,0 +1,600 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over one or more ranks of an affine tensor +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/reduction/thread/reduction_operators.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (i.e. number of outer ranks) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +struct TensorReductionAffineContiguousParams { + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + + Coord extent; /// Extent of source tensor + FastDivmodU64 divmod[kRank - 1]; /// FastDivmod by each strided rank + int64_t dst_stride[kReducedRank]; /// stride (units of bytes) - I, J + int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K + int64_t workspace_stride; /// stride (units of bytes) between workspace + int workspace_count; /// number of workspaces + + uint64_t inner_count; /// Number of elements in reduced index space + uint64_t outer_count; /// Number of elements in outer index space + + ElementOutput * destination; /// Pointer to output tensor of rank kReducedRank + ElementSource const * source; /// Poitner to source pointer of rank kRank + ReductionOp reduction_op; /// Reduction operator + ElementCompute reduction_identity; /// Identity element used by reduction operator + ElementCompute *device_workspace; /// Pointer to device workspace for inter-CTA reductions + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + TensorReductionAffineContiguousParams() { + + } + + /// Ctor + TensorReductionAffineContiguousParams( + Coord extent_, ///< Extent of source tensor + ElementOutput * dst_ptr_, ///< Output tensor data + int64_t dst_stride_[], ///< Stride (units of elements) + ElementSource const * src_ptr_, ///< Source tensor data + int64_t src_stride_[], ///< Stride (units of elements) + ElementCompute *device_workspace_, ///< Pointer to device workspace for inter-CTA reductions + int64_t workspace_stride_, ///< Stride between workspaces + int workspace_count_, ///< Number of workspaces + ReductionOp reduction_op_, ///< Reduction operator + ElementCompute reduction_identity_ = ElementCompute() ///< Identity element used by reduction operator + ): + extent(extent_), + inner_count(1), + outer_count(1), + destination(dst_ptr_), + source(src_ptr_), + device_workspace(device_workspace_), + workspace_stride(workspace_stride_), + workspace_count(workspace_count_), + reduction_op(reduction_op_), + reduction_identity(reduction_identity_) { + + // Initialize divisors for fast div-mod + for (int p = 1; p < kRank; ++p) { + divmod[p - 1] = FastDivmodU64(uint64_t(extent[p])); + } + + int input_size_bits = sizeof_bits::value; + int output_size_bits = sizeof_bits::value; + + // Compute strides in units of bytes + for (int p = 0; p < kReducedRank; ++p) { + dst_stride[p] = dst_stride_[p] * output_size_bits / 8; + } + + for (int p = 0; p < kRank - 1; ++p) { + src_stride[p] = src_stride_[p] * input_size_bits / 8; + } + + // Compute number of elements in strided ranks + for (int p = 0; p < kReducedRank; ++p) { + outer_count *= uint64_t(extent[p]); + } + + for (int p = 0; p < kInnerRank; ++p) { + inner_count *= uint64_t(extent[kRank - 1 - p]); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to reduce a tensor with affine layout over a set of ranks *INCLUDING* the contiguous +/// rank. This leads to favorable vectorized memory accesses over the contiguous rank. +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +class TensorReductionAffineContiguous { +public: + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + using ComputeFragment = Array; + using SourceFragment = AlignedArray; + using OutputFragment = AlignedArray; + + /// Shared memory allocation used for reduction within the CTA + struct SharedStorage { + Array workspace; + }; + + /// Parameters structure + using Params = TensorReductionAffineContiguousParams< + Rank, + ReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + VectorLength, + ElementCompute, + Threads, + BatchSize + >; + +private: + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_inner_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &src_offset, + uint64_t linear_idx) const { + + // Decompose into a coordinate of rank + coord = CoordinateDecomposition(linear_idx, ¶ms.divmod[kRank - kInnerRank]); + + // Compute an offset using the souce stride + src_offset = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kInnerRank - 1; ++i) { + src_offset += coord[i] * params.src_stride[kReducedRank + i]; + } + src_offset += coord[kInnerRank - 1] * sizeof_bits::value / 8; + } + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_outer_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &dst_offset, + int64_t &src_offset, + uint64_t linear_idx) const { + + // Decompose into coordinate of rank + coord = CoordinateDecomposition(linear_idx, params.divmod); + + // Compute offsets using destination and source strides + dst_offset = 0; + src_offset = 0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kReducedRank; ++i) { + dst_offset += params.dst_stride[i] * coord[i]; + src_offset += params.src_stride[i] * coord[i]; + } + } + + /// Reduces over the reduction indices yielding a single element + CUTLASS_DEVICE + ElementCompute reduce_indices_( + Params const ¶ms, + ElementCompute *threadblock_workspace, + char const *src_byte_ptr, + int coord_c) { + + NumericArrayConverter convert_source; + ReductionOp reduction_op(params.reduction_op); + + // + // Early exit or initialize to identity element + // + if (!params.inner_count) { + return params.reduction_identity; + } + + ComputeFragment accumulator; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < accumulator.size(); ++i) { + accumulator[i] = params.reduction_identity; + } + + // Compute the coordinate of the first access + int64_t src_byte_offset = 0; + Coord coord; + + uint64_t linear_idx = (threadIdx.x + blockDim.x * threadIdx.z + blockDim.x * blockIdx.z * blockDim.z) * kVectorLength; + compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); + + // Load the first vector + SourceFragment source_fragment[kBatchSize]; + + bool not_done = true; + + // Iterate over vectors in a linearized reduction index space + while (not_done) { + + bool guards[kBatchSize]; + + // Issue a batch of loads + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + + if (linear_idx < params.inner_count) { + source_fragment[b] = *reinterpret_cast(src_byte_ptr + src_byte_offset); + guards[b] = true; + } + else { + guards[b] = false; + not_done = false; + } + + linear_idx += (blockDim.z * gridDim.z * blockDim.x) * kVectorLength; + compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); + } + + // Perform a batch of reduction operations + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + if (guards[b]) { + auto cvt = convert_source(source_fragment[b]); + + accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( + reduction_op, + accumulator, + cvt); + } + } + }; + + // + // Reduction of vectors to scalar + // + + ElementCompute reduced_accumulator = accumulator[0]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kVectorLength; ++i) { + reduced_accumulator = reduction_op(reduced_accumulator, accumulator[i]); + } + + // + // Reduction within CTA across threadIdx.xz => threadIdx{.x = 0, .z = 0} + // + // This re-arranges data so threadIdx.y is effectively a row index and threadIdx.xz is a column + // + + int thread_count = blockDim.x * blockDim.z; + int thread_j = threadIdx.x + blockDim.x * threadIdx.z; + int thread_i = threadIdx.y; + + ElementCompute *frag_ptr = reinterpret_cast(threadblock_workspace) + thread_i * thread_count; + + frag_ptr[thread_j] = reduced_accumulator; + + // + // Reduce + // + CUTLASS_PRAGMA_NO_UNROLL + while (thread_count > 1) { + thread_count /= 2; + + __syncthreads(); + + if (thread_j < thread_count) { + ElementCompute other = frag_ptr[thread_j + thread_count]; + + reduced_accumulator = reduction_op(reduced_accumulator, other); + + frag_ptr[thread_j] = reduced_accumulator; + } + + __syncthreads(); + } + + + return reduced_accumulator; + } + +public: + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; + + char const * src_byte_ptr = reinterpret_cast(params.source); + char * dst_byte_ptr = nullptr; + + // If performing a reduction across CTAs, redirect output to device workspace + if (gridDim.z == 1) { + dst_byte_ptr = reinterpret_cast(params.destination); + } + else { + dst_byte_ptr = reinterpret_cast(params.device_workspace); + } + + uint64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; + + // Use modulo division to compute location + Coord outer_coord; + int64_t dst_byte_offset; + int64_t src_byte_offset; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + + if (gridDim.z == 1) { + + /// Complete the reduction with no workspace + while (idx_linear < params.outer_count) { + + ElementCompute result = reduce_indices_( + params, + shared_storage.workspace.data(), + src_byte_ptr + src_byte_offset, + coord_c); + + // Store the result after possible final reduction within the CTA + if (threadIdx.z == 0 && threadIdx.x == 0) { + + // Convert to output type and store + NumericConverter convert_output; + ElementOutput cvt = convert_output(result); + + *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = cvt; + } + + __syncthreads(); + + // Update indices and pointers + idx_linear += gridDim.y * blockDim.y; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + + } // while + } + else { + + /// Complete the reduction with workspace + while (idx_linear < params.outer_count) { + + ElementCompute result = reduce_indices_( + params, + shared_storage.workspace.data(), + src_byte_ptr + src_byte_offset, + coord_c); + + int64_t byte_offset = + blockIdx.z * params.workspace_stride + idx_linear * sizeof_bits::value / 8; + + // Store the result for final reduction + if (threadIdx.z == 0 && threadIdx.x == 0) { + *reinterpret_cast(dst_byte_ptr + byte_offset) = result; + } + + __syncthreads(); + + // Update indices and pointers + idx_linear += gridDim.y * blockDim.y; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + } // while + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to perform final reduction +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +class TensorReductionAffineContiguousFinal { +public: + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + + /// Shared memory + struct SharedStorage { }; + + /// Parameters structure + using Params = TensorReductionAffineContiguousParams< + Rank, + ReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + VectorLength, + ElementCompute, + Threads, + BatchSize + >; + +private: + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_outer_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &dst_offset, + uint64_t linear_idx) const { + + // Decompose into coordinate of rank + coord = CoordinateDecomposition(linear_idx, params.divmod); + + // Compute offsets using destination and source strides + dst_offset = 0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kReducedRank; ++i) { + dst_offset += params.dst_stride[i] * coord[i]; + } + } + + /// Reduces over the reduction indices + CUTLASS_DEVICE + ElementCompute reduce_indices_( + Params const ¶ms, + ElementCompute const *device_workspace) { + + ReductionOp reduction_op(params.reduction_op); + char const *src_byte_ptr = reinterpret_cast(device_workspace); + + // Accumulated output + ElementCompute accumulator = params.reduction_identity; + + for (int iter = 0; iter < params.workspace_count; ++iter) { + ElementCompute workspace_item = *reinterpret_cast(src_byte_ptr); + + accumulator = reduction_op(accumulator, workspace_item); + + src_byte_ptr += params.workspace_stride; + } + + return accumulator; + } + +public: + + // + // Methods + // + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + uint64_t idx_linear = blockIdx.x * blockDim.x + threadIdx.x; + + char * dst_byte_ptr = reinterpret_cast(params.destination); + + // Use modulo division to compute location + Coord outer_coord; + int64_t dst_byte_offset; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + idx_linear); + + /// Complete the reduction + while (idx_linear < params.outer_count) { + + ElementCompute result = reduce_indices_(params, params.device_workspace + idx_linear); + + // Convert to output type and store + NumericConverter convert_output; + + *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = convert_output(result); + + // Update indices and pointers + idx_linear += gridDim.x * blockDim.x; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + idx_linear); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h b/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h new file mode 100644 index 00000000..1dfe7e7e --- /dev/null +++ b/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h @@ -0,0 +1,635 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over one or more ranks of an affine tensor +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/reduction/thread/reduction_operators.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +/// Parameters structure +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +struct TensorReductionAffineStridedParams { + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + + Coord extent; /// Extent of source tensor + FastDivmodU64 divmod[kRank - 2]; /// FastDivmod by each strided rank + int64_t dst_stride[kReducedRank - 1]; /// stride (units of bytes) - I, J + int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K + int64_t workspace_stride; /// stride (units of bytes) between workspace + int64_t workspace_outer_stride; /// stride (units of bytes) between 'rows' of the workspace + int workspace_count; /// number of workspaces + + uint64_t inner_count; /// Number of elements in reduced index space + uint64_t outer_count; /// Number of elements in outer index space + + ElementOutput * destination; /// Pointer to output tensor of rank kReducedRank + ElementSource const * source; /// Poitner to source pointer of rank kRank + ReductionOp reduction_op; /// Reduction operator + ElementCompute reduction_identity; /// Identity element for reduction operator + ElementCompute *device_workspace; /// Pointer to device workspace for inter-CTA reductions + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + TensorReductionAffineStridedParams() { + + } + + /// Ctor + TensorReductionAffineStridedParams( + Coord extent_, ///< Extent of source tensor + ElementOutput * dst_ptr_, ///< Output tensor data + int64_t dst_stride_[], ///< Stride (units of elements) + ElementSource const * src_ptr_, ///< Source tensor data + int64_t src_stride_[], ///< Stride (units of elements) + ElementCompute *device_workspace_, ///< Pointer to device workspace for inter-CTA reductions + int64_t workspace_stride_, ///< Stride between workspaces + int workspace_count_, ///< Number of workspaces + ReductionOp reduction_op_, ///< Reduction operator + ElementCompute reduction_identity_ = ElementCompute() ///< Identity element for reduction operator + ): + extent(extent_), + inner_count(1), + outer_count(1), + destination(dst_ptr_), + source(src_ptr_), + device_workspace(device_workspace_), + workspace_outer_stride(0), + workspace_stride(workspace_stride_), + workspace_count(workspace_count_), + reduction_op(reduction_op_), + reduction_identity(reduction_identity_) { + + // Initialize divisors for fast div-mod + for (int p = 1; p < kRank - 1; ++p) { + divmod[p - 1] = FastDivmodU64(uint64_t(extent[p])); + } + + int input_size_bits = sizeof_bits::value; + int output_size_bits = sizeof_bits::value; + + workspace_outer_stride = workspace_stride * workspace_count; + + // Compute strides in units of bytes + for (int p = 0; p < kReducedRank - 1; ++p) { + dst_stride[p] = dst_stride_[p] * output_size_bits / 8; + } + + for (int p = 0; p < kRank - 1; ++p) { + src_stride[p] = src_stride_[p] * input_size_bits / 8; + } + + // Compute number of elements in strided ranks + for (int p = 0; p < kReducedRank - 1; ++p) { + outer_count *= uint64_t(extent[p]); + } + + for (int p = 0; p < kInnerRank; ++p) { + inner_count *= uint64_t(extent[kReducedRank + p - 1]); + } + } +}; + +/// Kernel to reduce a tensor with affine layout over a set of ranks *EXCLUDING* the contiguous +/// rank. This leads to favorable vectorized memory accesses over the contiguous rank. +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +class TensorReductionAffineStrided { +public: + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + using ComputeFragment = Array; + using SourceFragment = AlignedArray; + using OutputFragment = AlignedArray; + + /// Shared memory allocation used for reduction within the CTA + struct SharedStorage { + Array workspace; + }; + + /// Parameters structure + using Params = TensorReductionAffineStridedParams< + Rank, + ReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + VectorLength, + ElementCompute, + Threads, + BatchSize + >; + +private: + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_inner_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &src_offset, + uint64_t linear_idx) const { + + // Decompose into coordinate + coord = CoordinateDecomposition(linear_idx, ¶ms.divmod[kReducedRank]); + + // Compute linear offset + src_offset = 0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kInnerRank; ++i) { + src_offset += params.src_stride[kReducedRank + i - 1] * coord[i]; + } + } + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_outer_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &dst_offset, + int64_t &src_offset, + uint64_t linear_idx) const { + + // Decompose linear coordinate + coord = CoordinateDecomposition(linear_idx, params.divmod); + + // Compute offset into tensors + dst_offset = 0; + src_offset = 0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kReducedRank - 1; ++i) { + dst_offset += params.dst_stride[i] * coord[i]; + src_offset += params.src_stride[i] * coord[i]; + } + } + + /// Reduces over the reduction indices + CUTLASS_DEVICE + ComputeFragment reduce_indices_( + Params const ¶ms, + ElementCompute *threadblock_workspace, + char const *src_byte_ptr) { + + NumericArrayConverter convert_source; + ReductionOp reduction_op(params.reduction_op); + + // Accumulated output + ComputeFragment identity_frag; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < identity_frag.size(); ++i) { + identity_frag[i] = params.reduction_identity; + } + + if (!params.inner_count) { + return identity_frag; + } + + ComputeFragment accumulator = identity_frag; + + // Compute the coordinate of the first access + int64_t src_byte_offset = 0; + Coord coord; + + uint64_t linear_idx = threadIdx.z + blockIdx.z * blockDim.z; + compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); + + // Load the first vector + SourceFragment source_fragment[kBatchSize]; + + bool not_done = true; + + // Iterate over vectors in a linearized reduction index space + while (not_done) { + + bool guards[kBatchSize]; + + // Issue a batch of loads + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + + if (linear_idx < params.inner_count) { + source_fragment[b] = *reinterpret_cast(src_byte_ptr + src_byte_offset); + guards[b] = true; + } + else { + guards[b] = false; + not_done = false; + } + + linear_idx += blockDim.z * gridDim.z; + compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); + } + + // Perform a batch of reduction operations + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + if (guards[b]) { + + auto cvt = convert_source(source_fragment[b]); + + accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( + reduction_op, + accumulator, + cvt); + } + } + }; + + // Optional reduction within a CTA + if (blockDim.z > 1) { + + // Linearized thread ID + int thread_idx = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); + + // all threads store to workspace + ComputeFragment *frag_ptr = reinterpret_cast(threadblock_workspace); + + frag_ptr[thread_idx] = accumulator; + + __syncthreads(); + + if (threadIdx.z == 0) { + // Load all additional block indices + for (int z = 1; z < blockDim.z; ++z) { + ComputeFragment frag = frag_ptr[thread_idx + z * blockDim.x * blockDim.y]; + + accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( + reduction_op, + accumulator, + frag); + } + } + + __syncthreads(); + } + + return accumulator; + } + +public: + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; + + char const * src_byte_ptr = reinterpret_cast(params.source + coord_c); + char * dst_byte_ptr = nullptr; + + // If performing a reduction across CTAs, redirect output to device workspace + if (gridDim.z == 1) { + dst_byte_ptr = reinterpret_cast(params.destination + coord_c); + } + else { + dst_byte_ptr = reinterpret_cast(params.device_workspace + coord_c); + } + + // If the C index is out of bounds, exit + if (coord_c >= params.extent[kRank - 1]) { + return; + } + + int64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; + + // Use modulo division to compute location + Coord outer_coord; + int64_t dst_byte_offset; + int64_t src_byte_offset; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + + if (gridDim.z == 1) { + + /// Complete the reduction with no workspace + while (idx_linear < params.outer_count) { + + ComputeFragment result; + + result = reduce_indices_( + params, + shared_storage.workspace.data(), + src_byte_ptr + src_byte_offset); + + // Store the result after possible final reduction within the CTA + if (threadIdx.z == 0) { + + // Convert to output type and store + NumericArrayConverter convert_output; + auto cvt = convert_output(result); + + *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = + reinterpret_cast(cvt); + } + + // Update indices and pointers + idx_linear += gridDim.y * blockDim.y; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + + } // while + } + else { + + /// Complete the reduction with a device workspace + while (idx_linear < params.outer_count) { + + ComputeFragment result; + + result = reduce_indices_( + params, + shared_storage.workspace.data(), + src_byte_ptr + src_byte_offset); + + // Store the result after possible final reduction within the CTA + if (threadIdx.z == 0) { + + int64_t byte_offset = + blockIdx.z * params.workspace_stride + idx_linear * params.workspace_outer_stride; + + // No conversion - store in compute type + *reinterpret_cast(dst_byte_ptr + byte_offset) = + reinterpret_cast(result); + } + + // Update indices and pointers + idx_linear += gridDim.y * blockDim.y; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + + } // while (outer index) + } // if () + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to perform final reduction +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +class TensorReductionAffineStridedFinal { +public: + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + using ComputeFragment = Array; + using SourceFragment = AlignedArray; + using OutputFragment = AlignedArray; + + /// Shared memory + struct SharedStorage { }; + + /// Parameters structure + using Params = TensorReductionAffineStridedParams< + Rank, + ReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + VectorLength, + ElementCompute, + Threads, + BatchSize + >; + +private: + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_outer_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &dst_offset, + uint64_t linear_idx) const { + + // Decompose linear index + coord = CoordinateDecomposition(linear_idx, params.divmod); + + // Compute tensor offset + dst_offset = 0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kReducedRank - 1; ++i) { + dst_offset += params.dst_stride[i] * coord[i]; + } + } + + /// Reduces over the reduction indices + CUTLASS_DEVICE + ComputeFragment reduce_indices_( + Params const ¶ms, + char *src_byte_ptr) { + + ReductionOp reduction_op(params.reduction_op); + + // Accumulated output + ComputeFragment identity_frag; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < identity_frag.size(); ++i) { + identity_frag[i] = params.reduction_identity; + } + + ComputeFragment accumulator = identity_frag; + ComputeFragment workspace_fragments[kBatchSize]; + + // Partially unrolled loop + for (int idx = 0; idx < params.workspace_count; idx += kBatchSize) { + + // Issue a batch of loads + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + if (idx + b < params.workspace_count) { + workspace_fragments[b] = + *reinterpret_cast(src_byte_ptr); + } + else { + workspace_fragments[b] = identity_frag; + } + src_byte_ptr += + params.workspace_stride; + } + + // Perform a reduction + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorLength; ++i) { + accumulator[i] = reduction_op(accumulator[i], workspace_fragments[b][i]); + } + } + } + + return accumulator; + } + +public: + + // + // Methods + // + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; + + char * src_byte_ptr = reinterpret_cast(params.device_workspace + coord_c); + char * dst_byte_ptr = reinterpret_cast(params.destination + coord_c); + + // If the C index is out of bounds, exit + if (coord_c >= params.extent[kRank - 1]) { + return; + } + + int64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; + + // Use modulo division to compute location + Coord outer_coord; + int64_t dst_byte_offset; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + idx_linear); + + /// Complete the reduction + while (idx_linear < params.outer_count) { + + int64_t src_byte_offset = idx_linear * params.workspace_outer_stride; + + ComputeFragment result = reduce_indices_( + params, + src_byte_ptr + src_byte_offset); + + // Convert to output type and store + NumericArrayConverter convert_output; + auto cvt = convert_output(result); + + *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = + reinterpret_cast(cvt); + + // Update indices and pointers + idx_linear += gridDim.y * blockDim.y; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + idx_linear); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/reduction/device/CMakeLists.txt b/test/unit/reduction/device/CMakeLists.txt new file mode 100644 index 00000000..a13a33a2 --- /dev/null +++ b/test/unit/reduction/device/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_test_unit_add_executable( + cutlass_test_unit_reduction_device + tensor_reduce_strided.cu + tensor_reduce_contiguous.cu +) + diff --git a/test/unit/reduction/device/tensor_reduce_contiguous.cu b/test/unit/reduction/device/tensor_reduce_contiguous.cu new file mode 100644 index 00000000..ec406cfd --- /dev/null +++ b/test/unit/reduction/device/tensor_reduce_contiguous.cu @@ -0,0 +1,470 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for TensorReduce family of device-wide operators +*/ + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/reduction/thread/reduction_operators.h" +#include "cutlass/reduction/device/tensor_reduce.h" + +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This reduces the C dimension, transforming an NHWC tensor into NHWC with C=1. +template +bool TestAllReduction_NHWC_reduce_c(ElementCompute reduction_identity = ElementCompute()) { + + using Layout = typename TensorReduction::Layout; + using ElementOutput = typename TensorReduction::ElementOutput; + using ElementSource = typename TensorReduction::ElementSource; + + int const kV = TensorReduction::kVectorLength; + + int const N_indices[] = {3, 13}; + int const H_indices[] = {5, 17}; + int const W_indices[] = {7, 19}; + int const C_indices[] = {2049, 2048, 2047, 384, 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 1}; + + for (int N : N_indices) { + for (int H : H_indices) { + for (int W : W_indices) { + for (int Cx : C_indices) { + + int C = Cx * kV; + + cutlass::HostTensor src_tensor({N, H, W, C}); + cutlass::HostTensor dst_tensor({N, H, W, 1}); + + cutlass::reference::host::TensorFillRandomUniform( + src_tensor.host_view(), 17, 10, -10, 0); + + dst_tensor.sync_device(); + src_tensor.sync_device(); + + // Execute a tensor reduction over rank 3 (the 'C' dimension is reduced; NHWC => NHW) + TensorReduction reduction(src_tensor.extent(), 3); + + cutlass::DeviceAllocation device_workspace(reduction.workspace_size()); + + cutlass::Status status = reduction.reduce( + dst_tensor.device_ref(), + src_tensor.device_ref(), + device_workspace.get(), + reduction_identity + ); + + EXPECT_EQ(status, cutlass::Status::kSuccess); + EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess); + + dst_tensor.sync_host(); + + typename TensorReduction::ReductionOp reduction_op; + + // + // Reference check + // + for (int n = 0; n < src_tensor.extent().n(); ++n) { + for (int h = 0; h < src_tensor.extent().h(); ++h) { + for (int w = 0; w < src_tensor.extent().w(); ++w) { + + ElementCompute c_accum = reduction_identity; + + for (int c = 0; c < src_tensor.extent().c(); ++c) { + c_accum = reduction_op(c_accum, ElementCompute(src_tensor.at({n, h, w, c}))); + } + + ElementCompute got = ElementCompute(dst_tensor.at({n, h, w, 0})); + + bool equal = (c_accum == got); + + EXPECT_TRUE(equal); + if (!equal) { + + std::cerr + << "Error at location (" << n << ", " << h << ", " << w << ", 0)" << std::endl; + + std::cerr + << " expected: " << c_accum << std::endl + << " got: " << got << std::endl; + + std::cerr + << "Problem: " << src_tensor.extent() << " -> " + << dst_tensor.extent() << std::endl; + + std::cerr + << " Grid: " << reduction.reduction_strided.grid_shape + << "\n Block: " << reduction.reduction_strided.threadblock_shape << std::endl + << " FInal: " << reduction.reduction_strided.grid_final + << "\n Block: " << reduction.reduction_strided.threadblock_final << "\n"; + + return false; + } + + } //w + } // h + } // n + + // + // Next problem + // + + } // C + } // W + } // H + } // N + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHW +TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x1) { + + using Layout = cutlass::layout::TensorNHWC; + using ElementOutput = float; + using ElementSource = float; + using ElementCompute = float; + int const kV = 1; + + // Define the functor + using Functor = cutlass::plus; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHW +TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x1_f16x1) { + + using Layout = cutlass::layout::TensorNHWC; + using ElementOutput = float; + using ElementSource = cutlass::half_t; + using ElementCompute = float; + int const kV = 1; + + // Define the functor + using Functor = cutlass::plus; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHW +TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x2) { + + using Layout = cutlass::layout::TensorNHWC; + using ElementOutput = float; + using ElementSource = float; + using ElementCompute = float; + int const kV = 2; + + // Define the functor + using Functor = cutlass::plus; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHW +TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x2_f16x2) { + + using Layout = cutlass::layout::TensorNHWC; + using ElementOutput = float; + using ElementSource = cutlass::half_t; + using ElementCompute = float; + int const kV = 2; + + // Define the functor + using Functor = cutlass::plus; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHW +TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x4) { + + using Layout = cutlass::layout::TensorNHWC; + using ElementOutput = float; + using ElementSource = float; + using ElementCompute = float; + int const kV = 4; + + // Define the functor + using Functor = cutlass::plus; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHW +TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x4_f16x4) { + + using Layout = cutlass::layout::TensorNHWC; + using ElementOutput = float; + using ElementSource = cutlass::half_t; + using ElementCompute = float; + int const kV = 4; + + // Define the functor + using Functor = cutlass::plus; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHW +TEST(Reduction_TensorReduce, nhwc_maximum_c_f32x4) { + + using Layout = cutlass::layout::TensorNHWC; + using ElementOutput = float; + using ElementSource = float; + using ElementCompute = float; + int const kV = 4; + + // Define the functor + using Functor = cutlass::maximum; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( -std::numeric_limits::max() )); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHW +TEST(Reduction_TensorReduce, nhwc_minimum_c_f32x4) { + + using Layout = cutlass::layout::TensorNHWC; + using ElementOutput = float; + using ElementSource = float; + using ElementCompute = float; + int const kV = 4; + + // Define the functor + using Functor = cutlass::minimum; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( std::numeric_limits::max() )); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHW +TEST(Reduction_TensorReduce, nhwc_ANY_c_s32) { + + using Layout = cutlass::layout::TensorNHWC; + using ElementOutput = int; + using ElementSource = int; + using ElementCompute = int; + int const kV = 1; + + // Define the functor + using Functor = cutlass::logical_or; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( ElementCompute(0) )); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHW +TEST(Reduction_TensorReduce, nhwc_ALL_c_s32) { + + using Layout = cutlass::layout::TensorNHWC; + using ElementOutput = int; + using ElementSource = int; + using ElementCompute = int; + int const kV = 1; + + // Define the functor + using Functor = cutlass::logical_and; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( ElementCompute(1) )); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHW +TEST(Reduction_TensorReduce, nhwc_ANY_c_f32) { + + using Layout = cutlass::layout::TensorNHWC; + using ElementOutput = float; + using ElementSource = float; + using ElementCompute = float; + int const kV = 1; + + // Define the functor + using Functor = cutlass::logical_or; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( ElementCompute(0) )); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHW +TEST(Reduction_TensorReduce, nhwc_ALL_c_f32) { + + using Layout = cutlass::layout::TensorNHWC; + using ElementOutput = float; + using ElementSource = float; + using ElementCompute = float; + int const kV = 1; + + // Define the functor + using Functor = cutlass::logical_and; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( ElementCompute(1) )); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/reduction/device/tensor_reduce_strided.cu b/test/unit/reduction/device/tensor_reduce_strided.cu new file mode 100644 index 00000000..fda925c8 --- /dev/null +++ b/test/unit/reduction/device/tensor_reduce_strided.cu @@ -0,0 +1,517 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for TensorReduce family of device-wide operators +*/ + +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/reduction/thread/reduction_operators.h" +#include "cutlass/reduction/device/tensor_reduce.h" + +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This reduces the W dimension, transforming an NHWC tensor into NHWC with W=1. +template < + typename TensorReduction, + typename ElementCompute = typename TensorReduction::ElementCompute +> +bool TestAllReduction_NHWC_reduce_w(ElementCompute reduction_identity = ElementCompute()) { + + using Layout = typename TensorReduction::Layout; + using ElementOutput = typename TensorReduction::ElementOutput; + using ElementSource = typename TensorReduction::ElementSource; + + int const kV = TensorReduction::kVectorLength; + + int const N_indices[] = {1, 2, 5, 10}; + int const H_indices[] = {1, 3, 9 }; + int const W_indices[] = {1, 5, 19, 40, 224}; + int const C_indices[] = { + kV, + 2 * kV, + 5 * kV, + 9 * kV, + 17 * kV, + 39 * kV, + 257 * kV, + kV * 760 + }; + + using Element = int; + + for (int N : N_indices) { + for (int H : H_indices) { + for (int W : W_indices) { + for (int C : C_indices) { + + cutlass::HostTensor src_tensor({N, H, W, C}); + cutlass::HostTensor dst_tensor({N, H, 1, C}); + + cutlass::reference::host::TensorFillRandomUniform( + src_tensor.host_view(), 17, 10, -10, 0); + + cutlass::reference::host::BlockFillSequential( + dst_tensor.host_data(), dst_tensor.capacity()); + + dst_tensor.sync_device(); + src_tensor.sync_device(); + + // Execute a tensor reduction over rank 2 (the 'W' dimension is reduced; NHWC => NHC) + TensorReduction reduction(src_tensor.extent(), 2); + + cutlass::DeviceAllocation device_workspace(reduction.workspace_size()); + + cutlass::Status status = reduction.reduce( + dst_tensor.device_ref(), + src_tensor.device_ref(), + device_workspace.get(), + reduction_identity + ); + + EXPECT_EQ(status, cutlass::Status::kSuccess); + EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess); + // Reference check + dst_tensor.sync_host(); + + typename TensorReduction::ReductionOp reduction_op; + + for (int n = 0; n < src_tensor.extent().n(); ++n) { + for (int h = 0; h < src_tensor.extent().h(); ++h) { + for (int c = 0; c < src_tensor.extent().c(); ++c) { + + ElementCompute w_accum = reduction_identity; + + for (int w = 0; w < src_tensor.extent().w(); ++w) { + w_accum = reduction_op(w_accum, ElementCompute(src_tensor.at({n, h, w, c}))); + } + + ElementCompute got = ElementCompute(dst_tensor.at({n, h, 0, c})); + + bool equal = (w_accum == got); + + EXPECT_TRUE(equal); + if (!equal) { + + std::cerr + << "Error at location (" << n << ", " << h << ", 0, " << c << ")" << std::endl; + + std::cerr + << " expected: " << w_accum << std::endl + << " got: " << got << std::endl; + + std::cerr + << "Problem: " << src_tensor.extent() << " -> " + << dst_tensor.extent() << std::endl; + + std::cerr + << " Grid: " << reduction.reduction_strided.grid_shape + << "\n Block: " << reduction.reduction_strided.threadblock_shape << std::endl + << " Final: " << reduction.reduction_strided.grid_final + << "\n Block: " << reduction.reduction_strided.threadblock_final << "\n"; + + return false; + } + } + } + } + } + } + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHC +TEST(Reduction_TensorReduce, nhwc_reduce_w_f32x8_f16x8) { + + int const kV = 8; + using ElementOutput = float; + using ElementSource = cutlass::half_t; + using ElementCompute = float; + using Layout = cutlass::layout::TensorNHWC; + + // Define the functor + using Functor = cutlass::plus; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); +} + +/// Test tensor reduction from NHWC to NHC +TEST(Reduction_TensorReduce, nhwc_reduce_w_f32x2_f16x2) { + + int const kV = 2; + using ElementOutput = float; + using ElementSource = cutlass::half_t; + using ElementCompute = float; + using Layout = cutlass::layout::TensorNHWC; + + // Define the functor + using Functor = cutlass::plus; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); +} + +/// Test tensor reduction from NHWC to NHC +TEST(Reduction_TensorReduce, nhwc_reduce_w_f32x1_f16x1) { + + int const kV = 1; + using ElementOutput = float; + using ElementSource = cutlass::half_t; + using ElementCompute = float; + using Layout = cutlass::layout::TensorNHWC; + + // Define the functor + using Functor = cutlass::plus; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); +} + +/// Test tensor reduction from NHWC to NHC +TEST(Reduction_TensorReduce, nhwc_reduce_w_s32x4) { + + int const kV = 4; + using Element = int; + using Layout = cutlass::layout::TensorNHWC; + + // Define the functor + using Functor = cutlass::plus; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + Element, + Element, + Layout, + Functor, + kV, + Element + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); +} + +/// Test tensor reduction from NHWC to NHC +TEST(Reduction_TensorReduce, nhwc_reduce_w_cf32) { + + int const kV = 1; + using ElementOutput = cutlass::complex; + using ElementSource = cutlass::complex; + using ElementCompute = cutlass::complex; + using Layout = cutlass::layout::TensorNHWC; + + // Define the functor + using Functor = cutlass::plus; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHC +TEST(Reduction_TensorReduce, nhwc_maximum_w_cf32) { + + int const kV = 1; + using ElementOutput = float; + using ElementSource = float; + using ElementCompute = float; + using Layout = cutlass::layout::TensorNHWC; + + // Define the functor + using Functor = cutlass::maximum; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_w( -std::numeric_limits::max() )); +} + +/// Test tensor reduction from NHWC to NHC +TEST(Reduction_TensorReduce, nhwc_minimum_w_cf32) { + + int const kV = 1; + using ElementOutput = float; + using ElementSource = float; + using ElementCompute = float; + using Layout = cutlass::layout::TensorNHWC; + + // Define the functor + using Functor = cutlass::minimum; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(std::numeric_limits::max())); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHC +TEST(Reduction_TensorReduce, nhwc_XOR_w_u32) { + + int const kV = 1; + using ElementOutput = int; + using ElementSource = int; + using ElementCompute = int; + using Layout = cutlass::layout::TensorNHWC; + + // Define the functor + using Functor = cutlass::bit_xor; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHC +TEST(Reduction_TensorReduce, nhwc_AND_w_s32) { + + int const kV = 1; + using ElementOutput = unsigned; + using ElementSource = unsigned; + using ElementCompute = unsigned; + using Layout = cutlass::layout::TensorNHWC; + + // Define the functor + using Functor = cutlass::bit_and; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(0xffffffff)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHC +TEST(Reduction_TensorReduce, nhwc_OR_w_u32) { + + int const kV = 1; + using ElementOutput = int; + using ElementSource = int; + using ElementCompute = int; + using Layout = cutlass::layout::TensorNHWC; + + // Define the functor + using Functor = cutlass::bit_or; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHC +TEST(Reduction_TensorReduce, nhwc_ANY_w_s32) { + + int const kV = 1; + using ElementOutput = int; + using ElementSource = int; + using ElementCompute = int; + using Layout = cutlass::layout::TensorNHWC; + + // Define the functor + using Functor = cutlass::logical_or; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(ElementCompute(0))); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHC +TEST(Reduction_TensorReduce, nhwc_ALL_w_s32) { + + int const kV = 1; + using ElementOutput = int; + using ElementSource = int; + using ElementCompute = int; + using Layout = cutlass::layout::TensorNHWC; + + // Define the functor + using Functor = cutlass::logical_and; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(ElementCompute(1))); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHC +TEST(Reduction_TensorReduce, nhwc_ANY_w_f32) { + + int const kV = 1; + using ElementOutput = float; + using ElementSource = float; + using ElementCompute = float; + using Layout = cutlass::layout::TensorNHWC; + + // Define the functor + using Functor = cutlass::logical_or; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(ElementCompute(0))); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test tensor reduction from NHWC to NHC +TEST(Reduction_TensorReduce, nhwc_ALL_w_f32) { + + int const kV = 1; + using ElementOutput = float; + using ElementSource = float; + using ElementCompute = float; + using Layout = cutlass::layout::TensorNHWC; + + // Define the functor + using Functor = cutlass::logical_and; + + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementOutput, + ElementSource, + Layout, + Functor, + kV, + ElementCompute + >; + + EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(ElementCompute(1))); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////