1273 lines
36 KiB
C++
1273 lines
36 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
/*!
|
|
\file
|
|
\brief Defines a class for using IEEE half-precision floating-point types in host or
|
|
device code.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
// FP8 types are available starting CUDA 11.8+
|
|
#if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))
|
|
#define CUDA_FP8_ENABLED 1
|
|
#endif
|
|
|
|
#if defined(__CUDA_ARCH__)
|
|
# if (__CUDA_ARCH__ >= 900)
|
|
# if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))
|
|
# define CUDA_PTX_FP8_CVT_ENABLED 1
|
|
# endif // (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))
|
|
# elif (__CUDA_ARCH__ == 890)
|
|
# if (__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1))
|
|
# define CUDA_PTX_FP8_CVT_ENABLED 1
|
|
# endif // (__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1))
|
|
# endif // (__CUDA_ARCH__ >= 900)
|
|
#endif // defined(__CUDA_ARCH__)
|
|
|
|
#ifdef __GNUC__
|
|
// Ignore checks on reinterpret-casts that are being used for bitcasts.
|
|
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
|
#endif
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
#if defined(__CUDACC_RTC__)
|
|
|
|
#include "cutlass/floating_point_nvrtc.h"
|
|
|
|
#else
|
|
//
|
|
// Standard Library headers belong here to avoid conflicts with NVRTC.
|
|
//
|
|
#include <cmath>
|
|
#include <limits>
|
|
#include <cstdint>
|
|
#include <cstring>
|
|
#endif
|
|
|
|
#ifdef CUDA_FP8_ENABLED
|
|
#include <cuda_fp8.h>
|
|
#endif
|
|
#include <cuda_fp16.h>
|
|
|
|
#include "cutlass/cutlass.h"
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cutlass {
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
//
|
|
// FP8 Has 2 encodings possible : E4M3 and E5M2
|
|
//
|
|
// E4M3 : 7 | 6 5 4 3 | 2 1 0
|
|
// E5M2 : 7 | 6 5 4 3 2 | 1 0
|
|
//
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
enum class FloatEncoding {
|
|
E4M3,
|
|
E5M2
|
|
};
|
|
|
|
template<FloatEncoding T>
|
|
struct alignas(1) float8_base {
|
|
|
|
static constexpr bool IS_E4M3 = (T == FloatEncoding::E4M3);
|
|
static constexpr bool IS_E5M2 = (T == FloatEncoding::E5M2);
|
|
|
|
// Number of Bits representing mantissa and exponents
|
|
static constexpr int FP32_NUM_BITS = 32;
|
|
static constexpr int FP32_NUM_EXPONENT_BITS = 8;
|
|
static constexpr int FP32_NUM_MANTISSA_BITS = 23;
|
|
static constexpr uint32_t FP32_NAN = 0x7fffffff;
|
|
static constexpr uint32_t FP32_INFINITY_MASK = 0x7f800000;
|
|
static constexpr int FP32_MAX_EXPONENT = 127;
|
|
static constexpr int FP32_MIN_EXPONENT = -126;
|
|
static constexpr int FP32_EXPONENT_BIAS = 127;
|
|
|
|
static constexpr int FP16_NUM_BITS = 16;
|
|
static constexpr int FP16_NUM_EXPONENT_BITS = 5;
|
|
static constexpr int FP16_NUM_MANTISSA_BITS = 10;
|
|
static constexpr uint16_t FP16_NAN = 0x7fff;
|
|
static constexpr uint16_t FP16_INFINITY_MASK = 0x7c00;
|
|
static constexpr int FP16_MAX_EXPONENT = 15;
|
|
static constexpr int FP16_MIN_EXPONENT = -14;
|
|
static constexpr int FP16_EXPONENT_BIAS = 15;
|
|
|
|
static constexpr int FP8_NUM_BITS = 8;
|
|
static constexpr int FP8_NUM_EXPONENT_BITS = IS_E4M3 ? 4 : 5;
|
|
static constexpr int FP8_NUM_MANTISSA_BITS = IS_E4M3 ? 3 : 2;
|
|
static constexpr uint8_t FP8_NAN = 0x7f; // Also F8_INF
|
|
static constexpr uint8_t FP8_INFINITY_MASK = IS_E4M3 ? 0x78 : 0x7c;
|
|
static constexpr int FP8_MAX_EXPONENT = IS_E4M3 ? 7 : 15;
|
|
static constexpr int FP8_MIN_EXPONENT = IS_E4M3 ? -6 : -14;
|
|
static constexpr int FP8_EXPONENT_BIAS = IS_E4M3 ? 7 : 15;
|
|
|
|
static constexpr uint8_t FP8_EXPONENT_MASK = (1 << FP8_NUM_EXPONENT_BITS) - 1;
|
|
static constexpr uint8_t FP8_MANTISSA_MASK = (1 << FP8_NUM_MANTISSA_BITS) - 1;
|
|
|
|
static constexpr uint8_t FP8_MAX_FLT = (IS_E4M3 ? 0x7e : 0x7b);
|
|
|
|
// 256 in float
|
|
static constexpr uint32_t FP8_SAT_VAL_FP32 = 0x43800000;
|
|
|
|
//
|
|
// Data members
|
|
//
|
|
|
|
/// Data container
|
|
uint8_t storage;
|
|
|
|
/// Ctors.
|
|
CUTLASS_HOST_DEVICE
|
|
float8_base() : storage(0) { }
|
|
|
|
/// Is finite implementation
|
|
CUTLASS_HOST_DEVICE
|
|
static bool isfinite(float flt) {
|
|
uint32_t s;
|
|
|
|
#if defined(__CUDA_ARCH__)
|
|
s = reinterpret_cast<uint32_t const &>(flt);
|
|
#else
|
|
std::memcpy(&s, &flt, sizeof(s));
|
|
#endif
|
|
|
|
return (s & 0x7f800000) < 0x7f800000;
|
|
}
|
|
|
|
/// Is NaN implementation
|
|
CUTLASS_HOST_DEVICE
|
|
static bool isnan(float flt) {
|
|
uint32_t s;
|
|
|
|
#if defined(__CUDA_ARCH__)
|
|
s = reinterpret_cast<uint32_t const &>(flt);
|
|
#else
|
|
std::memcpy(&s, &flt, sizeof(s));
|
|
#endif
|
|
|
|
return (s & 0x7fffffff) > 0x7f800000;
|
|
}
|
|
|
|
/// Is infinite implementation
|
|
CUTLASS_HOST_DEVICE
|
|
static bool isinf(float flt) {
|
|
uint32_t s;
|
|
|
|
#if defined(__CUDA_ARCH__)
|
|
s = reinterpret_cast<uint32_t const &>(flt);
|
|
#else
|
|
std::memcpy(&s, &flt, sizeof(s));
|
|
#endif
|
|
|
|
// Sign = 0 for +inf, 1 for -inf
|
|
// Exponent = all ones
|
|
// Mantissa = all zeros
|
|
return (s == 0x7f800000) || (s == 0xff800000);
|
|
}
|
|
|
|
/// FP32 -> FP8 conversion - rounds to nearest even
|
|
CUTLASS_HOST_DEVICE
|
|
static uint8_t convert_float_to_fp8(float const& flt) {
|
|
|
|
// software implementation rounds toward nearest even
|
|
uint32_t s;
|
|
|
|
#if defined(__CUDA_ARCH__)
|
|
s = reinterpret_cast<uint32_t const &>(flt);
|
|
#else
|
|
std::memcpy(&s, &flt, sizeof(s));
|
|
#endif
|
|
|
|
// Extract the bits in the FP32 type
|
|
uint8_t sign = uint8_t((s >> 24 & 0x80));
|
|
int32_t exp = int32_t((s >> FP32_NUM_MANTISSA_BITS) & 0xff) - FP32_EXPONENT_BIAS;
|
|
int mantissa = s & 0x7fffff;
|
|
uint8_t u = 0;
|
|
|
|
uint8_t const kF8_NaN = 0x7f;
|
|
|
|
// NaN => NaN
|
|
if (isnan(flt)) {
|
|
return kF8_NaN;
|
|
}
|
|
|
|
// Inf => MAX_FLT (satfinite)
|
|
if (isinf(flt)) {
|
|
return sign | FP8_MAX_FLT;
|
|
}
|
|
|
|
// Special handling
|
|
if (exp == -128) {
|
|
// int8 range is from -128 to 127
|
|
// So 255(inf) - 127(bias) = 128 - will show up as -128
|
|
|
|
// satfinite
|
|
return (sign | FP8_MAX_FLT);
|
|
}
|
|
|
|
int sticky_bit = 0;
|
|
|
|
bool skip_sign = false;
|
|
bool may_be_nan = false;
|
|
|
|
if ( (exp >= FP8_MIN_EXPONENT) && (exp <= FP8_MAX_EXPONENT) ) {
|
|
// normal fp32 to normal fp8
|
|
exp = exp + FP8_EXPONENT_BIAS;
|
|
u = uint8_t((uint32_t(exp) & FP8_EXPONENT_MASK) << FP8_NUM_MANTISSA_BITS);
|
|
u = uint8_t(u | (mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS)));
|
|
} else if(exp < FP8_MIN_EXPONENT) {
|
|
// normal single-precision to subnormal float8-precision representation
|
|
int rshift = (FP8_MIN_EXPONENT - exp);
|
|
if (rshift < FP32_NUM_BITS) {
|
|
mantissa |= (1 << FP32_NUM_MANTISSA_BITS);
|
|
|
|
sticky_bit = ((mantissa & ((1 << rshift) - 1)) != 0);
|
|
|
|
mantissa = (mantissa >> rshift);
|
|
u = (uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS- FP8_NUM_MANTISSA_BITS)) & FP8_MANTISSA_MASK);
|
|
} else {
|
|
mantissa = 0;
|
|
u = 0;
|
|
}
|
|
// Exponent > FP8_MAX_EXPONENT - this is a special case done to match HW
|
|
// 0x4380_0000 to 0x43e0_0000 - maps from 256 to 448, and does not saturate / inf.
|
|
} else {
|
|
if( exp == (FP8_MAX_EXPONENT + 1) ) {
|
|
uint8_t mantissa_tmp = uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS));
|
|
if( mantissa_tmp < FP8_MANTISSA_MASK) {
|
|
exp = exp + FP8_EXPONENT_BIAS;
|
|
u = uint8_t(uint32_t(exp) << FP8_NUM_MANTISSA_BITS) | mantissa_tmp;
|
|
may_be_nan = (mantissa_tmp == (FP8_MANTISSA_MASK-1));
|
|
} else {
|
|
// satfinite
|
|
return (sign | FP8_MAX_FLT);
|
|
}
|
|
} else{
|
|
// satfinite
|
|
return (sign | FP8_MAX_FLT);
|
|
}
|
|
}
|
|
|
|
// round to nearest even
|
|
int NUM_BITS_SHIFT = FP32_NUM_MANTISSA_BITS - (FP8_NUM_MANTISSA_BITS + 1);
|
|
int round_bit = ((mantissa >> NUM_BITS_SHIFT) & 1);
|
|
sticky_bit |= ((mantissa & ((1 << NUM_BITS_SHIFT) - 1)) != 0);
|
|
|
|
if ((round_bit && sticky_bit) || (round_bit && (u & 1))) {
|
|
u = uint8_t(u + 1);
|
|
if( may_be_nan ) {
|
|
skip_sign = true;
|
|
}
|
|
}
|
|
|
|
if (u > FP8_MAX_FLT) {
|
|
// satfinite
|
|
u = (sign | FP8_MAX_FLT);
|
|
}
|
|
|
|
if( ! skip_sign ) {
|
|
u |= sign;
|
|
}
|
|
|
|
return u;
|
|
}
|
|
|
|
|
|
/// Converts a fp8 value stored as a uint8_t to a float
|
|
CUTLASS_HOST_DEVICE
|
|
static float convert_fp8_to_float(uint8_t const& x) {
|
|
|
|
uint32_t constexpr kF32_NaN = 0x7fffffff;
|
|
|
|
uint8_t const &f8 = x;
|
|
uint32_t sign = (f8 >> (FP8_NUM_BITS - 1)) & 1;
|
|
uint32_t exp = (f8 >> FP8_NUM_MANTISSA_BITS) & FP8_EXPONENT_MASK;
|
|
uint32_t mantissa = f8 & FP8_MANTISSA_MASK;
|
|
unsigned f = (sign << (FP32_NUM_BITS-1));
|
|
|
|
if (IS_E4M3 && exp == 15 && mantissa == 0x7) {
|
|
f = kF32_NaN;
|
|
}
|
|
else if (exp > 0 && (IS_E4M3 || exp < (FP8_MAX_EXPONENT + FP8_EXPONENT_BIAS + 1))) {
|
|
// normal
|
|
exp += (FP32_EXPONENT_BIAS - FP8_EXPONENT_BIAS);
|
|
f = f |
|
|
(exp << FP32_NUM_MANTISSA_BITS) |
|
|
(mantissa << (FP32_NUM_MANTISSA_BITS-FP8_NUM_MANTISSA_BITS));
|
|
} else if (exp == 0) {
|
|
if (mantissa) {
|
|
// subnormal
|
|
exp += (FP32_EXPONENT_BIAS - FP8_EXPONENT_BIAS) + 1;
|
|
while ((mantissa & (1 << FP8_NUM_MANTISSA_BITS)) == 0) {
|
|
mantissa <<= 1;
|
|
exp--;
|
|
}
|
|
mantissa &= FP8_MANTISSA_MASK;
|
|
f = f |
|
|
(exp << FP32_NUM_MANTISSA_BITS) |
|
|
(mantissa << (FP32_NUM_MANTISSA_BITS-FP8_NUM_MANTISSA_BITS));
|
|
} else {
|
|
// sign-preserving zero
|
|
}
|
|
} else {
|
|
if(mantissa == 0){
|
|
// Sign-preserving infinity
|
|
f = (f | 0x7f800000);
|
|
} else {
|
|
// Canonical NaN
|
|
f = kF32_NaN;
|
|
}
|
|
}
|
|
|
|
#if defined(__CUDA_ARCH__)
|
|
return reinterpret_cast<float const&>(f);
|
|
#else
|
|
float flt;
|
|
std::memcpy(&flt, &f, sizeof(flt));
|
|
return flt;
|
|
#endif
|
|
}
|
|
};
|
|
|
|
|
|
// Forward declaration of float_e5m2_t to define float_e4m3_t <=> float_e5m2_t
|
|
// conversions in class float_e4m3_t
|
|
struct float_e5m2_t;
|
|
|
|
|
|
///////////////////////////////////////////////////////////////
|
|
///
|
|
/// floating-point 8 type : E4M3
|
|
///
|
|
///////////////////////////////////////////////////////////////
|
|
struct alignas(1) float_e4m3_t : float8_base<FloatEncoding::E4M3> {
|
|
|
|
using Base = float8_base<FloatEncoding::E4M3>;
|
|
|
|
static constexpr int MAX_EXPONENT = Base::FP8_MAX_EXPONENT;
|
|
|
|
//
|
|
// Static conversion operators
|
|
//
|
|
|
|
/// Constructs from an uint8_t
|
|
CUTLASS_HOST_DEVICE
|
|
static float_e4m3_t bitcast(uint8_t x) {
|
|
float_e4m3_t f;
|
|
f.storage = x;
|
|
return f;
|
|
}
|
|
|
|
/// FP32 -> FP8 conversion - rounds to nearest even
|
|
CUTLASS_HOST_DEVICE
|
|
static float_e4m3_t from_float(float const& flt) {
|
|
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
|
|
uint16_t tmp;
|
|
float y = float();
|
|
asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(tmp) : "f"(y), "f"(flt));
|
|
|
|
return *reinterpret_cast<float_e4m3_t *>(&tmp);
|
|
#else
|
|
return bitcast(Base::convert_float_to_fp8(flt));
|
|
#endif
|
|
}
|
|
|
|
/// FP16 -> E5M2 conversion - rounds to nearest even
|
|
CUTLASS_HOST_DEVICE
|
|
static float_e4m3_t from_half(half const& flt) {
|
|
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
|
|
uint16_t tmp = 0;
|
|
uint32_t bits = reinterpret_cast<uint16_t const &>(flt);
|
|
asm volatile("cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;" : "=h"(tmp) : "r"(bits));
|
|
|
|
return *reinterpret_cast<float_e4m3_t *>(&tmp);
|
|
#else
|
|
return bitcast(Base::convert_float_to_fp8(__half2float(flt)));
|
|
#endif
|
|
}
|
|
|
|
// E4M3 -> half
|
|
CUTLASS_HOST_DEVICE
|
|
static half to_half(float_e4m3_t const& x) {
|
|
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
|
|
uint16_t bits = x.storage;
|
|
uint32_t packed;
|
|
asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;\n" : "=r"(packed) : "h"(bits));
|
|
|
|
return reinterpret_cast<half2 const &>(packed).x;
|
|
#else
|
|
return __float2half(Base::convert_fp8_to_float(x.storage));
|
|
#endif
|
|
}
|
|
|
|
// E4M3 -> Float
|
|
CUTLASS_HOST_DEVICE
|
|
static float to_float(float_e4m3_t const& x) {
|
|
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
|
|
uint16_t bits = x.storage;
|
|
uint32_t packed;
|
|
asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;\n" : "=r"(packed) : "h"(bits));
|
|
|
|
return __half2float(reinterpret_cast<half2 const &>(packed).x);
|
|
#else
|
|
return Base::convert_fp8_to_float(x.storage);
|
|
#endif
|
|
}
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
/// Constructor inheritance
|
|
using Base::Base;
|
|
|
|
/// Default constructor
|
|
float_e4m3_t() = default;
|
|
|
|
#ifdef CUDA_FP8_ENABLED
|
|
/// Conversion from CUDA's FP8 type
|
|
CUTLASS_HOST_DEVICE
|
|
explicit float_e4m3_t(__nv_fp8_e4m3 x) {
|
|
storage = x.__x;
|
|
}
|
|
#endif
|
|
|
|
/// Floating point conversion
|
|
CUTLASS_HOST_DEVICE
|
|
explicit float_e4m3_t(float x) {
|
|
storage = from_float(x).storage;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
explicit float_e4m3_t(half x) {
|
|
storage = from_half(x).storage;
|
|
}
|
|
|
|
/// Floating point conversion
|
|
CUTLASS_HOST_DEVICE
|
|
explicit float_e4m3_t(double x): float_e4m3_t(float(x)) {
|
|
}
|
|
|
|
/// Integer conversion
|
|
CUTLASS_HOST_DEVICE
|
|
explicit float_e4m3_t(int x): float_e4m3_t(float(x)) {
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
explicit float_e4m3_t(unsigned x): float_e4m3_t(float(x)) {
|
|
}
|
|
|
|
/// E5M2 conversion. Defined after float_e5m2_t is defined.
|
|
CUTLASS_HOST_DEVICE
|
|
explicit float_e4m3_t(float_e5m2_t x);
|
|
|
|
#ifdef CUDA_FP8_ENABLED
|
|
/// Assignment from CUDA's FP8 type
|
|
CUTLASS_HOST_DEVICE
|
|
float_e4m3_t & operator=(__nv_fp8_e4m3 x) {
|
|
storage = x.__x;
|
|
return *this;
|
|
}
|
|
#endif
|
|
|
|
/// Converts to float
|
|
CUTLASS_HOST_DEVICE
|
|
operator float() const {
|
|
return to_float(*this);
|
|
}
|
|
|
|
/// Converts to half
|
|
CUTLASS_HOST_DEVICE
|
|
operator half() const {
|
|
return to_half(*this);
|
|
}
|
|
|
|
/// Converts to float
|
|
CUTLASS_HOST_DEVICE
|
|
explicit operator double() const {
|
|
return double(to_float(*this));
|
|
}
|
|
|
|
/// Converts to int
|
|
CUTLASS_HOST_DEVICE
|
|
explicit operator int() const {
|
|
#if defined(__CUDA_ARCH__)
|
|
return __half2int_rn(to_half(*this));
|
|
#else
|
|
return int(to_float(*this));
|
|
#endif
|
|
}
|
|
|
|
/// Casts to bool
|
|
CUTLASS_HOST_DEVICE
|
|
explicit operator bool() const {
|
|
#if defined(__CUDA_ARCH__)
|
|
return bool(__half2int_rn(to_half(*this)));
|
|
#else
|
|
return bool(int(to_float(*this)));
|
|
#endif
|
|
}
|
|
|
|
/// Accesses raw internal state
|
|
CUTLASS_HOST_DEVICE
|
|
uint8_t& raw() {
|
|
return storage;
|
|
}
|
|
|
|
/// Accesses raw internal state
|
|
CUTLASS_HOST_DEVICE
|
|
uint8_t raw() const {
|
|
return storage;
|
|
}
|
|
|
|
/// Returns the sign bit
|
|
CUTLASS_HOST_DEVICE
|
|
bool signbit() const {
|
|
return ((storage & (1 << (Base::FP8_NUM_BITS - 1))) != 0);
|
|
}
|
|
|
|
/// Returns the biased exponent
|
|
CUTLASS_HOST_DEVICE
|
|
int exponent_biased() const {
|
|
return int((storage >> FP8_NUM_MANTISSA_BITS) & Base::FP8_EXPONENT_MASK);
|
|
}
|
|
|
|
/// Returns the unbiased exponent
|
|
CUTLASS_HOST_DEVICE
|
|
int exponent() const {
|
|
return exponent_biased() - 15;
|
|
}
|
|
|
|
/// Returns the mantissa
|
|
CUTLASS_HOST_DEVICE
|
|
int mantissa() const {
|
|
return int(storage & Base::FP8_MANTISSA_MASK);
|
|
}
|
|
};
|
|
///////////////////////////////////////////////////////////////
|
|
///
|
|
/// floating-point 8 type : E5M2
|
|
///
|
|
///////////////////////////////////////////////////////////////
|
|
struct alignas(1) float_e5m2_t : float8_base<FloatEncoding::E5M2> {
|
|
|
|
using Base = float8_base<FloatEncoding::E5M2>;
|
|
|
|
static constexpr int MAX_EXPONENT = Base::FP8_MAX_EXPONENT;
|
|
|
|
//
|
|
// Static conversion operators
|
|
//
|
|
|
|
/// Constructs from an uint8_t
|
|
CUTLASS_HOST_DEVICE
|
|
static float_e5m2_t bitcast(uint8_t x) {
|
|
float_e5m2_t f;
|
|
f.storage = x;
|
|
return f;
|
|
}
|
|
|
|
/// FP32 -> FP8 conversion - rounds to nearest even
|
|
CUTLASS_HOST_DEVICE
|
|
static float_e5m2_t from_float(float const& flt) {
|
|
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
|
|
uint16_t tmp;
|
|
float y = float();
|
|
asm volatile("cvt.rn.satfinite.e5m2x2.f32 %0, %1, %2;" : "=h"(tmp) : "f"(y), "f"(flt));
|
|
|
|
return *reinterpret_cast<float_e5m2_t *>(&tmp);
|
|
#else
|
|
return bitcast(Base::convert_float_to_fp8(flt));
|
|
#endif
|
|
}
|
|
|
|
/// FP16 -> E5M2 conversion - rounds to nearest even
|
|
CUTLASS_HOST_DEVICE
|
|
static float_e5m2_t from_half(half const& flt) {
|
|
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
|
|
uint16_t tmp = 0;
|
|
uint32_t bits = reinterpret_cast<uint16_t const &>(flt);
|
|
asm volatile("cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;" : "=h"(tmp) : "r"(bits));
|
|
|
|
return *reinterpret_cast<float_e5m2_t *>(&tmp);
|
|
#else
|
|
return bitcast(Base::convert_float_to_fp8(__half2float(flt)));
|
|
#endif
|
|
}
|
|
|
|
// E5M2 -> half
|
|
CUTLASS_HOST_DEVICE
|
|
static half to_half(float_e5m2_t const& x) {
|
|
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
|
|
uint16_t bits = x.storage;
|
|
uint32_t packed;
|
|
asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;\n" : "=r"(packed) : "h"(bits));
|
|
|
|
return reinterpret_cast<half2 const &>(packed).x;
|
|
#else
|
|
return __float2half(Base::convert_fp8_to_float(x.storage));
|
|
#endif
|
|
}
|
|
|
|
// E5M2 -> Float
|
|
CUTLASS_HOST_DEVICE
|
|
static float to_float(float_e5m2_t const& x) {
|
|
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
|
|
uint16_t bits = x.storage;
|
|
uint32_t packed;
|
|
asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;\n" : "=r"(packed) : "h"(bits));
|
|
|
|
return __half2float(reinterpret_cast<half2 const &>(packed).x);
|
|
#else
|
|
return Base::convert_fp8_to_float(x.storage);
|
|
#endif
|
|
}
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
/// Constructor inheritance
|
|
using Base::Base;
|
|
|
|
/// Default constructor
|
|
float_e5m2_t() = default;
|
|
|
|
#ifdef CUDA_FP8_ENABLED
|
|
/// Conversion from CUDA's FP8 type
|
|
CUTLASS_HOST_DEVICE
|
|
explicit float_e5m2_t(__nv_fp8_e5m2 x) {
|
|
storage = x.__x;
|
|
}
|
|
#endif
|
|
|
|
/// Floating point conversion
|
|
CUTLASS_HOST_DEVICE
|
|
explicit float_e5m2_t(float x) {
|
|
storage = from_float(x).storage;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
explicit float_e5m2_t(half x) {
|
|
storage = from_half(x).storage;
|
|
}
|
|
|
|
/// Floating point conversion
|
|
CUTLASS_HOST_DEVICE
|
|
explicit float_e5m2_t(double x): float_e5m2_t(float(x)) {
|
|
}
|
|
|
|
/// Integer conversion
|
|
CUTLASS_HOST_DEVICE
|
|
explicit float_e5m2_t(int x): float_e5m2_t(float(x)) {
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
explicit float_e5m2_t(unsigned x): float_e5m2_t(float(x)) {
|
|
}
|
|
|
|
/// E4M3 conversion
|
|
CUTLASS_HOST_DEVICE
|
|
explicit float_e5m2_t(float_e4m3_t x);
|
|
|
|
#ifdef CUDA_FP8_ENABLED
|
|
/// Assignment from CUDA's FP8 type
|
|
CUTLASS_HOST_DEVICE
|
|
float_e5m2_t & operator=(__nv_fp8_e5m2 x) {
|
|
storage = x.__x;
|
|
return *this;
|
|
}
|
|
#endif
|
|
|
|
/// Converts to float
|
|
CUTLASS_HOST_DEVICE
|
|
operator float() const {
|
|
return to_float(*this);
|
|
}
|
|
|
|
/// Converts to half
|
|
CUTLASS_HOST_DEVICE
|
|
operator half() const {
|
|
return to_half(*this);
|
|
}
|
|
|
|
/// Converts to float
|
|
CUTLASS_HOST_DEVICE
|
|
explicit operator double() const {
|
|
return double(to_float(*this));
|
|
}
|
|
|
|
/// Converts to int
|
|
CUTLASS_HOST_DEVICE
|
|
explicit operator int() const {
|
|
#if defined(__CUDA_ARCH__)
|
|
return __half2int_rn(to_half(*this));
|
|
#else
|
|
return int(to_float(*this));
|
|
#endif
|
|
}
|
|
|
|
/// Casts to bool
|
|
CUTLASS_HOST_DEVICE
|
|
explicit operator bool() const {
|
|
#if defined(__CUDA_ARCH__)
|
|
return bool(__half2int_rn(to_half(*this)));
|
|
#else
|
|
return bool(int(to_float(*this)));
|
|
#endif
|
|
}
|
|
|
|
/// Accesses raw internal state
|
|
CUTLASS_HOST_DEVICE
|
|
uint8_t& raw() {
|
|
return storage;
|
|
}
|
|
|
|
/// Accesses raw internal state
|
|
CUTLASS_HOST_DEVICE
|
|
uint8_t raw() const {
|
|
return storage;
|
|
}
|
|
|
|
/// Returns the sign bit
|
|
CUTLASS_HOST_DEVICE
|
|
bool signbit() const {
|
|
return ((storage & (1 << (Base::FP8_NUM_BITS - 1))) != 0);
|
|
}
|
|
|
|
/// Returns the biased exponent
|
|
CUTLASS_HOST_DEVICE
|
|
int exponent_biased() const {
|
|
return int((storage >> FP8_NUM_MANTISSA_BITS) & Base::FP8_EXPONENT_MASK);
|
|
}
|
|
|
|
/// Returns the unbiased exponent
|
|
CUTLASS_HOST_DEVICE
|
|
int exponent() const {
|
|
return exponent_biased() - 15;
|
|
}
|
|
|
|
/// Returns the mantissa
|
|
CUTLASS_HOST_DEVICE
|
|
int mantissa() const {
|
|
return int(storage & Base::FP8_MANTISSA_MASK);
|
|
}
|
|
};
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
//
|
|
// Arithmetic operators
|
|
//
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
bool operator==(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
|
|
return float(lhs) == float(rhs);
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
bool operator!=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
|
|
return float(lhs) != float(rhs);
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
bool operator<(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
|
|
return float(lhs) < float(rhs);
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
bool operator<=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
|
|
return float(lhs) <= float(rhs);
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
bool operator>(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
|
|
return float(lhs) > float(rhs);
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
bool operator>=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
|
|
return float(lhs) >= float(rhs);
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e4m3_t operator+(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
|
|
return float_e4m3_t(float(lhs) + float(rhs));
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e4m3_t operator-(float_e4m3_t const& lhs) {
|
|
return float_e4m3_t(-float(lhs));
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e4m3_t operator-(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
|
|
return float_e4m3_t(float(lhs) - float(rhs));
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e4m3_t operator*(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
|
|
return float_e4m3_t(float(lhs) * float(rhs));
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e4m3_t operator/(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
|
|
return float_e4m3_t(float(lhs) / float(rhs));
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e4m3_t& operator+=(float_e4m3_t & lhs, float_e4m3_t const& rhs) {
|
|
lhs = float_e4m3_t(float(lhs) + float(rhs));
|
|
return lhs;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e4m3_t& operator-=(float_e4m3_t & lhs, float_e4m3_t const& rhs) {
|
|
lhs = float_e4m3_t(float(lhs) - float(rhs));
|
|
return lhs;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e4m3_t& operator*=(float_e4m3_t & lhs, float_e4m3_t const& rhs) {
|
|
lhs = float_e4m3_t(float(lhs) * float(rhs));
|
|
return lhs;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e4m3_t& operator/=(float_e4m3_t & lhs, float_e4m3_t const& rhs) {
|
|
lhs = float_e4m3_t(float(lhs) / float(rhs));
|
|
return lhs;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e4m3_t& operator++(float_e4m3_t & lhs) {
|
|
float tmp(lhs);
|
|
++tmp;
|
|
lhs = float_e4m3_t(tmp);
|
|
return lhs;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e4m3_t& operator--(float_e4m3_t & lhs) {
|
|
float tmp(lhs);
|
|
--tmp;
|
|
lhs = float_e4m3_t(tmp);
|
|
return lhs;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e4m3_t operator++(float_e4m3_t & lhs, int) {
|
|
float_e4m3_t ret(lhs);
|
|
float tmp(lhs);
|
|
tmp++;
|
|
lhs = float_e4m3_t(tmp);
|
|
return ret;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e4m3_t operator--(float_e4m3_t & lhs, int) {
|
|
float_e4m3_t ret(lhs);
|
|
float tmp(lhs);
|
|
tmp--;
|
|
lhs = float_e4m3_t(tmp);
|
|
return ret;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
bool operator==(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
|
|
return float(lhs) == float(rhs);
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
bool operator!=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
|
|
return float(lhs) != float(rhs);
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
bool operator<(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
|
|
return float(lhs) < float(rhs);
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
bool operator<=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
|
|
return float(lhs) <= float(rhs);
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
bool operator>(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
|
|
return float(lhs) > float(rhs);
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
bool operator>=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
|
|
return float(lhs) >= float(rhs);
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e5m2_t operator+(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
|
|
return float_e5m2_t(float(lhs) + float(rhs));
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e5m2_t operator-(float_e5m2_t const& lhs) {
|
|
return float_e5m2_t(-float(lhs));
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e5m2_t operator-(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
|
|
return float_e5m2_t(float(lhs) - float(rhs));
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e5m2_t operator*(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
|
|
return float_e5m2_t(float(lhs) * float(rhs));
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e5m2_t operator/(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
|
|
return float_e5m2_t(float(lhs) / float(rhs));
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e5m2_t& operator+=(float_e5m2_t & lhs, float_e5m2_t const& rhs) {
|
|
lhs = float_e5m2_t(float(lhs) + float(rhs));
|
|
return lhs;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e5m2_t& operator-=(float_e5m2_t & lhs, float_e5m2_t const& rhs) {
|
|
lhs = float_e5m2_t(float(lhs) - float(rhs));
|
|
return lhs;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e5m2_t& operator*=(float_e5m2_t & lhs, float_e5m2_t const& rhs) {
|
|
lhs = float_e5m2_t(float(lhs) * float(rhs));
|
|
return lhs;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e5m2_t& operator/=(float_e5m2_t & lhs, float_e5m2_t const& rhs) {
|
|
lhs = float_e5m2_t(float(lhs) / float(rhs));
|
|
return lhs;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e5m2_t& operator++(float_e5m2_t & lhs) {
|
|
float tmp(lhs);
|
|
++tmp;
|
|
lhs = float_e5m2_t(tmp);
|
|
return lhs;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e5m2_t& operator--(float_e5m2_t & lhs) {
|
|
float tmp(lhs);
|
|
--tmp;
|
|
lhs = float_e5m2_t(tmp);
|
|
return lhs;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e5m2_t operator++(float_e5m2_t & lhs, int) {
|
|
float_e5m2_t ret(lhs);
|
|
float tmp(lhs);
|
|
tmp++;
|
|
lhs = float_e5m2_t(tmp);
|
|
return ret;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
float_e5m2_t operator--(float_e5m2_t & lhs, int) {
|
|
float_e5m2_t ret(lhs);
|
|
float tmp(lhs);
|
|
tmp--;
|
|
lhs = float_e5m2_t(tmp);
|
|
return ret;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
//
|
|
// float_e4m3_t <=> float_e5m2_t conversions
|
|
//
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// float_e4m3_t <= float_e5m2_t
|
|
CUTLASS_HOST_DEVICE
|
|
float_e4m3_t::float_e4m3_t(float_e5m2_t x) {
|
|
storage = from_float(float_e5m2_t::to_float(x)).storage;
|
|
}
|
|
|
|
/// float_e5m2_t <= float_e4m3_t
|
|
CUTLASS_HOST_DEVICE
|
|
float_e5m2_t::float_e5m2_t(float_e4m3_t x) {
|
|
storage = from_float(float_e4m3_t::to_float(x)).storage;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////
|
|
///
|
|
/// Umbrella floating-point 8-bit data type : type_erased_dynamic_float8_t
|
|
/// This umbrella datatype can be enabled when a user provides a specific
|
|
/// datatype in runtime argument list.
|
|
///
|
|
/// Currently supported runtime datatypes compatible with type_erased_dynamic_float8_t:
|
|
/// QMMAFormat::E5M2
|
|
/// QMMAFormat::E4M3
|
|
///
|
|
///////////////////////////////////////////////////////////////
|
|
|
|
union type_erased_dynamic_float8_t {
|
|
uint8_t data;
|
|
cutlass::float_e5m2_t e5m2;
|
|
cutlass::float_e4m3_t e4m3;
|
|
CUTLASS_HOST_DEVICE
|
|
explicit operator cutlass::float_e5m2_t() const {
|
|
return e5m2;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
explicit operator cutlass::float_e4m3_t() const {
|
|
return e4m3;
|
|
}
|
|
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace cutlass
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
//
|
|
// Standard Library operations and definitions
|
|
//
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
#if !defined(__CUDACC_RTC__)
|
|
namespace std {
|
|
|
|
/// Numeric limits common to all float8 types
|
|
template <typename T>
|
|
struct float8_base_numeric_limits {
|
|
private:
|
|
using F8Type = T;
|
|
public:
|
|
static bool const is_specialized = true;
|
|
static bool const is_signed = true;
|
|
static bool const is_integer = false;
|
|
static bool const is_exact = false;
|
|
static bool const has_quiet_NaN = true;
|
|
static bool const has_signaling_NaN = false;
|
|
static std::float_denorm_style const has_denorm = std::denorm_present;
|
|
static bool const has_denorm_loss = true;
|
|
static std::float_round_style const round_style = std::round_to_nearest;
|
|
static bool const is_iec559 = false;
|
|
static bool const is_bounded = true;
|
|
static bool const is_modulo = false;
|
|
static int const digits = F8Type::FP8_NUM_MANTISSA_BITS;
|
|
|
|
/// Least positive value
|
|
CUTLASS_HOST_DEVICE
|
|
static F8Type min() { return F8Type::bitcast(0x01); }
|
|
|
|
/// Maximum finite value
|
|
CUTLASS_HOST_DEVICE
|
|
static F8Type max() { return F8Type::bitcast(F8Type::FP8_MAX_FLT); }
|
|
|
|
/// Returns maximum rounding error
|
|
CUTLASS_HOST_DEVICE
|
|
static F8Type round_error() { return F8Type(0.5f); }
|
|
|
|
/// Returns positive infinity value
|
|
CUTLASS_HOST_DEVICE
|
|
static F8Type infinity() { return F8Type::bitcast(F8Type::FP8_INFINITY_MASK); }
|
|
|
|
/// Returns quiet NaN value
|
|
CUTLASS_HOST_DEVICE
|
|
static F8Type quiet_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); }
|
|
|
|
/// Returns signaling NaN value
|
|
CUTLASS_HOST_DEVICE
|
|
static F8Type signaling_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); }
|
|
|
|
/// Returns smallest positive subnormal value
|
|
CUTLASS_HOST_DEVICE
|
|
static F8Type denorm_min() { return F8Type::bitcast(0x01); }
|
|
};
|
|
|
|
/// Numeric limits for float_e4m3_t
|
|
template <>
|
|
struct numeric_limits<cutlass::float_e4m3_t> :
|
|
public float8_base_numeric_limits<cutlass::float_e4m3_t> {
|
|
static bool const has_infinity = false;
|
|
|
|
/// Minimum finite value
|
|
static cutlass::float_e4m3_t lowest() { return cutlass::float_e4m3_t::bitcast(0xfe); }
|
|
|
|
/// Machine epsilon, that is, the difference between 1.0 and the next representable value
|
|
static cutlass::float_e4m3_t epsilon() { return cutlass::float_e4m3_t::bitcast(0x20); }
|
|
};
|
|
|
|
/// Numeric limits for float_e5m2_t
|
|
template <>
|
|
struct numeric_limits<cutlass::float_e5m2_t> :
|
|
public float8_base_numeric_limits<cutlass::float_e5m2_t> {
|
|
static bool const has_infinity = true;
|
|
|
|
/// Minimum finite value
|
|
static cutlass::float_e5m2_t lowest() { return cutlass::float_e5m2_t::bitcast(0xfb); }
|
|
|
|
/// Machine epsilon, that is, the difference between 1.0 and the next representable value
|
|
static cutlass::float_e5m2_t epsilon() { return cutlass::float_e5m2_t::bitcast(0x34); }
|
|
};
|
|
|
|
} // namespace std
|
|
#endif
|
|
|
|
namespace cutlass {
|
|
namespace platform {
|
|
|
|
/// Numeric limits common to all float8 types
|
|
template <typename T>
|
|
struct float8_base_numeric_limits {
|
|
private:
|
|
using F8Type = T;
|
|
public:
|
|
static bool const is_specialized = true;
|
|
static bool const is_signed = true;
|
|
static bool const is_integer = false;
|
|
static bool const is_exact = false;
|
|
static bool const has_quiet_NaN = true;
|
|
static bool const has_signaling_NaN = false;
|
|
#if !defined(__CUDACC_RTC__)
|
|
static std::float_denorm_style const has_denorm = std::denorm_present;
|
|
#endif
|
|
static bool const has_denorm_loss = true;
|
|
#if !defined(__CUDACC_RTC__)
|
|
static std::float_round_style const round_style = std::round_to_nearest;
|
|
#endif
|
|
static bool const is_iec559 = false;
|
|
static bool const is_bounded = true;
|
|
static bool const is_modulo = false;
|
|
static int const digits = F8Type::FP8_NUM_MANTISSA_BITS;
|
|
|
|
/// Least positive value
|
|
CUTLASS_HOST_DEVICE
|
|
static F8Type min() { return F8Type::bitcast(0x01); }
|
|
|
|
/// Maximum finite value
|
|
CUTLASS_HOST_DEVICE
|
|
static F8Type max() { return F8Type::bitcast(F8Type::FP8_MAX_FLT); }
|
|
|
|
/// Returns maximum rounding error
|
|
CUTLASS_HOST_DEVICE
|
|
static F8Type round_error() { return F8Type(0.5f); }
|
|
|
|
/// Returns positive infinity value
|
|
CUTLASS_HOST_DEVICE
|
|
static F8Type infinity() { return F8Type::bitcast(F8Type::FP8_INFINITY_MASK); }
|
|
|
|
/// Returns quiet NaN value
|
|
CUTLASS_HOST_DEVICE
|
|
static F8Type quiet_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); }
|
|
|
|
/// Returns signaling NaN value
|
|
CUTLASS_HOST_DEVICE
|
|
static F8Type signaling_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); }
|
|
|
|
/// Returns smallest positive subnormal value
|
|
CUTLASS_HOST_DEVICE
|
|
static F8Type denorm_min() { return F8Type::bitcast(0x01); }
|
|
};
|
|
|
|
/// Forward Declaration
|
|
template <class T>
|
|
struct numeric_limits;
|
|
|
|
/// Numeric limits for float_e4m3_t
|
|
template <>
|
|
struct numeric_limits<cutlass::float_e4m3_t> :
|
|
public float8_base_numeric_limits<cutlass::float_e4m3_t> {
|
|
static bool const has_infinity = false;
|
|
|
|
/// Minimum finite value
|
|
static cutlass::float_e4m3_t lowest() { return cutlass::float_e4m3_t::bitcast(0xfe); }
|
|
|
|
/// Machine epsilon, that is, the difference between 1.0 and the next representable value
|
|
static cutlass::float_e4m3_t epsilon() { return cutlass::float_e4m3_t::bitcast(0x20); }
|
|
};
|
|
|
|
/// Numeric limits for float_e5m2_t
|
|
template <>
|
|
struct numeric_limits<cutlass::float_e5m2_t> :
|
|
public float8_base_numeric_limits<cutlass::float_e5m2_t> {
|
|
static bool const has_infinity = true;
|
|
|
|
/// Minimum finite value
|
|
static cutlass::float_e5m2_t lowest() { return cutlass::float_e5m2_t::bitcast(0xfb); }
|
|
|
|
/// Machine epsilon, that is, the difference between 1.0 and the next representable value
|
|
static cutlass::float_e5m2_t epsilon() { return cutlass::float_e5m2_t::bitcast(0x34); }
|
|
};
|
|
|
|
} // namespace platform
|
|
|
|
} // namespace cutlass
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//
|
|
// User-defined literals
|
|
//
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
cutlass::float_e4m3_t operator "" _fe4m3(long double x) {
|
|
return cutlass::float_e4m3_t(float(x));
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
cutlass::float_e4m3_t operator "" _fe4m3(unsigned long long int x) {
|
|
return cutlass::float_e4m3_t(int(x));
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
cutlass::float_e5m2_t operator "" _fe5m2(long double x) {
|
|
return cutlass::float_e5m2_t(float(x));
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
cutlass::float_e5m2_t operator "" _fe5m2(unsigned long long int x) {
|
|
return cutlass::float_e5m2_t(int(x));
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|