Use platform:: instead of std::abs and std::conditional (#452)
* Fixed template struct/class mismatch * Use platform implementation instead of std::abs and std::conditional during nvrtc compilation * Use platform implementation instead of std::abs and std::conditional during nvrtc compilation * Revert absolute_value() usage
This commit is contained in:
parent
70f3ba57f5
commit
71def2f084
@ -537,12 +537,12 @@ void strided_dgrad_starting_coords(
|
||||
// function locals for remainder by fast divmod
|
||||
int pad_h_rem_, pad_w_rem_;
|
||||
|
||||
// start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
|
||||
// start_h = platform::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
|
||||
stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h);
|
||||
int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r));
|
||||
stride_h_divmod.divmod(start_h, r_);
|
||||
|
||||
//start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
|
||||
//start_w = platform::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
|
||||
stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w);
|
||||
int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s));
|
||||
stride_w_divmod.divmod(start_w, s_);
|
||||
|
@ -55,6 +55,7 @@
|
||||
* (2) Re-implementations of STL functions and types:
|
||||
* - C++ features that need the \p __device__ annotation. These are
|
||||
* placed into the \p platform namespace.
|
||||
* - \p abs
|
||||
* - \p plus
|
||||
* - \p less
|
||||
* - \p greater
|
||||
@ -184,6 +185,22 @@
|
||||
namespace cutlass {
|
||||
namespace platform {
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Abs operations <algorithm>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
/// std::abs
|
||||
CUTLASS_HOST_DEVICE constexpr int abs(int a) {
|
||||
return (a < 0) ? -a : a;
|
||||
}
|
||||
CUTLASS_HOST_DEVICE constexpr long long abs(long long a) {
|
||||
return (a < 0) ? -a : a;
|
||||
}
|
||||
#else
|
||||
using std::abs;
|
||||
#endif
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Minimum/maximum operations <algorithm>
|
||||
//-----------------------------------------------------------------------------
|
||||
@ -435,7 +452,6 @@ struct is_base_of
|
||||
typename remove_cv<DerivedT>::type>::value) ||
|
||||
(is_same<typename remove_cv<BaseT>::type,
|
||||
typename remove_cv<DerivedT>::type>::value)> {};
|
||||
|
||||
#else
|
||||
|
||||
using std::is_same;
|
||||
|
Loading…
Reference in New Issue
Block a user