diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index c8b3c3bf..4da07d45 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -440,12 +440,16 @@ public: } if (group + 1 < ThreadMap::Iterations::kGroup) { - byte_pointer += params_.increment_group; + if (!ScatterD && !PermuteD) { + byte_pointer += params_.increment_group; + } } } if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; + if (!ScatterD && !PermuteD) { + byte_pointer += params_.increment_cluster; + } } } } @@ -650,8 +654,12 @@ public: state_[0] = 0; ++state_[1]; - byte_pointer_ += params_.advance_group; - store_byte_pointer_ += params_.advance_group; + if (!ScatterD) { + byte_pointer_ += params_.advance_group; + } + if (!ScatterD && !PermuteD) { + store_byte_pointer_ += params_.advance_group; + } thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; @@ -660,16 +668,24 @@ public: state_[1] = 0; ++state_[2]; - byte_pointer_ += params_.advance_cluster; - store_byte_pointer_ += params_.advance_cluster; + if (!ScatterD) { + byte_pointer_ += params_.advance_cluster; + } + if (!ScatterD && !PermuteD) { + store_byte_pointer_ += params_.advance_cluster; + } thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; - byte_pointer_ += params_.advance_tile; - store_byte_pointer_ += params_.advance_tile; + if (!ScatterD) { + byte_pointer_ += params_.advance_tile; + } + if (!ScatterD && !PermuteD) { + store_byte_pointer_ += params_.advance_tile; + } thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow * ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile;