/*************************************************************************************************** * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include #include #include #include #include #include /* This implements a swizzle pointer of the form * InvolutionFn o PtrAdd * where the InvolutionFn need not be linear. * * This differs subtly from swizzle_layout because the smem pointer is used * as the offset. That means that swizzle_layout will implement position-independent * swizzle layouts, while swizzle_ptr implements position-dependent swizzle tensors. * Arch chose to design hardware with position-dependent swizzles. * * For clarity: * NormalLayout : DeRef <- PtrAdd <- [Layout] * ComposedLayout: DeRef <- PtrAdd <- [Swizzle <- OffsetAdd <- Layout] * SwizzlePtr : [DeRef <- Swizzle <- PtrAdd] <- Layout * * Furthermore, for known swizzles, this pointer attempts to decay itself * to a normal-pointer with a new layout containing dynamic or static strides. * This is possible by determining the subdomain of the InvolutionFn * that is identity and testing if the Layout's codomain is contained * within it. */ namespace cute { template struct smem_ptr_swizzle { static_assert(is_empty::value, "Swizzle can't have state."); CUTE_HOST_DEVICE constexpr T* get() const { return ptr_; } CUTE_HOST_DEVICE constexpr static Swizzle get_swizzle() { return {}; } CUTE_HOST_DEVICE constexpr static T* apply_swizzle(T* ptr) { return reinterpret_cast(Swizzle::apply(reinterpret_cast(ptr))); } CUTE_HOST_DEVICE constexpr T& operator*() const { return *apply_swizzle(get()); } template CUTE_HOST_DEVICE constexpr T& operator[](Int const& i) const { return *apply_swizzle(get() + i); } template CUTE_HOST_DEVICE constexpr smem_ptr_swizzle operator+(Int const& i) const { return {ptr_ + i}; } T* ptr_; }; template struct is_smem> : true_type {}; // Make a swizzle pointer template CUTE_HOST_DEVICE constexpr auto make_smem_ptr(T* ptr, Swizzle const&) { return smem_ptr_swizzle{ptr}; } // Specialization for immediate decay template CUTE_HOST_DEVICE constexpr auto make_smem_ptr(T* ptr, Swizzle<0,M,S> const&) { return make_smem_ptr(ptr); } // A model of a nullptr smem_ptr with B == sizeof_bits::value // That represents an unset pointer. This is a placeholder type that is waiting for an smem_ptr template struct smem_ptr_flag_bits : Int<0> {}; using smem_ptr_flag = smem_ptr_flag_bits<1>; // A flagged construction method to transform ComposedLayout // Make a swizzle pointer tensor and check that the intended type size matches template CUTE_HOST_DEVICE constexpr auto make_tensor(smem_ptr const& ptr, ComposedLayout,Layout> const& layout) { static_assert(B == sizeof_bits::value, "Expected a B-bit pointer type."); return make_tensor(make_smem_ptr(ptr.get(), layout.layout_a()), layout.layout_b()); } // NOTE: To preserve smem_ptr_flag_bits under recast ops template CUTE_HOST_DEVICE constexpr auto upcast(ComposedLayout,Layout> const& layout) { return composition(layout.layout_a(), smem_ptr_flag_bits{}, upcast(layout.layout_b())); } template CUTE_HOST_DEVICE constexpr auto downcast(ComposedLayout,Layout> const& layout) { return composition(layout.layout_a(), smem_ptr_flag_bits{}, downcast(layout.layout_b())); } // // Recast // Swizzle operates on the pointer address, so it doesn't care about the type // template CUTE_HOST_DEVICE constexpr auto recast(smem_ptr_swizzle const& ptr) { return smem_ptr_swizzle{recast(ptr.ptr_)}; } template CUTE_HOST_DEVICE constexpr auto recast(smem_ptr_swizzle const& ptr) { return smem_ptr_swizzle{recast(ptr.ptr_)}; } template CUTE_HOST_DEVICE constexpr T* raw_pointer_cast(smem_ptr_swizzle ptr) { return ptr.get(); } // // Conversion with swizzle_layout // template CUTE_HOST_DEVICE auto as_position_independent_swizzle_layout(ComposedLayout,Layout> const& layout) { return composition(recast,uint_bit_t>(layout.layout_a()), Int<0>{}, layout.layout_b()); } template CUTE_HOST_DEVICE auto as_position_independent_swizzle_tensor(Tensor>, Layout> const& tensor) { { uint32_t address = cast_smem_ptr_to_uint(tensor.data().get()); uint32_t mask = ((uint32_t(1) << Swizzle::num_base) - 1) & (Swizzle::swizzle_code); assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle } auto new_swizzle = recast,uint_bit_t>>(tensor.data().get_swizzle()); return make_tensor(make_smem_ptr(tensor.data().get()), composition(new_swizzle, Int<0>{}, tensor.layout())); } template CUTE_HOST_DEVICE auto as_position_independent_swizzle_tensor(Tensor>, Layout>& tensor) { { [[maybe_unused]] uint32_t address = cast_smem_ptr_to_uint(tensor.data().get()); [[maybe_unused]] uint32_t mask = ((uint32_t(1) << Swizzle::num_base) - 1) & (Swizzle::swizzle_code); assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle } auto new_swizzle = recast,uint_bit_t>>(tensor.data().get_swizzle()); return make_tensor(make_smem_ptr(tensor.data().get()), composition(new_swizzle, Int<0>{}, tensor.layout())); } template CUTE_HOST_DEVICE auto as_position_independent_swizzle_tensor(Tensor>, Layout>&& tensor) { return as_position_independent_swizzle_tensor(tensor); } // Pass through everything else // Used if the tensor doesn't have a swizzled layout, e.g. Layout_MN_INTER_Atom, Layout_K_INTER_Atom template CUTE_HOST_DEVICE constexpr auto as_position_independent_swizzle_tensor(Tensor const& tensor) { return tensor; } template CUTE_HOST_DEVICE constexpr auto as_position_independent_swizzle_tensor(Tensor&& tensor) { return tensor; } // // Print // // Capture and cast smem_ptr_flag Layouts to offset-0 layouts template CUTE_HOST_DEVICE void print_latex(ComposedLayout,Layout> const& layout) { auto new_swizzle = recast,uint_bit_t>(layout.layout_a()); print_latex(composition(new_swizzle, Int<0>{}, layout.layout_b())); } template CUTE_HOST_DEVICE void print(smem_ptr_flag_bits const& ptr) { printf("smem_ptr_%db(unset)", B); } template CUTE_HOST_DEVICE void print(smem_ptr_swizzle> const& ptr) { printf("smem_ptr_S<%d,%d,%d>_%db(%p)", B, M, S, int(8*sizeof(T)), ptr.get()); } #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr_swizzle> const&) { return os << "smem_ptr_S<" << B << "," << M << "," << S << ">_" << int(8*sizeof(T)) << "b"; } #endif } // end namespace cute