#include "core.h" #include #include #include #include #include #include #include using namespace std; template __device__ dest_type fi_cast(src_type a) { } template <> __device__ float fi_cast<__nv_bfloat16, float>(__nv_bfloat16 a) { return __bfloat162float(a); } template <> __device__ float fi_cast<__half, float>(__half a) { return __half2float(a); } template <> __device__ __nv_bfloat16 fi_cast(float a) { return __float2bfloat16(a); } template <> __device__ __half fi_cast(float a) { return __float2half(a); } template __global__ void rms_norm_kernel(scalar_t *states, int hidden_dim, float eps, float gamma) { __shared__ float smem[BLOCK_SIZE]; int idx = threadIdx.x; int offset = blockIdx.x * hidden_dim; float local_sum = 0.0f; for (int i = idx; i < hidden_dim; i += blockDim.x) { int local_offset = offset + i; float tmp = fi_cast(states[local_offset]); local_sum += tmp * tmp; } if (idx < BLOCK_SIZE) smem[idx] = local_sum; else smem[idx] = 0.0f; __syncthreads(); typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; float sum_res = BlockReduce(temp_storage).Sum(smem[idx]); sum_res = sqrtf(sum_res); sum_res = sum_res + eps; for (int i = idx; i < hidden_dim; i += blockDim.x) { int local_offset = offset + i; float tmp = fi_cast(states[local_offset]); tmp = tmp / sum_res * gamma; states[local_offset] = fi_cast(tmp); } } void rms_norm(torch::Tensor &states, float eps, float gamma) { int h = states.size(0); int hidden_dim = states.size(1); int block_size = 1024; dim3 block(h); dim3 grid(block_size); cout << states.scalar_type() << endl; TYPING_DISPATCH(states.scalar_type(), [&] { rms_norm_kernel<<>>(reinterpret_cast(states.data_ptr()), hidden_dim, eps, gamma); }); }