feat: support kFactor 8 used in mma tensor op tile iterator (#1512)

This commit is contained in:
chenwei 2024-10-29 23:56:59 +08:00 committed by GitHub
parent e8a8b69365
commit 19f51596e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2315,7 +2315,17 @@ class MmaTensorOpMultiplicandTileIterator<
int access_contiguous_idx = -1;
int access_strided_idx = -1;
if (Layout::kFactor == 4) {
if (Layout::kFactor == 8) {
int factor_in_partition =
(Layout::PartitionShape::kContiguous * Layout::kFactor /
Layout::TileShape::kContiguous);
if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) {
partition_contiguous_idx = lane_in_quad_pair / factor_in_partition;
access_contiguous_idx = ((lane_in_quad) ^ (lane_id / Layout::kFactor));
access_strided_idx = lane_id / Layout::kFactor;
}
} else if (Layout::kFactor == 4) {
// Super Integer matrix multiply Interleaved-32
int factor_in_partition =