diff --git a/include/cutlass/conv/conv2d_problem_size.h b/include/cutlass/conv/conv2d_problem_size.h index 7a1bc92c..7d0c86f4 100644 --- a/include/cutlass/conv/conv2d_problem_size.h +++ b/include/cutlass/conv/conv2d_problem_size.h @@ -537,12 +537,12 @@ void strided_dgrad_starting_coords( // function locals for remainder by fast divmod int pad_h_rem_, pad_w_rem_; - // start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h; + // start_h = platform::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h; stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h); int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r)); stride_h_divmod.divmod(start_h, r_); - //start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w; + //start_w = platform::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w; stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w); int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s)); stride_w_divmod.divmod(start_w, s_); diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index a7719921..ff6e3db6 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -55,6 +55,7 @@ * (2) Re-implementations of STL functions and types: * - C++ features that need the \p __device__ annotation. These are * placed into the \p platform namespace. + * - \p abs * - \p plus * - \p less * - \p greater @@ -184,6 +185,22 @@ namespace cutlass { namespace platform { +//----------------------------------------------------------------------------- +// Abs operations +//----------------------------------------------------------------------------- + +#if defined(__CUDACC_RTC__) +/// std::abs +CUTLASS_HOST_DEVICE constexpr int abs(int a) { + return (a < 0) ? -a : a; +} +CUTLASS_HOST_DEVICE constexpr long long abs(long long a) { + return (a < 0) ? -a : a; +} +#else +using std::abs; +#endif + //----------------------------------------------------------------------------- // Minimum/maximum operations //----------------------------------------------------------------------------- @@ -435,7 +452,6 @@ struct is_base_of typename remove_cv::type>::value) || (is_same::type, typename remove_cv::type>::value)> {}; - #else using std::is_same;