/*************************************************************************************************** * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include #include namespace cute { template struct constant : std::integral_constant { static constexpr T value = v; using value_type = T; using type = constant; CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } }; template using integral_constant = constant; template using bool_constant = constant; using true_type = bool_constant; using false_type = bool_constant; // // Traits // // Use std::is_integral to match built-in integral types (int, int64_t, unsigned, etc) // Use cute::is_integral to match both built-in integral types AND constant template struct is_integral : bool_constant::value> {}; template struct is_integral> : true_type {}; // is_static detects if an (abstract) value is defined completely by it's type (no members) template struct is_static : bool_constant::value> {}; // is_constant detects if a type is a constant and if v is equal to a value template struct is_constant : false_type {}; template struct is_constant > : bool_constant {}; template struct is_constant const > : bool_constant {}; template struct is_constant const&> : bool_constant {}; template struct is_constant &> : bool_constant {}; template struct is_constant &&> : bool_constant {}; // // Specializations // template using Int = constant; using _m32 = Int<-32>; using _m24 = Int<-24>; using _m16 = Int<-16>; using _m12 = Int<-12>; using _m10 = Int<-10>; using _m9 = Int<-9>; using _m8 = Int<-8>; using _m7 = Int<-7>; using _m6 = Int<-6>; using _m5 = Int<-5>; using _m4 = Int<-4>; using _m3 = Int<-3>; using _m2 = Int<-2>; using _m1 = Int<-1>; using _0 = Int<0>; using _1 = Int<1>; using _2 = Int<2>; using _3 = Int<3>; using _4 = Int<4>; using _5 = Int<5>; using _6 = Int<6>; using _7 = Int<7>; using _8 = Int<8>; using _9 = Int<9>; using _10 = Int<10>; using _12 = Int<12>; using _16 = Int<16>; using _24 = Int<24>; using _32 = Int<32>; using _64 = Int<64>; using _96 = Int<96>; using _128 = Int<128>; using _192 = Int<192>; using _256 = Int<256>; using _512 = Int<512>; using _1024 = Int<1024>; using _2048 = Int<2048>; using _4096 = Int<4096>; using _8192 = Int<8192>; /***************/ /** Operators **/ /***************/ #define CUTE_LEFT_UNARY_OP(OP) \ template \ CUTE_HOST_DEVICE constexpr \ constant \ operator OP (constant) { \ return {}; \ } #define CUTE_RIGHT_UNARY_OP(OP) \ template \ CUTE_HOST_DEVICE constexpr \ constant \ operator OP (constant) { \ return {}; \ } #define CUTE_BINARY_OP(OP) \ template \ CUTE_HOST_DEVICE constexpr \ constant \ operator OP (constant, constant) { \ return {}; \ } CUTE_LEFT_UNARY_OP(+); CUTE_LEFT_UNARY_OP(-); CUTE_LEFT_UNARY_OP(~); CUTE_LEFT_UNARY_OP(!); CUTE_LEFT_UNARY_OP(*); CUTE_BINARY_OP( +); CUTE_BINARY_OP( -); CUTE_BINARY_OP( *); CUTE_BINARY_OP( /); CUTE_BINARY_OP( %); CUTE_BINARY_OP( &); CUTE_BINARY_OP( |); CUTE_BINARY_OP( ^); CUTE_BINARY_OP(<<); CUTE_BINARY_OP(>>); CUTE_BINARY_OP(&&); CUTE_BINARY_OP(||); CUTE_BINARY_OP(==); CUTE_BINARY_OP(!=); CUTE_BINARY_OP( >); CUTE_BINARY_OP( <); CUTE_BINARY_OP(>=); CUTE_BINARY_OP(<=); #undef CUTE_BINARY_OP #undef CUTE_LEFT_UNARY_OP #undef CUTE_RIGHT_UNARY_OP // // Mixed static-dynamic special cases // template ::value)> CUTE_HOST_DEVICE constexpr constant operator*(constant, U) { return {}; } template ::value)> CUTE_HOST_DEVICE constexpr constant operator*(U, constant) { return {}; } template ::value)> CUTE_HOST_DEVICE constexpr constant operator/(constant, U) { return {}; } template ::value)> CUTE_HOST_DEVICE constexpr constant operator%(U, constant) { return {}; } template ::value)> CUTE_HOST_DEVICE constexpr constant operator%(U, constant) { return {}; } template ::value)> CUTE_HOST_DEVICE constexpr constant operator%(constant, U) { return {}; } template ::value)> CUTE_HOST_DEVICE constexpr constant operator&(constant, U) { return {}; } template ::value)> CUTE_HOST_DEVICE constexpr constant operator&(U, constant) { return {}; } template ::value && !bool(t))> CUTE_HOST_DEVICE constexpr constant operator&&(constant, U) { return {}; } template ::value && !bool(t))> CUTE_HOST_DEVICE constexpr constant operator&&(U, constant) { return {}; } template ::value && bool(t))> CUTE_HOST_DEVICE constexpr constant operator||(constant, U) { return {}; } template ::value && bool(t))> CUTE_HOST_DEVICE constexpr constant operator||(U, constant) { return {}; } // // Named functions from math.hpp // #define CUTE_NAMED_UNARY_FN(OP) \ template \ CUTE_HOST_DEVICE constexpr \ constant \ OP (constant) { \ return {}; \ } #define CUTE_NAMED_BINARY_FN(OP) \ template \ CUTE_HOST_DEVICE constexpr \ constant \ OP (constant, constant) { \ return {}; \ } \ \ template ::value)> \ CUTE_HOST_DEVICE constexpr \ auto \ OP (constant, U u) { \ return OP(t,u); \ } \ \ template ::value)> \ CUTE_HOST_DEVICE constexpr \ auto \ OP (T t, constant) { \ return OP(t,u); \ } CUTE_NAMED_UNARY_FN(abs); CUTE_NAMED_UNARY_FN(signum); CUTE_NAMED_UNARY_FN(has_single_bit); CUTE_NAMED_BINARY_FN(max); CUTE_NAMED_BINARY_FN(min); CUTE_NAMED_BINARY_FN(shiftl); CUTE_NAMED_BINARY_FN(shiftr); CUTE_NAMED_BINARY_FN(gcd); CUTE_NAMED_BINARY_FN(lcm); #undef CUTE_NAMED_UNARY_FN #undef CUTE_NAMED_BINARY_FN // // Other functions // template CUTE_HOST_DEVICE constexpr constant safe_div(constant, constant) { static_assert(t % u == 0, "Static safe_div requires t % u == 0"); return {}; } template ::value)> CUTE_HOST_DEVICE constexpr auto safe_div(constant, U u) { return t / u; } template ::value)> CUTE_HOST_DEVICE constexpr auto safe_div(T t, constant) { return t / u; } // cute::true_type prefers standard conversion to std::true_type // over user-defined conversion to bool template CUTE_HOST_DEVICE constexpr decltype(auto) conditional_return(std::true_type, TrueType&& t, FalseType&&) { return static_cast(t); } // cute::false_type prefers standard conversion to std::false_type // over user-defined conversion to bool template CUTE_HOST_DEVICE constexpr decltype(auto) conditional_return(std::false_type, TrueType&&, FalseType&& f) { return static_cast(f); } // TrueType and FalseType must have a common type template CUTE_HOST_DEVICE constexpr auto conditional_return(bool b, TrueType const& t, FalseType const& f) { return b ? t : f; } // // Display utilities // template CUTE_HOST_DEVICE void print(integral_constant const&) { printf("_%d", N); } template CUTE_HOST std::ostream& operator<<(std::ostream& os, integral_constant const&) { return os << "_" << N; } } // end namespace cute