Allow changing epsilon parameter in RMS norm kernel (#1112)
This commit is contained in:
parent
26986bbc60
commit
ff61a49dd1
@ -43,7 +43,8 @@ using Layout = cutlass::layout::RowMajor;
|
||||
void rmsnorm_host(cutlass::MatrixCoord tensor_size,
|
||||
cutlass::TensorRef<ElementType, Layout> output,
|
||||
cutlass::TensorRef<ElementType, Layout> input,
|
||||
cutlass::TensorRef<ElementType, Layout> weight) {
|
||||
cutlass::TensorRef<ElementType, Layout> weight,
|
||||
float epsilon) {
|
||||
const int M = tensor_size.row();
|
||||
const int N = tensor_size.column();
|
||||
|
||||
@ -56,7 +57,7 @@ void rmsnorm_host(cutlass::MatrixCoord tensor_size,
|
||||
}
|
||||
|
||||
float sq_mean = square_sum / (float)N;
|
||||
float sqrt_var = cutlass::fast_sqrt(sq_mean + (float)1e-6);
|
||||
float sqrt_var = cutlass::fast_sqrt(sq_mean + epsilon);
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
float inp = static_cast<float>(input.at({m, n}));
|
||||
@ -91,9 +92,9 @@ void run_test(int M, int N) {
|
||||
input.sync_device();
|
||||
weight.sync_device();
|
||||
|
||||
rmsnorm_host({M, N}, output_ref.host_ref(), input.host_ref(), weight.host_ref());
|
||||
rmsnorm_host({M, N}, output_ref.host_ref(), input.host_ref(), weight.host_ref(), (float)1e-5);
|
||||
cutlass::rmsnorm({M, N}, output.device_ref(),
|
||||
input.device_ref(), weight.device_ref(), NULL);
|
||||
input.device_ref(), weight.device_ref(), NULL, (float)1e-5);
|
||||
|
||||
output.sync_host();
|
||||
|
||||
|
@ -43,7 +43,7 @@ namespace cutlass {
|
||||
|
||||
__global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input,
|
||||
const float4 *weight,
|
||||
const int m, const int n) {
|
||||
const int m, const int n, float epsilon) {
|
||||
const int m_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
const int bdimx = blockDim.x;
|
||||
@ -76,7 +76,7 @@ __global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input,
|
||||
blockReduceSum<float, 1>(local_sums);
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
s_mean = rsqrtf(local_sums[0] / n + 1e-6);
|
||||
s_mean = rsqrtf(local_sums[0] / n + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@ -117,7 +117,8 @@ template<typename T>
|
||||
__global__ void rmsnorm_twoPassAlgo_e1(T* output,
|
||||
const T* input,
|
||||
const T* weight,
|
||||
const int m, const int n)
|
||||
const int m, const int n,
|
||||
float epsilon)
|
||||
{
|
||||
const int m_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
@ -139,7 +140,7 @@ __global__ void rmsnorm_twoPassAlgo_e1(T* output,
|
||||
blockReduceSum<float, 1>(local_sums);
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
s_mean = rsqrtf(local_sums[0] / n + 1e-6);
|
||||
s_mean = rsqrtf(local_sums[0] / n + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@ -155,7 +156,7 @@ void rmsnorm(cutlass::MatrixCoord tensor_size,
|
||||
TensorRef<T, layout::RowMajor> ref_output,
|
||||
TensorRef<T, layout::RowMajor> ref_input,
|
||||
TensorRef<T, layout::RowMajor> ref_weight,
|
||||
cudaStream_t stream){
|
||||
cudaStream_t stream, float epsilon = 1e-5){
|
||||
const int m = tensor_size.row();
|
||||
const int n = tensor_size.column();
|
||||
T* output = ref_output.data();
|
||||
@ -167,12 +168,12 @@ void rmsnorm(cutlass::MatrixCoord tensor_size,
|
||||
dim3 block(min(1024, (n / 8 + 31) / 32 * 32));
|
||||
|
||||
rmsnorm_twoPassAlgo_e8<<<grid, block, 0, stream>>>(
|
||||
(float4 *)output, (const float4 *)input, (const float4 *)weight, m, n);
|
||||
(float4 *)output, (const float4 *)input, (const float4 *)weight, m, n, epsilon);
|
||||
} else {
|
||||
dim3 block(min(1024, ((n + 31)/32 + 31)/32*32));
|
||||
|
||||
rmsnorm_twoPassAlgo_e1<<<grid, block, 0, stream>>>(
|
||||
output, input, weight, m, n);
|
||||
output, input, weight, m, n, epsilon);
|
||||
}
|
||||
|
||||
auto result = cudaGetLastError();
|
||||
|
Loading…
Reference in New Issue
Block a user