744 lines
		
	
	
		
			21 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			744 lines
		
	
	
		
			21 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /***************************************************************************************************
 | |
|  * Copyright (c) 2017-2018, NVIDIA CORPORATION.  All rights reserved.
 | |
|  *
 | |
|  * Redistribution and use in source and binary forms, with or without modification, are permitted
 | |
|  * provided that the following conditions are met:
 | |
|  *     * Redistributions of source code must retain the above copyright notice, this list of
 | |
|  *       conditions and the following disclaimer.
 | |
|  *     * Redistributions in binary form must reproduce the above copyright notice, this list of
 | |
|  *       conditions and the following disclaimer in the documentation and/or other materials
 | |
|  *       provided with the distribution.
 | |
|  *     * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
 | |
|  *       to endorse or promote products derived from this software without specific prior written
 | |
|  *       permission.
 | |
|  *
 | |
|  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
 | |
|  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 | |
|  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
 | |
|  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 | |
|  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
 | |
|  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
 | |
|  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | |
|  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
|  *
 | |
|  **************************************************************************************************/
 | |
| /*!
 | |
|     \file
 | |
|     \brief Host-side implementation of half-precision float
 | |
| */
 | |
| 
 | |
| #pragma once
 | |
| 
 | |
| #include <stdint.h>
 | |
| #include <cmath>
 | |
| #include <limits>
 | |
| #include <utility>
 | |
| #include <utility>
 | |
| 
 | |
| #include <iomanip>
 | |
| #include <istream>
 | |
| #include <ostream>
 | |
| 
 | |
| #include <cuda_fp16.h>
 | |
| 
 | |
