Addressed code review comments.
This commit is contained in:
parent
e18292db46
commit
fb8b3a98b7
@ -82,9 +82,9 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int>
|
|||||||
int const* a_int = reinterpret_cast<int const*>(&a[0]);
|
int const* a_int = reinterpret_cast<int const*>(&a[0]);
|
||||||
int const* b_int = reinterpret_cast<int const*>(&b[0]);
|
int const* b_int = reinterpret_cast<int const*>(&b[0]);
|
||||||
|
|
||||||
#pragma unroll
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||||
#pragma unroll
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||||
|
|
||||||
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
|
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
|
||||||
|
@ -35,7 +35,7 @@ struct DefaultBlockSwizzle {
|
|||||||
CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {}
|
CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {}
|
||||||
|
|
||||||
/// Swizzle the block index.
|
/// Swizzle the block index.
|
||||||
CUTLASS_DEVICE dim3 swizzle() { return {blockIdx.x, blockIdx.y, blockIdx.z}; }
|
CUTLASS_DEVICE dim3 swizzle() { return dim3(blockIdx.x, blockIdx.y, blockIdx.z); }
|
||||||
|
|
||||||
///
|
///
|
||||||
CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size,
|
CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size,
|
||||||
|
@ -40,11 +40,7 @@
|
|||||||
|
|
||||||
#include "stdio.h"
|
#include "stdio.h"
|
||||||
|
|
||||||
#if CUDA_VERSION >= 10000
|
|
||||||
#include <mma.h>
|
#include <mma.h>
|
||||||
#else
|
|
||||||
#include <mma.h>
|
|
||||||
#endif
|
|
||||||
#include "cutlass/fragment.h"
|
#include "cutlass/fragment.h"
|
||||||
#include "cutlass/matrix_traits.h"
|
#include "cutlass/matrix_traits.h"
|
||||||
#include "cutlass/shape.h"
|
#include "cutlass/shape.h"
|
||||||
|
@ -235,7 +235,9 @@ int fpclassify(cutlass::half_t const&); /// returns a flag classifying floating
|
|||||||
bool signbit(cutlass::half_t const&); /// returns true if negative, false if positive
|
bool signbit(cutlass::half_t const&); /// returns true if negative, false if positive
|
||||||
|
|
||||||
cutlass::half_t sqrt(cutlass::half_t const&); /// square root of half_t
|
cutlass::half_t sqrt(cutlass::half_t const&); /// square root of half_t
|
||||||
|
#if __cplusplus >= 201103L
|
||||||
cutlass::half_t copysign(cutlass::half_t const&, cutlass::half_t const&);
|
cutlass::half_t copysign(cutlass::half_t const&, cutlass::half_t const&);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace std {
|
namespace std {
|
||||||
@ -745,8 +747,11 @@ inline bool signbit(cutlass::half_t const& h) { return h.signbit(); }
|
|||||||
inline cutlass::half_t sqrt(cutlass::half_t const& h) {
|
inline cutlass::half_t sqrt(cutlass::half_t const& h) {
|
||||||
return cutlass::half_t(std::sqrt(float(h)));
|
return cutlass::half_t(std::sqrt(float(h)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if __cplusplus >= 201103L
|
||||||
inline cutlass::half_t copysign(cutlass::half_t const& a,
|
inline cutlass::half_t copysign(cutlass::half_t const& a,
|
||||||
cutlass::half_t const& b) {
|
cutlass::half_t const& b) {
|
||||||
return cutlass::half_t(std::copysign(float(a), float(b)));
|
return cutlass::half_t(std::copysign(float(a), float(b)));
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
} // namespace std
|
} // namespace std
|
||||||
|
Loading…
Reference in New Issue
Block a user