测试一下。

This commit is contained in:
longfei li 2024-12-27 21:55:12 +08:00
parent 4da12fd0c2
commit acdacc2592
3 changed files with 15 additions and 8 deletions

View File

@ -1,6 +1,8 @@
#ifndef CORE_H
#define CORE_H
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);

View File

@ -115,13 +115,18 @@ __global__ void md_row_sum_kernel(const float *src, float *dest, int stride_a, i
void md_block_sum(const torch::Tensor &src, torch::Tensor &dest)
{
int block_size = 1024;
dim3 grid(src.size(0), src.size(1), (src.size(2) + block_size - 1) / block_size);
dim3 block(block_size);
md_row_sum_kernel<<<grid, block>>>(src.data_ptr<float>(),
dest.data_ptr<float>(),
src.stride(0),
src.stride(1),
src.size(0),
src.size(1),
src.size(2));
printf("this is the device num:%d\n", src.get_device());
int dev = src.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
md_row_sum_kernel<<<grid, block, 0, stream>>>(src.data_ptr<float>(),
dest.data_ptr<float>(),
src.stride(0),
src.stride(1),
src.size(0),
src.size(1),
src.size(2));
}

View File

@ -2,7 +2,7 @@ import torch
import torch_cuda_ext.core as core
n = 1000000
for i in range(1000):
for i in range(10000):
src = torch.randn(size=(n,)).float().cuda()
dest_n = int((n + 1024 - 1) / 1024)
dest = torch.zeros(size=(dest_n,)).float().cuda()