torch_ext/csrc/layernorm.cu

83 lines
2.2 KiB
Plaintext
Raw Normal View History

2024-12-14 13:34:30 +08:00
#include "core.h"
#include <cub/cub.cuh>
#include <cub/util_device.cuh>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_bf16.h>
#include <torch/torch.h>
#include <torch/all.h>
using namespace std;
template <typename src_type, typename dest_type>
__device__ dest_type fi_cast(src_type a)
{
}
template <>
__device__ float fi_cast<__nv_bfloat16, float>(__nv_bfloat16 a)
{
return __bfloat162float(a);
}
template <>
__device__ float fi_cast<__half, float>(__half a)
{
return __half2float(a);
}
template <>
__device__ __nv_bfloat16 fi_cast<float, __nv_bfloat16>(float a)
{
return __float2bfloat16(a);
}
template <>
__device__ __half fi_cast<float, __half>(float a)
{
return __float2half(a);
}
template <typename scalar_t, int BLOCK_SIZE = 1024>
__global__ void rms_norm_kernel(scalar_t *states, int hidden_dim, float eps, float gamma)
{
__shared__ float smem[BLOCK_SIZE];
int idx = threadIdx.x;
int offset = blockIdx.x * hidden_dim;
float local_sum = 0.0f;
for (int i = idx; i < hidden_dim; i += blockDim.x)
{
int local_offset = offset + i;
float tmp = fi_cast<scalar_t, float>(states[local_offset]);
local_sum += tmp * tmp;
}
if (idx < BLOCK_SIZE)
smem[idx] = local_sum;
else
smem[idx] = 0.0f;
__syncthreads();
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
float sum_res = BlockReduce(temp_storage).Sum(smem[idx]);
sum_res = sqrtf(sum_res);
sum_res = sum_res + eps;
for (int i = idx; i < hidden_dim; i += blockDim.x)
{
int local_offset = offset + i;
float tmp = fi_cast<scalar_t, float>(states[local_offset]);
tmp = tmp / sum_res * gamma;
states[local_offset] = fi_cast<float, scalar_t>(tmp);
}
}
void rms_norm(torch::Tensor &states, float eps, float gamma)
{
int h = states.size(0);
int hidden_dim = states.size(1);
int block_size = 1024;
dim3 block(h);
dim3 grid(block_size);
cout << states.scalar_type() << endl;
TYPING_DISPATCH(states.scalar_type(), [&]
{ rms_norm_kernel<fi_type><<<block, grid>>>(reinterpret_cast<fi_type *>(states.data_ptr()), hidden_dim, eps, gamma); });
}