| namespace cutlass {
 | |
| 
 | |
| /// IEEE binary16 floating-point value
 | |
| class half_t {
 | |
|  public:
 | |
|   half_t();
 | |
|   half_t(int);     /// conversion from integer
 | |
|   half_t(float);   /// conversion from fp32
 | |
|   half_t(double);  /// conversion from fp64
 | |
| 
 | |
|   static half_t bitcast(unsigned short);  /// bitcast performs no conversion
 | |
| 
 | |
|   static half_t convert(float const&);          /// FP conversion - round toward nearest even
 | |
|   static float convert(unsigned short const&);  /// floating point conversion to fp32
 | |
| 
 | |
|   static half_t zero() { return bitcast(0); }          /// +zero
 | |
|   static half_t one() { return bitcast(0x3c00); }      /// one
 | |
|   static half_t nan() { return bitcast(0x7fff); }      /// canonical not a number
 | |
|   static half_t inf() { return bitcast(0x7c00); }      /// +infinity
 | |
|   static half_t ninf() { return bitcast(0xfc00); }     /// -infinity
 | |
|   static half_t epsilon() { return bitcast(0x1000); }  /// Machine epsilon
 | |
| 
 | |
|   bool signbit() const;             /// sign bit - true: negative, false: positive
 | |
|   int exponent() const;             /// unbiased exponent
 | |
|   unsigned short mantissa() const;  /// mantissa bits
 | |
| 
 | |
|   bool isfinite() const;  /// true if neither inf nor nan
 | |
|   bool isinf() const;     /// true if value is + or - infinity
 | |
|   bool isnan() const;     /// true if value is not a number
 | |
|   bool isnormal() const;  /// true if nonzero value is normalized
 | |
|   bool iszero() const;    /// true if value is + or - zero
 | |
| 
 | |
|   bool operator==(half_t const&) const;
 | |
|   bool operator!=(half_t const&) const;
 | |
|   bool operator==(float const&) const;
 | |
|   bool operator!=(float const&) const;
 | |
| 
 | |
|   bool operator<(half_t const&) const;
 | |
|   bool operator<=(half_t const&) const;
 | |
|   bool operator>(half_t const&) const;
 | |
|   bool operator>=(half_t const&) const;
 | |
| 
 | |
|   half_t operator+(half_t const&) const;
 | |
|   half_t operator-() const;
 | |
|   half_t operator-(half_t const&) const;
 | |
|   half_t operator*(half_t const&)const;
 | |
|   half_t operator/(half_t const&) const;
 | |
| 
 | |
|   half_t& operator+=(half_t const&);
 | |
|   half_t& operator-=(half_t const&);
 | |
|   half_t& operator*=(half_t const&);
 | |
|   half_t& operator/=(half_t const&);
 | |
| 
 | |
|   half_t& operator++();
 | |
|   half_t& operator--();
 | |
|   half_t operator++(int);
 | |
|   half_t operator--(int);
 | |
| 
 | |
|   operator bool() const;   /// false if zero
 | |
|   operator int() const;    /// conversion to int
 | |
|   operator float() const;  /// conversion to fp32
 | |
|   operator half() const;   /// conversion to half
 | |
| 
 | |
|   uint16_t& raw() { return x; }
 | |
|   uint16_t raw() const { return x; }
 | |
| 
 | |
|  public:
 | |
|   /// data
 | |
|   unsigned short x;
 | |
| };
 | |
| 
 | |
| /// Packed pair of half-precision elements
 | |
| class half2_t {
 | |
|  public:
 | |
|   half2_t();
 | |
|   half2_t(half_t lo, half_t hi);
 | |
|   half2_t(std::pair<float, float> const&);
 | |
|   explicit half2_t(unsigned data);
 | |
| 
 | |
|   half2_t operator+(half2_t const&) const;
 | |
|   half2_t operator-(half2_t const&) const;
 | |
|   half2_t operator*(half2_t const&)const;
 | |
|   half2_t operator/(half2_t const&) const;
 | |
| 
 | |
|   half2_t& operator+=(half2_t const&);
 | |
|   half2_t& operator-=(half2_t const&);
 | |
|   half2_t& operator*=(half2_t const&);
 | |
|   half2_t& operator/=(half2_t const&);
 | |
| 
 | |
|   float dot(half2_t const&) const;         /// dot product with single-precision accumulation
 | |
|   float dot(half2_t const&, float) const;  /// dot product with single-precision accumulation
 | |
| 
 | |
|   half_t doth(half2_t const&) const;          /// dot product with half_t-precision accumulation
 | |
|   half_t doth(half2_t const&, half_t) const;  /// dot product with half_t-precision accumulation
 | |
| 
 | |
|   unsigned packed() const;
 | |
| 
 | |
|   operator std::pair<float, float>() const;
 | |
|   operator unsigned() const;
 | |
| 
 | |
|  public:
 | |
|   half_t lo;
 | |
|   half_t hi;
 | |
| };
 | |
| 
 | |
| template <typename Dest, typename Src>
 | |
| Dest bitcast(Src const&);
 | |
| template <>
 | |
| float bitcast<float, unsigned>(unsigned const&);
 | |
| template <>
 | |
| float bitcast<float, int>(int const&);
 | |
| template <>
 | |
| unsigned bitcast<unsigned, float>(float const&);
 | |
| template <>
 | |
| half_t bitcast<half_t, unsigned short>(unsigned short const&);
 | |
| template <>
 | |
| unsigned short bitcast<unsigned short, half_t>(half_t const&);
 | |
| template <>
 | |
| half bitcast<half, unsigned short>(unsigned short const&);
 | |
| }  // namespace cutlass
 | |
| 
 | |
| cutlass::half_t operator+(float, cutlass::half_t const&);
 | |
| cutlass::half_t operator-(float, cutlass::half_t const&);
 | |
| cutlass::half_t operator*(float, cutlass::half_t const&);
 | |
| cutlass::half_t operator/(float, cutlass::half_t const&);
 | |
| 
 | |
| std::ostream& operator<<(std::ostream&, cutlass::half_t const&);  /// writes a half_t
 | |
| std::istream& operator>>(std::istream&, cutlass::half_t&);        /// reads a half_t
 | |
| 
 | |
| #ifdef BOOST_LEXICAL_CAST_INCLUDED
 | |
