Implement approximate GELU kernels (#828)

This commit is contained in:
Woosuk Kwon 2023-08-23 07:43:21 +09:00 committed by GitHub
parent a41c20435e
commit d64bf1646c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 164 additions and 18 deletions

View File

@ -4,9 +4,25 @@ void silu_and_mul(
torch::Tensor& out, torch::Tensor& out,
torch::Tensor& input); torch::Tensor& input);
void gelu_new(
torch::Tensor& out,
torch::Tensor& input);
void gelu_fast(
torch::Tensor& out,
torch::Tensor& input);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"silu_and_mul", "silu_and_mul",
&silu_and_mul, &silu_and_mul,
"Activation function used in SwiGLU."); "Activation function used in SwiGLU.");
m.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
m.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
} }

View File

@ -46,3 +46,71 @@ void silu_and_mul(
d); d);
}); });
} }
namespace vllm {
// Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
scalar_t* __restrict__ out, // [num_tokens, d]
const scalar_t* __restrict__ input, // [num_tokens, d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
}
} // namespace vllm
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int num_tokens = input.size(0); \
int d = input.size(1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
AT_DISPATCH_FLOATING_TYPES_AND2( \
at::ScalarType::Half, \
at::ScalarType::BFloat16, \
input.scalar_type(), \
"activation_kernel", \
[&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
namespace vllm {
template<typename T>
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
const float x3 = (float) (x * x * x);
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
return ((T) 0.5) * x * (((T) 1.0) + t);
}
template<typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
const float f = (float) x;
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
return ((T) 0.5) * x * (((T) 1.0) + t);
}
} // namespace vllm
void gelu_new(
torch::Tensor& out, // [num_tokens, d]
torch::Tensor& input) // [num_tokens, d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}
void gelu_fast(
torch::Tensor& out, // [num_tokens, d]
torch::Tensor& input) // [num_tokens, d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}

View File

@ -1,6 +1,6 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers.activations import get_activation
from vllm import activation_ops from vllm import activation_ops
@ -28,3 +28,45 @@ def test_silu_and_mul() -> None:
for d in [512, 4096, 5120, 13824]: for d in [512, 4096, 5120, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
run_silu_and_mul(num_tokens, d, dtype) run_silu_and_mul(num_tokens, d, dtype)
@torch.inference_mode()
def run_gelu_new(
num_tokens: int,
d: int,
dtype: torch.dtype,
) -> None:
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
activation_ops.gelu_new(out, x)
ref_out = get_activation("gelu_new")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
def test_gelu_new() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 83, 2048]:
for d in [512, 4096, 5120, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
run_gelu_new(num_tokens, d, dtype)
@torch.inference_mode()
def run_gelu_fast(
num_tokens: int,
d: int,
dtype: torch.dtype,
) -> None:
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
activation_ops.gelu_fast(out, x)
ref_out = get_activation("gelu_fast")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
def test_gelu_fast() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 83, 2048]:
for d in [512, 4096, 5120, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
run_gelu_fast(num_tokens, d, dtype)

View File

@ -4,23 +4,6 @@ import torch.nn as nn
from vllm import activation_ops from vllm import activation_ops
_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
# NOTE: The following GELU functions may introduce small rounding errors.
"gelu_new": nn.GELU(approximate="tanh"),
"gelu_fast": nn.GELU(approximate="tanh"),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
}
def get_act_fn(act_fn: str) -> nn.Module:
"""Get an activation function by name."""
act_fn = act_fn.lower()
if act_fn in _ACTIVATION_REGISTRY:
return _ACTIVATION_REGISTRY[act_fn]
raise ValueError(f"Activation function {act_fn!r} is not supported.")
class SiluAndMul(nn.Module): class SiluAndMul(nn.Module):
"""An activation function for SwiGLU. """An activation function for SwiGLU.
@ -38,3 +21,40 @@ class SiluAndMul(nn.Module):
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.silu_and_mul(out, x) activation_ops.silu_and_mul(out, x)
return out return out
class NewGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
d = x.shape[1]
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.gelu_new(out, x)
return out
class FastGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
d = x.shape[1]
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.gelu_fast(out, x)
return out
_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
"gelu_fast": FastGELU(),
"gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
}
def get_act_fn(act_fn: str) -> nn.Module:
"""Get an activation function by name."""
act_fn = act_fn.lower()
if act_fn in _ACTIVATION_REGISTRY:
return _ACTIVATION_REGISTRY[act_fn]
raise ValueError(f"Activation function {act_fn!r} is not supported.")