diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index 31f2cb5f..def07421 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -231,9 +231,10 @@ int ceil_div(int a, int b) { * log2_up/down codes? */ template -CUTLASS_HOST_DEVICE value_t clz(value_t x) { +CUTLASS_HOST_DEVICE int clz(value_t x) { for (int i = 31; i >= 0; --i) { - if ((1 << i) & x) return 31 - i; + if ((1 << i) & x) + return value_t(31 - i); } return 32; } diff --git a/include/cutlass/float8.h b/include/cutlass/float8.h index 38af9c18..f7ec68e8 100644 --- a/include/cutlass/float8.h +++ b/include/cutlass/float8.h @@ -1,4 +1,4 @@ -/*************************************************************************************************** +/************************************************************************************************** * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * @@ -216,8 +216,8 @@ struct alignas(1) float8_base { // Extract the bits in the FP32 type uint8_t sign = uint8_t((s >> 24 & 0x80)); - int8_t exp = uint8_t(((s >> FP32_NUM_MANTISSA_BITS) & 0xff) - FP32_EXPONENT_BIAS); - int mantissa = s & 0x7fffff; + int32_t exp = int32_t((s >> FP32_NUM_MANTISSA_BITS) & 0xff) - FP32_EXPONENT_BIAS; + uint32_t mantissa = s & 0x7fffff; uint8_t u = 0; uint8_t const kF8_NaN = 0x7f; @@ -233,7 +233,7 @@ struct alignas(1) float8_base { } // Special handling - if ( exp == -128 ) { + if (exp == -128) { // int8 range is from -128 to 127 // So 255(inf) - 127(bias) = 128 - will show up as -128 @@ -248,8 +248,8 @@ struct alignas(1) float8_base { if ( (exp >= FP8_MIN_EXPONENT) && (exp <= FP8_MAX_EXPONENT) ) { // normal fp32 to normal fp8 - exp = uint8_t(exp + uint8_t(FP8_EXPONENT_BIAS)); - u = uint8_t(((exp & FP8_EXPONENT_MASK) << FP8_NUM_MANTISSA_BITS)); + exp = exp + FP8_EXPONENT_BIAS; + u = uint8_t((uint32_t(exp) & FP8_EXPONENT_MASK) << FP8_NUM_MANTISSA_BITS); u = uint8_t(u | (mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS))); } else if(exp < FP8_MIN_EXPONENT) { // normal single-precision to subnormal float8-precision representation @@ -271,8 +271,8 @@ struct alignas(1) float8_base { if( exp == (FP8_MAX_EXPONENT + 1) ) { uint8_t mantissa_tmp = uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS)); if( mantissa_tmp < FP8_MANTISSA_MASK) { - exp = uint8_t(exp + uint8_t(FP8_EXPONENT_BIAS)); - u = uint8_t(exp << FP8_NUM_MANTISSA_BITS) | mantissa_tmp; + exp = exp + FP8_EXPONENT_BIAS; + u = uint8_t(uint32_t(exp) << FP8_NUM_MANTISSA_BITS) | mantissa_tmp; may_be_nan = (mantissa_tmp == (FP8_MANTISSA_MASK-1)); } else { // satfinite @@ -316,9 +316,9 @@ struct alignas(1) float8_base { uint32_t constexpr kF32_NaN = 0x7fffffff; uint8_t const &f8 = x; - int sign = (f8 >> (FP8_NUM_BITS - 1)) & 1; - int exp = (f8 >> FP8_NUM_MANTISSA_BITS) & FP8_EXPONENT_MASK; - int mantissa = f8 & FP8_MANTISSA_MASK; + uint32_t sign = (f8 >> (FP8_NUM_BITS - 1)) & 1; + uint32_t exp = (f8 >> FP8_NUM_MANTISSA_BITS) & FP8_EXPONENT_MASK; + uint32_t mantissa = f8 & FP8_MANTISSA_MASK; unsigned f = (sign << (FP32_NUM_BITS-1)); if (IS_E4M3 && exp == 15 && mantissa == 0x7) { diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index 66e6a6d5..5e69ffb5 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -112,7 +112,8 @@ protected: CUTLASS_THREAD_LOCAL static int sm_occupancy_; /// Kernel dynamic shared memory allocation requirement - CUTLASS_THREAD_LOCAL static int smem_size_; + /// Update the kernel function's shared memory configuration for the current device + static constexpr size_t smem_size_ = sizeof(typename GemmKernel::SharedStorage); /// Initialize static thread-local members for the thread's current device, /// if necessary. @@ -143,11 +144,8 @@ protected: return Status::kErrorInternal; } - // Update the kernel function's shared memory configuration for the current device - smem_size_ = int(sizeof(typename GemmKernel::SharedStorage)); - // If requires more than 48KB: configure for extended, dynamic shared memory - if (smem_size_ >= (48 << 10)) + if constexpr (smem_size_ >= (48 << 10)) { cudart_result = cudaFuncSetAttribute( Kernel2, @@ -377,7 +375,6 @@ public: } }; - ///////////////////////////////////////////////////////////////////////////////////////////////// /// Static initializers ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -394,12 +391,6 @@ CUTLASS_THREAD_LOCAL int GemmUniversalBase::device_sms_ = -1; template CUTLASS_THREAD_LOCAL int GemmUniversalBase::sm_occupancy_ = -1; -/// Kernel dynamic shared memory allocation requirement -template -CUTLASS_THREAD_LOCAL int GemmUniversalBase::smem_size_ = -1; - - - ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace device diff --git a/include/cutlass/half.h b/include/cutlass/half.h index 10262db6..cff04e09 100644 --- a/include/cutlass/half.h +++ b/include/cutlass/half.h @@ -321,9 +321,9 @@ struct alignas(2) half_t { #endif uint16_t const &h = x.storage; - int sign = ((h >> 15) & 1); - int exp = ((h >> 10) & 0x1f); - int mantissa = (h & 0x3ff); + uint32_t sign = ((h >> 15) & 1); + uint32_t exp = ((h >> 10) & 0x1f); + uint32_t mantissa = (h & 0x3ff); unsigned f = 0; if (exp > 0 && exp < 31) { diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 27fc0e6f..6de5d038 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -364,7 +364,7 @@ struct NumericConverter { // software implementation rounds toward nearest even unsigned const& s = reinterpret_cast(flt); uint16_t sign = uint16_t((s >> 16) & 0x8000); - int16_t exp = uint16_t(((s >> 23) & 0xff) - 127); + int32_t exp = int32_t((s >> 23) & 0xff) - 127; int mantissa = s & 0x7fffff; uint16_t u = 0; @@ -386,8 +386,7 @@ struct NumericConverter { if (exp >= -14) { // normal fp32 to normal fp16 - exp = uint16_t(exp + uint16_t(15)); - u = uint16_t(((exp & 0x1f) << 10)); + u = uint16_t((uint32_t(exp + 15) & 0x1f) << 10); u = uint16_t(u | (mantissa >> 13)); } else { // normal single-precision to subnormal half_t-precision representation