From 19f51596e8be9fe87d583616466581ab5740c19d Mon Sep 17 00:00:00 2001 From: chenwei <15601910741@163.com> Date: Tue, 29 Oct 2024 23:56:59 +0800 Subject: [PATCH] feat: support kFactor 8 used in mma tensor op tile iterator (#1512) --- .../cutlass/gemm/warp/mma_tensor_op_tile_iterator.h | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h index 46690bf1..e6e6d70f 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h @@ -2315,7 +2315,17 @@ class MmaTensorOpMultiplicandTileIterator< int access_contiguous_idx = -1; int access_strided_idx = -1; - if (Layout::kFactor == 4) { + if (Layout::kFactor == 8) { + int factor_in_partition = + (Layout::PartitionShape::kContiguous * Layout::kFactor / + Layout::TileShape::kContiguous); + + if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) { + partition_contiguous_idx = lane_in_quad_pair / factor_in_partition; + access_contiguous_idx = ((lane_in_quad) ^ (lane_id / Layout::kFactor)); + access_strided_idx = lane_id / Layout::kFactor; + } + } else if (Layout::kFactor == 4) { // Super Integer matrix multiply Interleaved-32 int factor_in_partition =