From 4da12fd0c227afbe8306c66c0a26749de1d4ad7f Mon Sep 17 00:00:00 2001 From: longfei li Date: Fri, 22 Nov 2024 22:31:57 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=9A=E7=BB=B4=E7=9A=84=E9=97=AE=E9=A2=98?= =?UTF-8?q?=E4=B9=9F=E5=AE=9E=E9=AA=8C=E4=BA=86=E4=B8=80=E4=B8=8B=EF=BC=8C?= =?UTF-8?q?=E7=9C=8B=E8=B5=B7=E6=9D=A5=E8=BF=98=E4=B8=8D=E9=94=99=E7=9A=84?= =?UTF-8?q?=E6=A0=B7=E5=AD=90=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/settings.json | 5 ++- csrc/core.h | 2 + csrc/core_bind.cpp | 2 + csrc/md.cu | 97 ++++++++++++++++++++++++++++++++++++++++++- test_reducemax.py | 19 +++++++++ 5 files changed, 123 insertions(+), 2 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 23830fb..9daa158 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,3 +1,6 @@ { - "git.ignoreLimitWarning": true + "git.ignoreLimitWarning": true, + "files.associations": { + "__config": "cpp" + } } diff --git a/csrc/core.h b/csrc/core.h index 4c0b4f9..815e0df 100644 --- a/csrc/core.h +++ b/csrc/core.h @@ -17,4 +17,6 @@ void print_idx(); void reducemax(const torch::Tensor &src, torch::Tensor &dest); void test_cute_tensor(); void md_mm(const torch::Tensor &src); +void block_sum(const torch::Tensor &src, torch::Tensor &dest); +void md_block_sum(const torch::Tensor &src, torch::Tensor &dest); #endif \ No newline at end of file diff --git a/csrc/core_bind.cpp b/csrc/core_bind.cpp index 1672e6a..50ee703 100644 --- a/csrc/core_bind.cpp +++ b/csrc/core_bind.cpp @@ -16,4 +16,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("reducemax", &reducemax, "reduce max"); m.def("test_cute_tensor", &test_cute_tensor, "just test cute tensor"); m.def("md_mm", &md_mm, "just a test of multi dimension mm"); + m.def("block_sum", &block_sum, "test block sum"); + m.def("md_block_sum", &md_block_sum, "multi dimension block sum"); } diff --git a/csrc/md.cu b/csrc/md.cu index 809df94..32bb37c 100644 --- a/csrc/md.cu +++ b/csrc/md.cu @@ -1,4 +1,12 @@ #include "core.h" +#include +#include + +#include +#include +#include +#include +#include #include #include @@ -8,12 +16,16 @@ __device__ void mm_device(const float *src) { } + +template __global__ void md_mm_kernel(const float *src, int stride_a, int stride_b, int stride_c, int thread_num) { int batch_idx = blockIdx.x; int head_idx = blockIdx.y; - int sequence_idx = blockIdx.z; + int block_idx = blockIdx.z; int tidx = threadIdx.x; + int current_idx = batch_idx * stride_a + head_idx * stride_b + block_idx * stride_c + tidx; + // 其实是否一开始就用最原始的方法来写,然后后面进行拆分更容易一些呢。 } void md_mm(const torch::Tensor &src) @@ -30,3 +42,86 @@ void md_mm(const torch::Tensor &src) src.stride(0), src.stride(1), src.stride(2), thread_num); } + +template +__global__ void row_sum_kernel(const float *src, float *dest, int hidden_dim) +{ + + __shared__ float tmp_data[BLOCK_SIZE]; + float local_sum = 0.0f; + int offset = blockIdx.x * hidden_dim; + int idx = blockIdx.y * blockDim.y + threadIdx.x; + int tid = threadIdx.x; + for (int i = threadIdx.x; i < hidden_dim; i += BLOCK_SIZE) + { + // add some other place's data. + local_sum += (src[offset + i] * src[offset + i]); + } + if (idx < hidden_dim) + tmp_data[tid] = local_sum; + else + tmp_data[tid] = 0.0f; + __syncthreads(); + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + float sum = BlockReduce(temp_storage).Sum(tmp_data[tid]); + if (tid == 0) + { + dest[blockIdx.x] = sum; + } +} + +void block_sum(const torch::Tensor &src, torch::Tensor &dest) +{ + int block_size = 1024; + dim3 grid(src.size(0), (src.size(1) + block_size - 1) / block_size); + dim3 block(block_size); + row_sum_kernel<<>>(src.data_ptr(), dest.data_ptr(), src.size(1)); +} + +template +__global__ void md_row_sum_kernel(const float *src, float *dest, int stride_a, int stride_b, int batch, int seq_len, int hidden_dim) +{ + + __shared__ float tmp_data[BLOCK_SIZE]; + float local_sum = 0.0f; + int offset = blockIdx.x * stride_a + blockIdx.y * stride_b; + int tid = threadIdx.x; + int block_offset = blockIdx.x * seq_len + blockIdx.y; + int all_len = batch * seq_len; + int idx = blockIdx.z * BLOCK_SIZE + tid; + + for (int i = threadIdx.x; i < hidden_dim; i += BLOCK_SIZE) + { + // add some other place's data. + local_sum += src[offset + i]; + } + if (idx < hidden_dim) + tmp_data[tid] = local_sum; + else + tmp_data[tid] = 0.0f; + __syncthreads(); + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + float sum = BlockReduce(temp_storage).Sum(tmp_data[tid]); + if (tid == 0 && block_offset < all_len) + { + dest[block_offset] = sum; + } +} + +void md_block_sum(const torch::Tensor &src, torch::Tensor &dest) +{ + int block_size = 1024; + dim3 grid(src.size(0), src.size(1), (src.size(2) + block_size - 1) / block_size); + dim3 block(block_size); + md_row_sum_kernel<<>>(src.data_ptr(), + dest.data_ptr(), + src.stride(0), + src.stride(1), + src.size(0), + src.size(1), + src.size(2)); +} \ No newline at end of file diff --git a/test_reducemax.py b/test_reducemax.py index ace4516..c76fbc7 100644 --- a/test_reducemax.py +++ b/test_reducemax.py @@ -12,3 +12,22 @@ print(dest[0]) print(src.sum()) core.test_cute_tensor() + +src = torch.randn(size=(4096, 4096)).float().cuda() +dest = torch.zeros(size=(4096,)).float().cuda() +core.block_sum(src, dest) +src = src * src +real_sum = src.sum(dim=1) + +diff = real_sum - dest +print(diff) + + +src = torch.randn(size=((64, 128, 4096))).float().cuda() +dest = torch.randn(size=(64, 128)).float().cuda() + +core.md_block_sum(src, dest) + +real_sum = src.sum(dim=-1) +diff = real_sum - dest +print(diff)