diff --git a/setup.py b/setup.py index 89b5d27..3fb7666 100644 --- a/setup.py +++ b/setup.py @@ -136,6 +136,8 @@ ext_modules.append( "-std=c++17", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math",