diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 9e5c93f2..d164a954 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -338,12 +338,12 @@ struct TiledMMA : MMA_Atom auto t_tensor = logical_divide(btensor, t_tile); // (PermN,PermK) // Tile the tensor for the Atom - auto a_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), + auto b_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), make_layout(size<2>(AtomShape_MNK{}))); - auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomN,AtomK),(RestN,RestK)) + auto b_tensor = zipped_divide(t_tensor, b_tile); // ((AtomN,AtomK),(RestN,RestK)) - // Transform the Atom mode from (M,K) to (Thr,Val) - auto tv_tensor = a_tensor.compose(AtomLayoutB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) + // Transform the Atom mode from (N,K) to (Thr,Val) + auto tv_tensor = b_tensor.compose(AtomLayoutB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) // Tile the tensor for the Thread auto thr_tile = make_tile(_, @@ -492,7 +492,7 @@ struct TiledMMA : MMA_Atom // (bthrid,val) -> (N,K) auto layoutB_TV = thrfrg_B(ref_B); - // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + // (ThrV,(ThrN,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) auto btile = make_tile(_, make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk_), size<2>(thr_layout_vmnk_)), make_stride( Int<0>{} , Int<1>{} )),