From 58093d7a710d742b99fc61ae7e9349bc73cb0181 Mon Sep 17 00:00:00 2001 From: longfei li Date: Sun, 29 Dec 2024 01:23:00 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AF=95=E4=BA=86=E4=B8=80=E4=B8=8B=E5=86=99so?= =?UTF-8?q?ftmax=EF=BC=8C=E5=8F=88=E5=AD=A6=E5=88=B0=E4=B8=80=E7=82=B9?= =?UTF-8?q?=E3=80=82=E5=8F=AF=E4=BB=A5=E4=BA=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- csrc/core.h | 8 ++++ csrc/core_bind.cpp | 1 + csrc/max.cu | 7 --- csrc/md.cu | 115 ++++++++++++++++++++++++++++++++++++++++++++- test_reducemax.py | 47 +++++++++++++++++- 5 files changed, 169 insertions(+), 9 deletions(-) diff --git a/csrc/core.h b/csrc/core.h index 0afed83..96989f3 100644 --- a/csrc/core.h +++ b/csrc/core.h @@ -3,6 +3,13 @@ #include #include #include +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output); @@ -21,4 +28,5 @@ void test_cute_tensor(); void md_mm(const torch::Tensor &src); void block_sum(const torch::Tensor &src, torch::Tensor &dest); void md_block_sum(const torch::Tensor &src, torch::Tensor &dest); +void softmax(const torch::Tensor &src, torch::Tensor &dest); #endif \ No newline at end of file diff --git a/csrc/core_bind.cpp b/csrc/core_bind.cpp index 50ee703..90fa03b 100644 --- a/csrc/core_bind.cpp +++ b/csrc/core_bind.cpp @@ -18,4 +18,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("md_mm", &md_mm, "just a test of multi dimension mm"); m.def("block_sum", &block_sum, "test block sum"); m.def("md_block_sum", &md_block_sum, "multi dimension block sum"); + m.def("softmax", &softmax, "test softmax example"); } diff --git a/csrc/max.cu b/csrc/max.cu index 8f2f15a..2f504ea 100644 --- a/csrc/max.cu +++ b/csrc/max.cu @@ -15,13 +15,6 @@ #include "core.h" using namespace cute; -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) - -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) template __global__ void reducemax_kernel(const scalar_t *src, scalar_t *dest, int len) diff --git a/csrc/md.cu b/csrc/md.cu index 96c4ee0..81953ee 100644 --- a/csrc/md.cu +++ b/csrc/md.cu @@ -129,4 +129,117 @@ void md_block_sum(const torch::Tensor &src, torch::Tensor &dest) src.size(0), src.size(1), src.size(2)); -} \ No newline at end of file +} + +void interaction(const torch::Tensor &src) +{ + int block_size = 1024; + + dim3 grid(src.size(0), src.size(1), (src.size(2) + block_size - 1) / block_size); + dim3 block(block_size); + + printf("this is the device num:%d\n", src.get_device()); + int dev = src.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + // seems can do some other things. +} + +template +__device__ s_scalar exp(s_scalar a) +{ + return expf(a); +} + +template <> +__device__ __nv_bfloat16 exp(__nv_bfloat16 a) +{ + float tmp = __bfloat162float(a); + float tmp_score = expf(tmp); + return __float2bfloat16(tmp_score); +} + +template <> +__device__ __half exp(__half a) +{ + float tmp = __half2float(a); + float tmp_score = expf(tmp); + return __float2half(tmp_score); +} + +template <> +__device__ float exp(float a) +{ + return expf(a); +} + +template +__device__ float fi_cast(scalar_t a) +{ + return a; +} + +template <> +__device__ float fi_cast(__nv_bfloat16 a) +{ + return __bfloat162float(a); +} + +template <> +__device__ float fi_cast(__half a) +{ + return __half2float(a); +} + +template +__global__ void softmax_kernel(const scalar_t *src, scalar_t *dest, int hidden_dim) +{ + int tid = threadIdx.x; + int offset = blockIdx.x * hidden_dim; + __shared__ scalar_t smem[BLOCK_SIZE]; + float local_sum = 0.0f; + for (int i = tid; i < hidden_dim; i += blockDim.x) + { + // sum the res; + int tmp_index = offset + i; + scalar_t tmp_score = exp(src[tmp_index]); + dest[tmp_index] = tmp_score; + local_sum += tmp_score; + } + if (tid < BLOCK_SIZE) + smem[tid] = local_sum; + else + smem[tid] = 0.0f; + __syncthreads(); + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + scalar_t sum = BlockReduce(temp_storage).Sum(smem[tid]); + // remember the block reduce sum means only the first thread has the real sum. + if (tid == 0) + smem[0] = sum; + __syncthreads(); + for (int i = tid; i < hidden_dim; i += blockDim.x) + { + int tmp_index = offset + i; + scalar_t tmp_score = dest[tmp_index] / smem[0]; + dest[tmp_index] = tmp_score; + } +} + +void softmax(const torch::Tensor &src, torch::Tensor &dest) +{ + int batch_num = src.size(0); + int hidden_dim = src.size(1); + int block_size = 1024; + dim3 grid(batch_num); + dim3 block(block_size); + VLLM_DISPATCH_FLOATING_TYPES( + src.scalar_type(), "softmax", + [&] + { + int dev = src.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + softmax_kernel<1024, scalar_t><<>>( + src.data_ptr(), + dest.data_ptr(), + hidden_dim); }); +} diff --git a/test_reducemax.py b/test_reducemax.py index 2ad7347..6fea348 100644 --- a/test_reducemax.py +++ b/test_reducemax.py @@ -2,7 +2,7 @@ import torch import torch_cuda_ext.core as core n = 1000000 -for i in range(10000): +for i in range(100): src = torch.randn(size=(n,)).float().cuda() dest_n = int((n + 1024 - 1) / 1024) dest = torch.zeros(size=(dest_n,)).float().cuda() @@ -31,3 +31,48 @@ core.md_block_sum(src, dest) real_sum = src.sum(dim=-1) diff = real_sum - dest print(diff) + +for k in range(128, 4096, 128): + for j in range(1024, 4096, 1024): + a = torch.randn(size=(k, j)).half().cuda() + b = torch.empty_like(a) + num_runs = 100 + times = [] + for _ in range(num_runs): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + core.softmax(a, b) + end.record() + torch.cuda.synchronize() # 等待 CUDA 操作完成 + elapsed_time = start.elapsed_time(end) / 1000 # 转换为秒 + times.append(elapsed_time) + + own_avg_time = sum(times) / num_runs + own_std_time = (sum((t - own_avg_time) ** 2 for t in times) / num_runs) ** 0.5 + print(f"own softmax cost time: {own_avg_time}, {own_std_time}") + + times = [] + for _ in range(num_runs): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + + res = torch.softmax(a, dim=1) + end.record() + torch.cuda.synchronize() # 等待 CUDA 操作完成 + elapsed_time = start.elapsed_time(end) / 1000 # 转换为秒 + times.append(elapsed_time) + + avg_time = sum(times) / num_runs + std_time = (sum((t - avg_time) ** 2 for t in times) / num_runs) ** 0.5 + print(f"torch softmax cost time: {avg_time}, {std_time}") + + # print("this is b", b) + diff = (res - b).abs().max() + if diff < 1e-4: + print("softmax is good") + time_diff_rate = (own_avg_time - avg_time) / avg_time + print(f"{k}, {j} matrix result {time_diff_rate}") + else: + print("softmax is not equal")