744 lines
		
	
	
		
			21 KiB
		
	
	
	
		
			C
		
	
	
	
	
	
		
		
			
		
	
	
			744 lines
		
	
	
		
			21 KiB
		
	
	
	
		
			C
		
	
	
	
	
	
|   | /***************************************************************************************************
 | ||
|  |  * Copyright (c) 2017-2018, NVIDIA CORPORATION.  All rights reserved. | ||
|  |  * | ||
|  |  * Redistribution and use in source and binary forms, with or without modification, are permitted | ||
|  |  * provided that the following conditions are met: | ||
|  |  *     * Redistributions of source code must retain the above copyright notice, this list of | ||
|  |  *       conditions and the following disclaimer. | ||
|  |  *     * Redistributions in binary form must reproduce the above copyright notice, this list of | ||
|  |  *       conditions and the following disclaimer in the documentation and/or other materials | ||
|  |  *       provided with the distribution. | ||
|  |  *     * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used | ||
|  |  *       to endorse or promote products derived from this software without specific prior written | ||
|  |  *       permission. | ||
|  |  * | ||
|  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR | ||
|  |  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND | ||
|  |  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE | ||
|  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | ||
|  |  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; | ||
|  |  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, | ||
|  |  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
|  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|  |  * | ||
|  |  **************************************************************************************************/ | ||
|  | /*!
 | ||
|  |     \file | ||
|  |     \brief Host-side implementation of half-precision float | ||
|  | */ | ||
|  | 
 | ||
|  | #pragma once
 | ||
|  | 
 | ||
|  | #include <stdint.h>
 | ||
|  | #include <cmath>
 | ||
|  | #include <limits>
 | ||
|  | #include <utility>
 | ||
|  | #include <utility>
 | ||
|  | 
 | ||
|  | #include <iomanip>
 | ||
|  | #include <istream>
 | ||
|  | #include <ostream>
 | ||
|  | 
 | ||
|  | #include <cuda_fp16.h>
 | ||
|  | 
 | ||
