torch_ext/csrc/max.cu
2024-11-18 19:54:12 +08:00

120 lines
4.1 KiB
Plaintext

#include <cub/cub.cuh>
#include <cub/util_device.cuh>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_bf16.h>
#include <torch/torch.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include "core.h"
using namespace cute;
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
template <int BLOCK_SIZE = 1024, typename scalar_t>
__global__ void reducemax_kernel(const scalar_t *src, scalar_t *dest, int len)
{
__shared__ float tmp_data[BLOCK_SIZE];
float local_sum = 0;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x;
for (int i = threadIdx.x; i < len; i += blockDim.x)
{
// add some other place's data.
local_sum += (float)src[blockIdx.x * blockDim.x + i];
}
if (idx < len)
tmp_data[tid] = local_sum;
else
tmp_data[tid] = 0.0f;
__syncthreads();
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
float sum = BlockReduce(temp_storage).Sum(tmp_data[tid]);
if (tid == 0)
{
dest[blockIdx.x] = (scalar_t)sum;
}
}
__global__ void fuse_rope_kernel()
{
int x_x = threadIdx.x;
int wrap_idx = threadIdx.x % 32;
int lane_idx = threadIdx.x / 32;
int head_idx = blockIdx.x;
int batch_idx = blockIdx.y;
}
void reducemax(const torch::Tensor &src, torch::Tensor &dest)
{
int len = src.size(0);
int block_size = 1024;
dim3 grid((len + block_size - 1) / block_size);
dim3 block(block_size);
VLLM_DISPATCH_FLOATING_TYPES(src.scalar_type(), "reducemax", [&]
{ reducemax_kernel<1024, scalar_t><<<grid, block>>>(
src.data_ptr<scalar_t>(),
dest.data_ptr<scalar_t>(),
len); });
}
__global__ void test_cute_tensor_kernel()
{
cute::Tensor rmem_4x8_col = cute::make_tensor<cute::half_t>(Shape<_4, _8>{});
Tensor ind_tensor = make_identity_tensor(make_shape(16, 16));
Tensor bool_tensor = make_tensor<bool>(shape(rmem_4x8_col));
Tensor rmem_4x8_pad = make_tensor<float>(Shape<_4, _8>{},
Stride<_32, _2>{});
Layout smem_layout = make_layout(make_shape(Int<4>{}, Int<8>{}));
__shared__ float smem[decltype(cosize(smem_layout))::value]; // (static-only allocation)
// printf("smem size is :%d\n", decltype(cosize(smem_layout))::value);
Tensor stensor = make_tensor(make_smem_ptr(smem), smem_layout);
printf("tensor size is: %d, ind size is: %d, rmem size is: %d , rmem4x8 is: %d, smem size is: %d\n",
bool_tensor.size(),
ind_tensor.size(), rmem_4x8_col.size(),
rmem_4x8_pad.size(),
stensor.size());
auto TA = make_layout(make_shape(Int<32>{}, Int<8>{}), LayoutRight{}); // (m,k) -> thr_idx; k-major
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, float>{}, // Atom: Copy TAs as if they were uint128_t
Layout<Shape<_32, _8>>{}, // Thr layout 32x8 m-major
Layout<Shape<_4, _1>>{}); // Val layout 4x1 m-major
print_latex(copyA);
}
// template <int head, int batch = 0, int head_dim = 0>
// void test_template()
// {
// if (head % 2 == 0)
// {
// std::cout << "what the fuck" << endl;
// }
// }
void test_cute_tensor()
{
dim3 thread_block(16, 16);
dim3 block(16);
test_cute_tensor_kernel<<<block, thread_block>>>();
}
__global__ void md_op(const float *a)
{
int tidx = threadIdx.x;
int bid = blockIdx.x;
int hid = blockIdx.y;
int offset = blockDim.x * blockDim.y;
// 绑定到自己的进
}