| namespace boost {
 | |
| 
 | |
| /// lexical cast from string to half_t
 | |
| template <>
 | |
| cutlass::half_t lexical_cast<cutlass::half_t>(std::string const& arg);
 | |
| 
 | |
| /// lexical cast from half_t to string
 | |
| template <>
 | |
| std::string lexical_cast<std::string>(cutlass::half_t const& arg);
 | |
| }  // namespace boost
 | |
| #endif
 | |
| 
 | |
| #define HLF_MANT_DIG 10
 | |
| 
 | |
| namespace std {
 | |
| 
 | |
| cutlass::half_t abs(cutlass::half_t const&);  /// absolute value
 | |
| 
 | |
| bool isnan(cutlass::half_t const&);  /// true if argument is NaN
 | |
| 
 | |
| bool isfinite(cutlass::half_t const&);  /// true if argument is neither NaN nor infinity
 | |
| 
 | |
| cutlass::half_t nanh(const char* = 0);  /// returns a not-a-number
 | |
| 
 | |
| bool isinf(cutlass::half_t const&);  /// returns true if argument is infinitey (+ or -)
 | |
| 
 | |
| bool isnormal(
 | |
|     cutlass::half_t const&);  /// returns true if argument is normal (neither zero nor infinity)
 | |
| 
 | |
| int fpclassify(cutlass::half_t const&);  /// returns a flag classifying floating-point value
 | |
| 
 | |
| bool signbit(cutlass::half_t const&);  /// returns true if negative, false if positive
 | |
| 
 | |
| cutlass::half_t sqrt(cutlass::half_t const&);  /// square root of half_t
 | |
| 
 | |
| /// Numeric limits
 | |
| template <>
 | |
| struct numeric_limits<cutlass::half_t> {
 | |
|   static bool const is_specialized = true;
 | |
|   static bool const is_signed = true;
 | |
|   static bool const is_integer = false;
 | |
|   static bool const is_exact = false;
 | |
|   static bool const has_infinity = true;
 | |
|   static bool const has_quiet_NaN = true;
 | |
|   static bool const has_signaling_NaN = false;
 | |
|   static std::float_denorm_style const has_denorm = std::denorm_present;
 | |
|   static bool const has_denorm_loss = true;
 | |
|   static std::float_round_style const round_style = std::round_to_nearest;
 | |
|   static bool const is_iec559 = false;
 | |
|   static bool const is_bounded = true;
 | |
|   static bool const is_modulo = false;
 | |
|   static int const digits = HLF_MANT_DIG;
 | |
| 
 | |
|   static cutlass::half_t min() { return cutlass::half_t::bitcast(0x0001); }
 | |
| 
 | |
|   static cutlass::half_t lowest() { return cutlass::half_t::bitcast(0xfbff); }
 | |
| 
 | |
|   static cutlass::half_t max() { return cutlass::half_t::bitcast(0x7bff); }
 | |
| 
 | |
|   /// Returns smallest finite value
 | |
|   static cutlass::half_t epsilon() { return cutlass::half_t::epsilon(); }
 | |
| 
 | |
|   /// Returns smallest finite value
 | |
|   static cutlass::half_t round_error() { return cutlass::half_t(0.5f); }
 | |
| 
 | |
|   /// Returns smallest finite value
 | |
|   static cutlass::half_t infinity() { return cutlass::half_t::inf(); }
 | |
| 
 | |
|   /// Returns smallest finite value
 | |
|   static cutlass::half_t quiet_NaN() { return cutlass::half_t::nan(); }
 | |
| 
 | |
|   /// Returns smallest finite value
 | |
|   static cutlass::half_t signaling_NaN() { return cutlass::half_t::nan(); }
 | |
| 
 | |
|   /// Returns smallest finite value
 | |
|   static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); }
 | |
| };
 | |
| }  // namespace std
 | |
| 
 | |
| //
 | |
| //
 | |
| //
 | |
| 
 | |
| inline cutlass::half_t cutlass::half_t::bitcast(unsigned short _x) {
 | |
|   half_t h;
 | |
|   h.x = _x;
 | |
|   return h;
 | |
| }
 | |
| 
 | |
| /// FP32 -> FP16 conversion - rounds to nearest even
 | |
