diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp index d3dd1f5a..80b374ad 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp @@ -308,6 +308,21 @@ public: current_work_linear_idx_ += uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z) * uint64_t(advance_count); } + CUTLASS_DEVICE + bool is_last_tile(WorkTileInfo work_tile_info, uint32_t advance_count = 1) const { + // Never pass this by reference; it needs a copy, + // because continue_current_work will modify it. + if (continue_current_work(work_tile_info)) { + return false; + } + return not get_current_work_for_linear_idx( + current_work_linear_idx_ + ( + uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z) * uint64_t(advance_count) + ), + scheduler_params + ).is_valid(); + } + // Given the inputs, computes the total number of output blocks this problem will compute over // Note that this is only the logical size of our grid, not the physical grid we will actually launch. template diff --git a/include/cutlass/gemm/kernel/static_tile_scheduler.hpp b/include/cutlass/gemm/kernel/static_tile_scheduler.hpp index 2b61c737..67d346e3 100644 --- a/include/cutlass/gemm/kernel/static_tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/static_tile_scheduler.hpp @@ -193,6 +193,16 @@ public: current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count); } + CUTLASS_DEVICE + bool is_last_tile(WorkTileInfo& work_tile_info, uint32_t advance_count = 1) const { + if (continue_current_work(work_tile_info)) { + return false; + } + return not get_current_work_for_linear_idx( + current_work_linear_idx_ + (total_grid_size_ * uint64_t(advance_count)) + ).is_valid(); + } + // Computes the linear index within a batch given M and N tile offsets within the batch. // This essentially inverts the mapping performed in get_work_idx_m_and_n static CUTLASS_DEVICE