Fix some sign conversion warnings (#1172)

* Fix sign conversion warnings

* Fix type conversion warnings

* Fix sign conversion warnings

* Change smem_size_ to constexpr

* clang warnings

* undo cast change

* one miss change

* missing part

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
cyyever 2023-11-30 13:28:40 +08:00 committed by GitHub
parent 99c4eebe3b
commit 10b850f9c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 22 additions and 31 deletions

View File

@ -231,9 +231,10 @@ int ceil_div(int a, int b) {
* log2_up/down codes? * log2_up/down codes?
*/ */
template <typename value_t> template <typename value_t>
CUTLASS_HOST_DEVICE value_t clz(value_t x) { CUTLASS_HOST_DEVICE int clz(value_t x) {
for (int i = 31; i >= 0; --i) { for (int i = 31; i >= 0; --i) {
if ((1 << i) & x) return 31 - i; if ((1 << i) & x)
return value_t(31 - i);
} }
return 32; return 32;
} }

View File

@ -1,4 +1,4 @@
/*************************************************************************************************** /**************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause * SPDX-License-Identifier: BSD-3-Clause
* *
@ -216,8 +216,8 @@ struct alignas(1) float8_base {
// Extract the bits in the FP32 type // Extract the bits in the FP32 type
uint8_t sign = uint8_t((s >> 24 & 0x80)); uint8_t sign = uint8_t((s >> 24 & 0x80));
int8_t exp = uint8_t(((s >> FP32_NUM_MANTISSA_BITS) & 0xff) - FP32_EXPONENT_BIAS); int32_t exp = int32_t((s >> FP32_NUM_MANTISSA_BITS) & 0xff) - FP32_EXPONENT_BIAS;
int mantissa = s & 0x7fffff; uint32_t mantissa = s & 0x7fffff;
uint8_t u = 0; uint8_t u = 0;
uint8_t const kF8_NaN = 0x7f; uint8_t const kF8_NaN = 0x7f;
@ -233,7 +233,7 @@ struct alignas(1) float8_base {
} }
// Special handling // Special handling
if ( exp == -128 ) { if (exp == -128) {
// int8 range is from -128 to 127 // int8 range is from -128 to 127
// So 255(inf) - 127(bias) = 128 - will show up as -128 // So 255(inf) - 127(bias) = 128 - will show up as -128
@ -248,8 +248,8 @@ struct alignas(1) float8_base {
if ( (exp >= FP8_MIN_EXPONENT) && (exp <= FP8_MAX_EXPONENT) ) { if ( (exp >= FP8_MIN_EXPONENT) && (exp <= FP8_MAX_EXPONENT) ) {
// normal fp32 to normal fp8 // normal fp32 to normal fp8
exp = uint8_t(exp + uint8_t(FP8_EXPONENT_BIAS)); exp = exp + FP8_EXPONENT_BIAS;
u = uint8_t(((exp & FP8_EXPONENT_MASK) << FP8_NUM_MANTISSA_BITS)); u = uint8_t((uint32_t(exp) & FP8_EXPONENT_MASK) << FP8_NUM_MANTISSA_BITS);
u = uint8_t(u | (mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS))); u = uint8_t(u | (mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS)));
} else if(exp < FP8_MIN_EXPONENT) { } else if(exp < FP8_MIN_EXPONENT) {
// normal single-precision to subnormal float8-precision representation // normal single-precision to subnormal float8-precision representation
@ -271,8 +271,8 @@ struct alignas(1) float8_base {
if( exp == (FP8_MAX_EXPONENT + 1) ) { if( exp == (FP8_MAX_EXPONENT + 1) ) {
uint8_t mantissa_tmp = uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS)); uint8_t mantissa_tmp = uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS));
if( mantissa_tmp < FP8_MANTISSA_MASK) { if( mantissa_tmp < FP8_MANTISSA_MASK) {
exp = uint8_t(exp + uint8_t(FP8_EXPONENT_BIAS)); exp = exp + FP8_EXPONENT_BIAS;
u = uint8_t(exp << FP8_NUM_MANTISSA_BITS) | mantissa_tmp; u = uint8_t(uint32_t(exp) << FP8_NUM_MANTISSA_BITS) | mantissa_tmp;
may_be_nan = (mantissa_tmp == (FP8_MANTISSA_MASK-1)); may_be_nan = (mantissa_tmp == (FP8_MANTISSA_MASK-1));
} else { } else {
// satfinite // satfinite
@ -316,9 +316,9 @@ struct alignas(1) float8_base {
uint32_t constexpr kF32_NaN = 0x7fffffff; uint32_t constexpr kF32_NaN = 0x7fffffff;
uint8_t const &f8 = x; uint8_t const &f8 = x;
int sign = (f8 >> (FP8_NUM_BITS - 1)) & 1; uint32_t sign = (f8 >> (FP8_NUM_BITS - 1)) & 1;
int exp = (f8 >> FP8_NUM_MANTISSA_BITS) & FP8_EXPONENT_MASK; uint32_t exp = (f8 >> FP8_NUM_MANTISSA_BITS) & FP8_EXPONENT_MASK;
int mantissa = f8 & FP8_MANTISSA_MASK; uint32_t mantissa = f8 & FP8_MANTISSA_MASK;
unsigned f = (sign << (FP32_NUM_BITS-1)); unsigned f = (sign << (FP32_NUM_BITS-1));
if (IS_E4M3 && exp == 15 && mantissa == 0x7) { if (IS_E4M3 && exp == 15 && mantissa == 0x7) {

View File

@ -112,7 +112,8 @@ protected:
CUTLASS_THREAD_LOCAL static int sm_occupancy_; CUTLASS_THREAD_LOCAL static int sm_occupancy_;
/// Kernel dynamic shared memory allocation requirement /// Kernel dynamic shared memory allocation requirement
CUTLASS_THREAD_LOCAL static int smem_size_; /// Update the kernel function's shared memory configuration for the current device
static constexpr size_t smem_size_ = sizeof(typename GemmKernel::SharedStorage);
/// Initialize static thread-local members for the thread's current device, /// Initialize static thread-local members for the thread's current device,
/// if necessary. /// if necessary.
@ -143,11 +144,8 @@ protected:
return Status::kErrorInternal; return Status::kErrorInternal;
} }
// Update the kernel function's shared memory configuration for the current device
smem_size_ = int(sizeof(typename GemmKernel::SharedStorage));
// If requires more than 48KB: configure for extended, dynamic shared memory // If requires more than 48KB: configure for extended, dynamic shared memory
if (smem_size_ >= (48 << 10)) if constexpr (smem_size_ >= (48 << 10))
{ {
cudart_result = cudaFuncSetAttribute( cudart_result = cudaFuncSetAttribute(
Kernel2<GemmKernel>, Kernel2<GemmKernel>,
@ -377,7 +375,6 @@ public:
} }
}; };
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// Static initializers /// Static initializers
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
@ -394,12 +391,6 @@ CUTLASS_THREAD_LOCAL int GemmUniversalBase<GemmKernel_>::device_sms_ = -1;
template <typename GemmKernel_> template <typename GemmKernel_>
CUTLASS_THREAD_LOCAL int GemmUniversalBase<GemmKernel_>::sm_occupancy_ = -1; CUTLASS_THREAD_LOCAL int GemmUniversalBase<GemmKernel_>::sm_occupancy_ = -1;
/// Kernel dynamic shared memory allocation requirement
template <typename GemmKernel_>
CUTLASS_THREAD_LOCAL int GemmUniversalBase<GemmKernel_>::smem_size_ = -1;
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device } // namespace device

View File

@ -321,9 +321,9 @@ struct alignas(2) half_t {
#endif #endif
uint16_t const &h = x.storage; uint16_t const &h = x.storage;
int sign = ((h >> 15) & 1); uint32_t sign = ((h >> 15) & 1);
int exp = ((h >> 10) & 0x1f); uint32_t exp = ((h >> 10) & 0x1f);
int mantissa = (h & 0x3ff); uint32_t mantissa = (h & 0x3ff);
unsigned f = 0; unsigned f = 0;
if (exp > 0 && exp < 31) { if (exp > 0 && exp < 31) {

View File

@ -364,7 +364,7 @@ struct NumericConverter<half_t, float, FloatRoundStyle::round_toward_zero> {
// software implementation rounds toward nearest even // software implementation rounds toward nearest even
unsigned const& s = reinterpret_cast<unsigned const &>(flt); unsigned const& s = reinterpret_cast<unsigned const &>(flt);
uint16_t sign = uint16_t((s >> 16) & 0x8000); uint16_t sign = uint16_t((s >> 16) & 0x8000);
int16_t exp = uint16_t(((s >> 23) & 0xff) - 127); int32_t exp = int32_t((s >> 23) & 0xff) - 127;
int mantissa = s & 0x7fffff; int mantissa = s & 0x7fffff;
uint16_t u = 0; uint16_t u = 0;
@ -386,8 +386,7 @@ struct NumericConverter<half_t, float, FloatRoundStyle::round_toward_zero> {
if (exp >= -14) { if (exp >= -14) {
// normal fp32 to normal fp16 // normal fp32 to normal fp16
exp = uint16_t(exp + uint16_t(15)); u = uint16_t((uint32_t(exp + 15) & 0x1f) << 10);
u = uint16_t(((exp & 0x1f) << 10));
u = uint16_t(u | (mantissa >> 13)); u = uint16_t(u | (mantissa >> 13));
} else { } else {
// normal single-precision to subnormal half_t-precision representation // normal single-precision to subnormal half_t-precision representation