#include "core.h" #include #include #include #include #include #include #include #include #include #include #include __device__ void mm_device(const float *src) { } template __global__ void md_mm_kernel(const float *src, int stride_a, int stride_b, int stride_c, int thread_num) { int batch_idx = blockIdx.x; int head_idx = blockIdx.y; int block_idx = blockIdx.z; int tidx = threadIdx.x; int current_idx = batch_idx * stride_a + head_idx * stride_b + block_idx * stride_c + tidx; // 其实是否一开始就用最原始的方法来写,然后后面进行拆分更容易一些呢。 } void md_mm(const torch::Tensor &src) { int batch_size = src.size(0); int head_size = src.size(1); int sequence_size = src.size(2); int head_dim = src.size(3); int data_block = sequence_size * head_dim; int thread_num = 256; dim3 grid(batch_size, head_size, (data_block + thread_num - 1) / thread_num); dim3 block(thread_num); md_mm_kernel<<>>(reinterpret_cast(src.data_ptr()), src.stride(0), src.stride(1), src.stride(2), thread_num); } template __global__ void row_sum_kernel(const float *src, float *dest, int hidden_dim) { __shared__ float tmp_data[BLOCK_SIZE]; float local_sum = 0.0f; int offset = blockIdx.x * hidden_dim; int idx = blockIdx.y * blockDim.y + threadIdx.x; int tid = threadIdx.x; for (int i = threadIdx.x; i < hidden_dim; i += BLOCK_SIZE) { // add some other place's data. local_sum += (src[offset + i] * src[offset + i]); } if (idx < hidden_dim) tmp_data[tid] = local_sum; else tmp_data[tid] = 0.0f; __syncthreads(); typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; float sum = BlockReduce(temp_storage).Sum(tmp_data[tid]); if (tid == 0) { dest[blockIdx.x] = sum; printf("blockidx.x: %d, blockIdx.y %d, blockIdx.z %d\n", blockIdx.x, blockIdx.y, blockIdx.z); } } void block_sum(const torch::Tensor &src, torch::Tensor &dest) { int block_size = 1024; dim3 grid(src.size(0), (src.size(1) + block_size - 1) / block_size); dim3 block(block_size); row_sum_kernel<<>>(src.data_ptr(), dest.data_ptr(), src.size(1)); } template __global__ void md_row_sum_kernel(const float *src, float *dest, int stride_a, int stride_b, int batch, int seq_len, int hidden_dim) { __shared__ float tmp_data[BLOCK_SIZE]; float local_sum = 0.0f; int offset = blockIdx.x * stride_a + blockIdx.y * stride_b; int tid = threadIdx.x; int block_offset = blockIdx.x * seq_len + blockIdx.y; int all_len = batch * seq_len; int idx = blockIdx.z * BLOCK_SIZE + tid; for (int i = threadIdx.x; i < hidden_dim; i += BLOCK_SIZE) { // add some other place's data. local_sum += src[offset + i]; } if (idx < hidden_dim) tmp_data[tid] = local_sum; else tmp_data[tid] = 0.0f; __syncthreads(); typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; float sum = BlockReduce(temp_storage).Sum(tmp_data[tid]); if (tid == 0 && block_offset < all_len) { dest[block_offset] = sum; printf("blockIdx.x %d, blockIdx.y %d, blockIdx.z %d, blockDim.x %d\n", blockIdx.x, blockIdx.y, blockIdx.z, blockDim.x); } } void md_block_sum(const torch::Tensor &src, torch::Tensor &dest) { int block_size = 1024; dim3 grid(src.size(0), src.size(1), (src.size(2) + block_size - 1) / block_size); dim3 block(block_size); printf("this is the device num:%d\n", src.get_device()); int dev = src.get_device(); cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); md_row_sum_kernel<<>>(src.data_ptr(), dest.data_ptr(), src.stride(0), src.stride(1), src.size(0), src.size(1), src.size(2)); } void interaction(const torch::Tensor &src) { int block_size = 1024; dim3 grid(src.size(0), src.size(1), (src.size(2) + block_size - 1) / block_size); dim3 block(block_size); printf("this is the device num:%d\n", src.get_device()); int dev = src.get_device(); cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); // seems can do some other things. } template __device__ s_scalar exp(s_scalar a) { return expf(a); } template <> __device__ __nv_bfloat16 exp(__nv_bfloat16 a) { float tmp = __bfloat162float(a); float tmp_score = expf(tmp); return __float2bfloat16(tmp_score); } template <> __device__ __half exp(__half a) { float tmp = __half2float(a); float tmp_score = expf(tmp); return __float2half(tmp_score); } template <> __device__ float exp(float a) { return expf(a); } template __device__ float fi_cast(scalar_t a) { return a; } template <> __device__ float fi_cast(__nv_bfloat16 a) { return __bfloat162float(a); } template <> __device__ float fi_cast(__half a) { return __half2float(a); } template __global__ void softmax_kernel(const scalar_t *src, scalar_t *dest, int hidden_dim) { int tid = threadIdx.x; int offset = blockIdx.x * hidden_dim; __shared__ scalar_t smem[BLOCK_SIZE]; float local_sum = 0.0f; for (int i = tid; i < hidden_dim; i += blockDim.x) { // sum the res; int tmp_index = offset + i; scalar_t tmp_score = exp(src[tmp_index]); dest[tmp_index] = tmp_score; local_sum += tmp_score; } if (tid < BLOCK_SIZE) smem[tid] = local_sum; else smem[tid] = 0.0f; __syncthreads(); typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; scalar_t sum = BlockReduce(temp_storage).Sum(smem[tid]); // remember the block reduce sum means only the first thread has the real sum. if (tid == 0) smem[0] = sum; __syncthreads(); for (int i = tid; i < hidden_dim; i += blockDim.x) { int tmp_index = offset + i; scalar_t tmp_score = dest[tmp_index] / smem[0]; dest[tmp_index] = tmp_score; } } void softmax(const torch::Tensor &src, torch::Tensor &dest) { int batch_num = src.size(0); int hidden_dim = src.size(1); int block_size = 1024; dim3 grid(batch_num); dim3 block(block_size); VLLM_DISPATCH_FLOATING_TYPES( src.scalar_type(), "softmax", [&] { int dev = src.get_device(); cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); softmax_kernel<1024, scalar_t><<>>( src.data_ptr(), dest.data_ptr(), hidden_dim); }); } template __global__ void test_head_dim_kernel() { int idx = threadIdx.x; } #define LANUCH(head_num) test_head_dim_kernel<<>>(); void test_head_dim(int head_num) { dim3 block(10); dim3 grid(1024); switch (head_num) { case 1: LANUCH(1); case 8: LANUCH(8); case 16: LANUCH(16); case 32: LANUCH(32); case 48: LANUCH(48); case 64: LANUCH(64); default: printf("do not support head num\n"); } }