/*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this list of * conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright notice, this list of * conditions and the following disclaimer in the documentation and/or other materials * provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used * to endorse or promote products derived from this software without specific prior written * permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Host-side implementation of half-precision float */ #pragma once #include #include #include #include #include #include #include #include #include namespace cutlass { /// IEEE binary16 floating-point value class half_t { public: half_t(); half_t(int); /// conversion from integer half_t(float); /// conversion from fp32 half_t(double); /// conversion from fp64 static half_t bitcast(unsigned short); /// bitcast performs no conversion static half_t convert(float const&); /// FP conversion - round toward nearest even static float convert(unsigned short const&); /// floating point conversion to fp32 static half_t zero() { return bitcast(0); } /// +zero static half_t one() { return bitcast(0x3c00); } /// one static half_t nan() { return bitcast(0x7fff); } /// canonical not a number static half_t inf() { return bitcast(0x7c00); } /// +infinity static half_t ninf() { return bitcast(0xfc00); } /// -infinity static half_t epsilon() { return bitcast(0x1000); } /// Machine epsilon bool signbit() const; /// sign bit - true: negative, false: positive int exponent() const; /// unbiased exponent unsigned short mantissa() const; /// mantissa bits bool isfinite() const; /// true if neither inf nor nan bool isinf() const; /// true if value is + or - infinity bool isnan() const; /// true if value is not a number bool isnormal() const; /// true if nonzero value is normalized bool iszero() const; /// true if value is + or - zero bool operator==(half_t const&) const; bool operator!=(half_t const&) const; bool operator==(float const&) const; bool operator!=(float const&) const; bool operator<(half_t const&) const; bool operator<=(half_t const&) const; bool operator>(half_t const&) const; bool operator>=(half_t const&) const; half_t operator+(half_t const&) const; half_t operator-() const; half_t operator-(half_t const&) const; half_t operator*(half_t const&) const; half_t operator/(half_t const&) const; half_t& operator+=(half_t const&); half_t& operator-=(half_t const&); half_t& operator*=(half_t const&); half_t& operator/=(half_t const&); half_t& operator++(); half_t& operator--(); half_t operator++(int); half_t operator--(int); operator bool() const; /// false if zero operator int() const; /// conversion to int operator float() const; /// conversion to fp32 operator half() const; /// conversion to half uint16_t& raw() { return x; } uint16_t raw() const { return x; } #if defined(__clang__) __device__ half_t operator+(half_t const&) const; __device__ half_t operator*(half_t const&) const; __device__ operator float() const; /// conversion to fp32 #endif // // Stream interactions // /// put to stream - half_t-precision types bitcast as unsigned shorts if base is hexadecimal friend std::ostream& operator<<(std::ostream& out, cutlass::half_t const& h) { if (out.flags() & std::ios::hex) { return out << h.x; } else { return out << float(h); } } /// read from stream - half_t-precision types parsed as unsigned shorts if base is hexadecimal friend std::istream& operator>>(std::istream& in, cutlass::half_t& h) { if (in.flags() & std::ios::hex) { unsigned short u = 0; in >> u; h = cutlass::half_t::bitcast(u); } else { float f = 0; in >> f; h = cutlass::half_t(f); } return in; } public: /// data unsigned short x; }; /// Packed pair of half-precision elements class half2_t { public: half2_t(); half2_t(half_t lo, half_t hi); half2_t(std::pair const&); explicit half2_t(unsigned data); half2_t operator+(half2_t const&) const; half2_t operator-(half2_t const&) const; half2_t operator*(half2_t const&)const; half2_t operator/(half2_t const&) const; half2_t& operator+=(half2_t const&); half2_t& operator-=(half2_t const&); half2_t& operator*=(half2_t const&); half2_t& operator/=(half2_t const&); float dot(half2_t const&) const; /// dot product with single-precision accumulation float dot(half2_t const&, float) const; /// dot product with single-precision accumulation half_t doth(half2_t const&) const; /// dot product with half_t-precision accumulation half_t doth(half2_t const&, half_t) const; /// dot product with half_t-precision accumulation unsigned packed() const; operator std::pair() const; operator unsigned() const; public: half_t lo; half_t hi; }; template Dest bitcast(Src const&); template <> float bitcast(unsigned const&); template <> float bitcast(int const&); template <> unsigned bitcast(float const&); template <> half_t bitcast(unsigned short const&); template <> unsigned short bitcast(half_t const&); template <> half bitcast(unsigned short const&); } // namespace cutlass cutlass::half_t operator+(float, cutlass::half_t const&); cutlass::half_t operator-(float, cutlass::half_t const&); cutlass::half_t operator*(float, cutlass::half_t const&); cutlass::half_t operator/(float, cutlass::half_t const&); #ifdef BOOST_LEXICAL_CAST_INCLUDED namespace boost { /// lexical cast from string to half_t template <> cutlass::half_t lexical_cast(std::string const& arg); /// lexical cast from half_t to string template <> std::string lexical_cast(cutlass::half_t const& arg); } // namespace boost #endif #define HLF_MANT_DIG 10 namespace cutlass { cutlass::half_t abs(cutlass::half_t const&); /// absolute value bool isnan(cutlass::half_t const&); /// true if argument is NaN bool isfinite(cutlass::half_t const&); /// true if argument is neither NaN nor infinity cutlass::half_t nanh(const char* = 0); /// returns a not-a-number bool isinf(cutlass::half_t const&); /// returns true if argument is infinitey (+ or -) bool isnormal( cutlass::half_t const&); /// returns true if argument is normal (neither zero nor infinity) int fpclassify(cutlass::half_t const&); /// returns a flag classifying floating-point value bool signbit(cutlass::half_t const&); /// returns true if negative, false if positive cutlass::half_t sqrt(cutlass::half_t const&); /// square root of half_t #if __cplusplus >= 201103L cutlass::half_t copysign(cutlass::half_t const&, cutlass::half_t const&); #endif } namespace std { /// Numeric limits template <> struct numeric_limits { static bool const is_specialized = true; static bool const is_signed = true; static bool const is_integer = false; static bool const is_exact = false; static bool const has_infinity = true; static bool const has_quiet_NaN = true; static bool const has_signaling_NaN = false; static std::float_denorm_style const has_denorm = std::denorm_present; static bool const has_denorm_loss = true; static std::float_round_style const round_style = std::round_to_nearest; static bool const is_iec559 = false; static bool const is_bounded = true; static bool const is_modulo = false; static int const digits = HLF_MANT_DIG; static cutlass::half_t min() { return cutlass::half_t::bitcast(0x0001); } static cutlass::half_t lowest() { return cutlass::half_t::bitcast(0xfbff); } static cutlass::half_t max() { return cutlass::half_t::bitcast(0x7bff); } /// Returns smallest finite value static cutlass::half_t epsilon() { return cutlass::half_t::epsilon(); } /// Returns smallest finite value static cutlass::half_t round_error() { return cutlass::half_t(0.5f); } /// Returns smallest finite value static cutlass::half_t infinity() { return cutlass::half_t::inf(); } /// Returns smallest finite value static cutlass::half_t quiet_NaN() { return cutlass::half_t::nan(); } /// Returns smallest finite value static cutlass::half_t signaling_NaN() { return cutlass::half_t::nan(); } /// Returns smallest finite value static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); } }; } // namespace std // // // inline cutlass::half_t cutlass::half_t::bitcast(unsigned short _x) { half_t h; h.x = _x; return h; } /// FP32 -> FP16 conversion - rounds to nearest even inline cutlass::half_t cutlass::half_t::convert(float const& flt) { // software implementation rounds toward nearest even unsigned const& s = *reinterpret_cast(&flt); uint16_t sign = uint16_t((s >> 16) & 0x8000); int16_t exp = uint16_t(((s >> 23) & 0xff) - 127); int mantissa = s & 0x7fffff; uint16_t u = 0; if ((s & 0x7fffffff) == 0) { // sign-preserving zero return cutlass::half_t::bitcast(sign); } if (exp > 15) { if (exp == 128 && mantissa) { // not a number u = 0x7fff; } else { // overflow to infinity u = sign | 0x7c00; } return cutlass::half_t::bitcast(u); } int sticky_bit = 0; if (exp >= -14) { // normal fp32 to normal fp16 exp = uint16_t(exp + uint16_t(15)); u = uint16_t(((exp & 0x1f) << 10)); u = uint16_t(u | (mantissa >> 13)); } else { // normal single-precision to subnormal half_t-precision representation int rshift = (-14 - exp); if (rshift < 32) { mantissa |= (1 << 23); sticky_bit = ((mantissa & ((1 << rshift) - 1)) != 0); mantissa = (mantissa >> rshift); u = (uint16_t(mantissa >> 13) & 0x3ff); } else { mantissa = 0; u = 0; } } // round to nearest even int round_bit = ((mantissa >> 12) & 1); sticky_bit |= ((mantissa & ((1 << 12) - 1)) != 0); if ((round_bit && sticky_bit) || (round_bit && (u & 1))) { u = uint16_t(u + 1); } u |= sign; return cutlass::half_t::bitcast(u); } inline float cutlass::half_t::convert(unsigned short const& h) { int sign = ((h >> 15) & 1); int exp = ((h >> 10) & 0x1f); int mantissa = (h & 0x3ff); unsigned f = 0; if (exp > 0 && exp < 31) { // normal exp += 112; f = (sign << 31) | (exp << 23) | (mantissa << 13); } else if (exp == 0) { if (mantissa) { // subnormal exp += 113; while ((mantissa & (1 << 10)) == 0) { mantissa <<= 1; exp--; } mantissa &= 0x3ff; f = (sign << 31) | (exp << 23) | (mantissa << 13); } else { // sign-preserving zero f = (sign << 31); } } else if (exp == 31) { if (mantissa) { f = 0x7fffffff; // not a number } else { f = (0xff << 23) | (sign << 31); // inf } } return *reinterpret_cast(&f); } inline cutlass::half_t::half_t() {} inline cutlass::half_t::half_t(int i) { x = convert(float(i)).x; } inline cutlass::half_t::half_t(float f) { x = convert(f).x; } inline cutlass::half_t::half_t(double d) { x = convert(float(d)).x; } inline bool cutlass::half_t::signbit() const { return (x >> 15) & 1; } inline int cutlass::half_t::exponent() const { return ((x >> 10) & 0x1f) - 15; } inline unsigned short cutlass::half_t::mantissa() const { return x & 0x3ff; } inline cutlass::half_t::operator bool() const { return (x & 0x7fff) != 0; } inline cutlass::half_t::operator int() const { return static_cast(convert(x)); } inline cutlass::half_t::operator float() const { return convert(x); } inline cutlass::half_t::operator half() const { return cutlass::bitcast(x); } inline bool cutlass::half_t::operator==(cutlass::half_t const& h) const { if (iszero() && h.iszero()) { return true; } return x == h.x; } inline bool cutlass::half_t::operator!=(cutlass::half_t const& h) const { if (iszero() && h.iszero()) { return false; } return x != h.x; } inline bool cutlass::half_t::operator==(float const& b) const { return x == half_t(b).x; } inline bool cutlass::half_t::operator!=(float const& b) const { return x != half_t(b).x; } inline bool cutlass::half_t::iszero() const { return (x & 0x7fff) == 0; } inline bool cutlass::half_t::isfinite() const { return (exponent() < 16); } inline bool cutlass::half_t::isnan() const { int exp = ((x >> 10) & 0x1f); if (exp == 0x1f) { return (x & 0x3ff) != 0; } return false; } inline bool cutlass::half_t::isinf() const { int exp = ((x >> 10) & 0x1f); if (exp == 0x1f) { return (x & 0x3ff) == 0; } return false; } inline bool cutlass::half_t::isnormal() const { int exp = exponent(); return exp > -15 && exp < 16; } inline bool cutlass::half_t::operator<(half_t const& h) const { int sign = ((x >> 15) & 1); int h_sign = ((h.x >> 15) & 1); if (sign == h_sign) { return (x & 0x7fff) < (h.x & 0x7fff); } else if (sign) { return true; } return false; } inline bool cutlass::half_t::operator<=(half_t const& h) const { int sign = ((x >> 15) & 1); int h_sign = ((h.x >> 15) & 1); if (sign == h_sign) { return (x & 0x7fff) <= (h.x & 0x7fff); } else if (sign) { return true; } return false; } inline bool cutlass::half_t::operator>(half_t const& h) const { int sign = ((x >> 15) & 1); int h_sign = ((h.x >> 15) & 1); if (sign == h_sign) { return (x & 0x7fff) > (h.x & 0x7fff); } else if (h_sign) { return true; } return false; } inline bool cutlass::half_t::operator>=(half_t const& h) const { int sign = ((x >> 15) & 1); int h_sign = ((h.x >> 15) & 1); if (sign == h_sign) { return (x & 0x7fff) >= (h.x & 0x7fff); } else if (h_sign) { return true; } return false; } inline cutlass::half_t cutlass::half_t::operator+(cutlass::half_t const& b) const { return cutlass::half_t(float(*this) + float(b)); } inline cutlass::half_t cutlass::half_t::operator-() const { return bitcast(x ^ 0x8000); } inline cutlass::half_t cutlass::half_t::operator-(cutlass::half_t const& b) const { return cutlass::half_t(float(*this) - float(b)); } inline cutlass::half_t cutlass::half_t::operator*(cutlass::half_t const& b) const { return cutlass::half_t(float(*this) * float(b)); } inline cutlass::half_t cutlass::half_t::operator/(cutlass::half_t const& b) const { return cutlass::half_t(float(*this) / float(b)); } inline cutlass::half_t& cutlass::half_t::operator+=(cutlass::half_t const& b) { *this = cutlass::half_t(float(*this) + float(b)); return *this; } inline cutlass::half_t& cutlass::half_t::operator-=(cutlass::half_t const& b) { *this = cutlass::half_t(float(*this) - float(b)); return *this; } inline cutlass::half_t& cutlass::half_t::operator*=(cutlass::half_t const& b) { *this = cutlass::half_t(float(*this) * float(b)); return *this; } inline cutlass::half_t& cutlass::half_t::operator/=(cutlass::half_t const& b) { *this = cutlass::half_t(float(*this) / float(b)); return *this; } inline cutlass::half_t& cutlass::half_t::operator++() { *this = cutlass::half_t(float(*this) + 1.0f); return *this; } inline cutlass::half_t& cutlass::half_t::operator--() { *this = cutlass::half_t(float(*this) - 1.0f); return *this; } inline cutlass::half_t cutlass::half_t::operator++(int) { half_t h = *this; *this = cutlass::half_t(float(*this) + 1.0f); return h; } inline cutlass::half_t cutlass::half_t::operator--(int) { half_t h = *this; *this = cutlass::half_t(float(*this) - 1.0f); return h; } inline cutlass::half_t operator+(float a, cutlass::half_t const& b) { return cutlass::half_t(a + float(b)); } inline cutlass::half_t operator-(float a, cutlass::half_t const& b) { return cutlass::half_t(a - float(b)); } inline cutlass::half_t operator*(float a, cutlass::half_t const& b) { return cutlass::half_t(a * float(b)); } inline cutlass::half_t operator/(float a, cutlass::half_t const& b) { return cutlass::half_t(a / float(b)); } // // // inline cutlass::half2_t::half2_t() {} inline cutlass::half2_t::half2_t(half_t lo, half_t hi) : lo(lo), hi(hi) {} inline cutlass::half2_t::half2_t(std::pair const& p) : lo(p.first), hi(p.second) {} inline cutlass::half2_t::half2_t(unsigned data) : lo(half_t::bitcast(uint16_t(data & 0x0ffff))), hi(half_t::bitcast(uint16_t((data >> 16) & 0x0ffff))) {} inline cutlass::half2_t cutlass::half2_t::operator+(half2_t const& b) const { return half2_t(lo + b.lo, hi + b.hi); } inline cutlass::half2_t cutlass::half2_t::operator-(half2_t const& b) const { return half2_t(lo - b.lo, hi - b.hi); } inline cutlass::half2_t cutlass::half2_t::operator*(half2_t const& b) const { return half2_t(lo * b.lo, hi * b.hi); } inline cutlass::half2_t cutlass::half2_t::operator/(half2_t const& b) const { return half2_t(lo / b.lo, hi / b.hi); } inline cutlass::half2_t& cutlass::half2_t::operator+=(half2_t const& b) { lo += b.lo; hi += b.hi; return *this; } inline cutlass::half2_t& cutlass::half2_t::operator-=(half2_t const& b) { lo -= b.lo; hi -= b.hi; return *this; } inline cutlass::half2_t& cutlass::half2_t::operator*=(half2_t const& b) { lo *= b.lo; hi *= b.hi; return *this; } inline cutlass::half2_t& cutlass::half2_t::operator/=(half2_t const& b) { lo /= b.lo; hi /= b.hi; return *this; } inline float cutlass::half2_t::dot(half2_t const& b) const { return float(lo) * float(b.lo) + float(hi) * float(b.hi); } inline float cutlass::half2_t::dot(half2_t const& b, float c) const { return c + dot(b); } inline cutlass::half_t cutlass::half2_t::doth(half2_t const& b) const { return cutlass::half_t(dot(b)); } inline cutlass::half_t cutlass::half2_t::doth(half2_t const& b, half_t c) const { return cutlass::half_t(dot(b, float(c))); } inline cutlass::half2_t::operator std::pair() const { return std::pair(float(lo), float(hi)); } inline unsigned cutlass::half2_t::packed() const { return (lo.x | (hi.x << 16)); } inline cutlass::half2_t::operator unsigned() const { return packed(); } // // // template <> inline float cutlass::bitcast(unsigned const& u) { return *reinterpret_cast(&u); } template <> inline float cutlass::bitcast(int const& i) { return *reinterpret_cast(&i); } template <> inline unsigned cutlass::bitcast(float const& f) { return *reinterpret_cast(&f); } template <> inline cutlass::half_t cutlass::bitcast(unsigned short const& s) { return *reinterpret_cast(&s); } template <> inline unsigned short cutlass::bitcast(cutlass::half_t const& h) { return *reinterpret_cast(&h); } template <> inline half cutlass::bitcast(unsigned short const& s) { return *reinterpret_cast(&s); } // // Lexical casts // #ifdef BOOST_LEXICAL_CAST_INCLUDED namespace boost { template <> cutlass::half_t lexical_cast(std::string const& arg) { return cutlass::half_t(boost::lexical_cast(arg)); } template <> std::string lexical_cast(cutlass::half_t const& arg) { return boost::lexical_cast(float(arg)); } } // namespace boost #endif // // Standard Library Operations // namespace cutlass { inline cutlass::half_t abs(cutlass::half_t const& h) { return cutlass::half_t::bitcast(h.x & 0x7fff); } inline bool isnan(cutlass::half_t const& h) { return h.isnan(); } inline bool isfinite(cutlass::half_t const& h) { return h.isfinite(); } inline cutlass::half_t nanh(const char*) { return cutlass::half_t::nan(); } inline bool isinf(cutlass::half_t const& h) { return h.isinf(); } inline bool isnormal(cutlass::half_t const& h) { return h.isnormal(); } inline int fpclassify(cutlass::half_t const& h) { int exp = h.exponent(); unsigned short mantissa = h.mantissa(); if (exp < -14) { if (mantissa == 0) { return FP_ZERO; } else { return FP_SUBNORMAL; } } else if (exp > 15) { if (mantissa == 0) { return FP_INFINITE; } else { return FP_NAN; } } return FP_NORMAL; } inline bool signbit(cutlass::half_t const& h) { return h.signbit(); } inline cutlass::half_t sqrt(cutlass::half_t const& h) { return cutlass::half_t(std::sqrt(float(h))); } #if __cplusplus >= 201103L inline cutlass::half_t copysign(cutlass::half_t const& a, cutlass::half_t const& b) { return cutlass::half_t(std::copysign(float(a), float(b))); } #endif } // namespace std