Fixes for public issue #265
This commit is contained in:
parent
b68113f5be
commit
da2f110906
@ -186,11 +186,14 @@ public:
|
|||||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||||
|
|
||||||
/// Clamping constant value
|
/// Clamping constant value
|
||||||
ElementCompute const kClamp =
|
ElementCompute const kClampMax =
|
||||||
ElementCompute((1U << (sizeof_bits<ElementOutput>::value - 1)) - 1);
|
ElementCompute(platform::numeric_limits<ElementOutput>::max());
|
||||||
|
|
||||||
intermediate = max_accumulator(intermediate, -kClamp - ElementCompute(1));
|
ElementCompute const kClampMin =
|
||||||
intermediate = min_accumulator(intermediate, kClamp);
|
ElementCompute(platform::numeric_limits<ElementOutput>::lowest());
|
||||||
|
|
||||||
|
intermediate = max_accumulator(intermediate, kClampMin);
|
||||||
|
intermediate = min_accumulator(intermediate, kClampMax);
|
||||||
|
|
||||||
// Convert to destination numeric type
|
// Convert to destination numeric type
|
||||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||||
|
@ -69,6 +69,7 @@ enum
|
|||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/platform/platform.h"
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
@ -545,9 +546,9 @@ half_t copysign(half_t const& a, half_t const& b) {
|
|||||||
//
|
//
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
namespace std {
|
namespace cutlass {
|
||||||
|
namespace platform {
|
||||||
|
|
||||||
#if !defined(__CUDACC_RTC__)
|
|
||||||
/// Numeric limits
|
/// Numeric limits
|
||||||
template <>
|
template <>
|
||||||
struct numeric_limits<cutlass::half_t> {
|
struct numeric_limits<cutlass::half_t> {
|
||||||
@ -593,9 +594,9 @@ struct numeric_limits<cutlass::half_t> {
|
|||||||
/// Returns smallest finite value
|
/// Returns smallest finite value
|
||||||
static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); }
|
static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); }
|
||||||
};
|
};
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace std
|
} // namespace platform
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
//
|
//
|
||||||
|
@ -200,4 +200,25 @@ struct sizeof_bits<uint4b_t> {
|
|||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct numeric_limits<cutlass::int4b_t> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static cutlass::int4b_t const lowest() noexcept { return -8;}
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static cutlass::int4b_t const max() noexcept { return 7;}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct numeric_limits<cutlass::uint4b_t> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static cutlass::uint4b_t const lowest() noexcept { return 0;}
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static cutlass::uint4b_t const max() noexcept { return 15;}
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace platform
|
||||||
} // namespace cutlass
|
} // namespace cutlass
|
||||||
|
@ -483,69 +483,16 @@ struct NumericConverterClamp {
|
|||||||
using result_type = T;
|
using result_type = T;
|
||||||
using source_type = S;
|
using source_type = S;
|
||||||
|
|
||||||
static_assert((platform::is_same<result_type, int32_t>::value ||
|
|
||||||
platform::is_same<result_type, int8_t>::value ||
|
|
||||||
platform::is_same<result_type, cutlass::int4b_t>::value),
|
|
||||||
"Clamp is only needed for integer types");
|
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
static result_type convert(source_type const & s) {
|
static result_type convert(source_type const & s) {
|
||||||
NumericConverter<result_type, source_type> convert_op;
|
NumericConverter<result_type, source_type> convert_op;
|
||||||
result_type const kClamp_max =
|
result_type const kClamp_max = platform::numeric_limits<result_type>::max();
|
||||||
(0x1U << (sizeof_bits<result_type>::value - 1)) - 1;
|
result_type const kClamp_min = platform::numeric_limits<result_type>::lowest();
|
||||||
result_type const kClamp_min = -kClamp_max - 1;
|
if (s < (source_type)kClamp_min)
|
||||||
bool is_int_min = !(s > kClamp_min);
|
return kClamp_min;
|
||||||
bool is_int_max = !(s < kClamp_max);
|
if (s > (source_type)kClamp_max)
|
||||||
return is_int_min ? kClamp_min : (is_int_max ? kClamp_max : convert_op(s));
|
return kClamp_max;
|
||||||
}
|
return convert_op(s);
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
|
||||||
result_type operator()(source_type const &s) {
|
|
||||||
return convert(s);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Partial specialization for clamping from a single-precision float.
|
|
||||||
template <
|
|
||||||
typename T
|
|
||||||
>
|
|
||||||
struct NumericConverterClamp<T, float> {
|
|
||||||
|
|
||||||
using result_type = T;
|
|
||||||
using source_type = float;
|
|
||||||
|
|
||||||
static_assert((platform::is_same<result_type, int32_t>::value ||
|
|
||||||
platform::is_same<result_type, int16_t>::value ||
|
|
||||||
platform::is_same<result_type, uint16_t>::value ||
|
|
||||||
platform::is_same<result_type, int8_t>::value ||
|
|
||||||
platform::is_same<result_type, uint8_t>::value ||
|
|
||||||
platform::is_same<result_type, cutlass::int4b_t>::value ||
|
|
||||||
platform::is_same<result_type, cutlass::uint4b_t>::value),
|
|
||||||
"Clamp is only needed for integer types");
|
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
|
||||||
static result_type convert(source_type const & s) {
|
|
||||||
|
|
||||||
NumericConverter<result_type, double> convert_op;
|
|
||||||
double kClamp_max, kClamp_min;
|
|
||||||
|
|
||||||
if (platform::is_same<result_type, int32_t>::value ||
|
|
||||||
platform::is_same<result_type, int16_t>::value ||
|
|
||||||
platform::is_same<result_type, int8_t>::value ||
|
|
||||||
platform::is_same<result_type, cutlass::int4b_t>::value) {
|
|
||||||
kClamp_max = double((1LLU << (sizeof_bits<result_type>::value - 1)) - 1);
|
|
||||||
kClamp_min = -kClamp_max - 1;
|
|
||||||
} else {
|
|
||||||
kClamp_max = double((1LLU << (sizeof_bits<result_type>::value)) - 1);
|
|
||||||
kClamp_min = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
double source = s;
|
|
||||||
|
|
||||||
source = fmax(source, kClamp_min);
|
|
||||||
source = fmin(source, kClamp_max);
|
|
||||||
|
|
||||||
return convert_op(source);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
|
@ -783,5 +783,58 @@ void swap(unique_ptr<T, Deleter>& lhs, unique_ptr<T, Deleter>& rhs) noexcept {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/// std::numeric_limits
|
||||||
|
template <class T>
|
||||||
|
struct numeric_limits;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct numeric_limits<int32_t> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static constexpr int32_t lowest() noexcept { return -2147483647 - 1;}
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static constexpr int32_t max() noexcept { return 2147483647;}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct numeric_limits<int16_t> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static constexpr int16_t lowest() noexcept { return -32768;}
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static constexpr int16_t max() noexcept { return 32767;}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct numeric_limits<int8_t> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static constexpr int8_t lowest() noexcept { return -128;}
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static constexpr int8_t max() noexcept { return 127;}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct numeric_limits<uint32_t> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static constexpr uint32_t lowest() noexcept { return 0;}
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static constexpr uint32_t max() noexcept { return 4294967295;}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct numeric_limits<uint16_t> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static constexpr uint16_t lowest() noexcept { return 0;}
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static constexpr uint16_t max() noexcept { return 65535;}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct numeric_limits<uint8_t> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static constexpr uint8_t lowest() noexcept { return 0;}
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static constexpr uint8_t max() noexcept { return 255;}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace platform
|
} // namespace platform
|
||||||
} // namespace cutlass
|
} // namespace cutlass
|
||||||
|
Loading…
Reference in New Issue
Block a user