add is_last_tile

This commit is contained in:
Haicheng Wu 2024-10-17 12:11:02 -07:00
parent 53668799b2
commit 755194a7bd
2 changed files with 25 additions and 0 deletions

View File

@ -308,6 +308,21 @@ public:
current_work_linear_idx_ += uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z) * uint64_t(advance_count);
}
CUTLASS_DEVICE
bool is_last_tile(WorkTileInfo work_tile_info, uint32_t advance_count = 1) const {
// Never pass this by reference; it needs a copy,
// because continue_current_work will modify it.
if (continue_current_work(work_tile_info)) {
return false;
}
return not get_current_work_for_linear_idx(
current_work_linear_idx_ + (
uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z) * uint64_t(advance_count)
),
scheduler_params
).is_valid();
}
// Given the inputs, computes the total number of output blocks this problem will compute over
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
template <class ProblemShape>

View File

@ -193,6 +193,16 @@ public:
current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count);
}
CUTLASS_DEVICE
bool is_last_tile(WorkTileInfo& work_tile_info, uint32_t advance_count = 1) const {
if (continue_current_work(work_tile_info)) {
return false;
}
return not get_current_work_for_linear_idx(
current_work_linear_idx_ + (total_grid_size_ * uint64_t(advance_count))
).is_valid();
}
// Computes the linear index within a batch given M and N tile offsets within the batch.
// This essentially inverts the mapping performed in get_work_idx_m_and_n
static CUTLASS_DEVICE