feat: support kFactor 8 used in mma tensor op tile iterator (#1512)
This commit is contained in:
parent
e8a8b69365
commit
19f51596e8
@ -2315,7 +2315,17 @@ class MmaTensorOpMultiplicandTileIterator<
|
|||||||
int access_contiguous_idx = -1;
|
int access_contiguous_idx = -1;
|
||||||
int access_strided_idx = -1;
|
int access_strided_idx = -1;
|
||||||
|
|
||||||
if (Layout::kFactor == 4) {
|
if (Layout::kFactor == 8) {
|
||||||
|
int factor_in_partition =
|
||||||
|
(Layout::PartitionShape::kContiguous * Layout::kFactor /
|
||||||
|
Layout::TileShape::kContiguous);
|
||||||
|
|
||||||
|
if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) {
|
||||||
|
partition_contiguous_idx = lane_in_quad_pair / factor_in_partition;
|
||||||
|
access_contiguous_idx = ((lane_in_quad) ^ (lane_id / Layout::kFactor));
|
||||||
|
access_strided_idx = lane_id / Layout::kFactor;
|
||||||
|
}
|
||||||
|
} else if (Layout::kFactor == 4) {
|
||||||
// Super Integer matrix multiply Interleaved-32
|
// Super Integer matrix multiply Interleaved-32
|
||||||
|
|
||||||
int factor_in_partition =
|
int factor_in_partition =
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user