Add RMS norm (#979)
This commit is contained in:
parent
e066ced33b
commit
f679663224
@ -30,4 +30,5 @@ cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_util
|
||||
tensor_reduce.cu
|
||||
cutlass_test_levels.cu
|
||||
rms_norm.cu
|
||||
)
|
||||
|
123
test/unit/util/rms_norm.cu
Normal file
123
test/unit/util/rms_norm.cu
Normal file
@ -0,0 +1,123 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include "../common/cutlass_unit_test.h"
|
||||
|
||||
#include "cutlass/util/device_rmsnorm.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/constants.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
using ElementType = cutlass::half_t;
|
||||
using Layout = cutlass::layout::RowMajor;
|
||||
|
||||
void rmsnorm_host(cutlass::MatrixCoord tensor_size,
|
||||
cutlass::TensorRef<ElementType, Layout> output,
|
||||
cutlass::TensorRef<ElementType, Layout> input,
|
||||
cutlass::TensorRef<ElementType, Layout> weight) {
|
||||
const int M = tensor_size.row();
|
||||
const int N = tensor_size.column();
|
||||
|
||||
for (int m = 0; m < M; ++m) {
|
||||
float square_sum{0};
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
float inp = static_cast<float>(input.at({m, n}));
|
||||
square_sum += inp * inp;
|
||||
}
|
||||
|
||||
float sq_mean = square_sum / (float)N;
|
||||
float sqrt_var = cutlass::fast_sqrt(sq_mean + (float)1e-6);
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
float inp = static_cast<float>(input.at({m, n}));
|
||||
float g = static_cast<float>(weight.at({0, n}));
|
||||
float res_fp32 = inp / sqrt_var * g;
|
||||
output.at({m, n}) = ElementType(res_fp32);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void run_test(int M, int N) {
|
||||
cutlass::HostTensor<ElementType, Layout> input, output_ref, output, weight;
|
||||
input.reset({M, N});
|
||||
output.reset({M, N});
|
||||
output_ref.reset({M, N});
|
||||
weight.reset({1, N});
|
||||
|
||||
const unsigned seed = 2022;
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(input.host_view(),
|
||||
seed,
|
||||
ElementType(5),
|
||||
ElementType(-5),
|
||||
0);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(weight.host_view(),
|
||||
seed,
|
||||
ElementType(5),
|
||||
ElementType(-5),
|
||||
0);
|
||||
|
||||
input.sync_device();
|
||||
weight.sync_device();
|
||||
|
||||
rmsnorm_host({M, N}, output_ref.host_ref(), input.host_ref(), weight.host_ref());
|
||||
cutlass::rmsnorm({M, N}, output.device_ref(),
|
||||
input.device_ref(), weight.device_ref(), NULL);
|
||||
|
||||
output.sync_host();
|
||||
|
||||
float max_abs_diff = -1;
|
||||
float mean_abs_diff = 0;
|
||||
for (int m = 0; m < M; ++m) {
|
||||
for (int n = 0; n < N; ++n) {
|
||||
auto diff = abs(static_cast<float>(output_ref.at({m, n}) - output.at({m, n})));
|
||||
mean_abs_diff += diff;
|
||||
max_abs_diff = max(max_abs_diff, diff);
|
||||
}
|
||||
}
|
||||
|
||||
mean_abs_diff /= float(M * N);
|
||||
|
||||
EXPECT_TRUE(max_abs_diff < 0.001f && mean_abs_diff < 0.001f)
|
||||
<< "Max absolute difference : " << max_abs_diff << "\n"
|
||||
<< "Mean absolute difference: " << mean_abs_diff;
|
||||
}
|
||||
|
||||
TEST(RMSNorm, 16x1024) {
|
||||
run_test(16, 1024);
|
||||
}
|
||||
|
||||
TEST(RMSNorm, 1x127) {
|
||||
run_test(1, 127);
|
||||
}
|
185
tools/util/include/cutlass/util/device_rmsnorm.h
Normal file
185
tools/util/include/cutlass/util/device_rmsnorm.h
Normal file
@ -0,0 +1,185 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/util/device_utils.h"
|
||||
#include <float.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
__global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input,
|
||||
const float4 *weight,
|
||||
const int m, const int n) {
|
||||
const int m_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
const int bdimx = blockDim.x;
|
||||
__shared__ float s_mean;
|
||||
float local_sums[1] = {0.0f};
|
||||
const int n_8 = n / 8;
|
||||
int offset = m_idx * n_8;
|
||||
input += offset;
|
||||
output += offset;
|
||||
|
||||
for (int index = tid; index < n_8; index += bdimx) {
|
||||
const float4 local_val = input[index];
|
||||
const half2 *h1 = (half2 *)&local_val.x;
|
||||
const half2 *h2 = (half2 *)&local_val.y;
|
||||
const half2 *h3 = (half2 *)&local_val.z;
|
||||
const half2 *h4 = (half2 *)&local_val.w;
|
||||
local_sums[0] += static_cast<float>(h1->x) * static_cast<float>(h1->x) +
|
||||
static_cast<float>(h1->y) * static_cast<float>(h1->y) +
|
||||
static_cast<float>(h2->x) * static_cast<float>(h2->x) +
|
||||
static_cast<float>(h2->y) * static_cast<float>(h2->y) +
|
||||
static_cast<float>(h3->x) * static_cast<float>(h3->x) +
|
||||
static_cast<float>(h3->y) * static_cast<float>(h3->y) +
|
||||
static_cast<float>(h4->x) * static_cast<float>(h4->x) +
|
||||
static_cast<float>(h4->y) * static_cast<float>(h4->y);
|
||||
}
|
||||
|
||||
if (blockDim.x <= 32) {
|
||||
warpReduceSum<float, 1>(local_sums);
|
||||
} else {
|
||||
blockReduceSum<float, 1>(local_sums);
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
s_mean = rsqrtf(local_sums[0] / n + 1e-6);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int index = tid; index < n_8; index += bdimx) {
|
||||
const float4 local_val = input[index];
|
||||
const float4 weight_val = weight[index];
|
||||
|
||||
const half2 *l1 = (half2 *)&local_val.x;
|
||||
const half2 *l2 = (half2 *)&local_val.y;
|
||||
const half2 *l3 = (half2 *)&local_val.z;
|
||||
const half2 *l4 = (half2 *)&local_val.w;
|
||||
|
||||
const half2 *g1 = (half2 *)&weight_val.x;
|
||||
const half2 *g2 = (half2 *)&weight_val.y;
|
||||
const half2 *g3 = (half2 *)&weight_val.z;
|
||||
const half2 *g4 = (half2 *)&weight_val.w;
|
||||
|
||||
float4 tmp;
|
||||
half2 *h1 = (half2 *)&tmp.x;
|
||||
half2 *h2 = (half2 *)&tmp.y;
|
||||
half2 *h3 = (half2 *)&tmp.z;
|
||||
half4 *h4 = (half4 *)&tmp.w;
|
||||
|
||||
h1->x = half(static_cast<float>(l1->x) * s_mean * static_cast<float>(g1->x));
|
||||
h1->y = half(static_cast<float>(l1->y) * s_mean * static_cast<float>(g1->y));
|
||||
h2->x = half(static_cast<float>(l2->x) * s_mean * static_cast<float>(g2->x));
|
||||
h2->y = half(static_cast<float>(l2->y) * s_mean * static_cast<float>(g2->y));
|
||||
h3->x = half(static_cast<float>(l3->x) * s_mean * static_cast<float>(g3->x));
|
||||
h3->y = half(static_cast<float>(l3->y) * s_mean * static_cast<float>(g3->y));
|
||||
h4->x = half(static_cast<float>(l4->x) * s_mean * static_cast<float>(g4->x));
|
||||
h4->y = half(static_cast<float>(l4->y) * s_mean * static_cast<float>(g4->y));
|
||||
|
||||
output[index] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__global__ void rmsnorm_twoPassAlgo_e1(T* output,
|
||||
const T* input,
|
||||
const T* weight,
|
||||
const int m, const int n)
|
||||
{
|
||||
const int m_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
const int bdimx = blockDim.x;
|
||||
__shared__ float s_mean;
|
||||
float local_sums[1] = {0.0f};
|
||||
int offset = m_idx * n;
|
||||
input += offset;
|
||||
output += offset;
|
||||
|
||||
for (int index = tid ; index < n ; index += bdimx){
|
||||
float local_val = static_cast<float>(input[index]);
|
||||
local_sums[0] += local_val * local_val;
|
||||
}
|
||||
if (blockDim.x <= 32) {
|
||||
warpReduceSum<float, 1>(local_sums);
|
||||
}
|
||||
else {
|
||||
blockReduceSum<float, 1>(local_sums);
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
s_mean = rsqrtf(local_sums[0] / n + 1e-6);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int index = tid ; index < n ; index += bdimx){
|
||||
const T weight_val = weight[index];
|
||||
const T local_val = input[index];
|
||||
output[index] = T(static_cast<float>(local_val) * s_mean * static_cast<float>(weight_val));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void rmsnorm(cutlass::MatrixCoord tensor_size,
|
||||
TensorRef<T, layout::RowMajor> ref_output,
|
||||
TensorRef<T, layout::RowMajor> ref_input,
|
||||
TensorRef<T, layout::RowMajor> ref_weight,
|
||||
cudaStream_t stream){
|
||||
const int m = tensor_size.row();
|
||||
const int n = tensor_size.column();
|
||||
T* output = ref_output.data();
|
||||
const T* input = ref_input.data();
|
||||
const T* weight = ref_weight.data();
|
||||
dim3 grid(m);
|
||||
|
||||
if (n % 8 == 0 && std::is_same<T, cutlass::half_t>::value) {
|
||||
dim3 block(min(1024, (n / 8 + 31) / 32 * 32));
|
||||
|
||||
rmsnorm_twoPassAlgo_e8<<<grid, block, 0, stream>>>(
|
||||
(float4 *)output, (const float4 *)input, (const float4 *)weight, m, n);
|
||||
} else {
|
||||
dim3 block(min(1024, ((n + 31)/32 + 31)/32*32));
|
||||
|
||||
rmsnorm_twoPassAlgo_e1<<<grid, block, 0, stream>>>(
|
||||
output, input, weight, m, n);
|
||||
}
|
||||
|
||||
auto result = cudaGetLastError();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "CUDA error: " << cudaGetErrorString(result) << std::endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cutlass
|
Loading…
Reference in New Issue
Block a user