试了一下写softmax,又学到一点。可以了

This commit is contained in:
longfei li 2024-12-29 01:23:00 +08:00
parent acdacc2592
commit 58093d7a71
5 changed files with 169 additions and 9 deletions

View File

@ -3,6 +3,13 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
@ -21,4 +28,5 @@ void test_cute_tensor();
void md_mm(const torch::Tensor &src);
void block_sum(const torch::Tensor &src, torch::Tensor &dest);
void md_block_sum(const torch::Tensor &src, torch::Tensor &dest);
void softmax(const torch::Tensor &src, torch::Tensor &dest);
#endif

View File

@ -18,4 +18,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("md_mm", &md_mm, "just a test of multi dimension mm");
m.def("block_sum", &block_sum, "test block sum");
m.def("md_block_sum", &md_block_sum, "multi dimension block sum");
m.def("softmax", &softmax, "test softmax example");
}

View File

@ -15,13 +15,6 @@
#include "core.h"
using namespace cute;
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
template <int BLOCK_SIZE = 1024, typename scalar_t>
__global__ void reducemax_kernel(const scalar_t *src, scalar_t *dest, int len)

View File

@ -129,4 +129,117 @@ void md_block_sum(const torch::Tensor &src, torch::Tensor &dest)
src.size(0),
src.size(1),
src.size(2));
}
}
void interaction(const torch::Tensor &src)
{
int block_size = 1024;
dim3 grid(src.size(0), src.size(1), (src.size(2) + block_size - 1) / block_size);
dim3 block(block_size);
printf("this is the device num:%d\n", src.get_device());
int dev = src.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
// seems can do some other things.
}
template <typename s_scalar>
__device__ s_scalar exp(s_scalar a)
{
return expf(a);
}
template <>
__device__ __nv_bfloat16 exp(__nv_bfloat16 a)
{
float tmp = __bfloat162float(a);
float tmp_score = expf(tmp);
return __float2bfloat16(tmp_score);
}
template <>
__device__ __half exp(__half a)
{
float tmp = __half2float(a);
float tmp_score = expf(tmp);
return __float2half(tmp_score);
}
template <>
__device__ float exp(float a)
{
return expf(a);
}
template <typename scalar_t>
__device__ float fi_cast(scalar_t a)
{
return a;
}
template <>
__device__ float fi_cast(__nv_bfloat16 a)
{
return __bfloat162float(a);
}
template <>
__device__ float fi_cast(__half a)
{
return __half2float(a);
}
template <int BLOCK_SIZE, typename scalar_t>
__global__ void softmax_kernel(const scalar_t *src, scalar_t *dest, int hidden_dim)
{
int tid = threadIdx.x;
int offset = blockIdx.x * hidden_dim;
__shared__ scalar_t smem[BLOCK_SIZE];
float local_sum = 0.0f;
for (int i = tid; i < hidden_dim; i += blockDim.x)
{
// sum the res;
int tmp_index = offset + i;
scalar_t tmp_score = exp(src[tmp_index]);
dest[tmp_index] = tmp_score;
local_sum += tmp_score;
}
if (tid < BLOCK_SIZE)
smem[tid] = local_sum;
else
smem[tid] = 0.0f;
__syncthreads();
typedef cub::BlockReduce<scalar_t, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
scalar_t sum = BlockReduce(temp_storage).Sum(smem[tid]);
// remember the block reduce sum means only the first thread has the real sum.
if (tid == 0)
smem[0] = sum;
__syncthreads();
for (int i = tid; i < hidden_dim; i += blockDim.x)
{
int tmp_index = offset + i;
scalar_t tmp_score = dest[tmp_index] / smem[0];
dest[tmp_index] = tmp_score;
}
}
void softmax(const torch::Tensor &src, torch::Tensor &dest)
{
int batch_num = src.size(0);
int hidden_dim = src.size(1);
int block_size = 1024;
dim3 grid(batch_num);
dim3 block(block_size);
VLLM_DISPATCH_FLOATING_TYPES(
src.scalar_type(), "softmax",
[&]
{
int dev = src.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
softmax_kernel<1024, scalar_t><<<grid, block, 0, stream>>>(
src.data_ptr<scalar_t>(),
dest.data_ptr<scalar_t>(),
hidden_dim); });
}

View File

@ -2,7 +2,7 @@ import torch
import torch_cuda_ext.core as core
n = 1000000
for i in range(10000):
for i in range(100):
src = torch.randn(size=(n,)).float().cuda()
dest_n = int((n + 1024 - 1) / 1024)
dest = torch.zeros(size=(dest_n,)).float().cuda()
@ -31,3 +31,48 @@ core.md_block_sum(src, dest)
real_sum = src.sum(dim=-1)
diff = real_sum - dest
print(diff)
for k in range(128, 4096, 128):
for j in range(1024, 4096, 1024):
a = torch.randn(size=(k, j)).half().cuda()
b = torch.empty_like(a)
num_runs = 100
times = []
for _ in range(num_runs):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
core.softmax(a, b)
end.record()
torch.cuda.synchronize() # 等待 CUDA 操作完成
elapsed_time = start.elapsed_time(end) / 1000 # 转换为秒
times.append(elapsed_time)
own_avg_time = sum(times) / num_runs
own_std_time = (sum((t - own_avg_time) ** 2 for t in times) / num_runs) ** 0.5
print(f"own softmax cost time: {own_avg_time}, {own_std_time}")
times = []
for _ in range(num_runs):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
res = torch.softmax(a, dim=1)
end.record()
torch.cuda.synchronize() # 等待 CUDA 操作完成
elapsed_time = start.elapsed_time(end) / 1000 # 转换为秒
times.append(elapsed_time)
avg_time = sum(times) / num_runs
std_time = (sum((t - avg_time) ** 2 for t in times) / num_runs) ** 0.5
print(f"torch softmax cost time: {avg_time}, {std_time}")
# print("this is b", b)
diff = (res - b).abs().max()
if diff < 1e-4:
print("softmax is good")
time_diff_rate = (own_avg_time - avg_time) / avg_time
print(f"{k}, {j} matrix result {time_diff_rate}")
else:
print("softmax is not equal")