diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 6dc826ef..2358dd56 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -928,6 +928,183 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and printf(latex_footer); } +// MNK MMA Layout to SVG -- 8-value color coded by thread +template +CUTE_HOST_DEVICE +void +print_svg_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx + LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx +{ + char const *color_map[8] = {"175,175,255", "175,255,175", "255,255,175", + "255,175,175", "210,210,255", "210,255,210", + "255,255,210", "255,210,210"}; + + const int cell_width = 20; + const int cell_height = 20; + + const int page_width = (size<1>(A) + size<0>(B) + 2) * cell_width; + const int page_height = (size<1>(B) + size<0>(A) + 2) * cell_height; + + // header + printf("\n", + page_width, page_height); + + // C + int c_base_x = (size<1>(A) + 2) * cell_width; + int c_base_y = (size<1>(B) + 2) * cell_height; + for (int m = 0; m < cute::size<0>(C); ++m) { + for (int n = 0; n < cute::size<1>(C); ++n) { + + int thrid = C(m, n) % size(TC); + int val_idx = C(m, n) / size(TC); + int thr_idx = TC(thrid); + + int x = n * cell_width + c_base_x; + int y = m * cell_height + c_base_y; + + int thr_x = x + cell_width / 2; + int thr_y = y + cell_height / 4; + int val_x = x + cell_width / 2; + int val_y = y + cell_height * 3 / 4; + + printf("\n", + x, y, cell_width, cell_height, color_map[thr_idx % 8]); + + printf("T%d\n", + thr_x, thr_y, thr_idx); + printf("V%d\n", + val_x, val_y, val_idx); + } + } + + // A + int a_base_x = cell_width; + int a_base_y = (size<1>(B) + 2) * cell_height; + for (int m = 0; m < size<0>(A); ++m) { + for (int k = 0; k < size<1>(A); ++k) { + int thrid = A(m, k) % size(TA); + int val_idx = A(m, k) / size(TA); + int thr_idx = TA(thrid); + + int x = k * cell_width + a_base_x; + int y = m * cell_height + a_base_y; + + int thr_x = x + cell_width / 2; + int thr_y = y + cell_height / 4; + int val_x = x + cell_width / 2; + int val_y = y + cell_height * 3 / 4; + + printf("\n", + x, y, cell_width, cell_height, color_map[thr_idx % 8]); + printf("T%d\n", + thr_x, thr_y, thr_idx); + printf("V%d\n", + val_x, val_y, val_idx); + } + } + + // B + int b_base_x = (size<1>(A) + 2) * cell_width; + int b_base_y = cell_height; + for (int n = 0; n < size<0>(B); ++n) { + for (int k = 0; k < size<1>(B); ++k) { + int thrid = B(n, k) % size(TB); + int val_idx = B(n, k) / size(TB); + int thr_idx = TB(thrid); + + int x = n * cell_width + b_base_x; + int y = k * cell_height + b_base_y; + + int thr_x = x + cell_width / 2; + int thr_y = y + cell_height / 4; + int val_x = x + cell_width / 2; + int val_y = y + cell_height * 3 / 4; + + printf("\n", + x, y, cell_width, cell_height, color_map[thr_idx % 8]); + printf("T%d\n", + thr_x, thr_y, thr_idx); + printf("V%d\n", + val_x, val_y, val_idx); + } + } + + // A labels + for (int m = 0; m < size<0>(A); ++m) { + int x = cell_width / 2; + int y = m * cell_height + cell_height / 2 + a_base_y; + printf("%d\n", + x, y, m); + } + for (int k = 0; k < size<1>(A); ++k) { + int x = cell_width + k * cell_width + cell_width / 2; + int y = -cell_height / 2 + a_base_y; + printf("%d\n", + x, y, k); + } + + // B labels + for (int n = 0; n < size<0>(B); ++n) { + int x = b_base_x + cell_width * n + cell_width / 2; + int y = cell_height / 2; + printf("%d\n", + x, y, n); + } + for (int k = 0; k < size<1>(B); ++k) { + int x = b_base_x - cell_width / 2; + int y = cell_height * (k + 1) + cell_height / 2; + printf("%d\n", + x, y, k); + } + + // footer + printf(""); +} + +template +CUTE_HOST_DEVICE +void +print_svg(MMA_Atom const &mma_atom) { + print_svg(make_tiled_mma(mma_atom)); +} + +template +CUTE_HOST_DEVICE +void +print_svg(TiledMMA const &mma) { + auto layout_and_thrid_C = mma.get_layoutC_MN(); + auto layoutC_MN = get<0>(layout_and_thrid_C); + auto thrID_C = get<1>(layout_and_thrid_C); + + auto layout_and_thrid_A = mma.get_layoutA_MK(); + auto layoutA_MK = get<0>(layout_and_thrid_A); + auto thrID_A = get<1>(layout_and_thrid_A); + + auto layout_and_thrid_B = mma.get_layoutB_NK(); + auto layoutB_NK = get<0>(layout_and_thrid_B); + auto thrID_B = get<1>(layout_and_thrid_B); + + print_svg_mma(layoutC_MN, thrID_C, layoutA_MK, thrID_A, layoutB_NK, thrID_B); +} + } // namespace cute ////////////////////////////////////////////////////////////////////////////////////////////////////