123 lines
4.2 KiB
Plaintext
123 lines
4.2 KiB
Plaintext
#include <cub/cub.cuh>
|
|
#include <cub/util_device.cuh>
|
|
|
|
#include <cuda_fp16.h>
|
|
#include <cuda_fp8.h>
|
|
#include <cuda_bf16.h>
|
|
#include <torch/torch.h>
|
|
#include <torch/all.h>
|
|
|
|
#include <cute/tensor.hpp>
|
|
#include <cutlass/cutlass.h>
|
|
#include <cutlass/array.h>
|
|
#include <cutlass/numeric_types.h>
|
|
|
|
#include "core.h"
|
|
|
|
using namespace cute;
|
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
|
|
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
|
|
|
template <int BLOCK_SIZE = 1024, typename scalar_t>
|
|
__global__ void reducemax_kernel(const scalar_t *src, scalar_t *dest, int len)
|
|
{
|
|
__shared__ float tmp_data[BLOCK_SIZE];
|
|
float local_sum = 0;
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
int tid = threadIdx.x;
|
|
for (int i = threadIdx.x; i < len; i += blockDim.x)
|
|
{
|
|
// add some other place's data.
|
|
local_sum += (float)src[blockIdx.x * blockDim.x + i];
|
|
}
|
|
if (idx < len)
|
|
tmp_data[tid] = local_sum;
|
|
else
|
|
tmp_data[tid] = 0.0f;
|
|
__syncthreads();
|
|
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
|
|
float sum = BlockReduce(temp_storage).Sum(tmp_data[tid]);
|
|
if (tid == 0)
|
|
{
|
|
dest[blockIdx.x] = (scalar_t)sum;
|
|
}
|
|
}
|
|
|
|
__global__ void fuse_rope_kernel()
|
|
{
|
|
int x_x = threadIdx.x;
|
|
int wrap_idx = threadIdx.x % 32;
|
|
int lane_idx = threadIdx.x / 32;
|
|
int head_idx = blockIdx.x;
|
|
int batch_idx = blockIdx.y;
|
|
}
|
|
|
|
void reducemax(const torch::Tensor &src, torch::Tensor &dest)
|
|
{
|
|
int len = src.size(0);
|
|
int block_size = 1024;
|
|
dim3 grid((len + block_size - 1) / block_size);
|
|
dim3 block(block_size);
|
|
VLLM_DISPATCH_FLOATING_TYPES(src.scalar_type(), "reducemax", [&]
|
|
{ reducemax_kernel<1024, scalar_t><<<grid, block>>>(
|
|
src.data_ptr<scalar_t>(),
|
|
dest.data_ptr<scalar_t>(),
|
|
len); });
|
|
}
|
|
|
|
__global__ void test_cute_tensor_kernel()
|
|
{
|
|
cute::Tensor rmem_4x8_col = cute::make_tensor<cute::half_t>(Shape<_4, _8>{});
|
|
Tensor ind_tensor = make_identity_tensor(make_shape(16, 16));
|
|
Tensor bool_tensor = make_tensor<bool>(shape(rmem_4x8_col));
|
|
Tensor rmem_4x8_pad = make_tensor<float>(Shape<_4, _8>{},
|
|
Stride<_32, _2>{});
|
|
Layout smem_layout = make_layout(make_shape(Int<4>{}, Int<8>{}));
|
|
__shared__ float smem[decltype(cosize(smem_layout))::value]; // (static-only allocation)
|
|
printf("smem size is :%d\n", decltype(cosize(smem_layout))::value);
|
|
Tensor stensor = make_tensor(make_smem_ptr(smem), smem_layout);
|
|
printf("tensor size is: %d, ind size is: %d, rmem size is: %d , rmem4x8 is: %d, smem size is: %d\n",
|
|
bool_tensor.size(),
|
|
ind_tensor.size(), rmem_4x8_col.size(),
|
|
rmem_4x8_pad.size(),
|
|
stensor.size());
|
|
auto TA = make_layout(make_shape(Int<32>{}, Int<8>{}), LayoutRight{}); // (m,k) -> thr_idx; k-major
|
|
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, float>{}, // Atom: Copy TAs as if they were uint128_t
|
|
Layout<Shape<_32, _8>>{}, // Thr layout 32x8 m-major
|
|
Layout<Shape<_4, _1>>{}); // Val layout 4x1 m-major
|
|
printf("stensor size 1 is %d\n", cute::size<1>(stensor));
|
|
#if 0
|
|
print_latex(copyA);
|
|
#endif
|
|
}
|
|
|
|
// template <int head, int batch = 0, int head_dim = 0>
|
|
// void test_template()
|
|
// {
|
|
// if (head % 2 == 0)
|
|
// {
|
|
// std::cout << "what the fuck" << endl;
|
|
// }
|
|
// }
|
|
void test_cute_tensor()
|
|
{
|
|
dim3 thread_block(16, 16);
|
|
dim3 block(16);
|
|
test_cute_tensor_kernel<<<block, thread_block>>>();
|
|
}
|
|
|
|
__global__ void md_op(const float *a)
|
|
{
|
|
int tidx = threadIdx.x;
|
|
int bid = blockIdx.x;
|
|
int hid = blockIdx.y;
|
|
int offset = blockDim.x * blockDim.y;
|
|
// 绑定到自己的进
|
|
} |