简单修改一下。
This commit is contained in:
parent
a43baa8b7f
commit
bf81e39d83
@ -81,7 +81,7 @@ __global__ void test_cute_tensor_kernel()
|
||||
Stride<_32, _2>{});
|
||||
Layout smem_layout = make_layout(make_shape(Int<4>{}, Int<8>{}));
|
||||
__shared__ float smem[decltype(cosize(smem_layout))::value]; // (static-only allocation)
|
||||
// printf("smem size is :%d\n", decltype(cosize(smem_layout))::value);
|
||||
printf("smem size is :%d\n", decltype(cosize(smem_layout))::value);
|
||||
Tensor stensor = make_tensor(make_smem_ptr(smem), smem_layout);
|
||||
printf("tensor size is: %d, ind size is: %d, rmem size is: %d , rmem4x8 is: %d, smem size is: %d\n",
|
||||
bool_tensor.size(),
|
||||
@ -92,7 +92,10 @@ __global__ void test_cute_tensor_kernel()
|
||||
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, float>{}, // Atom: Copy TAs as if they were uint128_t
|
||||
Layout<Shape<_32, _8>>{}, // Thr layout 32x8 m-major
|
||||
Layout<Shape<_4, _1>>{}); // Val layout 4x1 m-major
|
||||
printf("stensor size 1 is %d\n", cute::size<1>(stensor));
|
||||
#if 0
|
||||
print_latex(copyA);
|
||||
#endif
|
||||
}
|
||||
|
||||
// template <int head, int batch = 0, int head_dim = 0>
|
||||
|
||||
@ -5,6 +5,9 @@
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
__device__ void mm_device(const float *src)
|
||||
{
|
||||
}
|
||||
__global__ void md_mm_kernel(const float *src, int stride_a, int stride_b, int stride_c, int thread_num)
|
||||
{
|
||||
int batch_idx = blockIdx.x;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user