Add print_svg for mma (#1733)
* add print_svg for mma * correct the code indentation
This commit is contained in:
parent
1ebda1ccef
commit
2991ce18d3
@ -928,6 +928,183 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and
|
||||
printf(latex_footer);
|
||||
}
|
||||
|
||||
// MNK MMA Layout to SVG -- 8-value color coded by thread
|
||||
template <class LayoutC, class ThrIDC,
|
||||
class LayoutA, class ThrIDA,
|
||||
class LayoutB, class ThrIDB>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_svg_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx
|
||||
LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx
|
||||
LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx
|
||||
{
|
||||
char const *color_map[8] = {"175,175,255", "175,255,175", "255,255,175",
|
||||
"255,175,175", "210,210,255", "210,255,210",
|
||||
"255,255,210", "255,210,210"};
|
||||
|
||||
const int cell_width = 20;
|
||||
const int cell_height = 20;
|
||||
|
||||
const int page_width = (size<1>(A) + size<0>(B) + 2) * cell_width;
|
||||
const int page_height = (size<1>(B) + size<0>(A) + 2) * cell_height;
|
||||
|
||||
// header
|
||||
printf("<svg width=\"100%%\" height=\"100%%\" viewBox=\"0 0 %d %d\" "
|
||||
"preserveAspectRatio=\"xMidYMid meet\" "
|
||||
"xmlns=\"http://www.w3.org/2000/svg\">\n",
|
||||
page_width, page_height);
|
||||
|
||||
// C
|
||||
int c_base_x = (size<1>(A) + 2) * cell_width;
|
||||
int c_base_y = (size<1>(B) + 2) * cell_height;
|
||||
for (int m = 0; m < cute::size<0>(C); ++m) {
|
||||
for (int n = 0; n < cute::size<1>(C); ++n) {
|
||||
|
||||
int thrid = C(m, n) % size(TC);
|
||||
int val_idx = C(m, n) / size(TC);
|
||||
int thr_idx = TC(thrid);
|
||||
|
||||
int x = n * cell_width + c_base_x;
|
||||
int y = m * cell_height + c_base_y;
|
||||
|
||||
int thr_x = x + cell_width / 2;
|
||||
int thr_y = y + cell_height / 4;
|
||||
int val_x = x + cell_width / 2;
|
||||
int val_y = y + cell_height * 3 / 4;
|
||||
|
||||
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
|
||||
"fill=\"rgb(%s)\" stroke=\"black\"/>\n",
|
||||
x, y, cell_width, cell_height, color_map[thr_idx % 8]);
|
||||
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
|
||||
thr_x, thr_y, thr_idx);
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
|
||||
val_x, val_y, val_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// A
|
||||
int a_base_x = cell_width;
|
||||
int a_base_y = (size<1>(B) + 2) * cell_height;
|
||||
for (int m = 0; m < size<0>(A); ++m) {
|
||||
for (int k = 0; k < size<1>(A); ++k) {
|
||||
int thrid = A(m, k) % size(TA);
|
||||
int val_idx = A(m, k) / size(TA);
|
||||
int thr_idx = TA(thrid);
|
||||
|
||||
int x = k * cell_width + a_base_x;
|
||||
int y = m * cell_height + a_base_y;
|
||||
|
||||
int thr_x = x + cell_width / 2;
|
||||
int thr_y = y + cell_height / 4;
|
||||
int val_x = x + cell_width / 2;
|
||||
int val_y = y + cell_height * 3 / 4;
|
||||
|
||||
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
|
||||
"fill=\"rgb(%s)\" stroke=\"black\" />\n",
|
||||
x, y, cell_width, cell_height, color_map[thr_idx % 8]);
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
|
||||
thr_x, thr_y, thr_idx);
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
|
||||
val_x, val_y, val_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// B
|
||||
int b_base_x = (size<1>(A) + 2) * cell_width;
|
||||
int b_base_y = cell_height;
|
||||
for (int n = 0; n < size<0>(B); ++n) {
|
||||
for (int k = 0; k < size<1>(B); ++k) {
|
||||
int thrid = B(n, k) % size(TB);
|
||||
int val_idx = B(n, k) / size(TB);
|
||||
int thr_idx = TB(thrid);
|
||||
|
||||
int x = n * cell_width + b_base_x;
|
||||
int y = k * cell_height + b_base_y;
|
||||
|
||||
int thr_x = x + cell_width / 2;
|
||||
int thr_y = y + cell_height / 4;
|
||||
int val_x = x + cell_width / 2;
|
||||
int val_y = y + cell_height * 3 / 4;
|
||||
|
||||
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
|
||||
"fill=\"rgb(%s)\" stroke=\"black\" />\n",
|
||||
x, y, cell_width, cell_height, color_map[thr_idx % 8]);
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
|
||||
thr_x, thr_y, thr_idx);
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
|
||||
val_x, val_y, val_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// A labels
|
||||
for (int m = 0; m < size<0>(A); ++m) {
|
||||
int x = cell_width / 2;
|
||||
int y = m * cell_height + cell_height / 2 + a_base_y;
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
|
||||
x, y, m);
|
||||
}
|
||||
for (int k = 0; k < size<1>(A); ++k) {
|
||||
int x = cell_width + k * cell_width + cell_width / 2;
|
||||
int y = -cell_height / 2 + a_base_y;
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
|
||||
x, y, k);
|
||||
}
|
||||
|
||||
// B labels
|
||||
for (int n = 0; n < size<0>(B); ++n) {
|
||||
int x = b_base_x + cell_width * n + cell_width / 2;
|
||||
int y = cell_height / 2;
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
|
||||
x, y, n);
|
||||
}
|
||||
for (int k = 0; k < size<1>(B); ++k) {
|
||||
int x = b_base_x - cell_width / 2;
|
||||
int y = cell_height * (k + 1) + cell_height / 2;
|
||||
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
|
||||
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
|
||||
x, y, k);
|
||||
}
|
||||
|
||||
// footer
|
||||
printf("</svg>");
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_svg(MMA_Atom<Args...> const &mma_atom) {
|
||||
print_svg(make_tiled_mma(mma_atom));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_svg(TiledMMA<Args...> const &mma) {
|
||||
auto layout_and_thrid_C = mma.get_layoutC_MN();
|
||||
auto layoutC_MN = get<0>(layout_and_thrid_C);
|
||||
auto thrID_C = get<1>(layout_and_thrid_C);
|
||||
|
||||
auto layout_and_thrid_A = mma.get_layoutA_MK();
|
||||
auto layoutA_MK = get<0>(layout_and_thrid_A);
|
||||
auto thrID_A = get<1>(layout_and_thrid_A);
|
||||
|
||||
auto layout_and_thrid_B = mma.get_layoutB_NK();
|
||||
auto layoutB_NK = get<0>(layout_and_thrid_B);
|
||||
auto thrID_B = get<1>(layout_and_thrid_B);
|
||||
|
||||
print_svg_mma(layoutC_MN, thrID_C, layoutA_MK, thrID_A, layoutB_NK, thrID_B);
|
||||
}
|
||||
|
||||
} // namespace cute
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Loading…
Reference in New Issue
Block a user