diff --git a/include/cutlass/epilogue/thread/linear_combination_clamp.h b/include/cutlass/epilogue/thread/linear_combination_clamp.h index b2231bf7..7d47a5c5 100644 --- a/include/cutlass/epilogue/thread/linear_combination_clamp.h +++ b/include/cutlass/epilogue/thread/linear_combination_clamp.h @@ -186,11 +186,14 @@ public: intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X /// Clamping constant value - ElementCompute const kClamp = - ElementCompute((1U << (sizeof_bits::value - 1)) - 1); + ElementCompute const kClampMax = + ElementCompute(platform::numeric_limits::max()); - intermediate = max_accumulator(intermediate, -kClamp - ElementCompute(1)); - intermediate = min_accumulator(intermediate, kClamp); + ElementCompute const kClampMin = + ElementCompute(platform::numeric_limits::lowest()); + + intermediate = max_accumulator(intermediate, kClampMin); + intermediate = min_accumulator(intermediate, kClampMax); // Convert to destination numeric type NumericArrayConverter destination_converter; diff --git a/include/cutlass/half.h b/include/cutlass/half.h index 5503f5b3..caa66657 100644 --- a/include/cutlass/half.h +++ b/include/cutlass/half.h @@ -69,6 +69,7 @@ enum #include #include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -545,9 +546,9 @@ half_t copysign(half_t const& a, half_t const& b) { // /////////////////////////////////////////////////////////////////////////////////////////////////// -namespace std { +namespace cutlass { +namespace platform { -#if !defined(__CUDACC_RTC__) /// Numeric limits template <> struct numeric_limits { @@ -593,9 +594,9 @@ struct numeric_limits { /// Returns smallest finite value static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); } }; -#endif -} // namespace std +} // namespace platform +} // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////// // diff --git a/include/cutlass/integer_subbyte.h b/include/cutlass/integer_subbyte.h index bd8a6a01..93aeb2b9 100644 --- a/include/cutlass/integer_subbyte.h +++ b/include/cutlass/integer_subbyte.h @@ -200,4 +200,25 @@ struct sizeof_bits { /////////////////////////////////////////////////////////////////////////////////////////////////// +namespace platform { + +template <> +struct numeric_limits { + CUTLASS_HOST_DEVICE + static cutlass::int4b_t const lowest() noexcept { return -8;} + CUTLASS_HOST_DEVICE + static cutlass::int4b_t const max() noexcept { return 7;} +}; + +template <> +struct numeric_limits { + CUTLASS_HOST_DEVICE + static cutlass::uint4b_t const lowest() noexcept { return 0;} + CUTLASS_HOST_DEVICE + static cutlass::uint4b_t const max() noexcept { return 15;} +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace platform } // namespace cutlass diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 57f3984b..5f00688c 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -483,69 +483,16 @@ struct NumericConverterClamp { using result_type = T; using source_type = S; - static_assert((platform::is_same::value || - platform::is_same::value || - platform::is_same::value), - "Clamp is only needed for integer types"); - CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { NumericConverter convert_op; - result_type const kClamp_max = - (0x1U << (sizeof_bits::value - 1)) - 1; - result_type const kClamp_min = -kClamp_max - 1; - bool is_int_min = !(s > kClamp_min); - bool is_int_max = !(s < kClamp_max); - return is_int_min ? kClamp_min : (is_int_max ? kClamp_max : convert_op(s)); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { - return convert(s); - } -}; - -/// Partial specialization for clamping from a single-precision float. -template < - typename T -> -struct NumericConverterClamp { - - using result_type = T; - using source_type = float; - - static_assert((platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value), - "Clamp is only needed for integer types"); - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { - - NumericConverter convert_op; - double kClamp_max, kClamp_min; - - if (platform::is_same::value || - platform::is_same::value || - platform::is_same::value || - platform::is_same::value) { - kClamp_max = double((1LLU << (sizeof_bits::value - 1)) - 1); - kClamp_min = -kClamp_max - 1; - } else { - kClamp_max = double((1LLU << (sizeof_bits::value)) - 1); - kClamp_min = 0; - } - - double source = s; - - source = fmax(source, kClamp_min); - source = fmin(source, kClamp_max); - - return convert_op(source); + result_type const kClamp_max = platform::numeric_limits::max(); + result_type const kClamp_min = platform::numeric_limits::lowest(); + if (s < (source_type)kClamp_min) + return kClamp_min; + if (s > (source_type)kClamp_max) + return kClamp_max; + return convert_op(s); } CUTLASS_HOST_DEVICE diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index e9ccae2e..dfe5e2da 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -783,5 +783,58 @@ void swap(unique_ptr& lhs, unique_ptr& rhs) noexcept { } #endif +/// std::numeric_limits +template +struct numeric_limits; + +template <> +struct numeric_limits { + CUTLASS_HOST_DEVICE + static constexpr int32_t lowest() noexcept { return -2147483647 - 1;} + CUTLASS_HOST_DEVICE + static constexpr int32_t max() noexcept { return 2147483647;} +}; + +template <> +struct numeric_limits { + CUTLASS_HOST_DEVICE + static constexpr int16_t lowest() noexcept { return -32768;} + CUTLASS_HOST_DEVICE + static constexpr int16_t max() noexcept { return 32767;} +}; + +template <> +struct numeric_limits { + CUTLASS_HOST_DEVICE + static constexpr int8_t lowest() noexcept { return -128;} + CUTLASS_HOST_DEVICE + static constexpr int8_t max() noexcept { return 127;} +}; + + +template <> +struct numeric_limits { + CUTLASS_HOST_DEVICE + static constexpr uint32_t lowest() noexcept { return 0;} + CUTLASS_HOST_DEVICE + static constexpr uint32_t max() noexcept { return 4294967295;} +}; + +template <> +struct numeric_limits { + CUTLASS_HOST_DEVICE + static constexpr uint16_t lowest() noexcept { return 0;} + CUTLASS_HOST_DEVICE + static constexpr uint16_t max() noexcept { return 65535;} +}; + +template <> +struct numeric_limits { + CUTLASS_HOST_DEVICE + static constexpr uint8_t lowest() noexcept { return 0;} + CUTLASS_HOST_DEVICE + static constexpr uint8_t max() noexcept { return 255;} +}; + } // namespace platform } // namespace cutlass