Use platform:: instead of std::abs and std::conditional (#452)
* Fixed template struct/class mismatch * Use platform implementation instead of std::abs and std::conditional during nvrtc compilation * Use platform implementation instead of std::abs and std::conditional during nvrtc compilation * Revert absolute_value() usage
This commit is contained in:
parent
70f3ba57f5
commit
71def2f084
@ -537,12 +537,12 @@ void strided_dgrad_starting_coords(
|
|||||||
// function locals for remainder by fast divmod
|
// function locals for remainder by fast divmod
|
||||||
int pad_h_rem_, pad_w_rem_;
|
int pad_h_rem_, pad_w_rem_;
|
||||||
|
|
||||||
// start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
|
// start_h = platform::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
|
||||||
stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h);
|
stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h);
|
||||||
int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r));
|
int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r));
|
||||||
stride_h_divmod.divmod(start_h, r_);
|
stride_h_divmod.divmod(start_h, r_);
|
||||||
|
|
||||||
//start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
|
//start_w = platform::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
|
||||||
stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w);
|
stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w);
|
||||||
int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s));
|
int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s));
|
||||||
stride_w_divmod.divmod(start_w, s_);
|
stride_w_divmod.divmod(start_w, s_);
|
||||||
|
@ -55,6 +55,7 @@
|
|||||||
* (2) Re-implementations of STL functions and types:
|
* (2) Re-implementations of STL functions and types:
|
||||||
* - C++ features that need the \p __device__ annotation. These are
|
* - C++ features that need the \p __device__ annotation. These are
|
||||||
* placed into the \p platform namespace.
|
* placed into the \p platform namespace.
|
||||||
|
* - \p abs
|
||||||
* - \p plus
|
* - \p plus
|
||||||
* - \p less
|
* - \p less
|
||||||
* - \p greater
|
* - \p greater
|
||||||
@ -184,6 +185,22 @@
|
|||||||
namespace cutlass {
|
namespace cutlass {
|
||||||
namespace platform {
|
namespace platform {
|
||||||
|
|
||||||
|
//-----------------------------------------------------------------------------
|
||||||
|
// Abs operations <algorithm>
|
||||||
|
//-----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
#if defined(__CUDACC_RTC__)
|
||||||
|
/// std::abs
|
||||||
|
CUTLASS_HOST_DEVICE constexpr int abs(int a) {
|
||||||
|
return (a < 0) ? -a : a;
|
||||||
|
}
|
||||||
|
CUTLASS_HOST_DEVICE constexpr long long abs(long long a) {
|
||||||
|
return (a < 0) ? -a : a;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
using std::abs;
|
||||||
|
#endif
|
||||||
|
|
||||||
//-----------------------------------------------------------------------------
|
//-----------------------------------------------------------------------------
|
||||||
// Minimum/maximum operations <algorithm>
|
// Minimum/maximum operations <algorithm>
|
||||||
//-----------------------------------------------------------------------------
|
//-----------------------------------------------------------------------------
|
||||||
@ -435,7 +452,6 @@ struct is_base_of
|
|||||||
typename remove_cv<DerivedT>::type>::value) ||
|
typename remove_cv<DerivedT>::type>::value) ||
|
||||||
(is_same<typename remove_cv<BaseT>::type,
|
(is_same<typename remove_cv<BaseT>::type,
|
||||||
typename remove_cv<DerivedT>::type>::value)> {};
|
typename remove_cv<DerivedT>::type>::value)> {};
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
using std::is_same;
|
using std::is_same;
|
||||||
|
Loading…
Reference in New Issue
Block a user