|  | namespace cutlass { | ||
|  | 
 | ||
|  | /// IEEE binary16 floating-point value
 | ||
|  | class half_t { | ||
|  |  public: | ||
|  |   half_t(); | ||
|  |   half_t(int);     /// conversion from integer
 | ||
|  |   half_t(float);   /// conversion from fp32
 | ||
|  |   half_t(double);  /// conversion from fp64
 | ||
|  | 
 | ||
|  |   static half_t bitcast(unsigned short);  /// bitcast performs no conversion
 | ||
|  | 
 | ||
|  |   static half_t convert(float const&);          /// FP conversion - round toward nearest even
 | ||
|  |   static float convert(unsigned short const&);  /// floating point conversion to fp32
 | ||
|  | 
 | ||
|  |   static half_t zero() { return bitcast(0); }          /// +zero
 | ||
|  |   static half_t one() { return bitcast(0x3c00); }      /// one
 | ||
|  |   static half_t nan() { return bitcast(0x7fff); }      /// canonical not a number
 | ||
|  |   static half_t inf() { return bitcast(0x7c00); }      /// +infinity
 | ||
|  |   static half_t ninf() { return bitcast(0xfc00); }     /// -infinity
 | ||
|  |   static half_t epsilon() { return bitcast(0x1000); }  /// Machine epsilon
 | ||
|  | 
 | ||
|  |   bool signbit() const;             /// sign bit - true: negative, false: positive
 | ||
|  |   int exponent() const;             /// unbiased exponent
 | ||
|  |   unsigned short mantissa() const;  /// mantissa bits
 | ||
|  | 
 | ||
|  |   bool isfinite() const;  /// true if neither inf nor nan
 | ||
|  |   bool isinf() const;     /// true if value is + or - infinity
 | ||
|  |   bool isnan() const;     /// true if value is not a number
 | ||
|  |   bool isnormal() const;  /// true if nonzero value is normalized
 | ||
|  |   bool iszero() const;    /// true if value is + or - zero
 | ||
|  | 
 | ||
|  |   bool operator==(half_t const&) const; | ||
|  |   bool operator!=(half_t const&) const; | ||
|  |   bool operator==(float const&) const; | ||
|  |   bool operator!=(float const&) const; | ||
|  | 
 | ||
|  |   bool operator<(half_t const&) const; | ||
|  |   bool operator<=(half_t const&) const; | ||
|  |   bool operator>(half_t const&) const; | ||
|  |   bool operator>=(half_t const&) const; | ||
|  | 
 | ||
|  |   half_t operator+(half_t const&) const; | ||
|  |   half_t operator-() const; | ||
|  |   half_t operator-(half_t const&) const; | ||
|  |   half_t operator*(half_t const&)const; | ||
|  |   half_t operator/(half_t const&) const; | ||
|  | 
 | ||
|  |   half_t& operator+=(half_t const&); | ||
|  |   half_t& operator-=(half_t const&); | ||
|  |   half_t& operator*=(half_t const&); | ||
|  |   half_t& operator/=(half_t const&); | ||
|  | 
 | ||
|  |   half_t& operator++(); | ||
|  |   half_t& operator--(); | ||
|  |   half_t operator++(int); | ||
|  |   half_t operator--(int); | ||
|  | 
 | ||
|  |   operator bool() const;   /// false if zero
 | ||
|  |   operator int() const;    /// conversion to int
 | ||
|  |   operator float() const;  /// conversion to fp32
 | ||
|  |   operator half() const;   /// conversion to half
 | ||
|  | 
 | ||
|  |   uint16_t& raw() { return x; } | ||
|  |   uint16_t raw() const { return x; } | ||
|  | 
 | ||
|  |  public: | ||
|  |   /// data
 | ||
|  |   unsigned short x; | ||
|  | }; | ||
|  | 
 | ||
|  | /// Packed pair of half-precision elements
 | ||
|  | class half2_t { | ||
|  |  public: | ||
|  |   half2_t(); | ||
|  |   half2_t(half_t lo, half_t hi); | ||
|  |   half2_t(std::pair<float, float> const&); | ||
|  |   explicit half2_t(unsigned data); | ||
|  | 
 | ||
|  |   half2_t operator+(half2_t const&) const; | ||
|  |   half2_t operator-(half2_t const&) const; | ||
|  |   half2_t operator*(half2_t const&)const; | ||
|  |   half2_t operator/(half2_t const&) const; | ||
|  | 
 | ||
|  |   half2_t& operator+=(half2_t const&); | ||
|  |   half2_t& operator-=(half2_t const&); | ||
|  |   half2_t& operator*=(half2_t const&); | ||
|  |   half2_t& operator/=(half2_t const&); | ||
|  | 
 | ||
|  |   float dot(half2_t const&) const;         /// dot product with single-precision accumulation
 | ||
|  |   float dot(half2_t const&, float) const;  /// dot product with single-precision accumulation
 | ||
|  | 
 | ||
|  |   half_t doth(half2_t const&) const;          /// dot product with half_t-precision accumulation
 | ||
|  |   half_t doth(half2_t const&, half_t) const;  /// dot product with half_t-precision accumulation
 | ||
|  | 
 | ||
|  |   unsigned packed() const; | ||
|  | 
 | ||
|  |   operator std::pair<float, float>() const; | ||
|  |   operator unsigned() const; | ||
|  | 
 | ||
|  |  public: | ||
|  |   half_t lo; | ||
|  |   half_t hi; | ||
|  | }; | ||
|  | 
 | ||
|  | template <typename Dest, typename Src> | ||
|  | Dest bitcast(Src const&); | ||
|  | template <> | ||
|  | float bitcast<float, unsigned>(unsigned const&); | ||
|  | template <> | ||
|  | float bitcast<float, int>(int const&); | ||
|  | template <> | ||
|  | unsigned bitcast<unsigned, float>(float const&); | ||
|  | template <> | ||
|  | half_t bitcast<half_t, unsigned short>(unsigned short const&); | ||
|  | template <> | ||
|  | unsigned short bitcast<unsigned short, half_t>(half_t const&); | ||
|  | template <> | ||
|  | half bitcast<half, unsigned short>(unsigned short const&); | ||
|  | }  // namespace cutlass
 | ||
|  | 
 | ||
|  | cutlass::half_t operator+(float, cutlass::half_t const&); | ||
|  | cutlass::half_t operator-(float, cutlass::half_t const&); | ||
|  | cutlass::half_t operator*(float, cutlass::half_t const&); | ||
|  | cutlass::half_t operator/(float, cutlass::half_t const&); | ||
|  | 
 | ||
|  | std::ostream& operator<<(std::ostream&, cutlass::half_t const&);  /// writes a half_t
 | ||
|  | std::istream& operator>>(std::istream&, cutlass::half_t&);        /// reads a half_t
 | ||
