torch_ext/csrc/md.cu

132 lines
4.3 KiB
Plaintext
Raw Normal View History

2024-11-18 19:54:12 +08:00
#include "core.h"
#include <cub/cub.cuh>
#include <cub/util_device.cuh>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_bf16.h>
#include <torch/torch.h>
#include <torch/all.h>
2024-11-18 19:54:12 +08:00
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
2024-11-18 22:13:43 +08:00
__device__ void mm_device(const float *src)
{
}
template <int BLOCk_SIZE = 128>
2024-11-18 19:54:12 +08:00
__global__ void md_mm_kernel(const float *src, int stride_a, int stride_b, int stride_c, int thread_num)
{
int batch_idx = blockIdx.x;
int head_idx = blockIdx.y;
int block_idx = blockIdx.z;
2024-11-18 19:54:12 +08:00
int tidx = threadIdx.x;
int current_idx = batch_idx * stride_a + head_idx * stride_b + block_idx * stride_c + tidx;
// 其实是否一开始就用最原始的方法来写,然后后面进行拆分更容易一些呢。
2024-11-18 19:54:12 +08:00
}
void md_mm(const torch::Tensor &src)
{
int batch_size = src.size(0);
int head_size = src.size(1);
int sequence_size = src.size(2);
int head_dim = src.size(3);
int data_block = sequence_size * head_dim;
int thread_num = 256;
dim3 grid(batch_size, head_size, (data_block + thread_num - 1) / thread_num);
dim3 block(thread_num);
md_mm_kernel<<<grid, block>>>(reinterpret_cast<float *>(src.data_ptr()),
src.stride(0), src.stride(1), src.stride(2),
thread_num);
}
template <int BLOCK_SIZE = 1024>
__global__ void row_sum_kernel(const float *src, float *dest, int hidden_dim)
{
__shared__ float tmp_data[BLOCK_SIZE];
float local_sum = 0.0f;
int offset = blockIdx.x * hidden_dim;
int idx = blockIdx.y * blockDim.y + threadIdx.x;
int tid = threadIdx.x;
for (int i = threadIdx.x; i < hidden_dim; i += BLOCK_SIZE)
{
// add some other place's data.
local_sum += (src[offset + i] * src[offset + i]);
}
if (idx < hidden_dim)
tmp_data[tid] = local_sum;
else
tmp_data[tid] = 0.0f;
__syncthreads();
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
float sum = BlockReduce(temp_storage).Sum(tmp_data[tid]);
if (tid == 0)
{
dest[blockIdx.x] = sum;
}
}
void block_sum(const torch::Tensor &src, torch::Tensor &dest)
{
int block_size = 1024;
dim3 grid(src.size(0), (src.size(1) + block_size - 1) / block_size);
dim3 block(block_size);
row_sum_kernel<<<grid, block>>>(src.data_ptr<float>(), dest.data_ptr<float>(), src.size(1));
}
template <int BLOCK_SIZE = 1024>
__global__ void md_row_sum_kernel(const float *src, float *dest, int stride_a, int stride_b, int batch, int seq_len, int hidden_dim)
{
__shared__ float tmp_data[BLOCK_SIZE];
float local_sum = 0.0f;
int offset = blockIdx.x * stride_a + blockIdx.y * stride_b;
int tid = threadIdx.x;
int block_offset = blockIdx.x * seq_len + blockIdx.y;
int all_len = batch * seq_len;
int idx = blockIdx.z * BLOCK_SIZE + tid;
for (int i = threadIdx.x; i < hidden_dim; i += BLOCK_SIZE)
{
// add some other place's data.
local_sum += src[offset + i];
}
if (idx < hidden_dim)
tmp_data[tid] = local_sum;
else
tmp_data[tid] = 0.0f;
__syncthreads();
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
float sum = BlockReduce(temp_storage).Sum(tmp_data[tid]);
if (tid == 0 && block_offset < all_len)
{
dest[block_offset] = sum;
}
}
void md_block_sum(const torch::Tensor &src, torch::Tensor &dest)
{
int block_size = 1024;
2024-12-27 21:55:12 +08:00
dim3 grid(src.size(0), src.size(1), (src.size(2) + block_size - 1) / block_size);
dim3 block(block_size);
2024-12-27 21:55:12 +08:00
printf("this is the device num:%d\n", src.get_device());
int dev = src.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
md_row_sum_kernel<<<grid, block, 0, stream>>>(src.data_ptr<float>(),
dest.data_ptr<float>(),
src.stride(0),
src.stride(1),
src.size(0),
src.size(1),
src.size(2));
}