From e066ced33b39b9fa4074da43bab31aa24451f4cb Mon Sep 17 00:00:00 2001 From: ChangyouSiom <122081726+ChangyouSiom@users.noreply.github.com> Date: Tue, 11 Jul 2023 09:30:31 +0800 Subject: [PATCH] fix epilogue iterator error (#995) * fix epilogue iterator error * fix epilogue iterator error --------- Co-authored-by: maxiao --- .../threadblock/predicated_tile_iterator.h | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index c8b3c3bf..4da07d45 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -440,12 +440,16 @@ public: } if (group + 1 < ThreadMap::Iterations::kGroup) { - byte_pointer += params_.increment_group; + if (!ScatterD && !PermuteD) { + byte_pointer += params_.increment_group; + } } } if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; + if (!ScatterD && !PermuteD) { + byte_pointer += params_.increment_cluster; + } } } } @@ -650,8 +654,12 @@ public: state_[0] = 0; ++state_[1]; - byte_pointer_ += params_.advance_group; - store_byte_pointer_ += params_.advance_group; + if (!ScatterD) { + byte_pointer_ += params_.advance_group; + } + if (!ScatterD && !PermuteD) { + store_byte_pointer_ += params_.advance_group; + } thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; @@ -660,16 +668,24 @@ public: state_[1] = 0; ++state_[2]; - byte_pointer_ += params_.advance_cluster; - store_byte_pointer_ += params_.advance_cluster; + if (!ScatterD) { + byte_pointer_ += params_.advance_cluster; + } + if (!ScatterD && !PermuteD) { + store_byte_pointer_ += params_.advance_cluster; + } thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; - byte_pointer_ += params_.advance_tile; - store_byte_pointer_ += params_.advance_tile; + if (!ScatterD) { + byte_pointer_ += params_.advance_tile; + } + if (!ScatterD && !PermuteD) { + store_byte_pointer_ += params_.advance_tile; + } thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow * ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile;