简单修改一下。
This commit is contained in:
parent
a43baa8b7f
commit
bf81e39d83
@ -81,7 +81,7 @@ __global__ void test_cute_tensor_kernel()
|
|||||||
Stride<_32, _2>{});
|
Stride<_32, _2>{});
|
||||||
Layout smem_layout = make_layout(make_shape(Int<4>{}, Int<8>{}));
|
Layout smem_layout = make_layout(make_shape(Int<4>{}, Int<8>{}));
|
||||||
__shared__ float smem[decltype(cosize(smem_layout))::value]; // (static-only allocation)
|
__shared__ float smem[decltype(cosize(smem_layout))::value]; // (static-only allocation)
|
||||||
// printf("smem size is :%d\n", decltype(cosize(smem_layout))::value);
|
printf("smem size is :%d\n", decltype(cosize(smem_layout))::value);
|
||||||
Tensor stensor = make_tensor(make_smem_ptr(smem), smem_layout);
|
Tensor stensor = make_tensor(make_smem_ptr(smem), smem_layout);
|
||||||
printf("tensor size is: %d, ind size is: %d, rmem size is: %d , rmem4x8 is: %d, smem size is: %d\n",
|
printf("tensor size is: %d, ind size is: %d, rmem size is: %d , rmem4x8 is: %d, smem size is: %d\n",
|
||||||
bool_tensor.size(),
|
bool_tensor.size(),
|
||||||
@ -92,7 +92,10 @@ __global__ void test_cute_tensor_kernel()
|
|||||||
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, float>{}, // Atom: Copy TAs as if they were uint128_t
|
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, float>{}, // Atom: Copy TAs as if they were uint128_t
|
||||||
Layout<Shape<_32, _8>>{}, // Thr layout 32x8 m-major
|
Layout<Shape<_32, _8>>{}, // Thr layout 32x8 m-major
|
||||||
Layout<Shape<_4, _1>>{}); // Val layout 4x1 m-major
|
Layout<Shape<_4, _1>>{}); // Val layout 4x1 m-major
|
||||||
|
printf("stensor size 1 is %d\n", cute::size<1>(stensor));
|
||||||
|
#if 0
|
||||||
print_latex(copyA);
|
print_latex(copyA);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// template <int head, int batch = 0, int head_dim = 0>
|
// template <int head, int batch = 0, int head_dim = 0>
|
||||||
|
|||||||
@ -5,6 +5,9 @@
|
|||||||
#include <cutlass/array.h>
|
#include <cutlass/array.h>
|
||||||
#include <cutlass/numeric_types.h>
|
#include <cutlass/numeric_types.h>
|
||||||
|
|
||||||
|
__device__ void mm_device(const float *src)
|
||||||
|
{
|
||||||
|
}
|
||||||
__global__ void md_mm_kernel(const float *src, int stride_a, int stride_b, int stride_c, int thread_num)
|
__global__ void md_mm_kernel(const float *src, int stride_a, int stride_b, int stride_c, int thread_num)
|
||||||
{
|
{
|
||||||
int batch_idx = blockIdx.x;
|
int batch_idx = blockIdx.x;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user