add is_last_tile
This commit is contained in:
parent
53668799b2
commit
755194a7bd
@ -308,6 +308,21 @@ public:
|
||||
current_work_linear_idx_ += uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z) * uint64_t(advance_count);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_last_tile(WorkTileInfo work_tile_info, uint32_t advance_count = 1) const {
|
||||
// Never pass this by reference; it needs a copy,
|
||||
// because continue_current_work will modify it.
|
||||
if (continue_current_work(work_tile_info)) {
|
||||
return false;
|
||||
}
|
||||
return not get_current_work_for_linear_idx(
|
||||
current_work_linear_idx_ + (
|
||||
uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z) * uint64_t(advance_count)
|
||||
),
|
||||
scheduler_params
|
||||
).is_valid();
|
||||
}
|
||||
|
||||
// Given the inputs, computes the total number of output blocks this problem will compute over
|
||||
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
|
||||
template <class ProblemShape>
|
||||
|
||||
@ -193,6 +193,16 @@ public:
|
||||
current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_last_tile(WorkTileInfo& work_tile_info, uint32_t advance_count = 1) const {
|
||||
if (continue_current_work(work_tile_info)) {
|
||||
return false;
|
||||
}
|
||||
return not get_current_work_for_linear_idx(
|
||||
current_work_linear_idx_ + (total_grid_size_ * uint64_t(advance_count))
|
||||
).is_valid();
|
||||
}
|
||||
|
||||
// Computes the linear index within a batch given M and N tile offsets within the batch.
|
||||
// This essentially inverts the mapping performed in get_work_idx_m_and_n
|
||||
static CUTLASS_DEVICE
|
||||
|
||||
Loading…
Reference in New Issue
Block a user