/*************************************************************************************************** * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this list of * conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright notice, this list of * conditions and the following disclaimer in the documentation and/or other materials * provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used * to endorse or promote products derived from this software without specific prior written * permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /* \file \brief Execution environment */ #include #include "cutlass/numeric_types.h" #include "cutlass/layout/matrix.h" #include "cutlass/layout/tensor.h" #include "cutlass/util/reference/device/tensor_compare.h" #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/tensor_view_io.h" #include "cutlass/library/util.h" #include "device_allocation.h" namespace cutlass { namespace profiler { ///////////////////////////////////////////////////////////////////////////////////////////////// size_t DeviceAllocation::bytes(library::NumericTypeID type, size_t capacity) { return size_t(cutlass::library::sizeof_bits(type)) * capacity / 8; } ///////////////////////////////////////////////////////////////////////////////////////////////// template static std::vector get_packed_layout_stride(std::vector const &extent) { typename Layout::TensorCoord extent_coord; typename Layout::Stride stride_coord; if (extent.size() != size_t(Layout::kRank)) { throw std::runtime_error("Layout does not have same rank as extent vector."); } for (int i = 0; i < Layout::kRank; ++i) { extent_coord[i] = extent.at(i); } std::vector stride; stride.resize(Layout::kStrideRank, 0); Layout layout = Layout::packed(extent_coord); stride_coord = layout.stride(); for (int i = 0; i < Layout::kStrideRank; ++i) { stride.at(i) = stride_coord[i]; } return stride; } /// Returns the stride of a packed layout std::vector DeviceAllocation::get_packed_layout( library::LayoutTypeID layout_id, std::vector const &extent) { std::vector stride; switch (layout_id) { case library::LayoutTypeID::kColumnMajor: stride = get_packed_layout_stride(extent); break; case library::LayoutTypeID::kRowMajor: stride = get_packed_layout_stride(extent); break; case library::LayoutTypeID::kColumnMajorInterleavedK2: stride = get_packed_layout_stride>(extent); break; case library::LayoutTypeID::kRowMajorInterleavedK2: stride = get_packed_layout_stride>(extent); break; case library::LayoutTypeID::kColumnMajorInterleavedK4: stride = get_packed_layout_stride>(extent); break; case library::LayoutTypeID::kRowMajorInterleavedK4: stride = get_packed_layout_stride>(extent); break; case library::LayoutTypeID::kColumnMajorInterleavedK16: stride = get_packed_layout_stride>(extent); break; case library::LayoutTypeID::kRowMajorInterleavedK16: stride = get_packed_layout_stride>(extent); break; case library::LayoutTypeID::kColumnMajorInterleavedK32: stride = get_packed_layout_stride>(extent); break; case library::LayoutTypeID::kRowMajorInterleavedK32: stride = get_packed_layout_stride>(extent); break; case library::LayoutTypeID::kColumnMajorInterleavedK64: stride = get_packed_layout_stride>(extent); break; case library::LayoutTypeID::kRowMajorInterleavedK64: stride = get_packed_layout_stride>(extent); break; case library::LayoutTypeID::kTensorNCHW: stride = get_packed_layout_stride(extent); break; case library::LayoutTypeID::kTensorNHWC: stride = get_packed_layout_stride(extent); break; case library::LayoutTypeID::kTensorNDHWC: stride = get_packed_layout_stride(extent); break; case library::LayoutTypeID::kTensorNC32HW32: stride = get_packed_layout_stride>(extent); break; case library::LayoutTypeID::kTensorNC64HW64: stride = get_packed_layout_stride>(extent); break; case library::LayoutTypeID::kTensorC32RSK32: stride = get_packed_layout_stride>(extent); break; case library::LayoutTypeID::kTensorC64RSK64: stride = get_packed_layout_stride>(extent); break; default: break; } return stride; } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Template to use CUTLASS Layout functions to template static size_t construct_layout_( void *bytes, library::LayoutTypeID layout_id, std::vector const &extent, std::vector &stride) { if (extent.size() != Layout::kRank) { throw std::runtime_error( "Layout must have same rank as extent vector."); } if (Layout::kStrideRank && stride.empty()) { stride = get_packed_layout_stride(extent); return construct_layout_( bytes, layout_id, extent, stride); } else if (Layout::kStrideRank && stride.size() != Layout::kStrideRank) { throw std::runtime_error( "Layout requires either empty stride or stride vector matching Layout::kStrideRank"); } typename Layout::Stride stride_coord; for (int i = 0; i < Layout::kStrideRank; ++i) { stride_coord[i] = stride.at(i); } typename Layout::TensorCoord extent_coord; for (int i = 0; i < Layout::kRank; ++i) { extent_coord[i] = extent.at(i); } // Construct the CUTLASS layout object from the stride object Layout layout(stride_coord); // Pack it into bytes if (bytes) { *reinterpret_cast(bytes) = layout; } // Return capacity size_t capacity_ = layout.capacity(extent_coord); return capacity_; } /// returns the capacity needed size_t DeviceAllocation::construct_layout( void *bytes, library::LayoutTypeID layout_id, std::vector const &extent, std::vector &stride) { switch (layout_id) { case library::LayoutTypeID::kColumnMajor: return construct_layout_(bytes, layout_id, extent, stride); case library::LayoutTypeID::kRowMajor: return construct_layout_(bytes, layout_id, extent, stride); case library::LayoutTypeID::kColumnMajorInterleavedK2: return construct_layout_>(bytes, layout_id, extent, stride); case library::LayoutTypeID::kRowMajorInterleavedK2: return construct_layout_>(bytes, layout_id, extent, stride); case library::LayoutTypeID::kColumnMajorInterleavedK4: return construct_layout_>(bytes, layout_id, extent, stride); case library::LayoutTypeID::kRowMajorInterleavedK4: return construct_layout_>(bytes, layout_id, extent, stride); case library::LayoutTypeID::kColumnMajorInterleavedK16: return construct_layout_>(bytes, layout_id, extent, stride); case library::LayoutTypeID::kRowMajorInterleavedK16: return construct_layout_>(bytes, layout_id, extent, stride); case library::LayoutTypeID::kColumnMajorInterleavedK32: return construct_layout_>(bytes, layout_id, extent, stride); case library::LayoutTypeID::kRowMajorInterleavedK32: return construct_layout_>(bytes, layout_id, extent, stride); case library::LayoutTypeID::kColumnMajorInterleavedK64: return construct_layout_>(bytes, layout_id, extent, stride); case library::LayoutTypeID::kRowMajorInterleavedK64: return construct_layout_>(bytes, layout_id, extent, stride); case library::LayoutTypeID::kTensorNCHW: return construct_layout_(bytes, layout_id, extent, stride); case library::LayoutTypeID::kTensorNHWC: return construct_layout_(bytes, layout_id, extent, stride); case library::LayoutTypeID::kTensorNDHWC: return construct_layout_(bytes, layout_id, extent, stride); case library::LayoutTypeID::kTensorNC32HW32: return construct_layout_>(bytes, layout_id, extent, stride); case library::LayoutTypeID::kTensorNC64HW64: return construct_layout_>(bytes, layout_id, extent, stride); case library::LayoutTypeID::kTensorC32RSK32: return construct_layout_>(bytes, layout_id, extent, stride); case library::LayoutTypeID::kTensorC64RSK64: return construct_layout_>(bytes, layout_id, extent, stride); default: break; } return 0; } ///////////////////////////////////////////////////////////////////////////////////////////////// DeviceAllocation::DeviceAllocation(): type_(library::NumericTypeID::kInvalid), batch_stride_(0), capacity_(0), pointer_(nullptr), layout_(library::LayoutTypeID::kUnknown), batch_count_(1) { } DeviceAllocation::DeviceAllocation( library::NumericTypeID type, size_t capacity ): type_(type), batch_stride_(capacity), capacity_(capacity), pointer_(nullptr), layout_(library::LayoutTypeID::kUnknown), batch_count_(1) { cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type, capacity)); if (result != cudaSuccess) { type_ = library::NumericTypeID::kInvalid; capacity_ = 0; pointer_ = nullptr; throw std::bad_alloc(); } } DeviceAllocation::DeviceAllocation( library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, std::vector const &stride, int batch_count ): type_(type), batch_stride_(size_t(0)), capacity_(size_t(0)), pointer_(nullptr), batch_count_(1) { reset(type, layout_id, extent, stride, batch_count); } DeviceAllocation::~DeviceAllocation() { if (pointer_) { cudaFree(pointer_); } } DeviceAllocation &DeviceAllocation::reset() { if (pointer_) { cudaFree(pointer_); } type_ = library::NumericTypeID::kInvalid; batch_stride_ = 0; capacity_ = 0; pointer_ = nullptr; layout_ = library::LayoutTypeID::kUnknown; stride_.clear(); extent_.clear(); tensor_ref_buffer_.clear(); batch_count_ = 1; return *this; } DeviceAllocation &DeviceAllocation::reset(library::NumericTypeID type, size_t capacity) { reset(); type_ = type; batch_stride_ = capacity; capacity_ = capacity; cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type_, capacity_)); if (result != cudaSuccess) { throw std::bad_alloc(); } layout_ = library::LayoutTypeID::kUnknown; stride_.clear(); extent_.clear(); batch_count_ = 1; tensor_ref_buffer_.resize(sizeof(pointer_), 0); std::memcpy(tensor_ref_buffer_.data(), &pointer_, sizeof(pointer_)); return *this; } /// Allocates memory for a given layout and tensor DeviceAllocation &DeviceAllocation::reset( library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, std::vector const &stride, int batch_count) { reset(); tensor_ref_buffer_.resize(sizeof(pointer_) + (sizeof(int) * library::get_layout_stride_rank(layout_id)), 0); type_ = type; layout_ = layout_id; stride_ = stride; extent_ = extent; batch_count_ = batch_count; batch_stride_ = construct_layout( tensor_ref_buffer_.data() + sizeof(pointer_), layout_id, extent, stride_); capacity_ = batch_stride_ * batch_count_; cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type, capacity_)); if (result != cudaSuccess) { throw std::bad_alloc(); } std::memcpy(tensor_ref_buffer_.data(), &pointer_, sizeof(pointer_)); return *this; } bool DeviceAllocation::good() const { return (capacity_ && pointer_); } library::NumericTypeID DeviceAllocation::type() const { return type_; } void *DeviceAllocation::data() const { return pointer_; } void *DeviceAllocation::batch_data(int batch_idx) const { return static_cast(data()) + batch_stride_bytes() * batch_idx; } library::LayoutTypeID DeviceAllocation::layout() const { return layout_; } std::vector const & DeviceAllocation::stride() const { return stride_; } /// Gets the extent vector std::vector const & DeviceAllocation::extent() const { return extent_; } /// Gets the number of adjacent tensors in memory int DeviceAllocation::batch_count() const { return batch_count_; } /// Gets the stride (in units of elements) beteween items int64_t DeviceAllocation::batch_stride() const { return batch_stride_; } /// Gets the stride (in units of bytes) beteween items int64_t DeviceAllocation::batch_stride_bytes() const { return bytes(type_, batch_stride_); } size_t DeviceAllocation::capacity() const { return capacity_; } size_t DeviceAllocation::bytes() const { return bytes(type_, capacity_); } /// Copies from an equivalent-sized tensor in device memory void DeviceAllocation::copy_from_device(void const *ptr) { cudaError_t result = cudaMemcpy(data(), ptr, bytes(), cudaMemcpyDeviceToDevice); if (result != cudaSuccess) { throw std::runtime_error("Failed device-to-device copy"); } } /// Copies from an equivalent-sized tensor in device memory void DeviceAllocation::copy_from_host(void const *ptr) { cudaError_t result = cudaMemcpy(data(), ptr, bytes(), cudaMemcpyHostToDevice); if (result != cudaSuccess) { throw std::runtime_error("Failed device-to-device copy"); } } /// Copies from an equivalent-sized tensor in device memory void DeviceAllocation::copy_to_host(void *ptr) { cudaError_t result = cudaMemcpy(ptr, data(), bytes(), cudaMemcpyDeviceToHost); if (result != cudaSuccess) { throw std::runtime_error("Failed device-to-device copy"); } } void DeviceAllocation::initialize_random_device(int seed, Distribution dist) { if (!good()) { throw std::runtime_error("Attempting to initialize invalid allocation."); } // Instantiate calls to CURAND here. This file takes a long time to compile for // this reason. switch (type_) { case library::NumericTypeID::kF16: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kBF16: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kTF32: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kF32: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kCBF16: cutlass::reference::device::BlockFillRandom>( reinterpret_cast *>(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kCTF32: cutlass::reference::device::BlockFillRandom>( reinterpret_cast *>(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kCF32: cutlass::reference::device::BlockFillRandom>( reinterpret_cast *>(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kF64: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kCF64: cutlass::reference::device::BlockFillRandom>( reinterpret_cast *>(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kS2: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kS4: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kS8: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kS16: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kS32: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kS64: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kB1: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kU2: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kU4: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kU8: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kU16: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kU32: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; case library::NumericTypeID::kU64: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), capacity_, seed, dist ); break; default: break; } } void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { if (!good()) { throw std::runtime_error("Attempting to initialize invalid allocation."); } std::vector host_data(bytes()); switch (type_) { case library::NumericTypeID::kF16: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kBF16: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kTF32: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kF32: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kCF16: cutlass::reference::host::BlockFillRandom>( reinterpret_cast *>(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kCBF16: cutlass::reference::host::BlockFillRandom>( reinterpret_cast *>(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kCTF32: cutlass::reference::host::BlockFillRandom>( reinterpret_cast *>(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kCF32: cutlass::reference::host::BlockFillRandom>( reinterpret_cast *>(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kF64: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kCF64: cutlass::reference::host::BlockFillRandom>( reinterpret_cast *>(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kS2: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kS4: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kS8: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kS16: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kS32: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kS64: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kB1: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kU2: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kU4: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kU8: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kU16: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kU32: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; case library::NumericTypeID::kU64: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), capacity_, seed, dist ); break; default: break; } copy_from_host(host_data.data()); } void DeviceAllocation::initialize_random_sparsemeta_device(int seed, int MetaSizeInBits) { if (!good()) { throw std::runtime_error("Attempting to initialize invalid allocation."); } // Instantiate calls to CURAND here. This file takes a long time to compile for // this reason. switch (type_) { case library::NumericTypeID::kU16: cutlass::reference::device::BlockFillRandomSparseMeta( reinterpret_cast(pointer_), capacity_, seed, MetaSizeInBits ); break; case library::NumericTypeID::kU32: cutlass::reference::device::BlockFillRandomSparseMeta( reinterpret_cast(pointer_), capacity_, seed, MetaSizeInBits ); break; default: break; } } void DeviceAllocation::initialize_random_sparsemeta_host(int seed, int MetaSizeInBits) { if (!good()) { throw std::runtime_error("Attempting to initialize invalid allocation."); } std::vector host_data(bytes()); switch (type_) { case library::NumericTypeID::kS16: cutlass::reference::host::BlockFillRandomSparseMeta( reinterpret_cast(host_data.data()), capacity_, seed, MetaSizeInBits ); break; case library::NumericTypeID::kS32: cutlass::reference::host::BlockFillRandomSparseMeta( reinterpret_cast(host_data.data()), capacity_, seed, MetaSizeInBits ); break; default: break; } copy_from_host(host_data.data()); } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Returns true if two blocks have exactly the same value bool DeviceAllocation::block_compare_equal( library::NumericTypeID numeric_type, void const *ptr_A, void const *ptr_B, size_t capacity) { switch (numeric_type) { case library::NumericTypeID::kF16: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kBF16: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kTF32: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kF32: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kCF32: return reference::device::BlockCompareEqual >( reinterpret_cast const *>(ptr_A), reinterpret_cast const *>(ptr_B), capacity); case library::NumericTypeID::kCF16: return reference::device::BlockCompareEqual>( reinterpret_cast const *>(ptr_A), reinterpret_cast const *>(ptr_B), capacity); case library::NumericTypeID::kCBF16: return reference::device::BlockCompareEqual>( reinterpret_cast const *>(ptr_A), reinterpret_cast const *>(ptr_B), capacity); case library::NumericTypeID::kCTF32: return reference::device::BlockCompareEqual>( reinterpret_cast const *>(ptr_A), reinterpret_cast const *>(ptr_B), capacity); case library::NumericTypeID::kF64: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kCF64: return reference::device::BlockCompareEqual>( reinterpret_cast const *>(ptr_A), reinterpret_cast const *>(ptr_B), capacity); case library::NumericTypeID::kS2: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kS4: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kS8: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kS16: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kS32: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kS64: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kB1: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kU2: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kU4: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kU8: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kU16: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kU32: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); case library::NumericTypeID::kU64: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); default: throw std::runtime_error("Unsupported numeric type"); } } /// Returns true if two blocks have approximately the same value bool DeviceAllocation::block_compare_relatively_equal( library::NumericTypeID numeric_type, void const *ptr_A, void const *ptr_B, size_t capacity, double epsilon, double nonzero_floor) { switch (numeric_type) { case library::NumericTypeID::kF16: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kBF16: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kTF32: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kF32: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kF64: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kS2: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kS4: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kS8: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kS16: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kS32: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kS64: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kB1: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kU2: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kU4: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kU8: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kU16: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kU32: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); case library::NumericTypeID::kU64: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity, static_cast(epsilon), static_cast(nonzero_floor)); // No relatively equal comparison for complex numbers. // // As a simplification, we can require bitwise equality. This avoids false positives. // (i.e. "pass" really means passing. "Fail" may not actually mean failure given appropriate epsilon.) // case library::NumericTypeID::kCF16: return reference::device::BlockCompareEqual >( reinterpret_cast const *>(ptr_A), reinterpret_cast const *>(ptr_B), capacity); case library::NumericTypeID::kCF32: return reference::device::BlockCompareEqual >( reinterpret_cast const *>(ptr_A), reinterpret_cast const *>(ptr_B), capacity); case library::NumericTypeID::kCF64: return reference::device::BlockCompareEqual >( reinterpret_cast const *>(ptr_A), reinterpret_cast const *>(ptr_B), capacity); default: { throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(numeric_type)); } } } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Permits copying dynamic vectors into static-length vectors template struct vector_to_coord { vector_to_coord(TensorCoord &coord, std::vector const &vec) { coord[Rank - 1] = vec.at(Rank - 1); if (Rank > 1) { vector_to_coord(coord, vec); } } }; /// Permits copying dynamic vectors into static-length vectors template struct vector_to_coord { vector_to_coord(TensorCoord &coord, std::vector const &vec) { coord[0] = vec.at(0); } }; /// Permits copying dynamic vectors into static-length vectors template struct vector_to_coord { vector_to_coord(TensorCoord &coord, std::vector const &vec) { } }; ///////////////////////////////////////////////////////////////////////////////////////////////// template static void write_tensor_csv_static_tensor_view( std::ostream &out, DeviceAllocation &allocation) { Coord extent; Coord stride; if (allocation.extent().size() != Layout::kRank) { throw std::runtime_error("Allocation extent has invalid rank"); } if (allocation.stride().size() != Layout::kStrideRank) { throw std::runtime_error("Allocation stride has invalid rank"); } vector_to_coord, Layout::kRank>(extent, allocation.extent()); vector_to_coord, Layout::kStrideRank>(stride, allocation.stride()); Layout layout(stride); HostTensor host_tensor(extent, layout, false); if (host_tensor.capacity() != allocation.batch_stride()) { throw std::runtime_error("Unexpected capacity to equal."); } host_tensor.copy_in_device_to_host( static_cast(allocation.data()), allocation.batch_stride()); TensorViewWrite(out, host_tensor.host_view()); out << "\n\n"; } ///////////////////////////////////////////////////////////////////////////////////////////////// template static void write_tensor_csv_static_type( std::ostream &out, DeviceAllocation &allocation) { switch (allocation.layout()) { case library::LayoutTypeID::kRowMajor: write_tensor_csv_static_tensor_view(out, allocation); break; case library::LayoutTypeID::kColumnMajor: write_tensor_csv_static_tensor_view(out, allocation); break; case library::LayoutTypeID::kRowMajorInterleavedK2: write_tensor_csv_static_tensor_view>(out, allocation); break; case library::LayoutTypeID::kColumnMajorInterleavedK2: write_tensor_csv_static_tensor_view>(out, allocation); break; case library::LayoutTypeID::kRowMajorInterleavedK4: write_tensor_csv_static_tensor_view>(out, allocation); break; case library::LayoutTypeID::kColumnMajorInterleavedK4: write_tensor_csv_static_tensor_view>(out, allocation); break; case library::LayoutTypeID::kRowMajorInterleavedK16: write_tensor_csv_static_tensor_view>(out, allocation); break; case library::LayoutTypeID::kColumnMajorInterleavedK16: write_tensor_csv_static_tensor_view>(out, allocation); break; case library::LayoutTypeID::kRowMajorInterleavedK32: write_tensor_csv_static_tensor_view>(out, allocation); break; case library::LayoutTypeID::kColumnMajorInterleavedK32: write_tensor_csv_static_tensor_view>(out, allocation); break; case library::LayoutTypeID::kRowMajorInterleavedK64: write_tensor_csv_static_tensor_view>(out, allocation); break; case library::LayoutTypeID::kColumnMajorInterleavedK64: write_tensor_csv_static_tensor_view>(out, allocation); break; case library::LayoutTypeID::kTensorNHWC: write_tensor_csv_static_tensor_view(out, allocation); break; case library::LayoutTypeID::kTensorNDHWC: write_tensor_csv_static_tensor_view(out, allocation); break; case library::LayoutTypeID::kTensorNC32HW32: write_tensor_csv_static_tensor_view>(out, allocation); break; case library::LayoutTypeID::kTensorNC64HW64: write_tensor_csv_static_tensor_view>(out, allocation); break; case library::LayoutTypeID::kTensorC32RSK32: write_tensor_csv_static_tensor_view>(out, allocation); break; case library::LayoutTypeID::kTensorC64RSK64: write_tensor_csv_static_tensor_view>(out, allocation); break; default: throw std::runtime_error("Unhandled layout"); } } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Writes a tensor to csv void DeviceAllocation::write_tensor_csv( std::ostream &out) { switch (this->type()) { case library::NumericTypeID::kF16: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kBF16: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kTF32: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kF32: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kF64: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kS2: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kS4: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kS8: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kS16: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kS32: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kS64: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kB1: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kU2: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kU4: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kU8: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kU16: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kU32: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kU64: write_tensor_csv_static_type(out, *this); break; case library::NumericTypeID::kCF16: write_tensor_csv_static_type >(out, *this); break; case library::NumericTypeID::kCF32: write_tensor_csv_static_type >(out, *this); break; case library::NumericTypeID::kCF64: write_tensor_csv_static_type >(out, *this); break; default: throw std::runtime_error("Unsupported numeric type"); } } ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace profiler } // namespace cutlass