/*************************************************************************************************** * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include #include #include namespace cute { // // has_dereference to determine if a type is a pointer concept // template struct has_dereference : false_type { }; template struct has_dereference())>> : true_type { }; template CUTE_HOST_DEVICE constexpr T* raw_pointer_cast(T* ptr) { return ptr; } // // Extract the physical type from a logical elem type. // template struct get_raw_type { using type = T; }; template using get_raw_type_t = typename get_raw_type::type; // // Pointer categories // template struct is_gmem : false_type {}; template struct is_smem : false_type {}; // Anything that is not gmem or smem is rmem template struct is_rmem : bool_constant< not (is_gmem::value || is_smem::value)> {}; // // A very simplified wrapper for pointers -- use for constructing tagged pointers // template struct device_ptr { using value_type = T; static const uint32_t ElementsPerStoredItem = sizeof(T) * 8 / sizeof_bits_v; CUTE_HOST_DEVICE constexpr device_ptr(T* ptr) : ptr_(ptr) {} CUTE_HOST_DEVICE constexpr T* get() const { return ptr_; } CUTE_HOST_DEVICE constexpr T& operator*() const { return *ptr_; } template CUTE_HOST_DEVICE constexpr T& operator[](Index const& i) const { static_assert(sizeof_bits_v >= 8, "Use subbyte_iterator to access the element"); return ptr_[i]; } template CUTE_HOST_DEVICE constexpr DerivedType operator+(Index const& i) const { return {ptr_ + i / ElementsPerStoredItem}; } CUTE_HOST_DEVICE constexpr friend ptrdiff_t operator-(device_ptr const& a, device_ptr const& b) { return a.ptr_ - b.ptr_; } T* ptr_; }; template CUTE_HOST_DEVICE constexpr T* raw_pointer_cast(device_ptr ptr) { return ptr.get(); } // // gmem_ptr // template struct gmem_ptr : device_ptr> { using device_ptr>::device_ptr; }; template CUTE_HOST_DEVICE constexpr gmem_ptr make_gmem_ptr(T* ptr) { return {ptr}; } template CUTE_HOST_DEVICE constexpr gmem_ptr make_gmem_ptr(void* ptr) { return {reinterpret_cast(ptr)}; } template CUTE_HOST_DEVICE constexpr gmem_ptr make_gmem_ptr(void const* ptr) { return {reinterpret_cast(ptr)}; } // nullptr_t overloads are needed because otherwise, // make_gmem_ptr(nullptr) will be ambiguous, // as std::nullptr_t can be converted to any pointer // or pointer to member type. template CUTE_HOST_DEVICE constexpr gmem_ptr make_gmem_ptr(decltype(nullptr)) { // nullptr_t return {static_cast(nullptr)}; } template struct is_gmem> : true_type {}; // // smem_ptr // template struct smem_ptr : device_ptr> { using device_ptr>::device_ptr; }; template CUTE_HOST_DEVICE constexpr smem_ptr make_smem_ptr(T* ptr) { return {ptr}; } template CUTE_HOST_DEVICE constexpr smem_ptr make_smem_ptr(void* ptr) { return {reinterpret_cast(ptr)}; } template CUTE_HOST_DEVICE constexpr smem_ptr make_smem_ptr(void const* ptr) { return {reinterpret_cast(ptr)}; } template struct is_smem> : true_type {}; // // rmem_ptr // template struct rmem_ptr : device_ptr> { using device_ptr>::device_ptr; }; template CUTE_HOST_DEVICE constexpr rmem_ptr make_rmem_ptr(T* ptr) { return {ptr}; } template CUTE_HOST_DEVICE constexpr rmem_ptr make_rmem_ptr(void* ptr) { return {reinterpret_cast(ptr)}; } template CUTE_HOST_DEVICE constexpr rmem_ptr make_rmem_ptr(void const* ptr) { return {reinterpret_cast(ptr)}; } template struct is_rmem> : true_type {}; // // counting iterator -- quick and dirty // struct counting { using index_type = int; using value_type = index_type; CUTE_HOST_DEVICE constexpr counting() : n_(0) {} CUTE_HOST_DEVICE constexpr counting(index_type const& n) : n_(n) {} CUTE_HOST_DEVICE constexpr index_type operator[](index_type const& i) const { return n_ + i; } CUTE_HOST_DEVICE constexpr index_type const& operator*() const { return n_; } CUTE_HOST_DEVICE constexpr counting operator+(index_type const& i) const { return {n_ + i}; } CUTE_HOST_DEVICE constexpr counting& operator++() { ++n_; return *this; } CUTE_HOST_DEVICE constexpr bool operator==(counting const& other) const { return n_ == other.n_; } CUTE_HOST_DEVICE constexpr bool operator!=(counting const& other) const { return n_ != other.n_; } CUTE_HOST_DEVICE constexpr bool operator< (counting const& other) const { return n_ < other.n_; } index_type n_; }; // // recast // template CUTE_HOST_DEVICE constexpr auto recast(T* ptr) { return reinterpret_cast(ptr); } template CUTE_HOST_DEVICE constexpr auto recast(T const* ptr) { return reinterpret_cast(ptr); } template CUTE_HOST_DEVICE constexpr auto recast(gmem_ptr const& ptr) { return make_gmem_ptr(recast(ptr.ptr_)); } template CUTE_HOST_DEVICE constexpr auto recast(gmem_ptr const& ptr) { return make_gmem_ptr(recast(ptr.ptr_)); } template CUTE_HOST_DEVICE constexpr auto recast(smem_ptr const& ptr) { return make_smem_ptr(recast(ptr.ptr_)); } template CUTE_HOST_DEVICE constexpr auto recast(smem_ptr const& ptr) { return make_smem_ptr(recast(ptr.ptr_)); } template CUTE_HOST_DEVICE constexpr auto recast(rmem_ptr const& ptr) { return make_rmem_ptr(recast(ptr.ptr_)); } template CUTE_HOST_DEVICE constexpr auto recast(rmem_ptr const& ptr) { return make_rmem_ptr(recast(ptr.ptr_)); } // // Display utilities // template CUTE_HOST_DEVICE void print(T const* const ptr) { printf("raw_ptr_%db(%p)", int(sizeof_bits::value), ptr); } template CUTE_HOST_DEVICE void print(gmem_ptr const& ptr) { printf("gmem_ptr_%db(%p)", int(sizeof_bits::value), ptr.get()); } template CUTE_HOST_DEVICE void print(smem_ptr const& ptr) { printf("smem_ptr_%db(%p)", int(sizeof_bits::value), ptr.get()); } template CUTE_HOST_DEVICE void print(rmem_ptr const& ptr) { printf("rmem_ptr_%db(%p)", int(sizeof_bits::value), ptr.get()); } #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr const& ptr) { return os << "gmem_ptr_" << int(sizeof_bits::value) << "b"; } template CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr const& ptr) { return os << "smem_ptr_" << int(sizeof_bits::value) << "b"; } template CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr const& ptr) { return os << "rmem_ptr_" << int(sizeof_bits::value) << "b"; } #endif // !defined(__CUDACC_RTC__) } // end namespace cute