fix epilogue register spill
This commit is contained in:
parent
4516b833ce
commit
a77c658439
@ -703,7 +703,7 @@ private:
|
|||||||
|
|
||||||
int output_row = destination_iterator.thread_start_row() + row_offset;
|
int output_row = destination_iterator.thread_start_row() + row_offset;
|
||||||
|
|
||||||
fetch = (output_row < destination_iterator.extent().row() && column_guard);
|
fetch = (output_row < destination_iterator.extent_row() && column_guard);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
fetch = true;
|
fetch = true;
|
||||||
@ -785,7 +785,7 @@ private:
|
|||||||
|
|
||||||
int output_row = destination_iterator.thread_start_row() + row_offset;
|
int output_row = destination_iterator.thread_start_row() + row_offset;
|
||||||
|
|
||||||
fetch = (output_row < destination_iterator.extent().row() && column_guard);
|
fetch = (output_row < destination_iterator.extent_row() && column_guard);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
fetch = true;
|
fetch = true;
|
||||||
|
@ -172,7 +172,7 @@ private:
|
|||||||
Mask mask_;
|
Mask mask_;
|
||||||
|
|
||||||
/// Extent of the matrix tile in rows
|
/// Extent of the matrix tile in rows
|
||||||
TensorCoord extent_;
|
Index extent_row_;
|
||||||
|
|
||||||
/// A thread's starting row position (assuming steady-state predicates have been computed)
|
/// A thread's starting row position (assuming steady-state predicates have been computed)
|
||||||
Index thread_start_row_;
|
Index thread_start_row_;
|
||||||
@ -184,7 +184,7 @@ private:
|
|||||||
// Static asserts about internal strides
|
// Static asserts about internal strides
|
||||||
//
|
//
|
||||||
|
|
||||||
static_assert(sizeof(extent_.row()) == 4, "Expected 32b extents");
|
static_assert(sizeof(extent_row_) == 4, "Expected 32b extents");
|
||||||
static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents");
|
static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents");
|
||||||
static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides");
|
static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides");
|
||||||
|
|
||||||
@ -209,12 +209,12 @@ public:
|
|||||||
int thread_idx,
|
int thread_idx,
|
||||||
TensorCoord threadblock_offset = TensorCoord()
|
TensorCoord threadblock_offset = TensorCoord()
|
||||||
):
|
):
|
||||||
params_(params),
|
params_(params)
|
||||||
extent_(extent)
|
|
||||||
{
|
{
|
||||||
|
|
||||||
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
|
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
|
||||||
|
|
||||||
|
extent_row_ = extent.row();
|
||||||
thread_start_row_ = thread_offset.row();
|
thread_start_row_ = thread_offset.row();
|
||||||
|
|
||||||
// Initialize predicates
|
// Initialize predicates
|
||||||
@ -222,7 +222,7 @@ public:
|
|||||||
for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
|
for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
|
||||||
|
|
||||||
mask_.predicates[c] = ((thread_offset.column()
|
mask_.predicates[c] = ((thread_offset.column()
|
||||||
+ ThreadMap::Delta::kColumn * c) < extent_.column());
|
+ ThreadMap::Delta::kColumn * c) < extent.column());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Null pointer performs no accesses
|
// Null pointer performs no accesses
|
||||||
@ -268,7 +268,7 @@ public:
|
|||||||
+ group * ThreadMap::Delta::kGroup
|
+ group * ThreadMap::Delta::kGroup
|
||||||
+ cluster * ThreadMap::Delta::kCluster;
|
+ cluster * ThreadMap::Delta::kCluster;
|
||||||
|
|
||||||
bool row_guard = ((row_offset + thread_start_row_) < extent_.row());
|
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
|
||||||
|
|
||||||
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
|
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
|
||||||
|
|
||||||
@ -332,7 +332,7 @@ public:
|
|||||||
+ group * ThreadMap::Delta::kGroup
|
+ group * ThreadMap::Delta::kGroup
|
||||||
+ cluster * ThreadMap::Delta::kCluster;
|
+ cluster * ThreadMap::Delta::kCluster;
|
||||||
|
|
||||||
bool row_guard = ((row_offset + thread_start_row_) < extent_.row());
|
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
|
||||||
|
|
||||||
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
|
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
|
||||||
|
|
||||||
@ -379,8 +379,8 @@ public:
|
|||||||
|
|
||||||
/// Extent of the matrix in rows
|
/// Extent of the matrix in rows
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
TensorCoord extent() const {
|
Index extent_row() const {
|
||||||
return extent_;
|
return extent_row_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Advances to the next position to load or store
|
/// Advances to the next position to load or store
|
||||||
|
Loading…
Reference in New Issue
Block a user