/*************************************************************************************************** * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe to use in a union. */ #pragma once #include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/platform/platform.h" namespace cutlass { //////////////////////////////////////////////////////////////////////////////////////////////////// /// Statically sized array for any data type template < typename T, int N > class Array { public: static int const kSizeBits = sizeof_bits::value * N; /// Storage type using Storage = typename platform::conditional< ((kSizeBits % 32) != 0), typename platform::conditional< ((kSizeBits % 16) != 0), uint8_t, uint16_t >::type, uint32_t >::type; /// Element type using Element = T; /// Number of logical elements per stored object static int const kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits::value; /// Number of storage elements static size_t const kStorageElements = N / kElementsPerStoredItem; /// Number of logical elements static size_t const kElements = N; /// Bitmask for covering one item static Storage const kMask = ((Storage(1) << sizeof_bits::value) - 1); // // C++ standard members with pointer types removed // typedef T value_type; typedef size_t size_type; typedef ptrdiff_t difference_type; typedef value_type *pointer; typedef value_type const *const_pointer; // // References // /// Reference object inserts or extracts sub-byte items class reference { /// Pointer to storage element Storage *ptr_; /// Index into elements packed into Storage object int idx_; public: /// Default ctor CUTLASS_HOST_DEVICE reference(): ptr_(nullptr), idx_(0) { } /// Ctor CUTLASS_HOST_DEVICE reference(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } /// Assignment CUTLASS_HOST_DEVICE reference &operator=(T x) { Storage item = (reinterpret_cast(x) & kMask); Storage kUpdateMask = Storage(~(kMask << (idx_ * sizeof_bits::value))); *ptr_ = Storage(((*ptr_ & kUpdateMask) | (item << idx_ * sizeof_bits::value))); return *this; } CUTLASS_HOST_DEVICE T get() const { Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & kMask); return reinterpret_cast(item); } /// Extract CUTLASS_HOST_DEVICE operator T() const { return get(); } /// Explicit cast to int CUTLASS_HOST_DEVICE explicit operator int() const { return int(get()); } /// Explicit cast to float CUTLASS_HOST_DEVICE explicit operator float() const { return float(get()); } }; /// Reference object extracts sub-byte items class const_reference { /// Pointer to storage element Storage const *ptr_; /// Index into elements packed into Storage object int idx_; public: /// Default ctor CUTLASS_HOST_DEVICE const_reference(): ptr_(nullptr), idx_(0) { } /// Ctor CUTLASS_HOST_DEVICE const_reference(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } CUTLASS_HOST_DEVICE const T get() const { Storage item = (*ptr_ >> (idx_ * sizeof_bits::value)) & kMask; return reinterpret_cast(item); } /// Extract CUTLASS_HOST_DEVICE operator T() const { Storage item = Storage(Storage(*ptr_ >> Storage(idx_ * sizeof_bits::value)) & kMask); return reinterpret_cast(item); } /// Explicit cast to int CUTLASS_HOST_DEVICE explicit operator int() const { return int(get()); } /// Explicit cast to float CUTLASS_HOST_DEVICE explicit operator float() const { return float(get()); } }; // // Iterators // /// Bidirectional iterator over elements class iterator { /// Pointer to storage element Storage *ptr_; /// Index into elements packed into Storage object int idx_; public: CUTLASS_HOST_DEVICE iterator(): ptr_(nullptr), idx_(0) { } CUTLASS_HOST_DEVICE iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } CUTLASS_HOST_DEVICE iterator &operator++() { ++idx_; if (idx_ == kElementsPerStoredItem) { ++ptr_; idx_ = 0; } return *this; } CUTLASS_HOST_DEVICE iterator &operator--() { if (!idx_) { --ptr_; idx_ = kElementsPerStoredItem - 1; } else { --idx_; } return *this; } CUTLASS_HOST_DEVICE iterator operator++(int) { iterator ret(*this); ++idx_; if (idx_ == kElementsPerStoredItem) { ++ptr_; idx_ = 0; } return ret; } CUTLASS_HOST_DEVICE iterator operator--(int) { iterator ret(*this); if (!idx_) { --ptr_; idx_ = kElementsPerStoredItem - 1; } else { --idx_; } return ret; } CUTLASS_HOST_DEVICE reference operator*() const { return reference(ptr_, idx_); } CUTLASS_HOST_DEVICE bool operator==(iterator const &other) const { return ptr_ == other.ptr_ && idx_ == other.idx_; } CUTLASS_HOST_DEVICE bool operator!=(iterator const &other) const { return !(*this == other); } }; /// Bidirectional constant iterator over elements class const_iterator { /// Pointer to storage element Storage const *ptr_; /// Index into elements packed into Storage object int idx_; public: CUTLASS_HOST_DEVICE const_iterator(): ptr_(nullptr), idx_(0) { } CUTLASS_HOST_DEVICE const_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } CUTLASS_HOST_DEVICE iterator &operator++() { ++idx_; if (idx_ == kElementsPerStoredItem) { ++ptr_; idx_ = 0; } return *this; } CUTLASS_HOST_DEVICE iterator &operator--() { if (!idx_) { --ptr_; idx_ = kElementsPerStoredItem - 1; } else { --idx_; } return *this; } CUTLASS_HOST_DEVICE iterator operator++(int) { iterator ret(*this); ++idx_; if (idx_ == kElementsPerStoredItem) { ++ptr_; idx_ = 0; } return ret; } CUTLASS_HOST_DEVICE iterator operator--(int) { iterator ret(*this); if (!idx_) { --ptr_; idx_ = kElementsPerStoredItem - 1; } else { --idx_; } return ret; } CUTLASS_HOST_DEVICE const_reference operator*() const { return const_reference(ptr_, idx_); } CUTLASS_HOST_DEVICE bool operator==(iterator const &other) const { return ptr_ == other.ptr_ && idx_ == other.idx_; } CUTLASS_HOST_DEVICE bool operator!=(iterator const &other) const { return !(*this == other); } }; /// Bidirectional iterator over elements class reverse_iterator { /// Pointer to storage element Storage *ptr_; /// Index into elements packed into Storage object int idx_; public: CUTLASS_HOST_DEVICE reverse_iterator(): ptr_(nullptr), idx_(0) { } CUTLASS_HOST_DEVICE reverse_iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } // TODO }; /// Bidirectional constant iterator over elements class const_reverse_iterator { /// Pointer to storage element Storage const *ptr_; /// Index into elements packed into Storage object int idx_; public: CUTLASS_HOST_DEVICE const_reverse_iterator(): ptr_(nullptr), idx_(0) { } CUTLASS_HOST_DEVICE const_reverse_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } // TODO }; private: /// Internal storage Storage storage[kStorageElements]; public: #if 0 CUTLASS_HOST_DEVICE Array() { } CUTLASS_HOST_DEVICE Array(Array const &x) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < int(kStorageElements); ++i) { storage[i] = x.storage[i]; } } #endif /// Efficient clear method CUTLASS_HOST_DEVICE void clear() { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < int(kStorageElements); ++i) { storage[i] = Storage(0); } } CUTLASS_HOST_DEVICE reference at(size_type pos) { return reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); } CUTLASS_HOST_DEVICE const_reference at(size_type pos) const { return const_reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); } CUTLASS_HOST_DEVICE reference operator[](size_type pos) { return at(pos); } CUTLASS_HOST_DEVICE const_reference operator[](size_type pos) const { return at(pos); } CUTLASS_HOST_DEVICE reference front() { return at(0); } CUTLASS_HOST_DEVICE const_reference front() const { return at(0); } CUTLASS_HOST_DEVICE reference back() { return reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); } CUTLASS_HOST_DEVICE const_reference back() const { return const_reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); } CUTLASS_HOST_DEVICE pointer data() { return reinterpret_cast(storage); } CUTLASS_HOST_DEVICE const_pointer data() const { return reinterpret_cast(storage); } CUTLASS_HOST_DEVICE Storage * raw_data() { return storage; } CUTLASS_HOST_DEVICE Storage const * raw_data() const { return storage; } CUTLASS_HOST_DEVICE constexpr bool empty() const { return !kElements; } CUTLASS_HOST_DEVICE constexpr size_type size() const { return kElements; } CUTLASS_HOST_DEVICE constexpr size_type max_size() const { return kElements; } CUTLASS_HOST_DEVICE void fill(T const &value) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kElementsPerStoredItem; ++i) { reference ref(storage, i); ref = value; } CUTLASS_PRAGMA_UNROLL for (int i = 1; i < kStorageElements; ++i) { storage[i] = storage[0]; } } CUTLASS_HOST_DEVICE iterator begin() { return iterator(storage); } CUTLASS_HOST_DEVICE const_iterator cbegin() const { return const_iterator(storage); } CUTLASS_HOST_DEVICE iterator end() { return iterator(storage + kStorageElements); } CUTLASS_HOST_DEVICE const_iterator cend() const { return const_iterator(storage + kStorageElements); } CUTLASS_HOST_DEVICE reverse_iterator rbegin() { return reverse_iterator(storage + kStorageElements); } CUTLASS_HOST_DEVICE const_reverse_iterator crbegin() const { return const_reverse_iterator(storage + kStorageElements); } CUTLASS_HOST_DEVICE reverse_iterator rend() { return reverse_iterator(storage); } CUTLASS_HOST_DEVICE const_reverse_iterator crend() const { return const_reverse_iterator(storage); } // // Comparison operators // }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass ////////////////////////////////////////////////////////////////////////////////////////////////////