|  | 
 | ||
|  | #ifdef BOOST_LEXICAL_CAST_INCLUDED
 | ||
|  | namespace boost { | ||
|  | 
 | ||
|  | /// lexical cast from string to half_t
 | ||
|  | template <> | ||
|  | cutlass::half_t lexical_cast<cutlass::half_t>(std::string const& arg); | ||
|  | 
 | ||
|  | /// lexical cast from half_t to string
 | ||
|  | template <> | ||
|  | std::string lexical_cast<std::string>(cutlass::half_t const& arg); | ||
|  | }  // namespace boost
 | ||
|  | #endif
 | ||
|  | 
 | ||
|  | #define HLF_MANT_DIG 10
 | ||
|  | 
 | ||
|  | namespace std { | ||
|  | 
 | ||
|  | cutlass::half_t abs(cutlass::half_t const&);  /// absolute value
 | ||
|  | 
 | ||
|  | bool isnan(cutlass::half_t const&);  /// true if argument is NaN
 | ||
|  | 
 | ||
|  | bool isfinite(cutlass::half_t const&);  /// true if argument is neither NaN nor infinity
 | ||
|  | 
 | ||
|  | cutlass::half_t nanh(const char* = 0);  /// returns a not-a-number
 | ||
|  | 
 | ||
|  | bool isinf(cutlass::half_t const&);  /// returns true if argument is infinitey (+ or -)
 | ||
|  | 
 | ||
|  | bool isnormal( | ||
|  |     cutlass::half_t const&);  /// returns true if argument is normal (neither zero nor infinity)
 | ||
|  | 
 | ||
|  | int fpclassify(cutlass::half_t const&);  /// returns a flag classifying floating-point value
 | ||
|  | 
 | ||
|  | bool signbit(cutlass::half_t const&);  /// returns true if negative, false if positive
 | ||
|  | 
 | ||
|  | cutlass::half_t sqrt(cutlass::half_t const&);  /// square root of half_t
 | ||
|  | 
 | ||
|  | /// Numeric limits
 | ||
|  | template <> | ||
|  | struct numeric_limits<cutlass::half_t> { | ||
|  |   static bool const is_specialized = true; | ||
|  |   static bool const is_signed = true; | ||
|  |   static bool const is_integer = false; | ||
|  |   static bool const is_exact = false; | ||
|  |   static bool const has_infinity = true; | ||
|  |   static bool const has_quiet_NaN = true; | ||
|  |   static bool const has_signaling_NaN = false; | ||
|  |   static std::float_denorm_style const has_denorm = std::denorm_present; | ||
|  |   static bool const has_denorm_loss = true; | ||
|  |   static std::float_round_style const round_style = std::round_to_nearest; | ||
|  |   static bool const is_iec559 = false; | ||
|  |   static bool const is_bounded = true; | ||
|  |   static bool const is_modulo = false; | ||
|  |   static int const digits = HLF_MANT_DIG; | ||
|  | 
 | ||
|  |   static cutlass::half_t min() { return cutlass::half_t::bitcast(0x0001); } | ||
|  | 
 | ||
|  |   static cutlass::half_t lowest() { return cutlass::half_t::bitcast(0xfbff); } | ||
|  | 
 | ||
|  |   static cutlass::half_t max() { return cutlass::half_t::bitcast(0x7bff); } | ||
|  | 
 | ||
|  |   /// Returns smallest finite value
 | ||
|  |   static cutlass::half_t epsilon() { return cutlass::half_t::epsilon(); } | ||
|  | 
 | ||
|  |   /// Returns smallest finite value
 | ||
|  |   static cutlass::half_t round_error() { return cutlass::half_t(0.5f); } | ||
|  | 
 | ||
|  |   /// Returns smallest finite value
 | ||
|  |   static cutlass::half_t infinity() { return cutlass::half_t::inf(); } | ||
|  | 
 | ||
|  |   /// Returns smallest finite value
 | ||
|  |   static cutlass::half_t quiet_NaN() { return cutlass::half_t::nan(); } | ||
|  | 
 | ||
|  |   /// Returns smallest finite value
 | ||
|  |   static cutlass::half_t signaling_NaN() { return cutlass::half_t::nan(); } | ||
|  | 
 | ||
|  |   /// Returns smallest finite value
 | ||
|  |   static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); } | ||
|  | }; | ||
|  | }  // namespace std
 | ||
|  | 
 | ||
