测试一下。

This commit is contained in:
longfei li 2024-12-27 21:55:12 +08:00
parent 4da12fd0c2
commit acdacc2592
3 changed files with 15 additions and 8 deletions

View File

@ -1,6 +1,8 @@
#ifndef CORE_H #ifndef CORE_H
#define CORE_H #define CORE_H
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output); void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);

View File

@ -115,13 +115,18 @@ __global__ void md_row_sum_kernel(const float *src, float *dest, int stride_a, i
void md_block_sum(const torch::Tensor &src, torch::Tensor &dest) void md_block_sum(const torch::Tensor &src, torch::Tensor &dest)
{ {
int block_size = 1024; int block_size = 1024;
dim3 grid(src.size(0), src.size(1), (src.size(2) + block_size - 1) / block_size); dim3 grid(src.size(0), src.size(1), (src.size(2) + block_size - 1) / block_size);
dim3 block(block_size); dim3 block(block_size);
md_row_sum_kernel<<<grid, block>>>(src.data_ptr<float>(),
dest.data_ptr<float>(), printf("this is the device num:%d\n", src.get_device());
src.stride(0), int dev = src.get_device();
src.stride(1), cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
src.size(0), md_row_sum_kernel<<<grid, block, 0, stream>>>(src.data_ptr<float>(),
src.size(1), dest.data_ptr<float>(),
src.size(2)); src.stride(0),
src.stride(1),
src.size(0),
src.size(1),
src.size(2));
} }

View File

@ -2,7 +2,7 @@ import torch
import torch_cuda_ext.core as core import torch_cuda_ext.core as core
n = 1000000 n = 1000000
for i in range(1000): for i in range(10000):
src = torch.randn(size=(n,)).float().cuda() src = torch.randn(size=(n,)).float().cuda()
dest_n = int((n + 1024 - 1) / 1024) dest_n = int((n + 1024 - 1) / 1024)
dest = torch.zeros(size=(dest_n,)).float().cuda() dest = torch.zeros(size=(dest_n,)).float().cuda()