多维的问题也实验了一下,看起来还不错的样子。
This commit is contained in:
parent
bf81e39d83
commit
4da12fd0c2
5
.vscode/settings.json
vendored
5
.vscode/settings.json
vendored
@ -1,3 +1,6 @@
|
|||||||
{
|
{
|
||||||
"git.ignoreLimitWarning": true
|
"git.ignoreLimitWarning": true,
|
||||||
|
"files.associations": {
|
||||||
|
"__config": "cpp"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,4 +17,6 @@ void print_idx();
|
|||||||
void reducemax(const torch::Tensor &src, torch::Tensor &dest);
|
void reducemax(const torch::Tensor &src, torch::Tensor &dest);
|
||||||
void test_cute_tensor();
|
void test_cute_tensor();
|
||||||
void md_mm(const torch::Tensor &src);
|
void md_mm(const torch::Tensor &src);
|
||||||
|
void block_sum(const torch::Tensor &src, torch::Tensor &dest);
|
||||||
|
void md_block_sum(const torch::Tensor &src, torch::Tensor &dest);
|
||||||
#endif
|
#endif
|
||||||
@ -16,4 +16,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
|||||||
m.def("reducemax", &reducemax, "reduce max");
|
m.def("reducemax", &reducemax, "reduce max");
|
||||||
m.def("test_cute_tensor", &test_cute_tensor, "just test cute tensor");
|
m.def("test_cute_tensor", &test_cute_tensor, "just test cute tensor");
|
||||||
m.def("md_mm", &md_mm, "just a test of multi dimension mm");
|
m.def("md_mm", &md_mm, "just a test of multi dimension mm");
|
||||||
|
m.def("block_sum", &block_sum, "test block sum");
|
||||||
|
m.def("md_block_sum", &md_block_sum, "multi dimension block sum");
|
||||||
}
|
}
|
||||||
|
|||||||
97
csrc/md.cu
97
csrc/md.cu
@ -1,4 +1,12 @@
|
|||||||
#include "core.h"
|
#include "core.h"
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
#include <cub/util_device.cuh>
|
||||||
|
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
#include <cute/tensor.hpp>
|
#include <cute/tensor.hpp>
|
||||||
#include <cutlass/cutlass.h>
|
#include <cutlass/cutlass.h>
|
||||||
@ -8,12 +16,16 @@
|
|||||||
__device__ void mm_device(const float *src)
|
__device__ void mm_device(const float *src)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int BLOCk_SIZE = 128>
|
||||||
__global__ void md_mm_kernel(const float *src, int stride_a, int stride_b, int stride_c, int thread_num)
|
__global__ void md_mm_kernel(const float *src, int stride_a, int stride_b, int stride_c, int thread_num)
|
||||||
{
|
{
|
||||||
int batch_idx = blockIdx.x;
|
int batch_idx = blockIdx.x;
|
||||||
int head_idx = blockIdx.y;
|
int head_idx = blockIdx.y;
|
||||||
int sequence_idx = blockIdx.z;
|
int block_idx = blockIdx.z;
|
||||||
int tidx = threadIdx.x;
|
int tidx = threadIdx.x;
|
||||||
|
int current_idx = batch_idx * stride_a + head_idx * stride_b + block_idx * stride_c + tidx;
|
||||||
|
// 其实是否一开始就用最原始的方法来写,然后后面进行拆分更容易一些呢。
|
||||||
}
|
}
|
||||||
|
|
||||||
void md_mm(const torch::Tensor &src)
|
void md_mm(const torch::Tensor &src)
|
||||||
@ -30,3 +42,86 @@ void md_mm(const torch::Tensor &src)
|
|||||||
src.stride(0), src.stride(1), src.stride(2),
|
src.stride(0), src.stride(1), src.stride(2),
|
||||||
thread_num);
|
thread_num);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int BLOCK_SIZE = 1024>
|
||||||
|
__global__ void row_sum_kernel(const float *src, float *dest, int hidden_dim)
|
||||||
|
{
|
||||||
|
|
||||||
|
__shared__ float tmp_data[BLOCK_SIZE];
|
||||||
|
float local_sum = 0.0f;
|
||||||
|
int offset = blockIdx.x * hidden_dim;
|
||||||
|
int idx = blockIdx.y * blockDim.y + threadIdx.x;
|
||||||
|
int tid = threadIdx.x;
|
||||||
|
for (int i = threadIdx.x; i < hidden_dim; i += BLOCK_SIZE)
|
||||||
|
{
|
||||||
|
// add some other place's data.
|
||||||
|
local_sum += (src[offset + i] * src[offset + i]);
|
||||||
|
}
|
||||||
|
if (idx < hidden_dim)
|
||||||
|
tmp_data[tid] = local_sum;
|
||||||
|
else
|
||||||
|
tmp_data[tid] = 0.0f;
|
||||||
|
__syncthreads();
|
||||||
|
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
|
||||||
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
|
|
||||||
|
float sum = BlockReduce(temp_storage).Sum(tmp_data[tid]);
|
||||||
|
if (tid == 0)
|
||||||
|
{
|
||||||
|
dest[blockIdx.x] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void block_sum(const torch::Tensor &src, torch::Tensor &dest)
|
||||||
|
{
|
||||||
|
int block_size = 1024;
|
||||||
|
dim3 grid(src.size(0), (src.size(1) + block_size - 1) / block_size);
|
||||||
|
dim3 block(block_size);
|
||||||
|
row_sum_kernel<<<grid, block>>>(src.data_ptr<float>(), dest.data_ptr<float>(), src.size(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int BLOCK_SIZE = 1024>
|
||||||
|
__global__ void md_row_sum_kernel(const float *src, float *dest, int stride_a, int stride_b, int batch, int seq_len, int hidden_dim)
|
||||||
|
{
|
||||||
|
|
||||||
|
__shared__ float tmp_data[BLOCK_SIZE];
|
||||||
|
float local_sum = 0.0f;
|
||||||
|
int offset = blockIdx.x * stride_a + blockIdx.y * stride_b;
|
||||||
|
int tid = threadIdx.x;
|
||||||
|
int block_offset = blockIdx.x * seq_len + blockIdx.y;
|
||||||
|
int all_len = batch * seq_len;
|
||||||
|
int idx = blockIdx.z * BLOCK_SIZE + tid;
|
||||||
|
|
||||||
|
for (int i = threadIdx.x; i < hidden_dim; i += BLOCK_SIZE)
|
||||||
|
{
|
||||||
|
// add some other place's data.
|
||||||
|
local_sum += src[offset + i];
|
||||||
|
}
|
||||||
|
if (idx < hidden_dim)
|
||||||
|
tmp_data[tid] = local_sum;
|
||||||
|
else
|
||||||
|
tmp_data[tid] = 0.0f;
|
||||||
|
__syncthreads();
|
||||||
|
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
|
||||||
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
|
|
||||||
|
float sum = BlockReduce(temp_storage).Sum(tmp_data[tid]);
|
||||||
|
if (tid == 0 && block_offset < all_len)
|
||||||
|
{
|
||||||
|
dest[block_offset] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void md_block_sum(const torch::Tensor &src, torch::Tensor &dest)
|
||||||
|
{
|
||||||
|
int block_size = 1024;
|
||||||
|
dim3 grid(src.size(0), src.size(1), (src.size(2) + block_size - 1) / block_size);
|
||||||
|
dim3 block(block_size);
|
||||||
|
md_row_sum_kernel<<<grid, block>>>(src.data_ptr<float>(),
|
||||||
|
dest.data_ptr<float>(),
|
||||||
|
src.stride(0),
|
||||||
|
src.stride(1),
|
||||||
|
src.size(0),
|
||||||
|
src.size(1),
|
||||||
|
src.size(2));
|
||||||
|
}
|
||||||
@ -12,3 +12,22 @@ print(dest[0])
|
|||||||
print(src.sum())
|
print(src.sum())
|
||||||
|
|
||||||
core.test_cute_tensor()
|
core.test_cute_tensor()
|
||||||
|
|
||||||
|
src = torch.randn(size=(4096, 4096)).float().cuda()
|
||||||
|
dest = torch.zeros(size=(4096,)).float().cuda()
|
||||||
|
core.block_sum(src, dest)
|
||||||
|
src = src * src
|
||||||
|
real_sum = src.sum(dim=1)
|
||||||
|
|
||||||
|
diff = real_sum - dest
|
||||||
|
print(diff)
|
||||||
|
|
||||||
|
|
||||||
|
src = torch.randn(size=((64, 128, 4096))).float().cuda()
|
||||||
|
dest = torch.randn(size=(64, 128)).float().cuda()
|
||||||
|
|
||||||
|
core.md_block_sum(src, dest)
|
||||||
|
|
||||||
|
real_sum = src.sum(dim=-1)
|
||||||
|
diff = real_sum - dest
|
||||||
|
print(diff)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user