|  | //
 | ||
|  | //
 | ||
|  | //
 | ||
|  | 
 | ||
|  | inline cutlass::half_t cutlass::half_t::bitcast(unsigned short _x) { | ||
|  |   half_t h; | ||
|  |   h.x = _x; | ||
|  |   return h; | ||
|  | } | ||
|  | 
 | ||
|  | /// FP32 -> FP16 conversion - rounds to nearest even
 | ||
|  | inline cutlass::half_t cutlass::half_t::convert(float const& flt) { | ||
|  |   // software implementation rounds toward nearest even
 | ||
|  |   unsigned const& s = *reinterpret_cast<unsigned const*>(&flt); | ||
|  |   uint16_t sign = uint16_t((s >> 16) & 0x8000); | ||
|  |   int16_t exp = uint16_t(((s >> 23) & 0xff) - 127); | ||
|  |   int mantissa = s & 0x7fffff; | ||
|  |   uint16_t u = 0; | ||
|  | 
 | ||
|  |   if ((s & 0x7fffffff) == 0) { | ||
|  |     // sign-preserving zero
 | ||
|  |     return cutlass::half_t::bitcast(sign); | ||
|  |   } | ||
|  | 
 | ||
|  |   if (exp > 15) { | ||
|  |     if (exp == 128 && mantissa) { | ||
|  |       // not a number
 | ||
|  |       u = 0x7fff; | ||
|  |     } else { | ||
|  |       // overflow to infinity
 | ||
|  |       u = sign | 0x7c00; | ||
|  |     } | ||
|  |     return cutlass::half_t::bitcast(u); | ||
|  |   } | ||
|  | 
 | ||
|  |   int sticky_bit = 0; | ||
|  | 
 | ||
|  |   if (exp >= -14) { | ||
|  |     // normal fp32 to normal fp16
 | ||
|  |     exp = uint16_t(exp + uint16_t(15)); | ||
|  |     u = uint16_t(((exp & 0x1f) << 10)); | ||
|  |     u = uint16_t(u | (mantissa >> 13)); | ||
|  |   } else { | ||
|  |     // normal single-precision to subnormal half_t-precision representation
 | ||
|  |     int rshift = (-14 - exp); | ||
|  |     if (rshift < 32) { | ||
|  |       mantissa |= (1 << 23); | ||
|  | 
 | ||
|  |       sticky_bit = ((mantissa & ((1 << rshift) - 1)) != 0); | ||
|  | 
 | ||
|  |       mantissa = (mantissa >> rshift); | ||
|  |       u = (uint16_t(mantissa >> 13) & 0x3ff); | ||
|  |     } else { | ||
|  |       mantissa = 0; | ||
|  |       u = 0; | ||
|  |     } | ||
|  |   } | ||
|  | 
 | ||
|  |   // round to nearest even
 | ||
|  |   int round_bit = ((mantissa >> 12) & 1); | ||
|  |   sticky_bit |= ((mantissa & ((1 << 12) - 1)) != 0); | ||
|  | 
 | ||
|  |   if ((round_bit && sticky_bit) || (round_bit && (u & 1))) { | ||
|  |     u = uint16_t(u + 1); | ||
|  |   } | ||
|  | 
 | ||
|  |   u |= sign; | ||
|  | 
 | ||
|  |   return cutlass::half_t::bitcast(u); | ||
|  | } | ||
|  | 
 | ||
