fix epilogue register spill

This commit is contained in:
Haicheng Wu 2021-07-29 14:25:48 -07:00
parent 4516b833ce
commit a77c658439
2 changed files with 11 additions and 11 deletions

View File

@ -703,7 +703,7 @@ private:
int output_row = destination_iterator.thread_start_row() + row_offset;
fetch = (output_row < destination_iterator.extent().row() && column_guard);
fetch = (output_row < destination_iterator.extent_row() && column_guard);
}
else {
fetch = true;
@ -785,7 +785,7 @@ private:
int output_row = destination_iterator.thread_start_row() + row_offset;
fetch = (output_row < destination_iterator.extent().row() && column_guard);
fetch = (output_row < destination_iterator.extent_row() && column_guard);
}
else {
fetch = true;

View File

@ -172,7 +172,7 @@ private:
Mask mask_;
/// Extent of the matrix tile in rows
TensorCoord extent_;
Index extent_row_;
/// A thread's starting row position (assuming steady-state predicates have been computed)
Index thread_start_row_;
@ -184,7 +184,7 @@ private:
// Static asserts about internal strides
//
static_assert(sizeof(extent_.row()) == 4, "Expected 32b extents");
static_assert(sizeof(extent_row_) == 4, "Expected 32b extents");
static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents");
static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides");
@ -209,12 +209,12 @@ public:
int thread_idx,
TensorCoord threadblock_offset = TensorCoord()
):
params_(params),
extent_(extent)
params_(params)
{
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
extent_row_ = extent.row();
thread_start_row_ = thread_offset.row();
// Initialize predicates
@ -222,7 +222,7 @@ public:
for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
mask_.predicates[c] = ((thread_offset.column()
+ ThreadMap::Delta::kColumn * c) < extent_.column());
+ ThreadMap::Delta::kColumn * c) < extent.column());
}
// Null pointer performs no accesses
@ -268,7 +268,7 @@ public:
+ group * ThreadMap::Delta::kGroup
+ cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_.row());
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
@ -332,7 +332,7 @@ public:
+ group * ThreadMap::Delta::kGroup
+ cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_.row());
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
@ -379,8 +379,8 @@ public:
/// Extent of the matrix in rows
CUTLASS_DEVICE
TensorCoord extent() const {
return extent_;
Index extent_row() const {
return extent_row_;
}
/// Advances to the next position to load or store