add is_last_tile
This commit is contained in:
parent
53668799b2
commit
755194a7bd
@ -308,6 +308,21 @@ public:
|
|||||||
current_work_linear_idx_ += uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z) * uint64_t(advance_count);
|
current_work_linear_idx_ += uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z) * uint64_t(advance_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
bool is_last_tile(WorkTileInfo work_tile_info, uint32_t advance_count = 1) const {
|
||||||
|
// Never pass this by reference; it needs a copy,
|
||||||
|
// because continue_current_work will modify it.
|
||||||
|
if (continue_current_work(work_tile_info)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return not get_current_work_for_linear_idx(
|
||||||
|
current_work_linear_idx_ + (
|
||||||
|
uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z) * uint64_t(advance_count)
|
||||||
|
),
|
||||||
|
scheduler_params
|
||||||
|
).is_valid();
|
||||||
|
}
|
||||||
|
|
||||||
// Given the inputs, computes the total number of output blocks this problem will compute over
|
// Given the inputs, computes the total number of output blocks this problem will compute over
|
||||||
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
|
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
|
||||||
template <class ProblemShape>
|
template <class ProblemShape>
|
||||||
|
|||||||
@ -193,6 +193,16 @@ public:
|
|||||||
current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count);
|
current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
bool is_last_tile(WorkTileInfo& work_tile_info, uint32_t advance_count = 1) const {
|
||||||
|
if (continue_current_work(work_tile_info)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return not get_current_work_for_linear_idx(
|
||||||
|
current_work_linear_idx_ + (total_grid_size_ * uint64_t(advance_count))
|
||||||
|
).is_valid();
|
||||||
|
}
|
||||||
|
|
||||||
// Computes the linear index within a batch given M and N tile offsets within the batch.
|
// Computes the linear index within a batch given M and N tile offsets within the batch.
|
||||||
// This essentially inverts the mapping performed in get_work_idx_m_and_n
|
// This essentially inverts the mapping performed in get_work_idx_m_and_n
|
||||||
static CUTLASS_DEVICE
|
static CUTLASS_DEVICE
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user