简单修改一下。

This commit is contained in:
longfei li 2024-11-18 22:13:43 +08:00
parent a43baa8b7f
commit bf81e39d83
3 changed files with 7 additions and 1 deletions

View File

@ -81,7 +81,7 @@ __global__ void test_cute_tensor_kernel()
Stride<_32, _2>{});
Layout smem_layout = make_layout(make_shape(Int<4>{}, Int<8>{}));
__shared__ float smem[decltype(cosize(smem_layout))::value]; // (static-only allocation)
// printf("smem size is :%d\n", decltype(cosize(smem_layout))::value);
printf("smem size is :%d\n", decltype(cosize(smem_layout))::value);
Tensor stensor = make_tensor(make_smem_ptr(smem), smem_layout);
printf("tensor size is: %d, ind size is: %d, rmem size is: %d , rmem4x8 is: %d, smem size is: %d\n",
bool_tensor.size(),
@ -92,7 +92,10 @@ __global__ void test_cute_tensor_kernel()
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, float>{}, // Atom: Copy TAs as if they were uint128_t
Layout<Shape<_32, _8>>{}, // Thr layout 32x8 m-major
Layout<Shape<_4, _1>>{}); // Val layout 4x1 m-major
printf("stensor size 1 is %d\n", cute::size<1>(stensor));
#if 0
print_latex(copyA);
#endif
}
// template <int head, int batch = 0, int head_dim = 0>

View File

@ -5,6 +5,9 @@
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
__device__ void mm_device(const float *src)
{
}
__global__ void md_mm_kernel(const float *src, int stride_a, int stride_b, int stride_c, int thread_num)
{
int batch_idx = blockIdx.x;

BIN
test

Binary file not shown.