| inline cutlass::half_t cutlass::half_t::convert(float const& flt) {
 | |
|   // software implementation rounds toward nearest even
 | |
|   unsigned const& s = *reinterpret_cast<unsigned const*>(&flt);
 | |
|   uint16_t sign = uint16_t((s >> 16) & 0x8000);
 | |
|   int16_t exp = uint16_t(((s >> 23) & 0xff) - 127);
 | |
|   int mantissa = s & 0x7fffff;
 | |
|   uint16_t u = 0;
 | |
| 
 | |
|   if ((s & 0x7fffffff) == 0) {
 | |
|     // sign-preserving zero
 | |
|     return cutlass::half_t::bitcast(sign);
 | |
|   }
 | |
| 
 | |
|   if (exp > 15) {
 | |
|     if (exp == 128 && mantissa) {
 | |
|       // not a number
 | |
|       u = 0x7fff;
 | |
|     } else {
 | |
|       // overflow to infinity
 | |
|       u = sign | 0x7c00;
 | |
|     }
 | |
|     return cutlass::half_t::bitcast(u);
 | |
|   }
 | |
| 
 | |
|   int sticky_bit = 0;
 | |
| 
 | |
|   if (exp >= -14) {
 | |
|     // normal fp32 to normal fp16
 | |
|     exp = uint16_t(exp + uint16_t(15));
 | |
|     u = uint16_t(((exp & 0x1f) << 10));
 | |
|     u = uint16_t(u | (mantissa >> 13));
 | |
|   } else {
 | |
|     // normal single-precision to subnormal half_t-precision representation
 | |
|     int rshift = (-14 - exp);
 | |
|     if (rshift < 32) {
 | |
|       mantissa |= (1 << 23);
 | |
| 
 | |
|       sticky_bit = ((mantissa & ((1 << rshift) - 1)) != 0);
 | |
| 
 | |
|       mantissa = (mantissa >> rshift);
 | |
|       u = (uint16_t(mantissa >> 13) & 0x3ff);
 | |
|     } else {
 | |
|       mantissa = 0;
 | |
|       u = 0;
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   // round to nearest even
 | |
|   int round_bit = ((mantissa >> 12) & 1);
 | |
|   sticky_bit |= ((mantissa & ((1 << 12) - 1)) != 0);
 | |
| 
 | |
|   if ((round_bit && sticky_bit) || (round_bit && (u & 1))) {
 | |
|     u = uint16_t(u + 1);
 | |
|   }
 | |
| 
 | |
|   u |= sign;
 | |
| 
 | |
|   return cutlass::half_t::bitcast(u);
 | |
| }
 | |
| 
 | |
| inline float cutlass::half_t::convert(unsigned short const& h) {
 | |
|   int sign = ((h >> 15) & 1);
 | |
|   int exp = ((h >> 10) & 0x1f);
 | |
|   int mantissa = (h & 0x3ff);
 | |
|   unsigned f = 0;
 | |
| 
 | |
|   if (exp > 0 && exp < 31) {
 | |
|     // normal
 | |
|     exp += 112;
 | |
|     f = (sign << 31) | (exp << 23) | (mantissa << 13);
 | |
|   } else if (exp == 0) {
 | |
|     if (mantissa) {
 | |
|       // subnormal
 | |
|       exp += 113;
 | |
|       while ((mantissa & (1 << 10)) == 0) {
 | |
|         mantissa <<= 1;
 | |
|         exp--;
 | |
|       }
 | |
|       mantissa &= 0x3ff;
 | |
|       f = (sign << 31) | (exp << 23) | (mantissa << 13);
 | |
|     } else {
 | |
|       // sign-preserving zero
 | |
|       f = (sign << 31);
 | |
|     }
 | |
|   } else if (exp == 31) {
 | |
|     if (mantissa) {
 | |
|       f = 0x7fffffff;  // not a number
 | |
|     } else {
 | |
|       f = (0xff << 23) | (sign << 31);  //  inf
 | |
|     }
 | |
|   }
 | |
|   return *reinterpret_cast<float const*>(&f);
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t::half_t() {}
 | |
| 
 | |
| inline cutlass::half_t::half_t(int i) { x = convert(float(i)).x; }
 | |
| 
 | |
| inline cutlass::half_t::half_t(float f) { x = convert(f).x; }
 | |
| 
 | |
| inline cutlass::half_t::half_t(double d) { x = convert(float(d)).x; }
 | |
| 
 | |
| inline bool cutlass::half_t::signbit() const { return (x >> 15) & 1; }
 | |
| 
 | |
| inline int cutlass::half_t::exponent() const { return ((x >> 10) & 0x1f) - 15; }
 | |
| 
 | |
| inline unsigned short cutlass::half_t::mantissa() const { return x & 0x3ff; }
 | |
| 
 | |
| inline cutlass::half_t::operator bool() const { return (x & 0x7fff) != 0; }
 | |
| 
 | |
| inline cutlass::half_t::operator int() const { return static_cast<int>(convert(x)); }
 | |
| 
 | |
| inline cutlass::half_t::operator float() const { return convert(x); }
 | |
| 
 | |
| inline cutlass::half_t::operator half() const { return cutlass::bitcast<half, unsigned short>(x); }
 | |
| 
 | |
| inline bool cutlass::half_t::operator==(cutlass::half_t const& h) const {
 | |
|   if (iszero() && h.iszero()) {
 | |
|     return true;
 | |
|   }
 | |
|   return x == h.x;
 | |
| }
 | |
| 
 | |
| inline bool cutlass::half_t::operator!=(cutlass::half_t const& h) const {
 | |
|   if (iszero() && h.iszero()) {
 | |
|     return false;
 | |
|   }
 | |
|   return x != h.x;
 | |
| }
 | |
| 
 | |
| inline bool cutlass::half_t::operator==(float const& b) const { return x == half_t(b).x; }
 | |
| 
 | |
| inline bool cutlass::half_t::operator!=(float const& b) const { return x != half_t(b).x; }
 | |
| 
 | |
| inline bool cutlass::half_t::iszero() const { return (x & 0x7fff) == 0; }
 | |
| 
 | |
| inline bool cutlass::half_t::isfinite() const { return (exponent() < 16); }
 | |
| 
 | |
| inline bool cutlass::half_t::isnan() const {
 | |
|   int exp = ((x >> 10) & 0x1f);
 | |
|   if (exp == 0x1f) {
 | |
|     return (x & 0x3ff) != 0;
 | |
|   }
 | |
|   return false;
 | |
| }
 | |
| 
 | |
| inline bool cutlass::half_t::isinf() const {
 | |
|   int exp = ((x >> 10) & 0x1f);
 | |
|   if (exp == 0x1f) {
 | |
|     return (x & 0x3ff) == 0;
 | |
|   }
 | |
|   return false;
 | |
| }
 | |
| 
 | |
| inline bool cutlass::half_t::isnormal() const {
 | |
|   int exp = exponent();
 | |
|   return exp > -15 && exp < 16;
 | |
| }
 | |
| 
 | |
| inline bool cutlass::half_t::operator<(half_t const& h) const {
 | |
|   int sign = ((x >> 15) & 1);
 | |
|   int h_sign = ((h.x >> 15) & 1);
 | |
|   if (sign == h_sign) {
 | |
|     return (x & 0x7fff) < (h.x & 0x7fff);
 | |
|   } else if (sign) {
 | |
|     return true;
 | |
|   }
 | |
|   return false;
 | |
| }
 | |
| 
 | |
| inline bool cutlass::half_t::operator<=(half_t const& h) const {
 | |
|   int sign = ((x >> 15) & 1);
 | |
|   int h_sign = ((h.x >> 15) & 1);
 | |
|   if (sign == h_sign) {
 | |
|     return (x & 0x7fff) <= (h.x & 0x7fff);
 | |
|   } else if (sign) {
 | |
|     return true;
 | |
|   }
 | |
|   return false;
 | |
| }
 | |
| 
 | |
| inline bool cutlass::half_t::operator>(half_t const& h) const {
 | |
|   int sign = ((x >> 15) & 1);
 | |
|   int h_sign = ((h.x >> 15) & 1);
 | |
|   if (sign == h_sign) {
 | |
|     return (x & 0x7fff) > (h.x & 0x7fff);
 | |
|   } else if (h_sign) {
 | |
|     return true;
 | |
|   }
 | |
|   return false;
 | |
| }
 | |
| 
 | |
| inline bool cutlass::half_t::operator>=(half_t const& h) const {
 | |
|   int sign = ((x >> 15) & 1);
 | |
|   int h_sign = ((h.x >> 15) & 1);
 | |
|   if (sign == h_sign) {
 | |
|     return (x & 0x7fff) >= (h.x & 0x7fff);
 | |
|   } else if (h_sign) {
 | |
|     return true;
 | |
|   }
 | |
|   return false;
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t cutlass::half_t::operator+(cutlass::half_t const& b) const {
 | |
|   return cutlass::half_t(float(*this) + float(b));
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t cutlass::half_t::operator-() const { return bitcast(x ^ 0x8000); }
 | |
| 
 | |
| inline cutlass::half_t cutlass::half_t::operator-(cutlass::half_t const& b) const {
 | |
|   return cutlass::half_t(float(*this) - float(b));
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t cutlass::half_t::operator*(cutlass::half_t const& b) const {
 | |
|   return cutlass::half_t(float(*this) * float(b));
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t cutlass::half_t::operator/(cutlass::half_t const& b) const {
 | |
|   return cutlass::half_t(float(*this) / float(b));
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t& cutlass::half_t::operator+=(cutlass::half_t const& b) {
 | |
|   *this = cutlass::half_t(float(*this) + float(b));
 | |
|   return *this;
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t& cutlass::half_t::operator-=(cutlass::half_t const& b) {
 | |
|   *this = cutlass::half_t(float(*this) - float(b));
 | |
|   return *this;
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t& cutlass::half_t::operator*=(cutlass::half_t const& b) {
 | |
|   *this = cutlass::half_t(float(*this) * float(b));
 | |
|   return *this;
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t& cutlass::half_t::operator/=(cutlass::half_t const& b) {
 | |
|   *this = cutlass::half_t(float(*this) / float(b));
 | |
|   return *this;
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t& cutlass::half_t::operator++() {
 | |
|   *this = cutlass::half_t(float(*this) + 1.0f);
 | |
|   return *this;
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t& cutlass::half_t::operator--() {
 | |
|   *this = cutlass::half_t(float(*this) - 1.0f);
 | |
|   return *this;
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t cutlass::half_t::operator++(int) {
 | |
|   half_t h = *this;
 | |
|   *this = cutlass::half_t(float(*this) + 1.0f);
 | |
|   return h;
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t cutlass::half_t::operator--(int) {
 | |
|   half_t h = *this;
 | |
|   *this = cutlass::half_t(float(*this) - 1.0f);
 | |
|   return h;
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t operator+(float a, cutlass::half_t const& b) {
 | |
|   return cutlass::half_t(a + float(b));
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t operator-(float a, cutlass::half_t const& b) {
 | |
|   return cutlass::half_t(a - float(b));
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t operator*(float a, cutlass::half_t const& b) {
 | |
|   return cutlass::half_t(a * float(b));
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t operator/(float a, cutlass::half_t const& b) {
 | |
|   return cutlass::half_t(a / float(b));
 | |
| }
 | |
| 
 | |
| //
 | |
| //
 | |
| //
 | |
| 
 | |
| inline cutlass::half2_t::half2_t() {}
 | |
| 
 | |
| inline cutlass::half2_t::half2_t(half_t lo, half_t hi) : lo(lo), hi(hi) {}
 | |
| 
 | |
| inline cutlass::half2_t::half2_t(std::pair<float, float> const& p) : lo(p.first), hi(p.second) {}
 | |
| 
 | |
| inline cutlass::half2_t::half2_t(unsigned data)
 | |
|     : lo(half_t::bitcast(uint16_t(data & 0x0ffff))),
 | |
|       hi(half_t::bitcast(uint16_t((data >> 16) & 0x0ffff))) {}
 | |
| 
 | |
| inline cutlass::half2_t cutlass::half2_t::operator+(half2_t const& b) const {
 | |
|   return half2_t(lo + b.lo, hi + b.hi);
 | |
| }
 | |
| 
 | |
| inline cutlass::half2_t cutlass::half2_t::operator-(half2_t const& b) const {
 | |
|   return half2_t(lo - b.lo, hi - b.hi);
 | |
| }
 | |
| 
 | |
| inline cutlass::half2_t cutlass::half2_t::operator*(half2_t const& b) const {
 | |
|   return half2_t(lo * b.lo, hi * b.hi);
 | |
| }
 | |
| 
 | |
| inline cutlass::half2_t cutlass::half2_t::operator/(half2_t const& b) const {
 | |
|   return half2_t(lo / b.lo, hi / b.hi);
 | |
| }
 | |
| 
 | |
| inline cutlass::half2_t& cutlass::half2_t::operator+=(half2_t const& b) {
 | |
|   lo += b.lo;
 | |
|   hi += b.hi;
 | |
|   return *this;
 | |
| }
 | |
| 
 | |
| inline cutlass::half2_t& cutlass::half2_t::operator-=(half2_t const& b) {
 | |
|   lo -= b.lo;
 | |
|   hi -= b.hi;
 | |
|   return *this;
 | |
| }
 | |
| 
 | |
| inline cutlass::half2_t& cutlass::half2_t::operator*=(half2_t const& b) {
 | |
|   lo *= b.lo;
 | |
|   hi *= b.hi;
 | |
|   return *this;
 | |
| }
 | |
| 
 | |
| inline cutlass::half2_t& cutlass::half2_t::operator/=(half2_t const& b) {
 | |
|   lo /= b.lo;
 | |
|   hi /= b.hi;
 | |
|   return *this;
 | |
| }
 | |
| 
 | |
| inline float cutlass::half2_t::dot(half2_t const& b) const {
 | |
|   return float(lo) * float(b.lo) + float(hi) * float(b.hi);
 | |
| }
 | |
| 
 | |
| inline float cutlass::half2_t::dot(half2_t const& b, float c) const { return c + dot(b); }
 | |
| 
 | |
| inline cutlass::half_t cutlass::half2_t::doth(half2_t const& b) const {
 | |
|   return cutlass::half_t(dot(b));
 | |
| }
 | |
| 
 | |
| inline cutlass::half_t cutlass::half2_t::doth(half2_t const& b, half_t c) const {
 | |
|   return cutlass::half_t(dot(b, float(c)));
 | |
| }
 | |
| 
 | |
| inline cutlass::half2_t::operator std::pair<float, float>() const {
 | |
|   return std::pair<float, float>(float(lo), float(hi));
 | |
| }
 | |
| 
 | |
| inline unsigned cutlass::half2_t::packed() const { return (lo.x | (hi.x << 16)); }
 | |
| 
 | |
| inline cutlass::half2_t::operator unsigned() const { return packed(); }
 | |
| 
 | |
| //
 | |
| //
 | |
| //
 | |
| 
 | |
| template <>
 | |
| inline float cutlass::bitcast<float, unsigned>(unsigned const& u) {
 | |
|   return *reinterpret_cast<float const*>(&u);
 | |
| }
 | |
| 
 | |
| template <>
 | |
| inline float cutlass::bitcast<float, int>(int const& i) {
 | |
|   return *reinterpret_cast<float const*>(&i);
 | |
| }
 | |
| 
 | |
| template <>
 | |
| inline unsigned cutlass::bitcast<unsigned, float>(float const& f) {
 | |
|   return *reinterpret_cast<unsigned const*>(&f);
 | |
| }
 | |
| 
 | |
| template <>
 | |
| inline cutlass::half_t cutlass::bitcast<cutlass::half_t, unsigned short>(unsigned short const& s) {
 | |
|   return *reinterpret_cast<cutlass::half_t const*>(&s);
 | |
| }
 | |
| 
 | |
| template <>
 | |
| inline unsigned short cutlass::bitcast<unsigned short, cutlass::half_t>(cutlass::half_t const& h) {
 | |
|   return *reinterpret_cast<unsigned short const*>(&h);
 | |
| }
 | |
| 
 | |
| template <>
 | |
| inline half cutlass::bitcast<half, unsigned short>(unsigned short const& s) {
 | |
|   return *reinterpret_cast<half const*>(&s);
 | |
| }
 | |
| 
 | |
| //
 | |
| // Lexical casts
 | |
| //
 | |
| 
 | |
| #ifdef BOOST_LEXICAL_CAST_INCLUDED
 | |
| namespace boost {
 | |
| template <>
 | |
| cutlass::half_t lexical_cast<cutlass::half_t>(std::string const& arg) {
 | |
|   return cutlass::half_t(boost::lexical_cast<float>(arg));
 | |
| }
 | |
| 
 | |
| template <>
 | |
| std::string lexical_cast<std::string>(cutlass::half_t const& arg) {
 | |
|   return boost::lexical_cast<std::string>(float(arg));
 | |
| }
 | |
| }  // namespace boost
 | |
| #endif
 | |
| 
 | |
| //
 | |
| // Standard Library Operations
 | |
| //
 | |
| 
 | |
| // std
 | |
| namespace std {
 | |
| 
 | |
| inline cutlass::half_t abs(cutlass::half_t const& h) {
 | |
|   return cutlass::half_t::bitcast(h.x & 0x7fff);
 | |
| }
 | |
| 
 | |
| inline bool isnan(cutlass::half_t const& h) { return h.isnan(); }
 | |
| 
 | |
| inline bool isfinite(cutlass::half_t const& h) { return h.isfinite(); }
 | |
| 
 | |
| inline cutlass::half_t nanh(const char*) { return cutlass::half_t::nan(); }
 | |
| 
 | |
| inline bool isinf(cutlass::half_t const& h) { return h.isinf(); }
 | |
| 
 | |
| inline bool isnormal(cutlass::half_t const& h) { return h.isnormal(); }
 | |
| 
 | |
| inline int fpclassify(cutlass::half_t const& h) {
 | |
|   int exp = h.exponent();
 | |
|   unsigned short mantissa = h.mantissa();
 | |
|   if (exp < -14) {
 | |
|     if (mantissa == 0) {
 | |
|       return FP_ZERO;
 | |
|     } else {
 | |
|       return FP_SUBNORMAL;
 | |
|     }
 | |
|   } else if (exp > 15) {
 | |
|     if (mantissa == 0) {
 | |
|       return FP_INFINITE;
 | |
|     } else {
 | |
|       return FP_NAN;
 | |
|     }
 | |
|   }
 | |
|   return FP_NORMAL;
 | |
| }
 | |
| 
 | |
| inline bool signbit(cutlass::half_t const& h) { return h.signbit(); }
 | |
| 
 | |
| inline cutlass::half_t sqrt(cutlass::half_t const& h) {
 | |
|   return cutlass::half_t(std::sqrt(float(h)));
 | |
| }
 | |
| }  // namespace std
 | |
| 
 | |
| //
 | |
| // Stream interactions
 | |
| //
 | |
| 
 | |
| /// put to stream - half_t-precision types bitcast as unsigned shorts if base is hexadecimal
 | |
| inline std::ostream& operator<<(std::ostream& out, cutlass::half_t const& h) {
 | |
|   if (out.flags() & std::ios::hex) {
 | |
|     return out << h.x;
 | |
|   } else {
 | |
|     return out << float(h);
 | |
|   }
 | |
| }
 | |
| 
 | |
| /// read from stream - half_t-precision types parsed as unsigned shorts if base is hexadecimal
 | |
| inline std::istream& operator>>(std::istream& in, cutlass::half_t& h) {
 | |
|   if (in.flags() & std::ios::hex) {
 | |
|     unsigned short u = 0;
 | |
|     in >> u;
 | |
|     h = cutlass::half_t::bitcast(u);
 | |
|   } else {
 | |
|     float f = 0;
 | |
|     in >> f;
 | |
|     h = cutlass::half_t(f);
 | |
|   }
 | |
|   return in;
 | |
| }
 | 
