From acdacc25927586234cb240a98499c35610372f52 Mon Sep 17 00:00:00 2001 From: longfei li Date: Fri, 27 Dec 2024 21:55:12 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E4=B8=80=E4=B8=8B=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- csrc/core.h | 2 ++ csrc/md.cu | 19 ++++++++++++------- test_reducemax.py | 2 +- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/csrc/core.h b/csrc/core.h index 815e0df..0afed83 100644 --- a/csrc/core.h +++ b/csrc/core.h @@ -1,6 +1,8 @@ #ifndef CORE_H #define CORE_H #include +#include +#include void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output); diff --git a/csrc/md.cu b/csrc/md.cu index 32bb37c..96c4ee0 100644 --- a/csrc/md.cu +++ b/csrc/md.cu @@ -115,13 +115,18 @@ __global__ void md_row_sum_kernel(const float *src, float *dest, int stride_a, i void md_block_sum(const torch::Tensor &src, torch::Tensor &dest) { int block_size = 1024; + dim3 grid(src.size(0), src.size(1), (src.size(2) + block_size - 1) / block_size); dim3 block(block_size); - md_row_sum_kernel<<>>(src.data_ptr(), - dest.data_ptr(), - src.stride(0), - src.stride(1), - src.size(0), - src.size(1), - src.size(2)); + + printf("this is the device num:%d\n", src.get_device()); + int dev = src.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + md_row_sum_kernel<<>>(src.data_ptr(), + dest.data_ptr(), + src.stride(0), + src.stride(1), + src.size(0), + src.size(1), + src.size(2)); } \ No newline at end of file diff --git a/test_reducemax.py b/test_reducemax.py index c76fbc7..2ad7347 100644 --- a/test_reducemax.py +++ b/test_reducemax.py @@ -2,7 +2,7 @@ import torch import torch_cuda_ext.core as core n = 1000000 -for i in range(1000): +for i in range(10000): src = torch.randn(size=(n,)).float().cuda() dest_n = int((n + 1024 - 1) / 1024) dest = torch.zeros(size=(dest_n,)).float().cuda()