Fix isnan namespace qualification in cutlass/functional.h (#1679)

* Fix unrelated MSVC build warnings

* Fix use of isnan in functional.h

Correct namespace qualification of isnan in functional.h
so that it invokes cutlass::isnan for half_t, instead of
converting half_t to float and invoking std::isnan (on host,
or ::isnan on device).
This commit is contained in:
Mark Hoemmen 2024-08-05 12:28:13 -06:00 committed by GitHub
parent 06b21349bc
commit 19b4c5e065
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 153 additions and 37 deletions

View File

@ -95,6 +95,44 @@ CUTLASS_HOST_DEVICE void __CUTLASS_UNUSED(T const &)
#endif
#endif
// CUTLASS_CMATH_NAMESPACE is the namespace where code can find
// <cmath> functions like isnan and log. Such functions are in
// the std namespace in host code, but in the global namespace
// in device code.
//
// The intended use case for this macro is in "using" declarations
// for making argument-dependent lookup (ADL) work in generic code.
// For example, if T is cutlass::half_t, the following code will
// invoke cutlass::isnan(half_t). If T is float, it will invoke
// std::isnan on host and ::isnan on device. (CUTLASS's support
// for NVRTC prevents it from using things in the std namespace
// in device code.) Correct use of "using" declarations can help
// avoid unexpected implicit conversions, like from half_t to float.
//
// template<class T>
// bool foo(T x) {
// using CUTLASS_CMATH_NAMESPACE :: isnan;
// return isnan(x);
// }
//
// Without this macro, one would need to write the following.
//
// template<class T>
// bool foo(T x) {
// #if defined(__CUDA_ARCH__)
// using ::isnan;
// #else
// using std::isnan;
// #endif
// return isnan(x);
// }
#if defined(__CUDA_ARCH__)
# define CUTLASS_CMATH_NAMESPACE
#else
# define CUTLASS_CMATH_NAMESPACE std
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {

View File

@ -670,7 +670,8 @@ public:
// We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions
// Sync requirements of smem reuse may preclude this optimization
// Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD
int epi_m_prev = 0, epi_n_prev = 0;
[[maybe_unused]] int epi_m_prev = 0;
[[maybe_unused]] int epi_n_prev = 0;
static_assert(not (DelayTmaStore and ReuseSmemC and StagesC == StagesD), "This TMA epilogue configuration will deadlock");
// The TMA store sequence for one subtile iteration
@ -725,7 +726,7 @@ public:
for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) {
CUTLASS_PRAGMA_UNROLL
for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) {
bool is_first_iteration = epi_m == 0 && epi_n == 0;
[[maybe_unused]] bool is_first_iteration = epi_m == 0 && epi_n == 0;
bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1;
if (subtile_idx != -1 && (epi_n * static_cast<int>(size<2>(gD_epi)) + epi_m) != subtile_idx) {

View File

@ -369,11 +369,14 @@ template <typename T>
struct maximum<T, true> {
CUTLASS_HOST_DEVICE
T operator()(T const &lhs, T const &rhs) const {
#if defined(__CUDA_ARCH__)
return lhs > rhs or ::isnan(lhs) ? lhs : rhs;
#else
return lhs > rhs or std::isnan(lhs) ? lhs : rhs;
#endif
using CUTLASS_CMATH_NAMESPACE :: isnan;
// Call isnan unqualified, so argument-dependent lookup (ADL)
// will find overloads such as cutlass::isnan(half_t).
// Calling ::isnan or std::isnan directly would force
// implicit conversions to float of custom number types
// in the cutlass namespace (e.g., cutlass::half_t).
return lhs > rhs || isnan(lhs) ? lhs : rhs;
}
};
@ -389,15 +392,14 @@ template <>
struct maximum<float, true> {
CUTLASS_HOST_DEVICE
float operator()(float const lhs, float const rhs) const {
float res;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
float res;
asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs));
#elif defined(__CUDA_ARCH__)
res = lhs > rhs or ::isnan(lhs) ? lhs : rhs;
#else
res = lhs > rhs or std::isnan(lhs) ? lhs : rhs;
#endif
return res;
#else
using CUTLASS_CMATH_NAMESPACE :: isnan;
return lhs > rhs || isnan(lhs) ? lhs : rhs;
#endif
}
};
@ -427,11 +429,9 @@ template <typename T>
struct minimum<T, true> {
CUTLASS_HOST_DEVICE
T operator()(T const &lhs, T const &rhs) const {
#if defined(__CUDA_ARCH__)
return lhs < rhs or ::isnan(lhs) ? lhs : rhs;
#else
return lhs < rhs or std::isnan(lhs) ? lhs : rhs;
#endif
using CUTLASS_CMATH_NAMESPACE :: isnan;
return lhs < rhs || isnan(lhs) ? lhs : rhs;
}
};
@ -512,6 +512,8 @@ template <typename A, typename B = A, typename C = A>
struct guarded_multiply_add {
CUTLASS_HOST_DEVICE
C operator()(A const &a, B const &b, C const &c) const {
using CUTLASS_CMATH_NAMESPACE :: isnan;
if (isnan(a) || isnan(b)) {
return C(0);
}
@ -531,7 +533,10 @@ struct guarded_multiply_add<half_t, half_t, half_t> {
: "h"(*reinterpret_cast<uint16_t const*>(&a)), "h"(*reinterpret_cast<uint16_t const*>(&b)), "h"(*reinterpret_cast<uint16_t const*>(&c)));
return result;
#else
if (isnan(a) || isnan(b)) {
// Namespace-qualifying isnan as cutlass::isnan saves the compiler
// the trouble of argument-dependent lookup. Calling std::isnan or
// ::isnan here would result in unwanted implicit conversion to float.
if (cutlass::isnan(a) || cutlass::isnan(b)) {
return half_t(0);
}
return a * b + c;
@ -544,13 +549,9 @@ template <typename A, typename B = A, typename C = A>
struct guarded_multiply_add_relu0 {
CUTLASS_HOST_DEVICE
C operator()(A const &a, B const &b, C const &c) const {
if (
#if defined(__CUDA_ARCH__)
::isnan(a) || ::isnan(b)
#else
std::isnan(a) || std::isnan(b)
#endif
) {
using CUTLASS_CMATH_NAMESPACE :: isnan;
if (isnan(a) || isnan(b)) {
return C(0);
}
maximum<C> mx;
@ -569,13 +570,7 @@ struct guarded_multiply_add_relu0<half_t, half_t, half_t> {
: "h"(*reinterpret_cast<uint16_t const*>(&a)), "h"(*reinterpret_cast<uint16_t const*>(&b)), "h"(*reinterpret_cast<uint16_t const*>(&c)));
return result;
#else
if (
#if defined(__CUDA_ARCH__)
::isnan(a) || ::isnan(b)
#else
std::isnan(a) || std::isnan(b)
#endif
) {
if (cutlass::isnan(a) || cutlass::isnan(b)) {
return half_t(0);
}
maximum<half_t> mx;
@ -782,6 +777,10 @@ struct atomic_add
{
#if defined(__CUDA_ARCH__)
atomicAdd(ptr, data);
#else
CUTLASS_UNUSED(ptr);
CUTLASS_UNUSED(data);
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
@ -795,6 +794,7 @@ struct atomic_add<double>
#if !defined(__CUDA_ARCH__)
CUTLASS_UNUSED(ptr);
CUTLASS_UNUSED(data);
CUTLASS_NOT_IMPLEMENTED();
#elif (__CUDA_ARCH__ >= 600)
atomicAdd(ptr, data);
#else
@ -821,6 +821,7 @@ struct atomic_add<half2>
#if !defined(__CUDA_ARCH__) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600))
CUTLASS_UNUSED(ptr);
CUTLASS_UNUSED(data);
CUTLASS_NOT_IMPLEMENTED();
#else
// Vector-2 atomic reduction requires .target sm_60 or higher
uint32_t word = reinterpret_cast<const uint32_t&>(data);

View File

@ -491,4 +491,80 @@ TEST(Functional, multiply_add_quaternion_f32) {
Functional_multiply_add_QuaternionT<float>();
}
namespace cutlass_test {
__global__ void
test_cutlass_maximum(cutlass::half_t const* in1, cutlass::half_t const* in2, cutlass::half_t* out)
{
{
constexpr bool propagate_NaN = true;
cutlass::maximum<cutlass::half_t, propagate_NaN> op;
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0
&& blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
*out = op(*in1, *in2);
}
}
{
constexpr bool propagate_NaN = false;
cutlass::maximum<cutlass::half_t, propagate_NaN> op;
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0
&& blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
*out = op(*in1, *in2);
}
}
}
} // cutlass_test
// Test compilation on both host and device.
TEST(Functional, maximum_half_host_propagate_NaN) {
constexpr bool propagate_NaN = true;
cutlass::maximum<cutlass::half_t, propagate_NaN> op;
cutlass::half_t x(1.0f);
cutlass::half_t y(2.0f);
auto result = op(x, y);
static_assert(std::is_same_v<decltype(result), cutlass::half_t>);
EXPECT_EQ(result, y);
result = op(y, x);
EXPECT_EQ(result, y);
}
TEST(Functional, maximum_half_host_dont_propagate_NaN) {
constexpr bool propagate_NaN = false;
cutlass::maximum<cutlass::half_t, propagate_NaN> op;
cutlass::half_t x(1.0f);
cutlass::half_t y(2.0f);
auto result = op(x, y);
static_assert(std::is_same_v<decltype(result), cutlass::half_t>);
EXPECT_EQ(result, y);
result = op(y, x);
EXPECT_EQ(result, y);
}
TEST(Function, maximum_half_device) {
using Tensor = cutlass::HostTensor<cutlass::half_t, cutlass::layout::RowMajor>;
Tensor in1({1, 1});
Tensor in2({1, 1});
Tensor out({1, 1});
in1.host_data()[0] = cutlass::half_t(1.0f);
in2.host_data()[0] = cutlass::half_t(2.0f);
out.host_data()[0] = cutlass::half_t(0.0f);
in1.sync_device();
in2.sync_device();
out.sync_device();
cutlass_test::test_cutlass_maximum<<< 1, 1 >>>(
in1.device_data(),
in2.device_data(),
out.device_data()
);
out.sync_host();
EXPECT_EQ(out.host_data()[0], 2.0f);
}
/////////////////////////////////////////////////////////////////////////////////////////////////