Implement approximate GELU kernels (#828)
This commit is contained in:
parent
a41c20435e
commit
d64bf1646c
@ -4,9 +4,25 @@ void silu_and_mul(
|
|||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
torch::Tensor& input);
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
void gelu_new(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
void gelu_fast(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def(
|
m.def(
|
||||||
"silu_and_mul",
|
"silu_and_mul",
|
||||||
&silu_and_mul,
|
&silu_and_mul,
|
||||||
"Activation function used in SwiGLU.");
|
"Activation function used in SwiGLU.");
|
||||||
|
m.def(
|
||||||
|
"gelu_new",
|
||||||
|
&gelu_new,
|
||||||
|
"GELU implementation used in GPT-2.");
|
||||||
|
m.def(
|
||||||
|
"gelu_fast",
|
||||||
|
&gelu_fast,
|
||||||
|
"Approximate GELU implementation.");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -46,3 +46,71 @@ void silu_and_mul(
|
|||||||
d);
|
d);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// Element-wise activation kernel template.
|
||||||
|
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||||
|
__global__ void activation_kernel(
|
||||||
|
scalar_t* __restrict__ out, // [num_tokens, d]
|
||||||
|
const scalar_t* __restrict__ input, // [num_tokens, d]
|
||||||
|
const int d) {
|
||||||
|
const int token_idx = blockIdx.x;
|
||||||
|
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||||
|
const scalar_t x = __ldg(&input[token_idx * d + idx]);
|
||||||
|
out[token_idx * d + idx] = ACT_FN(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
// Launch element-wise activation kernel.
|
||||||
|
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||||
|
int num_tokens = input.size(0); \
|
||||||
|
int d = input.size(1); \
|
||||||
|
dim3 grid(num_tokens); \
|
||||||
|
dim3 block(std::min(d, 1024)); \
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
|
AT_DISPATCH_FLOATING_TYPES_AND2( \
|
||||||
|
at::ScalarType::Half, \
|
||||||
|
at::ScalarType::BFloat16, \
|
||||||
|
input.scalar_type(), \
|
||||||
|
"activation_kernel", \
|
||||||
|
[&] { \
|
||||||
|
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
||||||
|
out.data_ptr<scalar_t>(), \
|
||||||
|
input.data_ptr<scalar_t>(), \
|
||||||
|
d); \
|
||||||
|
});
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
||||||
|
const float x3 = (float) (x * x * x);
|
||||||
|
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
|
||||||
|
return ((T) 0.5) * x * (((T) 1.0) + t);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
||||||
|
const float f = (float) x;
|
||||||
|
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
|
||||||
|
return ((T) 0.5) * x * (((T) 1.0) + t);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
void gelu_new(
|
||||||
|
torch::Tensor& out, // [num_tokens, d]
|
||||||
|
torch::Tensor& input) // [num_tokens, d]
|
||||||
|
{
|
||||||
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gelu_fast(
|
||||||
|
torch::Tensor& out, // [num_tokens, d]
|
||||||
|
torch::Tensor& input) // [num_tokens, d]
|
||||||
|
{
|
||||||
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
||||||
|
}
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from transformers.activations import get_activation
|
||||||
from vllm import activation_ops
|
from vllm import activation_ops
|
||||||
|
|
||||||
|
|
||||||
@ -28,3 +28,45 @@ def test_silu_and_mul() -> None:
|
|||||||
for d in [512, 4096, 5120, 13824]:
|
for d in [512, 4096, 5120, 13824]:
|
||||||
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
|
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
|
||||||
run_silu_and_mul(num_tokens, d, dtype)
|
run_silu_and_mul(num_tokens, d, dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def run_gelu_new(
|
||||||
|
num_tokens: int,
|
||||||
|
d: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
||||||
|
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||||
|
activation_ops.gelu_new(out, x)
|
||||||
|
ref_out = get_activation("gelu_new")(x)
|
||||||
|
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gelu_new() -> None:
|
||||||
|
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||||
|
for num_tokens in [7, 83, 2048]:
|
||||||
|
for d in [512, 4096, 5120, 13824]:
|
||||||
|
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
|
||||||
|
run_gelu_new(num_tokens, d, dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def run_gelu_fast(
|
||||||
|
num_tokens: int,
|
||||||
|
d: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
||||||
|
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||||
|
activation_ops.gelu_fast(out, x)
|
||||||
|
ref_out = get_activation("gelu_fast")(x)
|
||||||
|
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gelu_fast() -> None:
|
||||||
|
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||||
|
for num_tokens in [7, 83, 2048]:
|
||||||
|
for d in [512, 4096, 5120, 13824]:
|
||||||
|
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
|
||||||
|
run_gelu_fast(num_tokens, d, dtype)
|
||||||
|
|||||||
@ -4,23 +4,6 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from vllm import activation_ops
|
from vllm import activation_ops
|
||||||
|
|
||||||
_ACTIVATION_REGISTRY = {
|
|
||||||
"gelu": nn.GELU(),
|
|
||||||
# NOTE: The following GELU functions may introduce small rounding errors.
|
|
||||||
"gelu_new": nn.GELU(approximate="tanh"),
|
|
||||||
"gelu_fast": nn.GELU(approximate="tanh"),
|
|
||||||
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
|
||||||
"relu": nn.ReLU(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_act_fn(act_fn: str) -> nn.Module:
|
|
||||||
"""Get an activation function by name."""
|
|
||||||
act_fn = act_fn.lower()
|
|
||||||
if act_fn in _ACTIVATION_REGISTRY:
|
|
||||||
return _ACTIVATION_REGISTRY[act_fn]
|
|
||||||
raise ValueError(f"Activation function {act_fn!r} is not supported.")
|
|
||||||
|
|
||||||
|
|
||||||
class SiluAndMul(nn.Module):
|
class SiluAndMul(nn.Module):
|
||||||
"""An activation function for SwiGLU.
|
"""An activation function for SwiGLU.
|
||||||
@ -38,3 +21,40 @@ class SiluAndMul(nn.Module):
|
|||||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||||
activation_ops.silu_and_mul(out, x)
|
activation_ops.silu_and_mul(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class NewGELU(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
num_tokens = x.shape[0]
|
||||||
|
d = x.shape[1]
|
||||||
|
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||||
|
activation_ops.gelu_new(out, x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class FastGELU(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
num_tokens = x.shape[0]
|
||||||
|
d = x.shape[1]
|
||||||
|
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||||
|
activation_ops.gelu_fast(out, x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
_ACTIVATION_REGISTRY = {
|
||||||
|
"gelu": nn.GELU(),
|
||||||
|
"gelu_fast": FastGELU(),
|
||||||
|
"gelu_new": NewGELU(),
|
||||||
|
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
||||||
|
"relu": nn.ReLU(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_act_fn(act_fn: str) -> nn.Module:
|
||||||
|
"""Get an activation function by name."""
|
||||||
|
act_fn = act_fn.lower()
|
||||||
|
if act_fn in _ACTIVATION_REGISTRY:
|
||||||
|
return _ACTIVATION_REGISTRY[act_fn]
|
||||||
|
raise ValueError(f"Activation function {act_fn!r} is not supported.")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user