diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h index 46690bf1..e6e6d70f 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h @@ -2315,7 +2315,17 @@ class MmaTensorOpMultiplicandTileIterator< int access_contiguous_idx = -1; int access_strided_idx = -1; - if (Layout::kFactor == 4) { + if (Layout::kFactor == 8) { + int factor_in_partition = + (Layout::PartitionShape::kContiguous * Layout::kFactor / + Layout::TileShape::kContiguous); + + if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) { + partition_contiguous_idx = lane_in_quad_pair / factor_in_partition; + access_contiguous_idx = ((lane_in_quad) ^ (lane_id / Layout::kFactor)); + access_strided_idx = lane_id / Layout::kFactor; + } + } else if (Layout::kFactor == 4) { // Super Integer matrix multiply Interleaved-32 int factor_in_partition =