diff --git a/csrc/activation.cpp b/csrc/activation.cpp index f95afdc0..c100f89a 100644 --- a/csrc/activation.cpp +++ b/csrc/activation.cpp @@ -4,9 +4,25 @@ void silu_and_mul( torch::Tensor& out, torch::Tensor& input); +void gelu_new( + torch::Tensor& out, + torch::Tensor& input); + +void gelu_fast( + torch::Tensor& out, + torch::Tensor& input); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); + m.def( + "gelu_new", + &gelu_new, + "GELU implementation used in GPT-2."); + m.def( + "gelu_fast", + &gelu_fast, + "Approximate GELU implementation."); } diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 45788093..fc1f086f 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -46,3 +46,71 @@ void silu_and_mul( d); }); } + +namespace vllm { + +// Element-wise activation kernel template. +template +__global__ void activation_kernel( + scalar_t* __restrict__ out, // [num_tokens, d] + const scalar_t* __restrict__ input, // [num_tokens, d] + const int d) { + const int token_idx = blockIdx.x; + for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = __ldg(&input[token_idx * d + idx]); + out[token_idx * d + idx] = ACT_FN(x); + } +} + +} // namespace vllm + +// Launch element-wise activation kernel. +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + int num_tokens = input.size(0); \ + int d = input.size(1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + AT_DISPATCH_FLOATING_TYPES_AND2( \ + at::ScalarType::Half, \ + at::ScalarType::BFloat16, \ + input.scalar_type(), \ + "activation_kernel", \ + [&] { \ + vllm::activation_kernel><<>>( \ + out.data_ptr(), \ + input.data_ptr(), \ + d); \ + }); + +namespace vllm { + +template +__device__ __forceinline__ T gelu_new_kernel(const T& x) { + const float x3 = (float) (x * x * x); + const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); + return ((T) 0.5) * x * (((T) 1.0) + t); +} + +template +__device__ __forceinline__ T gelu_fast_kernel(const T& x) { + const float f = (float) x; + const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); + return ((T) 0.5) * x * (((T) 1.0) + t); +} + +} // namespace vllm + +void gelu_new( + torch::Tensor& out, // [num_tokens, d] + torch::Tensor& input) // [num_tokens, d] +{ + LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); +} + +void gelu_fast( + torch::Tensor& out, // [num_tokens, d] + torch::Tensor& input) // [num_tokens, d] +{ + LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); +} diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 173f00cd..b4ddd3e5 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -1,6 +1,6 @@ import torch import torch.nn.functional as F - +from transformers.activations import get_activation from vllm import activation_ops @@ -28,3 +28,45 @@ def test_silu_and_mul() -> None: for d in [512, 4096, 5120, 13824]: print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') run_silu_and_mul(num_tokens, d, dtype) + + +@torch.inference_mode() +def run_gelu_new( + num_tokens: int, + d: int, + dtype: torch.dtype, +) -> None: + x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') + out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') + activation_ops.gelu_new(out, x) + ref_out = get_activation("gelu_new")(x) + assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) + + +def test_gelu_new() -> None: + for dtype in [torch.half, torch.bfloat16, torch.float]: + for num_tokens in [7, 83, 2048]: + for d in [512, 4096, 5120, 13824]: + print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') + run_gelu_new(num_tokens, d, dtype) + + +@torch.inference_mode() +def run_gelu_fast( + num_tokens: int, + d: int, + dtype: torch.dtype, +) -> None: + x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') + out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') + activation_ops.gelu_fast(out, x) + ref_out = get_activation("gelu_fast")(x) + assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) + + +def test_gelu_fast() -> None: + for dtype in [torch.half, torch.bfloat16, torch.float]: + for num_tokens in [7, 83, 2048]: + for d in [512, 4096, 5120, 13824]: + print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') + run_gelu_fast(num_tokens, d, dtype) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 53ce6de2..9222fe27 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -4,23 +4,6 @@ import torch.nn as nn from vllm import activation_ops -_ACTIVATION_REGISTRY = { - "gelu": nn.GELU(), - # NOTE: The following GELU functions may introduce small rounding errors. - "gelu_new": nn.GELU(approximate="tanh"), - "gelu_fast": nn.GELU(approximate="tanh"), - "gelu_pytorch_tanh": nn.GELU(approximate="tanh"), - "relu": nn.ReLU(), -} - - -def get_act_fn(act_fn: str) -> nn.Module: - """Get an activation function by name.""" - act_fn = act_fn.lower() - if act_fn in _ACTIVATION_REGISTRY: - return _ACTIVATION_REGISTRY[act_fn] - raise ValueError(f"Activation function {act_fn!r} is not supported.") - class SiluAndMul(nn.Module): """An activation function for SwiGLU. @@ -38,3 +21,40 @@ class SiluAndMul(nn.Module): out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) activation_ops.silu_and_mul(out, x) return out + + +class NewGELU(nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + num_tokens = x.shape[0] + d = x.shape[1] + out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) + activation_ops.gelu_new(out, x) + return out + + +class FastGELU(nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + num_tokens = x.shape[0] + d = x.shape[1] + out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) + activation_ops.gelu_fast(out, x) + return out + + +_ACTIVATION_REGISTRY = { + "gelu": nn.GELU(), + "gelu_fast": FastGELU(), + "gelu_new": NewGELU(), + "gelu_pytorch_tanh": nn.GELU(approximate="tanh"), + "relu": nn.ReLU(), +} + + +def get_act_fn(act_fn: str) -> nn.Module: + """Get an activation function by name.""" + act_fn = act_fn.lower() + if act_fn in _ACTIVATION_REGISTRY: + return _ACTIVATION_REGISTRY[act_fn] + raise ValueError(f"Activation function {act_fn!r} is not supported.")