|  | inline float cutlass::half_t::convert(unsigned short const& h) { | ||
|  |   int sign = ((h >> 15) & 1); | ||
|  |   int exp = ((h >> 10) & 0x1f); | ||
|  |   int mantissa = (h & 0x3ff); | ||
|  |   unsigned f = 0; | ||
|  | 
 | ||
|  |   if (exp > 0 && exp < 31) { | ||
|  |     // normal
 | ||
|  |     exp += 112; | ||
|  |     f = (sign << 31) | (exp << 23) | (mantissa << 13); | ||
|  |   } else if (exp == 0) { | ||
|  |     if (mantissa) { | ||
|  |       // subnormal
 | ||
|  |       exp += 113; | ||
|  |       while ((mantissa & (1 << 10)) == 0) { | ||
|  |         mantissa <<= 1; | ||
|  |         exp--; | ||
|  |       } | ||
|  |       mantissa &= 0x3ff; | ||
|  |       f = (sign << 31) | (exp << 23) | (mantissa << 13); | ||
|  |     } else { | ||
|  |       // sign-preserving zero
 | ||
|  |       f = (sign << 31); | ||
|  |     } | ||
|  |   } else if (exp == 31) { | ||
|  |     if (mantissa) { | ||
|  |       f = 0x7fffffff;  // not a number
 | ||
|  |     } else { | ||
|  |       f = (0xff << 23) | (sign << 31);  //  inf
 | ||
|  |     } | ||
|  |   } | ||
|  |   return *reinterpret_cast<float const*>(&f); | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t::half_t() {} | ||
|  | 
 | ||
|  | inline cutlass::half_t::half_t(int i) { x = convert(float(i)).x; } | ||
|  | 
 | ||
|  | inline cutlass::half_t::half_t(float f) { x = convert(f).x; } | ||
|  | 
 | ||
|  | inline cutlass::half_t::half_t(double d) { x = convert(float(d)).x; } | ||
|  | 
 | ||
|  | inline bool cutlass::half_t::signbit() const { return (x >> 15) & 1; } | ||
|  | 
 | ||
|  | inline int cutlass::half_t::exponent() const { return ((x >> 10) & 0x1f) - 15; } | ||
|  | 
 | ||
|  | inline unsigned short cutlass::half_t::mantissa() const { return x & 0x3ff; } | ||
|  | 
 | ||
|  | inline cutlass::half_t::operator bool() const { return (x & 0x7fff) != 0; } | ||
|  | 
 | ||
|  | inline cutlass::half_t::operator int() const { return static_cast<int>(convert(x)); } | ||
|  | 
 | ||
|  | inline cutlass::half_t::operator float() const { return convert(x); } | ||
|  | 
 | ||
|  | inline cutlass::half_t::operator half() const { return cutlass::bitcast<half, unsigned short>(x); } | ||
|  | 
 | ||
|  | inline bool cutlass::half_t::operator==(cutlass::half_t const& h) const { | ||
|  |   if (iszero() && h.iszero()) { | ||
|  |     return true; | ||
|  |   } | ||
|  |   return x == h.x; | ||
|  | } | ||
|  | 
 | ||
|  | inline bool cutlass::half_t::operator!=(cutlass::half_t const& h) const { | ||
|  |   if (iszero() && h.iszero()) { | ||
|  |     return false; | ||
|  |   } | ||
|  |   return x != h.x; | ||
|  | } | ||
|  | 
 | ||
|  | inline bool cutlass::half_t::operator==(float const& b) const { return x == half_t(b).x; } | ||
|  | 
 | ||
|  | inline bool cutlass::half_t::operator!=(float const& b) const { return x != half_t(b).x; } | ||
|  | 
 | ||
|  | inline bool cutlass::half_t::iszero() const { return (x & 0x7fff) == 0; } | ||
|  | 
 | ||
|  | inline bool cutlass::half_t::isfinite() const { return (exponent() < 16); } | ||
|  | 
 | ||
|  | inline bool cutlass::half_t::isnan() const { | ||
|  |   int exp = ((x >> 10) & 0x1f); | ||
|  |   if (exp == 0x1f) { | ||
|  |     return (x & 0x3ff) != 0; | ||
|  |   } | ||
|  |   return false; | ||
|  | } | ||
|  | 
 | ||
|  | inline bool cutlass::half_t::isinf() const { | ||
|  |   int exp = ((x >> 10) & 0x1f); | ||
|  |   if (exp == 0x1f) { | ||
|  |     return (x & 0x3ff) == 0; | ||
|  |   } | ||
|  |   return false; | ||
|  | } | ||
|  | 
 | ||
|  | inline bool cutlass::half_t::isnormal() const { | ||
|  |   int exp = exponent(); | ||
|  |   return exp > -15 && exp < 16; | ||
|  | } | ||
|  | 
 | ||
|  | inline bool cutlass::half_t::operator<(half_t const& h) const { | ||
|  |   int sign = ((x >> 15) & 1); | ||
|  |   int h_sign = ((h.x >> 15) & 1); | ||
|  |   if (sign == h_sign) { | ||
|  |     return (x & 0x7fff) < (h.x & 0x7fff); | ||
|  |   } else if (sign) { | ||
|  |     return true; | ||
|  |   } | ||
|  |   return false; | ||
|  | } | ||
|  | 
 | ||
|  | inline bool cutlass::half_t::operator<=(half_t const& h) const { | ||
|  |   int sign = ((x >> 15) & 1); | ||
|  |   int h_sign = ((h.x >> 15) & 1); | ||
|  |   if (sign == h_sign) { | ||
|  |     return (x & 0x7fff) <= (h.x & 0x7fff); | ||
|  |   } else if (sign) { | ||
|  |     return true; | ||
|  |   } | ||
|  |   return false; | ||
|  | } | ||
|  | 
 | ||
|  | inline bool cutlass::half_t::operator>(half_t const& h) const { | ||
|  |   int sign = ((x >> 15) & 1); | ||
|  |   int h_sign = ((h.x >> 15) & 1); | ||
|  |   if (sign == h_sign) { | ||
|  |     return (x & 0x7fff) > (h.x & 0x7fff); | ||
|  |   } else if (h_sign) { | ||
|  |     return true; | ||
|  |   } | ||
|  |   return false; | ||
|  | } | ||
|  | 
 | ||
|  | inline bool cutlass::half_t::operator>=(half_t const& h) const { | ||
|  |   int sign = ((x >> 15) & 1); | ||
|  |   int h_sign = ((h.x >> 15) & 1); | ||
|  |   if (sign == h_sign) { | ||
|  |     return (x & 0x7fff) >= (h.x & 0x7fff); | ||
|  |   } else if (h_sign) { | ||
|  |     return true; | ||
|  |   } | ||
|  |   return false; | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t cutlass::half_t::operator+(cutlass::half_t const& b) const { | ||
|  |   return cutlass::half_t(float(*this) + float(b)); | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t cutlass::half_t::operator-() const { return bitcast(x ^ 0x8000); } | ||
|  | 
 | ||
|  | inline cutlass::half_t cutlass::half_t::operator-(cutlass::half_t const& b) const { | ||
|  |   return cutlass::half_t(float(*this) - float(b)); | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t cutlass::half_t::operator*(cutlass::half_t const& b) const { | ||
|  |   return cutlass::half_t(float(*this) * float(b)); | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t cutlass::half_t::operator/(cutlass::half_t const& b) const { | ||
|  |   return cutlass::half_t(float(*this) / float(b)); | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t& cutlass::half_t::operator+=(cutlass::half_t const& b) { | ||
|  |   *this = cutlass::half_t(float(*this) + float(b)); | ||
|  |   return *this; | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t& cutlass::half_t::operator-=(cutlass::half_t const& b) { | ||
|  |   *this = cutlass::half_t(float(*this) - float(b)); | ||
|  |   return *this; | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t& cutlass::half_t::operator*=(cutlass::half_t const& b) { | ||
|  |   *this = cutlass::half_t(float(*this) * float(b)); | ||
|  |   return *this; | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t& cutlass::half_t::operator/=(cutlass::half_t const& b) { | ||
|  |   *this = cutlass::half_t(float(*this) / float(b)); | ||
|  |   return *this; | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t& cutlass::half_t::operator++() { | ||
|  |   *this = cutlass::half_t(float(*this) + 1.0f); | ||
|  |   return *this; | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t& cutlass::half_t::operator--() { | ||
|  |   *this = cutlass::half_t(float(*this) - 1.0f); | ||
|  |   return *this; | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t cutlass::half_t::operator++(int) { | ||
|  |   half_t h = *this; | ||
|  |   *this = cutlass::half_t(float(*this) + 1.0f); | ||
|  |   return h; | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t cutlass::half_t::operator--(int) { | ||
|  |   half_t h = *this; | ||
|  |   *this = cutlass::half_t(float(*this) - 1.0f); | ||
|  |   return h; | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t operator+(float a, cutlass::half_t const& b) { | ||
|  |   return cutlass::half_t(a + float(b)); | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t operator-(float a, cutlass::half_t const& b) { | ||
|  |   return cutlass::half_t(a - float(b)); | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t operator*(float a, cutlass::half_t const& b) { | ||
|  |   return cutlass::half_t(a * float(b)); | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t operator/(float a, cutlass::half_t const& b) { | ||
|  |   return cutlass::half_t(a / float(b)); | ||
|  | } | ||
|  | 
 | ||
|  | //
 | ||
|  | //
 | ||
|  | //
 | ||
|  | 
 | ||
|  | inline cutlass::half2_t::half2_t() {} | ||
|  | 
 | ||
|  | inline cutlass::half2_t::half2_t(half_t lo, half_t hi) : lo(lo), hi(hi) {} | ||
|  | 
 | ||
|  | inline cutlass::half2_t::half2_t(std::pair<float, float> const& p) : lo(p.first), hi(p.second) {} | ||
|  | 
 | ||
|  | inline cutlass::half2_t::half2_t(unsigned data) | ||
|  |     : lo(half_t::bitcast(uint16_t(data & 0x0ffff))), | ||
|  |       hi(half_t::bitcast(uint16_t((data >> 16) & 0x0ffff))) {} | ||
|  | 
 | ||
|  | inline cutlass::half2_t cutlass::half2_t::operator+(half2_t const& b) const { | ||
|  |   return half2_t(lo + b.lo, hi + b.hi); | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half2_t cutlass::half2_t::operator-(half2_t const& b) const { | ||
|  |   return half2_t(lo - b.lo, hi - b.hi); | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half2_t cutlass::half2_t::operator*(half2_t const& b) const { | ||
|  |   return half2_t(lo * b.lo, hi * b.hi); | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half2_t cutlass::half2_t::operator/(half2_t const& b) const { | ||
|  |   return half2_t(lo / b.lo, hi / b.hi); | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half2_t& cutlass::half2_t::operator+=(half2_t const& b) { | ||
|  |   lo += b.lo; | ||
|  |   hi += b.hi; | ||
|  |   return *this; | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half2_t& cutlass::half2_t::operator-=(half2_t const& b) { | ||
|  |   lo -= b.lo; | ||
|  |   hi -= b.hi; | ||
|  |   return *this; | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half2_t& cutlass::half2_t::operator*=(half2_t const& b) { | ||
|  |   lo *= b.lo; | ||
|  |   hi *= b.hi; | ||
|  |   return *this; | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half2_t& cutlass::half2_t::operator/=(half2_t const& b) { | ||
|  |   lo /= b.lo; | ||
|  |   hi /= b.hi; | ||
|  |   return *this; | ||
|  | } | ||
|  | 
 | ||
|  | inline float cutlass::half2_t::dot(half2_t const& b) const { | ||
|  |   return float(lo) * float(b.lo) + float(hi) * float(b.hi); | ||
|  | } | ||
|  | 
 | ||
|  | inline float cutlass::half2_t::dot(half2_t const& b, float c) const { return c + dot(b); } | ||
|  | 
 | ||
|  | inline cutlass::half_t cutlass::half2_t::doth(half2_t const& b) const { | ||
|  |   return cutlass::half_t(dot(b)); | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half_t cutlass::half2_t::doth(half2_t const& b, half_t c) const { | ||
|  |   return cutlass::half_t(dot(b, float(c))); | ||
|  | } | ||
|  | 
 | ||
|  | inline cutlass::half2_t::operator std::pair<float, float>() const { | ||
|  |   return std::pair<float, float>(float(lo), float(hi)); | ||
|  | } | ||
|  | 
 | ||
|  | inline unsigned cutlass::half2_t::packed() const { return (lo.x | (hi.x << 16)); } | ||
|  | 
 | ||
|  | inline cutlass::half2_t::operator unsigned() const { return packed(); } | ||
|  | 
 | ||
|  | //
 | ||
|  | //
 | ||
|  | //
 | ||
|  | 
 | ||
|  | template <> | ||
|  | inline float cutlass::bitcast<float, unsigned>(unsigned const& u) { | ||
|  |   return *reinterpret_cast<float const*>(&u); | ||
|  | } | ||
|  | 
 | ||
|  | template <> | ||
|  | inline float cutlass::bitcast<float, int>(int const& i) { | ||
|  |   return *reinterpret_cast<float const*>(&i); | ||
|  | } | ||
|  | 
 | ||
|  | template <> | ||
|  | inline unsigned cutlass::bitcast<unsigned, float>(float const& f) { | ||
|  |   return *reinterpret_cast<unsigned const*>(&f); | ||
|  | } | ||
|  | 
 | ||
|  | template <> | ||
|  | inline cutlass::half_t cutlass::bitcast<cutlass::half_t, unsigned short>(unsigned short const& s) { | ||
|  |   return *reinterpret_cast<cutlass::half_t const*>(&s); | ||
|  | } | ||
|  | 
 | ||
|  | template <> | ||
|  | inline unsigned short cutlass::bitcast<unsigned short, cutlass::half_t>(cutlass::half_t const& h) { | ||
|  |   return *reinterpret_cast<unsigned short const*>(&h); | ||
|  | } | ||
|  | 
 | ||
|  | template <> | ||
|  | inline half cutlass::bitcast<half, unsigned short>(unsigned short const& s) { | ||
|  |   return *reinterpret_cast<half const*>(&s); | ||
|  | } | ||
|  | 
 | ||
|  | //
 | ||
|  | // Lexical casts
 | ||
|  | //
 | ||
|  | 
 | ||
|  | #ifdef BOOST_LEXICAL_CAST_INCLUDED
 | ||
|  | namespace boost { | ||
|  | template <> | ||
|  | cutlass::half_t lexical_cast<cutlass::half_t>(std::string const& arg) { | ||
|  |   return cutlass::half_t(boost::lexical_cast<float>(arg)); | ||
|  | } | ||
|  | 
 | ||
|  | template <> | ||
|  | std::string lexical_cast<std::string>(cutlass::half_t const& arg) { | ||
|  |   return boost::lexical_cast<std::string>(float(arg)); | ||
|  | } | ||
|  | }  // namespace boost
 | ||
|  | #endif
 | ||
|  | 
 | ||
|  | //
 | ||
|  | // Standard Library Operations
 | ||
|  | //
 | ||
|  | 
 | ||
|  | // std
 | ||
|  | namespace std { | ||
|  | 
 | ||
|  | inline cutlass::half_t abs(cutlass::half_t const& h) { | ||
|  |   return cutlass::half_t::bitcast(h.x & 0x7fff); | ||
|  | } | ||
|  | 
 | ||
|  | inline bool isnan(cutlass::half_t const& h) { return h.isnan(); } | ||
|  | 
 | ||
|  | inline bool isfinite(cutlass::half_t const& h) { return h.isfinite(); } | ||
|  | 
 | ||
|  | inline cutlass::half_t nanh(const char*) { return cutlass::half_t::nan(); } | ||
|  | 
 | ||
|  | inline bool isinf(cutlass::half_t const& h) { return h.isinf(); } | ||
|  | 
 | ||
|  | inline bool isnormal(cutlass::half_t const& h) { return h.isnormal(); } | ||
|  | 
 | ||
|  | inline int fpclassify(cutlass::half_t const& h) { | ||
|  |   int exp = h.exponent(); | ||
|  |   unsigned short mantissa = h.mantissa(); | ||
|  |   if (exp < -14) { | ||
|  |     if (mantissa == 0) { | ||
|  |       return FP_ZERO; | ||
|  |     } else { | ||
|  |       return FP_SUBNORMAL; | ||
|  |     } | ||
|  |   } else if (exp > 15) { | ||
|  |     if (mantissa == 0) { | ||
|  |       return FP_INFINITE; | ||
|  |     } else { | ||
|  |       return FP_NAN; | ||
|  |     } | ||
|  |   } | ||
|  |   return FP_NORMAL; | ||
|  | } | ||
|  | 
 | ||
|  | inline bool signbit(cutlass::half_t const& h) { return h.signbit(); } | ||
|  | 
 | ||
|  | inline cutlass::half_t sqrt(cutlass::half_t const& h) { | ||
|  |   return cutlass::half_t(std::sqrt(float(h))); | ||
|  | } | ||
|  | }  // namespace std
 | ||
|  | 
 | ||
|  | //
 | ||
|  | // Stream interactions
 | ||
|  | //
 | ||
|  | 
 | ||
|  | /// put to stream - half_t-precision types bitcast as unsigned shorts if base is hexadecimal
 | ||
|  | inline std::ostream& operator<<(std::ostream& out, cutlass::half_t const& h) { | ||
|  |   if (out.flags() & std::ios::hex) { | ||
|  |     return out << h.x; | ||
|  |   } else { | ||
|  |     return out << float(h); | ||
|  |   } | ||
|  | } | ||
|  | 
 | ||
|  | /// read from stream - half_t-precision types parsed as unsigned shorts if base is hexadecimal
 | ||
|  | inline std::istream& operator>>(std::istream& in, cutlass::half_t& h) { | ||
|  |   if (in.flags() & std::ios::hex) { | ||
|  |     unsigned short u = 0; | ||
|  |     in >> u; | ||
|  |     h = cutlass::half_t::bitcast(u); | ||
|  |   } else { | ||
|  |     float f = 0; | ||
|  |     in >> f; | ||
|  |     h = cutlass::half_t(f); | ||
|  |   } | ||
|  |   return in; | ||
|  | } |