Use platform:: instead of std::abs and std::conditional (#452)

* Fixed template struct/class mismatch

* Use platform implementation instead of std::abs and std::conditional during nvrtc compilation

* Use platform implementation instead of std::abs and std::conditional during nvrtc compilation

* Revert absolute_value() usage
This commit is contained in:
Stepan Tezyunichev 2022-04-25 21:40:22 +03:00 committed by GitHub
parent 70f3ba57f5
commit 71def2f084
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 3 deletions

View File

@ -537,12 +537,12 @@ void strided_dgrad_starting_coords(
// function locals for remainder by fast divmod
int pad_h_rem_, pad_w_rem_;
// start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
// start_h = platform::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h);
int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r));
stride_h_divmod.divmod(start_h, r_);
//start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
//start_w = platform::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w);
int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s));
stride_w_divmod.divmod(start_w, s_);

View File

@ -55,6 +55,7 @@
* (2) Re-implementations of STL functions and types:
* - C++ features that need the \p __device__ annotation. These are
* placed into the \p platform namespace.
* - \p abs
* - \p plus
* - \p less
* - \p greater
@ -184,6 +185,22 @@
namespace cutlass {
namespace platform {
//-----------------------------------------------------------------------------
// Abs operations <algorithm>
//-----------------------------------------------------------------------------
#if defined(__CUDACC_RTC__)
/// std::abs
CUTLASS_HOST_DEVICE constexpr int abs(int a) {
return (a < 0) ? -a : a;
}
CUTLASS_HOST_DEVICE constexpr long long abs(long long a) {
return (a < 0) ? -a : a;
}
#else
using std::abs;
#endif
//-----------------------------------------------------------------------------
// Minimum/maximum operations <algorithm>
//-----------------------------------------------------------------------------
@ -435,7 +452,6 @@ struct is_base_of
typename remove_cv<DerivedT>::type>::value) ||
(is_same<typename remove_cv<BaseT>::type,
typename remove_cv<DerivedT>::type>::value)> {};
#else
using std::is_same;