测试一下。
This commit is contained in:
parent
4da12fd0c2
commit
acdacc2592
@ -1,6 +1,8 @@
|
||||
#ifndef CORE_H
|
||||
#define CORE_H
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
|
||||
|
||||
|
||||
19
csrc/md.cu
19
csrc/md.cu
@ -115,13 +115,18 @@ __global__ void md_row_sum_kernel(const float *src, float *dest, int stride_a, i
|
||||
void md_block_sum(const torch::Tensor &src, torch::Tensor &dest)
|
||||
{
|
||||
int block_size = 1024;
|
||||
|
||||
dim3 grid(src.size(0), src.size(1), (src.size(2) + block_size - 1) / block_size);
|
||||
dim3 block(block_size);
|
||||
md_row_sum_kernel<<<grid, block>>>(src.data_ptr<float>(),
|
||||
dest.data_ptr<float>(),
|
||||
src.stride(0),
|
||||
src.stride(1),
|
||||
src.size(0),
|
||||
src.size(1),
|
||||
src.size(2));
|
||||
|
||||
printf("this is the device num:%d\n", src.get_device());
|
||||
int dev = src.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
md_row_sum_kernel<<<grid, block, 0, stream>>>(src.data_ptr<float>(),
|
||||
dest.data_ptr<float>(),
|
||||
src.stride(0),
|
||||
src.stride(1),
|
||||
src.size(0),
|
||||
src.size(1),
|
||||
src.size(2));
|
||||
}
|
||||
@ -2,7 +2,7 @@ import torch
|
||||
import torch_cuda_ext.core as core
|
||||
|
||||
n = 1000000
|
||||
for i in range(1000):
|
||||
for i in range(10000):
|
||||
src = torch.randn(size=(n,)).float().cuda()
|
||||
dest_n = int((n + 1024 - 1) / 1024)
|
||||
dest = torch.zeros(size=(dest_n,)).float().cuda()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user