diff --git a/examples/cute/tutorial/sgemm_1.cu b/examples/cute/tutorial/sgemm_1.cu index 46bf537a..e5bf9a92 100644 --- a/examples/cute/tutorial/sgemm_1.cu +++ b/examples/cute/tutorial/sgemm_1.cu @@ -80,7 +80,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, static_assert(is_static::value); CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M - CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K diff --git a/examples/cute/tutorial/sgemm_2.cu b/examples/cute/tutorial/sgemm_2.cu index b0d25bfe..ee2b6b2e 100644 --- a/examples/cute/tutorial/sgemm_2.cu +++ b/examples/cute/tutorial/sgemm_2.cu @@ -69,7 +69,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, static_assert(is_static::value); CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M - CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K diff --git a/examples/cute/tutorial/sgemm_sm70.cu b/examples/cute/tutorial/sgemm_sm70.cu index 8aba8132..ef6284cf 100644 --- a/examples/cute/tutorial/sgemm_sm70.cu +++ b/examples/cute/tutorial/sgemm_sm70.cu @@ -69,7 +69,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, static_assert(is_static::value); CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M - CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K diff --git a/examples/cute/tutorial/sgemm_sm80.cu b/examples/cute/tutorial/sgemm_sm80.cu index 3adb042b..e1211aac 100644 --- a/examples/cute/tutorial/sgemm_sm80.cu +++ b/examples/cute/tutorial/sgemm_sm80.cu @@ -69,7 +69,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, static_assert(is_static::value); CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M - CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K diff --git a/media/docs/cute/0x_gemm_tutorial.md b/media/docs/cute/0x_gemm_tutorial.md index 63a58784..7fe5f81c 100644 --- a/media/docs/cute/0x_gemm_tutorial.md +++ b/media/docs/cute/0x_gemm_tutorial.md @@ -188,7 +188,7 @@ As is evident, these smem layouts can be almost anything. Inside the kernel, the static_assert(is_static::value); CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M - CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K