Merge branch 'main' of http://192.168.0.100:3000/squall/torch_ext
This commit is contained in:
commit
a1aa7fd0d6
29
csrc/attention.cu
Normal file
29
csrc/attention.cu
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
#include "core.h"
|
||||||
|
|
||||||
|
// calculate the vec cum of different matrix row and col.
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__ scalar_t vecsum(scalar_t *q, scalar_t *k)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void attention_kernel(const scalar_t *q,
|
||||||
|
const scalar_t *k,
|
||||||
|
const scalar_t *v,
|
||||||
|
int head_num,
|
||||||
|
int head_dim,
|
||||||
|
int seq_len,
|
||||||
|
int batch_size,
|
||||||
|
int hidden_dim,
|
||||||
|
scalar_t *output)
|
||||||
|
{
|
||||||
|
// calculate the gemm.
|
||||||
|
int tid = threadIdx.x;
|
||||||
|
// caculate the offset.
|
||||||
|
int q_offset = blockIdx.x * head_num * 1 * head_dim;
|
||||||
|
int k_offset = blockIdx.x * head_num * seq_len * head_dim;
|
||||||
|
int v_offset = blockIdx.x * head_num * seq_len * head_dim;
|
||||||
|
// calculate the sum.
|
||||||
|
// calculate the softmax
|
||||||
|
// calculate the weighted sum.
|
||||||
|
}
|
||||||
@ -1,6 +1,7 @@
|
|||||||
#include "core.h"
|
#include "core.h"
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_bf16.h>
|
||||||
// #include <mma.h>
|
// #include <mma.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
@ -58,7 +59,8 @@ __global__ void matmul_sigmoid_cuda(const T *in1, const T *in2, T *output, int r
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define BASE_BLOCK 256
|
#define BASE_BLOCK 256
|
||||||
#define CALL_ADD_FUNCTION add_two_tensors_cuda<<<(input1.size(0) * input1.size(1) + BASE_BLOCK - 1) / BASE_BLOCK, BASE_BLOCK>>>(src, src1, dest, input1.size(0) * input1.size(1));
|
#define CALL_ADD_FUNCTION \
|
||||||
|
add_two_tensors_cuda<<<(input1.size(0) * input1.size(1) + BASE_BLOCK - 1) / BASE_BLOCK, BASE_BLOCK>>>(src, src1, dest, input1.size(0) * input1.size(1));
|
||||||
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output)
|
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output)
|
||||||
{
|
{
|
||||||
// cout << input1.dtype() << " the size 1 is : " << input1.size(0) << " size 2 is " << input1.size(1) << "output dim is :" << output.size(0) << output.size(1) << endl;
|
// cout << input1.dtype() << " the size 1 is : " << input1.size(0) << " size 2 is " << input1.size(1) << "output dim is :" << output.size(0) << output.size(1) << endl;
|
||||||
|
|||||||
31
csrc/core.h
31
csrc/core.h
@ -11,6 +11,36 @@
|
|||||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
|
#define TYPING_DISPATCH(scalar_t, ...) \
|
||||||
|
switch (scalar_t) \
|
||||||
|
{ \
|
||||||
|
case at::ScalarType::Float: \
|
||||||
|
{ \
|
||||||
|
using fi_type = float; \
|
||||||
|
__VA_ARGS__(); \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::BFloat16: \
|
||||||
|
{ \
|
||||||
|
using fi_type = __nv_bfloat16; \
|
||||||
|
__VA_ARGS__(); \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::Half: \
|
||||||
|
{ \
|
||||||
|
using fi_type = __half; \
|
||||||
|
__VA_ARGS__(); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
// default: \
|
||||||
|
// printf("do not support such type\n"); \
|
||||||
|
// }
|
||||||
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
|
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
|
||||||
|
|
||||||
void rope_tensors(const torch::Tensor &input, torch::Tensor &output, int rope_index_start);
|
void rope_tensors(const torch::Tensor &input, torch::Tensor &output, int rope_index_start);
|
||||||
@ -29,4 +59,5 @@ void md_mm(const torch::Tensor &src);
|
|||||||
void block_sum(const torch::Tensor &src, torch::Tensor &dest);
|
void block_sum(const torch::Tensor &src, torch::Tensor &dest);
|
||||||
void md_block_sum(const torch::Tensor &src, torch::Tensor &dest);
|
void md_block_sum(const torch::Tensor &src, torch::Tensor &dest);
|
||||||
void softmax(const torch::Tensor &src, torch::Tensor &dest);
|
void softmax(const torch::Tensor &src, torch::Tensor &dest);
|
||||||
|
void rms_norm(torch::Tensor &states, float eps, float gamma);
|
||||||
#endif
|
#endif
|
||||||
@ -19,4 +19,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
|||||||
m.def("block_sum", &block_sum, "test block sum");
|
m.def("block_sum", &block_sum, "test block sum");
|
||||||
m.def("md_block_sum", &md_block_sum, "multi dimension block sum");
|
m.def("md_block_sum", &md_block_sum, "multi dimension block sum");
|
||||||
m.def("softmax", &softmax, "test softmax example");
|
m.def("softmax", &softmax, "test softmax example");
|
||||||
|
m.def("rms_norm", &rms_norm, "rms noram");
|
||||||
}
|
}
|
||||||
|
|||||||
4
csrc/fp8_vec.cu
Normal file
4
csrc/fp8_vec.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "core.h"
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
|
||||||
|
#define __nv_fp8_e4m3 fp8_e4m3
|
||||||
@ -0,0 +1,82 @@
|
|||||||
|
#include "core.h"
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
#include <cub/util_device.cuh>
|
||||||
|
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
template <typename src_type, typename dest_type>
|
||||||
|
__device__ dest_type fi_cast(src_type a)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
__device__ float fi_cast<__nv_bfloat16, float>(__nv_bfloat16 a)
|
||||||
|
{
|
||||||
|
return __bfloat162float(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ float fi_cast<__half, float>(__half a)
|
||||||
|
{
|
||||||
|
return __half2float(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __nv_bfloat16 fi_cast<float, __nv_bfloat16>(float a)
|
||||||
|
{
|
||||||
|
return __float2bfloat16(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __half fi_cast<float, __half>(float a)
|
||||||
|
{
|
||||||
|
return __float2half(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, int BLOCK_SIZE = 1024>
|
||||||
|
__global__ void rms_norm_kernel(scalar_t *states, int hidden_dim, float eps, float gamma)
|
||||||
|
{
|
||||||
|
__shared__ float smem[BLOCK_SIZE];
|
||||||
|
int idx = threadIdx.x;
|
||||||
|
int offset = blockIdx.x * hidden_dim;
|
||||||
|
float local_sum = 0.0f;
|
||||||
|
for (int i = idx; i < hidden_dim; i += blockDim.x)
|
||||||
|
{
|
||||||
|
int local_offset = offset + i;
|
||||||
|
float tmp = fi_cast<scalar_t, float>(states[local_offset]);
|
||||||
|
local_sum += tmp * tmp;
|
||||||
|
}
|
||||||
|
if (idx < BLOCK_SIZE)
|
||||||
|
smem[idx] = local_sum;
|
||||||
|
else
|
||||||
|
smem[idx] = 0.0f;
|
||||||
|
__syncthreads();
|
||||||
|
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
|
||||||
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
|
|
||||||
|
float sum_res = BlockReduce(temp_storage).Sum(smem[idx]);
|
||||||
|
sum_res = sqrtf(sum_res);
|
||||||
|
sum_res = sum_res + eps;
|
||||||
|
for (int i = idx; i < hidden_dim; i += blockDim.x)
|
||||||
|
{
|
||||||
|
int local_offset = offset + i;
|
||||||
|
float tmp = fi_cast<scalar_t, float>(states[local_offset]);
|
||||||
|
tmp = tmp / sum_res * gamma;
|
||||||
|
states[local_offset] = fi_cast<float, scalar_t>(tmp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void rms_norm(torch::Tensor &states, float eps, float gamma)
|
||||||
|
{
|
||||||
|
int h = states.size(0);
|
||||||
|
int hidden_dim = states.size(1);
|
||||||
|
int block_size = 1024;
|
||||||
|
dim3 block(h);
|
||||||
|
dim3 grid(block_size);
|
||||||
|
TYPING_DISPATCH(states.scalar_type(), [&]
|
||||||
|
{ rms_norm_kernel<fi_type><<<block, grid>>>(reinterpret_cast<fi_type *>(states.data_ptr()), hidden_dim, eps, gamma); });
|
||||||
|
}
|
||||||
33
csrc/md.cu
33
csrc/md.cu
@ -69,6 +69,7 @@ __global__ void row_sum_kernel(const float *src, float *dest, int hidden_dim)
|
|||||||
if (tid == 0)
|
if (tid == 0)
|
||||||
{
|
{
|
||||||
dest[blockIdx.x] = sum;
|
dest[blockIdx.x] = sum;
|
||||||
|
printf("blockidx.x: %d, blockIdx.y %d, blockIdx.z %d\n", blockIdx.x, blockIdx.y, blockIdx.z);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,6 +110,7 @@ __global__ void md_row_sum_kernel(const float *src, float *dest, int stride_a, i
|
|||||||
if (tid == 0 && block_offset < all_len)
|
if (tid == 0 && block_offset < all_len)
|
||||||
{
|
{
|
||||||
dest[block_offset] = sum;
|
dest[block_offset] = sum;
|
||||||
|
printf("blockIdx.x %d, blockIdx.y %d, blockIdx.z %d, blockDim.x %d\n", blockIdx.x, blockIdx.y, blockIdx.z, blockDim.x);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -243,3 +245,34 @@ void softmax(const torch::Tensor &src, torch::Tensor &dest)
|
|||||||
dest.data_ptr<scalar_t>(),
|
dest.data_ptr<scalar_t>(),
|
||||||
hidden_dim); });
|
hidden_dim); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int head_num = 8>
|
||||||
|
__global__ void test_head_dim_kernel()
|
||||||
|
{
|
||||||
|
int idx = threadIdx.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define LANUCH(head_num) test_head_dim_kernel<head_num><<<block, grid>>>();
|
||||||
|
|
||||||
|
void test_head_dim(int head_num)
|
||||||
|
{
|
||||||
|
dim3 block(10);
|
||||||
|
dim3 grid(1024);
|
||||||
|
switch (head_num)
|
||||||
|
{
|
||||||
|
case 1:
|
||||||
|
LANUCH(1);
|
||||||
|
case 8:
|
||||||
|
LANUCH(8);
|
||||||
|
case 16:
|
||||||
|
LANUCH(16);
|
||||||
|
case 32:
|
||||||
|
LANUCH(32);
|
||||||
|
case 48:
|
||||||
|
LANUCH(48);
|
||||||
|
case 64:
|
||||||
|
LANUCH(64);
|
||||||
|
default:
|
||||||
|
printf("do not support head num\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
35
csrc/random_env.cu
Normal file
35
csrc/random_env.cu
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <curand_kernel.h>
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
#include <cub/util_device.cuh>
|
||||||
|
|
||||||
|
__global__ void initRandom(curandState *state, unsigned long seed)
|
||||||
|
{
|
||||||
|
int id = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
curand_init(seed, id, 0, &state[id]);
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void random_generate(float *out, curandState *state)
|
||||||
|
{
|
||||||
|
curandState localState = state[id];
|
||||||
|
__shared__ float shared_data[1024];
|
||||||
|
int idx = threadIdx.x;
|
||||||
|
typedef cub::BlockReduce<float, 1024> BlockReduce;
|
||||||
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
|
for (int i = 0; i < 1024; i++)
|
||||||
|
{
|
||||||
|
shared_data[idx] += curand_uniform(&localState);
|
||||||
|
float sum = BlockReduce(temp_storage).Sum(shared_data[idx]);
|
||||||
|
shared_data[idx] += shared_data[idx] / sum;
|
||||||
|
}
|
||||||
|
out[idx] = shared_data[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
void random_invoke()
|
||||||
|
{
|
||||||
|
curandState *devStates;
|
||||||
|
int thread_num = 1024;
|
||||||
|
float out[1024];
|
||||||
|
initRandom<<<1, thread_num>>>(devStates, 1234);
|
||||||
|
random_generate<<<1, thread_num>>>(out, devStates);
|
||||||
|
}
|
||||||
35
csrc/type_utils.h
Normal file
35
csrc/type_utils.h
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
#ifndef TYPE_UTILS_H
|
||||||
|
#define TYPE_UTILS_H
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#define FP16 __half
|
||||||
|
#define BF16 __nv_bfloat16
|
||||||
|
template <typename src_type, typename dest_type>
|
||||||
|
__device__ dest_type fi_cast(src_type a)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
__device__ float fi_cast<BF16, float>(BF16 a)
|
||||||
|
{
|
||||||
|
return __bfloat162float(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ float fi_cast<FP16, float>(FP16 a)
|
||||||
|
{
|
||||||
|
return __half2float(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ BF16 fi_cast<float, BF16>(float a)
|
||||||
|
{
|
||||||
|
return __float2bfloat16(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ FP16 fi_cast<float, FP16>(float a)
|
||||||
|
{
|
||||||
|
return __float2half(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
28
fi/test_module.py
Normal file
28
fi/test_module.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class TestModule(nn.Module):
|
||||||
|
def __init__(self, start_layer_index: int, end_layer_index: int, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.model = DecodeLayer()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for module in self.model:
|
||||||
|
x = module(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DecodeLayer(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
for i in range(10):
|
||||||
|
self.layers.append(nn.Linear(10, 10))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_module = TestModule(0, 3)
|
||||||
|
for x in test_module.named_parameters():
|
||||||
|
print(x[0])
|
||||||
3
setup.py
3
setup.py
@ -12,12 +12,13 @@ files = [
|
|||||||
"csrc/max.cu",
|
"csrc/max.cu",
|
||||||
"csrc/md.cu",
|
"csrc/md.cu",
|
||||||
"csrc/quantize.cu",
|
"csrc/quantize.cu",
|
||||||
|
"csrc/layernorm.cu",
|
||||||
]
|
]
|
||||||
extension = CUDAExtension(
|
extension = CUDAExtension(
|
||||||
name="torch_cuda_ext.core",
|
name="torch_cuda_ext.core",
|
||||||
sources=files,
|
sources=files,
|
||||||
extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]},
|
extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]},
|
||||||
include_dirs=["/home/squall/program/cutlass/include"],
|
include_dirs=["/home/squall/quant_data/program/cutlass/include"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_exts.append(extension)
|
cuda_exts.append(extension)
|
||||||
|
|||||||
13
test_layernorm.py
Normal file
13
test_layernorm.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
import torch
|
||||||
|
import torch_cuda_ext.core as core
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
eps = float(0.01)
|
||||||
|
gamma = float(1)
|
||||||
|
states = torch.randn(size=(100, 1024)).half().cuda()
|
||||||
|
res_states = F.rms_norm(states, [1024], eps=eps)
|
||||||
|
print(res_states)
|
||||||
|
|
||||||
|
core.rms_norm(states, eps, gamma)
|
||||||
|
print(states)
|
||||||
Loading…
Reference in New Issue
Block a user