From 81b06ee0e0690f7fd3a662c32949caef8fa9bfe3 Mon Sep 17 00:00:00 2001 From: Andy Lo Date: Wed, 10 Jul 2024 16:06:29 +0100 Subject: [PATCH] Fix B operand variable name and comments (#1458) --- include/cute/atom/mma_atom.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 9e5c93f2..d164a954 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -338,12 +338,12 @@ struct TiledMMA : MMA_Atom auto t_tensor = logical_divide(btensor, t_tile); // (PermN,PermK) // Tile the tensor for the Atom - auto a_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), + auto b_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), make_layout(size<2>(AtomShape_MNK{}))); - auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomN,AtomK),(RestN,RestK)) + auto b_tensor = zipped_divide(t_tensor, b_tile); // ((AtomN,AtomK),(RestN,RestK)) - // Transform the Atom mode from (M,K) to (Thr,Val) - auto tv_tensor = a_tensor.compose(AtomLayoutB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) + // Transform the Atom mode from (N,K) to (Thr,Val) + auto tv_tensor = b_tensor.compose(AtomLayoutB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) // Tile the tensor for the Thread auto thr_tile = make_tile(_, @@ -492,7 +492,7 @@ struct TiledMMA : MMA_Atom // (bthrid,val) -> (N,K) auto layoutB_TV = thrfrg_B(ref_B); - // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + // (ThrV,(ThrN,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) auto btile = make_tile(_, make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk_), size<2>(thr_layout_vmnk_)), make_stride( Int<0>{} , Int<1>{} )),