Merge branch 'main' of http://192.168.0.100:3000/squall/torch_ext
This commit is contained in:
commit
a1aa7fd0d6
29
csrc/attention.cu
Normal file
29
csrc/attention.cu
Normal file
@ -0,0 +1,29 @@
|
||||
#include "core.h"
|
||||
|
||||
// calculate the vec cum of different matrix row and col.
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t vecsum(scalar_t *q, scalar_t *k)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void attention_kernel(const scalar_t *q,
|
||||
const scalar_t *k,
|
||||
const scalar_t *v,
|
||||
int head_num,
|
||||
int head_dim,
|
||||
int seq_len,
|
||||
int batch_size,
|
||||
int hidden_dim,
|
||||
scalar_t *output)
|
||||
{
|
||||
// calculate the gemm.
|
||||
int tid = threadIdx.x;
|
||||
// caculate the offset.
|
||||
int q_offset = blockIdx.x * head_num * 1 * head_dim;
|
||||
int k_offset = blockIdx.x * head_num * seq_len * head_dim;
|
||||
int v_offset = blockIdx.x * head_num * seq_len * head_dim;
|
||||
// calculate the sum.
|
||||
// calculate the softmax
|
||||
// calculate the weighted sum.
|
||||
}
|
||||
@ -1,6 +1,7 @@
|
||||
#include "core.h"
|
||||
#include <iostream>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
// #include <mma.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
@ -58,7 +59,8 @@ __global__ void matmul_sigmoid_cuda(const T *in1, const T *in2, T *output, int r
|
||||
}
|
||||
|
||||
#define BASE_BLOCK 256
|
||||
#define CALL_ADD_FUNCTION add_two_tensors_cuda<<<(input1.size(0) * input1.size(1) + BASE_BLOCK - 1) / BASE_BLOCK, BASE_BLOCK>>>(src, src1, dest, input1.size(0) * input1.size(1));
|
||||
#define CALL_ADD_FUNCTION \
|
||||
add_two_tensors_cuda<<<(input1.size(0) * input1.size(1) + BASE_BLOCK - 1) / BASE_BLOCK, BASE_BLOCK>>>(src, src1, dest, input1.size(0) * input1.size(1));
|
||||
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output)
|
||||
{
|
||||
// cout << input1.dtype() << " the size 1 is : " << input1.size(0) << " size 2 is " << input1.size(1) << "output dim is :" << output.size(0) << output.size(1) << endl;
|
||||
|
||||
31
csrc/core.h
31
csrc/core.h
@ -11,6 +11,36 @@
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#define TYPING_DISPATCH(scalar_t, ...) \
|
||||
switch (scalar_t) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using fi_type = float; \
|
||||
__VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: \
|
||||
{ \
|
||||
using fi_type = __nv_bfloat16; \
|
||||
__VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using fi_type = __half; \
|
||||
__VA_ARGS__(); \
|
||||
} \
|
||||
}
|
||||
// default: \
|
||||
// printf("do not support such type\n"); \
|
||||
// }
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
|
||||
|
||||
void rope_tensors(const torch::Tensor &input, torch::Tensor &output, int rope_index_start);
|
||||
@ -29,4 +59,5 @@ void md_mm(const torch::Tensor &src);
|
||||
void block_sum(const torch::Tensor &src, torch::Tensor &dest);
|
||||
void md_block_sum(const torch::Tensor &src, torch::Tensor &dest);
|
||||
void softmax(const torch::Tensor &src, torch::Tensor &dest);
|
||||
void rms_norm(torch::Tensor &states, float eps, float gamma);
|
||||
#endif
|
||||
@ -19,4 +19,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
m.def("block_sum", &block_sum, "test block sum");
|
||||
m.def("md_block_sum", &md_block_sum, "multi dimension block sum");
|
||||
m.def("softmax", &softmax, "test softmax example");
|
||||
m.def("rms_norm", &rms_norm, "rms noram");
|
||||
}
|
||||
|
||||
4
csrc/fp8_vec.cu
Normal file
4
csrc/fp8_vec.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "core.h"
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#define __nv_fp8_e4m3 fp8_e4m3
|
||||
@ -0,0 +1,82 @@
|
||||
#include "core.h"
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_device.cuh>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <torch/torch.h>
|
||||
#include <torch/all.h>
|
||||
using namespace std;
|
||||
|
||||
template <typename src_type, typename dest_type>
|
||||
__device__ dest_type fi_cast(src_type a)
|
||||
{
|
||||
}
|
||||
template <>
|
||||
__device__ float fi_cast<__nv_bfloat16, float>(__nv_bfloat16 a)
|
||||
{
|
||||
return __bfloat162float(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float fi_cast<__half, float>(__half a)
|
||||
{
|
||||
return __half2float(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __nv_bfloat16 fi_cast<float, __nv_bfloat16>(float a)
|
||||
{
|
||||
return __float2bfloat16(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __half fi_cast<float, __half>(float a)
|
||||
{
|
||||
return __float2half(a);
|
||||
}
|
||||
|
||||
template <typename scalar_t, int BLOCK_SIZE = 1024>
|
||||
__global__ void rms_norm_kernel(scalar_t *states, int hidden_dim, float eps, float gamma)
|
||||
{
|
||||
__shared__ float smem[BLOCK_SIZE];
|
||||
int idx = threadIdx.x;
|
||||
int offset = blockIdx.x * hidden_dim;
|
||||
float local_sum = 0.0f;
|
||||
for (int i = idx; i < hidden_dim; i += blockDim.x)
|
||||
{
|
||||
int local_offset = offset + i;
|
||||
float tmp = fi_cast<scalar_t, float>(states[local_offset]);
|
||||
local_sum += tmp * tmp;
|
||||
}
|
||||
if (idx < BLOCK_SIZE)
|
||||
smem[idx] = local_sum;
|
||||
else
|
||||
smem[idx] = 0.0f;
|
||||
__syncthreads();
|
||||
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
float sum_res = BlockReduce(temp_storage).Sum(smem[idx]);
|
||||
sum_res = sqrtf(sum_res);
|
||||
sum_res = sum_res + eps;
|
||||
for (int i = idx; i < hidden_dim; i += blockDim.x)
|
||||
{
|
||||
int local_offset = offset + i;
|
||||
float tmp = fi_cast<scalar_t, float>(states[local_offset]);
|
||||
tmp = tmp / sum_res * gamma;
|
||||
states[local_offset] = fi_cast<float, scalar_t>(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
void rms_norm(torch::Tensor &states, float eps, float gamma)
|
||||
{
|
||||
int h = states.size(0);
|
||||
int hidden_dim = states.size(1);
|
||||
int block_size = 1024;
|
||||
dim3 block(h);
|
||||
dim3 grid(block_size);
|
||||
TYPING_DISPATCH(states.scalar_type(), [&]
|
||||
{ rms_norm_kernel<fi_type><<<block, grid>>>(reinterpret_cast<fi_type *>(states.data_ptr()), hidden_dim, eps, gamma); });
|
||||
}
|
||||
33
csrc/md.cu
33
csrc/md.cu
@ -69,6 +69,7 @@ __global__ void row_sum_kernel(const float *src, float *dest, int hidden_dim)
|
||||
if (tid == 0)
|
||||
{
|
||||
dest[blockIdx.x] = sum;
|
||||
printf("blockidx.x: %d, blockIdx.y %d, blockIdx.z %d\n", blockIdx.x, blockIdx.y, blockIdx.z);
|
||||
}
|
||||
}
|
||||
|
||||
@ -109,6 +110,7 @@ __global__ void md_row_sum_kernel(const float *src, float *dest, int stride_a, i
|
||||
if (tid == 0 && block_offset < all_len)
|
||||
{
|
||||
dest[block_offset] = sum;
|
||||
printf("blockIdx.x %d, blockIdx.y %d, blockIdx.z %d, blockDim.x %d\n", blockIdx.x, blockIdx.y, blockIdx.z, blockDim.x);
|
||||
}
|
||||
}
|
||||
|
||||
@ -243,3 +245,34 @@ void softmax(const torch::Tensor &src, torch::Tensor &dest)
|
||||
dest.data_ptr<scalar_t>(),
|
||||
hidden_dim); });
|
||||
}
|
||||
|
||||
template <int head_num = 8>
|
||||
__global__ void test_head_dim_kernel()
|
||||
{
|
||||
int idx = threadIdx.x;
|
||||
}
|
||||
|
||||
#define LANUCH(head_num) test_head_dim_kernel<head_num><<<block, grid>>>();
|
||||
|
||||
void test_head_dim(int head_num)
|
||||
{
|
||||
dim3 block(10);
|
||||
dim3 grid(1024);
|
||||
switch (head_num)
|
||||
{
|
||||
case 1:
|
||||
LANUCH(1);
|
||||
case 8:
|
||||
LANUCH(8);
|
||||
case 16:
|
||||
LANUCH(16);
|
||||
case 32:
|
||||
LANUCH(32);
|
||||
case 48:
|
||||
LANUCH(48);
|
||||
case 64:
|
||||
LANUCH(64);
|
||||
default:
|
||||
printf("do not support head num\n");
|
||||
}
|
||||
}
|
||||
|
||||
35
csrc/random_env.cu
Normal file
35
csrc/random_env.cu
Normal file
@ -0,0 +1,35 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <curand_kernel.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_device.cuh>
|
||||
|
||||
__global__ void initRandom(curandState *state, unsigned long seed)
|
||||
{
|
||||
int id = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
curand_init(seed, id, 0, &state[id]);
|
||||
}
|
||||
|
||||
__global__ void random_generate(float *out, curandState *state)
|
||||
{
|
||||
curandState localState = state[id];
|
||||
__shared__ float shared_data[1024];
|
||||
int idx = threadIdx.x;
|
||||
typedef cub::BlockReduce<float, 1024> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
for (int i = 0; i < 1024; i++)
|
||||
{
|
||||
shared_data[idx] += curand_uniform(&localState);
|
||||
float sum = BlockReduce(temp_storage).Sum(shared_data[idx]);
|
||||
shared_data[idx] += shared_data[idx] / sum;
|
||||
}
|
||||
out[idx] = shared_data[idx];
|
||||
}
|
||||
|
||||
void random_invoke()
|
||||
{
|
||||
curandState *devStates;
|
||||
int thread_num = 1024;
|
||||
float out[1024];
|
||||
initRandom<<<1, thread_num>>>(devStates, 1234);
|
||||
random_generate<<<1, thread_num>>>(out, devStates);
|
||||
}
|
||||
35
csrc/type_utils.h
Normal file
35
csrc/type_utils.h
Normal file
@ -0,0 +1,35 @@
|
||||
#ifndef TYPE_UTILS_H
|
||||
#define TYPE_UTILS_H
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#define FP16 __half
|
||||
#define BF16 __nv_bfloat16
|
||||
template <typename src_type, typename dest_type>
|
||||
__device__ dest_type fi_cast(src_type a)
|
||||
{
|
||||
}
|
||||
template <>
|
||||
__device__ float fi_cast<BF16, float>(BF16 a)
|
||||
{
|
||||
return __bfloat162float(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float fi_cast<FP16, float>(FP16 a)
|
||||
{
|
||||
return __half2float(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ BF16 fi_cast<float, BF16>(float a)
|
||||
{
|
||||
return __float2bfloat16(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ FP16 fi_cast<float, FP16>(float a)
|
||||
{
|
||||
return __float2half(a);
|
||||
}
|
||||
|
||||
#endif
|
||||
28
fi/test_module.py
Normal file
28
fi/test_module.py
Normal file
@ -0,0 +1,28 @@
|
||||
# coding=utf-8
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class TestModule(nn.Module):
|
||||
def __init__(self, start_layer_index: int, end_layer_index: int, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model = DecodeLayer()
|
||||
|
||||
def forward(self, x):
|
||||
for module in self.model:
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
|
||||
class DecodeLayer(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.layers = nn.ModuleList()
|
||||
for i in range(10):
|
||||
self.layers.append(nn.Linear(10, 10))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_module = TestModule(0, 3)
|
||||
for x in test_module.named_parameters():
|
||||
print(x[0])
|
||||
3
setup.py
3
setup.py
@ -12,12 +12,13 @@ files = [
|
||||
"csrc/max.cu",
|
||||
"csrc/md.cu",
|
||||
"csrc/quantize.cu",
|
||||
"csrc/layernorm.cu",
|
||||
]
|
||||
extension = CUDAExtension(
|
||||
name="torch_cuda_ext.core",
|
||||
sources=files,
|
||||
extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]},
|
||||
include_dirs=["/home/squall/program/cutlass/include"],
|
||||
include_dirs=["/home/squall/quant_data/program/cutlass/include"],
|
||||
)
|
||||
|
||||
cuda_exts.append(extension)
|
||||
|
||||
13
test_layernorm.py
Normal file
13
test_layernorm.py
Normal file
@ -0,0 +1,13 @@
|
||||
# coding=utf-8
|
||||
import torch
|
||||
import torch_cuda_ext.core as core
|
||||
import torch.nn.functional as F
|
||||
|
||||
eps = float(0.01)
|
||||
gamma = float(1)
|
||||
states = torch.randn(size=(100, 1024)).half().cuda()
|
||||
res_states = F.rms_norm(states, [1024], eps=eps)
|
||||
print(res_states)
|
||||
|
||||
core.rms_norm(states, eps, gamma)
|
||||
print(states)
|
||||
Loading…
Reference in New Issue
Block a user