Fix some sign conversion warnings (#1172)
* Fix sign conversion warnings * Fix type conversion warnings * Fix sign conversion warnings * Change smem_size_ to constexpr * clang warnings * undo cast change * one miss change * missing part --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
99c4eebe3b
commit
10b850f9c7
@ -231,9 +231,10 @@ int ceil_div(int a, int b) {
|
||||
* log2_up/down codes?
|
||||
*/
|
||||
template <typename value_t>
|
||||
CUTLASS_HOST_DEVICE value_t clz(value_t x) {
|
||||
CUTLASS_HOST_DEVICE int clz(value_t x) {
|
||||
for (int i = 31; i >= 0; --i) {
|
||||
if ((1 << i) & x) return 31 - i;
|
||||
if ((1 << i) & x)
|
||||
return value_t(31 - i);
|
||||
}
|
||||
return 32;
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
/***************************************************************************************************
|
||||
/**************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
@ -216,8 +216,8 @@ struct alignas(1) float8_base {
|
||||
|
||||
// Extract the bits in the FP32 type
|
||||
uint8_t sign = uint8_t((s >> 24 & 0x80));
|
||||
int8_t exp = uint8_t(((s >> FP32_NUM_MANTISSA_BITS) & 0xff) - FP32_EXPONENT_BIAS);
|
||||
int mantissa = s & 0x7fffff;
|
||||
int32_t exp = int32_t((s >> FP32_NUM_MANTISSA_BITS) & 0xff) - FP32_EXPONENT_BIAS;
|
||||
uint32_t mantissa = s & 0x7fffff;
|
||||
uint8_t u = 0;
|
||||
|
||||
uint8_t const kF8_NaN = 0x7f;
|
||||
@ -233,7 +233,7 @@ struct alignas(1) float8_base {
|
||||
}
|
||||
|
||||
// Special handling
|
||||
if ( exp == -128 ) {
|
||||
if (exp == -128) {
|
||||
// int8 range is from -128 to 127
|
||||
// So 255(inf) - 127(bias) = 128 - will show up as -128
|
||||
|
||||
@ -248,8 +248,8 @@ struct alignas(1) float8_base {
|
||||
|
||||
if ( (exp >= FP8_MIN_EXPONENT) && (exp <= FP8_MAX_EXPONENT) ) {
|
||||
// normal fp32 to normal fp8
|
||||
exp = uint8_t(exp + uint8_t(FP8_EXPONENT_BIAS));
|
||||
u = uint8_t(((exp & FP8_EXPONENT_MASK) << FP8_NUM_MANTISSA_BITS));
|
||||
exp = exp + FP8_EXPONENT_BIAS;
|
||||
u = uint8_t((uint32_t(exp) & FP8_EXPONENT_MASK) << FP8_NUM_MANTISSA_BITS);
|
||||
u = uint8_t(u | (mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS)));
|
||||
} else if(exp < FP8_MIN_EXPONENT) {
|
||||
// normal single-precision to subnormal float8-precision representation
|
||||
@ -271,8 +271,8 @@ struct alignas(1) float8_base {
|
||||
if( exp == (FP8_MAX_EXPONENT + 1) ) {
|
||||
uint8_t mantissa_tmp = uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS));
|
||||
if( mantissa_tmp < FP8_MANTISSA_MASK) {
|
||||
exp = uint8_t(exp + uint8_t(FP8_EXPONENT_BIAS));
|
||||
u = uint8_t(exp << FP8_NUM_MANTISSA_BITS) | mantissa_tmp;
|
||||
exp = exp + FP8_EXPONENT_BIAS;
|
||||
u = uint8_t(uint32_t(exp) << FP8_NUM_MANTISSA_BITS) | mantissa_tmp;
|
||||
may_be_nan = (mantissa_tmp == (FP8_MANTISSA_MASK-1));
|
||||
} else {
|
||||
// satfinite
|
||||
@ -316,9 +316,9 @@ struct alignas(1) float8_base {
|
||||
uint32_t constexpr kF32_NaN = 0x7fffffff;
|
||||
|
||||
uint8_t const &f8 = x;
|
||||
int sign = (f8 >> (FP8_NUM_BITS - 1)) & 1;
|
||||
int exp = (f8 >> FP8_NUM_MANTISSA_BITS) & FP8_EXPONENT_MASK;
|
||||
int mantissa = f8 & FP8_MANTISSA_MASK;
|
||||
uint32_t sign = (f8 >> (FP8_NUM_BITS - 1)) & 1;
|
||||
uint32_t exp = (f8 >> FP8_NUM_MANTISSA_BITS) & FP8_EXPONENT_MASK;
|
||||
uint32_t mantissa = f8 & FP8_MANTISSA_MASK;
|
||||
unsigned f = (sign << (FP32_NUM_BITS-1));
|
||||
|
||||
if (IS_E4M3 && exp == 15 && mantissa == 0x7) {
|
||||
|
@ -112,7 +112,8 @@ protected:
|
||||
CUTLASS_THREAD_LOCAL static int sm_occupancy_;
|
||||
|
||||
/// Kernel dynamic shared memory allocation requirement
|
||||
CUTLASS_THREAD_LOCAL static int smem_size_;
|
||||
/// Update the kernel function's shared memory configuration for the current device
|
||||
static constexpr size_t smem_size_ = sizeof(typename GemmKernel::SharedStorage);
|
||||
|
||||
/// Initialize static thread-local members for the thread's current device,
|
||||
/// if necessary.
|
||||
@ -143,11 +144,8 @@ protected:
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
// Update the kernel function's shared memory configuration for the current device
|
||||
smem_size_ = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
// If requires more than 48KB: configure for extended, dynamic shared memory
|
||||
if (smem_size_ >= (48 << 10))
|
||||
if constexpr (smem_size_ >= (48 << 10))
|
||||
{
|
||||
cudart_result = cudaFuncSetAttribute(
|
||||
Kernel2<GemmKernel>,
|
||||
@ -377,7 +375,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Static initializers
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -394,12 +391,6 @@ CUTLASS_THREAD_LOCAL int GemmUniversalBase<GemmKernel_>::device_sms_ = -1;
|
||||
template <typename GemmKernel_>
|
||||
CUTLASS_THREAD_LOCAL int GemmUniversalBase<GemmKernel_>::sm_occupancy_ = -1;
|
||||
|
||||
/// Kernel dynamic shared memory allocation requirement
|
||||
template <typename GemmKernel_>
|
||||
CUTLASS_THREAD_LOCAL int GemmUniversalBase<GemmKernel_>::smem_size_ = -1;
|
||||
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
|
@ -321,9 +321,9 @@ struct alignas(2) half_t {
|
||||
#endif
|
||||
|
||||
uint16_t const &h = x.storage;
|
||||
int sign = ((h >> 15) & 1);
|
||||
int exp = ((h >> 10) & 0x1f);
|
||||
int mantissa = (h & 0x3ff);
|
||||
uint32_t sign = ((h >> 15) & 1);
|
||||
uint32_t exp = ((h >> 10) & 0x1f);
|
||||
uint32_t mantissa = (h & 0x3ff);
|
||||
unsigned f = 0;
|
||||
|
||||
if (exp > 0 && exp < 31) {
|
||||
|
@ -364,7 +364,7 @@ struct NumericConverter<half_t, float, FloatRoundStyle::round_toward_zero> {
|
||||
// software implementation rounds toward nearest even
|
||||
unsigned const& s = reinterpret_cast<unsigned const &>(flt);
|
||||
uint16_t sign = uint16_t((s >> 16) & 0x8000);
|
||||
int16_t exp = uint16_t(((s >> 23) & 0xff) - 127);
|
||||
int32_t exp = int32_t((s >> 23) & 0xff) - 127;
|
||||
int mantissa = s & 0x7fffff;
|
||||
uint16_t u = 0;
|
||||
|
||||
@ -386,8 +386,7 @@ struct NumericConverter<half_t, float, FloatRoundStyle::round_toward_zero> {
|
||||
|
||||
if (exp >= -14) {
|
||||
// normal fp32 to normal fp16
|
||||
exp = uint16_t(exp + uint16_t(15));
|
||||
u = uint16_t(((exp & 0x1f) << 10));
|
||||
u = uint16_t((uint32_t(exp + 15) & 0x1f) << 10);
|
||||
u = uint16_t(u | (mantissa >> 13));
|
||||
} else {
|
||||
// normal single-precision to subnormal half_t-precision representation
|
||||
|
Loading…
Reference in New Issue
Block a user