Addressed code review comments.

This commit is contained in:
Artem Belevich 2019-05-09 14:53:01 -07:00
parent e18292db46
commit fb8b3a98b7
4 changed files with 8 additions and 7 deletions

View File

@ -82,9 +82,9 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int>
int const* a_int = reinterpret_cast<int const*>(&a[0]);
int const* b_int = reinterpret_cast<int const*>(&b[0]);
#pragma unroll
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
#pragma unroll
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"

View File

@ -35,7 +35,7 @@ struct DefaultBlockSwizzle {
CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {}
/// Swizzle the block index.
CUTLASS_DEVICE dim3 swizzle() { return {blockIdx.x, blockIdx.y, blockIdx.z}; }
CUTLASS_DEVICE dim3 swizzle() { return dim3(blockIdx.x, blockIdx.y, blockIdx.z); }
///
CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size,

View File

@ -40,11 +40,7 @@
#include "stdio.h"
#if CUDA_VERSION >= 10000
#include <mma.h>
#else
#include <mma.h>
#endif
#include "cutlass/fragment.h"
#include "cutlass/matrix_traits.h"
#include "cutlass/shape.h"

View File

@ -235,7 +235,9 @@ int fpclassify(cutlass::half_t const&); /// returns a flag classifying floating
bool signbit(cutlass::half_t const&); /// returns true if negative, false if positive
cutlass::half_t sqrt(cutlass::half_t const&); /// square root of half_t
#if __cplusplus >= 201103L
cutlass::half_t copysign(cutlass::half_t const&, cutlass::half_t const&);
#endif
}
namespace std {
@ -745,8 +747,11 @@ inline bool signbit(cutlass::half_t const& h) { return h.signbit(); }
inline cutlass::half_t sqrt(cutlass::half_t const& h) {
return cutlass::half_t(std::sqrt(float(h)));
}
#if __cplusplus >= 201103L
inline cutlass::half_t copysign(cutlass::half_t const& a,
cutlass::half_t const& b) {
return cutlass::half_t(std::copysign(float(a), float(b)));
}
#endif
} // namespace std