2024-12-14 13:34:30 +08:00
|
|
|
#include "core.h"
|
|
|
|
|
#include <cub/cub.cuh>
|
|
|
|
|
#include <cub/util_device.cuh>
|
|
|
|
|
|
|
|
|
|
#include <cuda_fp16.h>
|
|
|
|
|
#include <cuda_fp8.h>
|
|
|
|
|
#include <cuda_bf16.h>
|
|
|
|
|
#include <torch/torch.h>
|
|
|
|
|
#include <torch/all.h>
|
|
|
|
|
using namespace std;
|
|
|
|
|
|
|
|
|
|
template <typename src_type, typename dest_type>
|
|
|
|
|
__device__ dest_type fi_cast(src_type a)
|
|
|
|
|
{
|
|
|
|
|
}
|
|
|
|
|
template <>
|
|
|
|
|
__device__ float fi_cast<__nv_bfloat16, float>(__nv_bfloat16 a)
|
|
|
|
|
{
|
|
|
|
|
return __bfloat162float(a);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
__device__ float fi_cast<__half, float>(__half a)
|
|
|
|
|
{
|
|
|
|
|
return __half2float(a);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
__device__ __nv_bfloat16 fi_cast<float, __nv_bfloat16>(float a)
|
|
|
|
|
{
|
|
|
|
|
return __float2bfloat16(a);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
__device__ __half fi_cast<float, __half>(float a)
|
|
|
|
|
{
|
|
|
|
|
return __float2half(a);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t, int BLOCK_SIZE = 1024>
|
|
|
|
|
__global__ void rms_norm_kernel(scalar_t *states, int hidden_dim, float eps, float gamma)
|
|
|
|
|
{
|
|
|
|
|
__shared__ float smem[BLOCK_SIZE];
|
|
|
|
|
int idx = threadIdx.x;
|
|
|
|
|
int offset = blockIdx.x * hidden_dim;
|
|
|
|
|
float local_sum = 0.0f;
|
|
|
|
|
for (int i = idx; i < hidden_dim; i += blockDim.x)
|
|
|
|
|
{
|
|
|
|
|
int local_offset = offset + i;
|
|
|
|
|
float tmp = fi_cast<scalar_t, float>(states[local_offset]);
|
|
|
|
|
local_sum += tmp * tmp;
|
|
|
|
|
}
|
|
|
|
|
if (idx < BLOCK_SIZE)
|
|
|
|
|
smem[idx] = local_sum;
|
|
|
|
|
else
|
|
|
|
|
smem[idx] = 0.0f;
|
|
|
|
|
__syncthreads();
|
|
|
|
|
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
|
|
|
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
|
|
|
|
|
|
|
|
float sum_res = BlockReduce(temp_storage).Sum(smem[idx]);
|
|
|
|
|
sum_res = sqrtf(sum_res);
|
|
|
|
|
sum_res = sum_res + eps;
|
|
|
|
|
for (int i = idx; i < hidden_dim; i += blockDim.x)
|
|
|
|
|
{
|
|
|
|
|
int local_offset = offset + i;
|
|
|
|
|
float tmp = fi_cast<scalar_t, float>(states[local_offset]);
|
|
|
|
|
tmp = tmp / sum_res * gamma;
|
|
|
|
|
states[local_offset] = fi_cast<float, scalar_t>(tmp);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void rms_norm(torch::Tensor &states, float eps, float gamma)
|
|
|
|
|
{
|
|
|
|
|
int h = states.size(0);
|
|
|
|
|
int hidden_dim = states.size(1);
|
|
|
|
|
int block_size = 1024;
|
|
|
|
|
dim3 block(h);
|
|
|
|
|
dim3 grid(block_size);
|
|
|
|
|
cout << states.scalar_type() << endl;
|
|
|
|
|
TYPING_DISPATCH(states.scalar_type(), [&]
|
|
|
|
|
{ rms_norm_kernel<fi_type><<<block, grid>>>(reinterpret_cast<fi_type *>(states.data_ptr()), hidden_dim, eps, gamma); });
|
|
|
|
|
}
|