From 9b294976a2373f6fda22c1b2e704c395c8bd0787 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 2 Dec 2023 21:18:40 -0800 Subject: [PATCH] Add PyTorch-native implementation of custom layers (#1898) --- tests/kernels/test_activation.py | 27 ++- tests/kernels/test_layernorm.py | 57 +++---- tests/kernels/test_pos_encoding.py | 157 +++--------------- vllm/model_executor/layers/activation.py | 18 ++ vllm/model_executor/layers/layernorm.py | 20 +++ .../model_executor/layers/rotary_embedding.py | 54 ++++++ 6 files changed, 149 insertions(+), 184 deletions(-) diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 978b377e..ba062054 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -1,9 +1,7 @@ import pytest import torch -import torch.nn.functional as F -from transformers.activations import get_activation -from vllm._C import ops +from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing @@ -11,11 +9,6 @@ D = [512, 4096, 5120, 13824] # Arbitrary values for testing SEEDS = [0] -def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor: - x1, x2 = x.chunk(chunks=2, dim=1) - return F.silu(x1) * x2 - - @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @@ -30,9 +23,9 @@ def test_silu_and_mul( torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda") - out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") - ops.silu_and_mul(out, x) - ref_out = ref_silu_and_mul(x) + layer = SiluAndMul() + out = layer(x) + ref_out = layer._forward(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) @@ -50,9 +43,9 @@ def test_gelu_new( torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") - out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") - ops.gelu_new(out, x) - ref_out = get_activation("gelu_new")(x) + layer = NewGELU() + out = layer(x) + ref_out = layer._forward(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) @@ -69,7 +62,7 @@ def test_gelu_fast( torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") - out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") - ops.gelu_fast(out, x) - ref_out = get_activation("gelu_fast")(x) + layer = FastGELU() + out = layer(x) + ref_out = layer._forward(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index ee5228d6..b362e2c4 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -1,58 +1,47 @@ import pytest import torch -import torch.nn as nn -from vllm._C import ops +from vllm.model_executor.layers.layernorm import RMSNorm DTYPES = [torch.half, torch.bfloat16, torch.float] -HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing +HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing +ADD_RESIDUAL = [False, True] SEEDS = [0] -class RefRMSNorm(nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - weight = torch.empty(hidden_size) - weight.normal_(mean=1.0, std=0.1) - self.weight = nn.Parameter(weight) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + - self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("add_residual", ADD_RESIDUAL) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() def test_rms_norm( num_tokens: int, hidden_size: int, + add_residual: bool, dtype: torch.dtype, seed: int, ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - scale = float(hidden_size**-0.5) - x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") - x.uniform_(-scale, scale) - ref = RefRMSNorm(hidden_size).to(dtype).cuda() + layer = RMSNorm(hidden_size).to(dtype).cuda() + layer.weight.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device="cuda") + x *= scale + residual = torch.randn_like(x) * scale if add_residual else None - out = torch.empty_like(x) - ops.rms_norm( - out, - x, - ref.weight.data, - ref.variance_epsilon, - ) - ref_out = ref(x) - assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5) + # NOTE(woosuk): The reference implementation should be executed first + # because the custom kernel is in-place. + ref_out = layer._forward(x, residual) + out = layer(x, residual) + # NOTE(woosuk): LayerNorm operators (including RMS) typically have larger + # numerical errors than other operators because they involve reductions. + # Therefore, we use a larger tolerance. + if add_residual: + assert torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2) + assert torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2) + else: + assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 7d22bdab..25d6bf23 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -1,105 +1,23 @@ -from typing import Optional, Tuple +from typing import Optional import pytest import torch -import torch.nn as nn -import torch.nn.functional as F -from vllm._C import ops +from vllm.model_executor.layers.rotary_embedding import get_rope IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] HEAD_SIZES = [64, 80, 96, 112, 128, 256] ROTARY_DIMS = [None, 32] # None means rotary dim == head size -NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing -NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing +NUM_HEADS = [7, 17] # Arbitrary values for testing +BATCH_SIZES = [1, 5] # Arbitrary values for testing +SEQ_LENS = [11, 8192] # Arbitrary values for testing SEEDS = [0] -def rotate_neox(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def rotate_gptj(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., ::2] - x2 = x[..., 1::2] - x = torch.stack((-x2, x1), dim=-1) - return x.flatten(-2) - - -def apply_rope( - q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool, -) -> Tuple[torch.Tensor, torch.Tensor]: - rotate_fn = rotate_neox if is_neox_style else rotate_gptj - q_embed = (q * cos) + (rotate_fn(q) * sin) - k_embed = (k * cos) + (rotate_fn(k) * sin) - return q_embed, k_embed - - -class RefRotaryEmbedding(nn.Module): - """Reference implementation of rotary embedding.""" - - def __init__( - self, - dim: int, - is_neox_style: bool, - max_position_embeddings: int = 8192, - base: int = 10000, - ) -> None: - super().__init__() - self.rotary_dim = dim - self.is_neox_style = is_neox_style - self.max_position_embeddings = max_position_embeddings - - # Create cos and sin embeddings. - inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim)) - t = torch.arange(max_position_embeddings).float() - freqs = torch.einsum("i,j->ij", t, inv_freq.float()) - if is_neox_style: - emb = torch.cat((freqs, freqs), dim=-1) - else: - emb = torch.repeat_interleave(freqs, 2, -1) - cos = emb.cos().to(dtype=inv_freq.dtype) - sin = emb.sin().to(dtype=inv_freq.dtype) - self.register_buffer("cos_cached", cos, persistent=False) - self.register_buffer("sin_cached", sin, persistent=False) - - def forward( - self, - positions: torch.Tensor, # [num_tokens] - query: torch.Tensor, # [num_tokens, num_heads, head_size] - key: torch.Tensor, # [num_tokens, num_heads, head_size] - ) -> Tuple[torch.Tensor, torch.Tensor]: - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - - query_rot = query_rot.transpose(0, 1) - key_rot = key_rot.transpose(0, 1) - cos = F.embedding(positions, self.cos_cached) - sin = F.embedding(positions, self.sin_cached) - - query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin, - self.is_neox_style) - query_rot = query_rot.transpose(0, 1).contiguous() - key_rot = key_rot.transpose(0, 1).contiguous() - - query = torch.cat((query_rot, query_pass), dim=-1) - key = torch.cat((key_rot, key_pass), dim=-1) - - # Output query/key shape: [num_tokens, num_tokens, head_size] - return query, key - - @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) -@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("seq_len", SEQ_LENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) @@ -108,7 +26,8 @@ class RefRotaryEmbedding(nn.Module): @torch.inference_mode() def test_rotary_embedding( is_neox_style: bool, - num_tokens: int, + batch_size: int, + seq_len: int, num_heads: int, head_size: int, rotary_dim: Optional[int], @@ -122,53 +41,25 @@ def test_rotary_embedding( torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - positions = torch.randint(0, max_position, (num_tokens, ), device="cuda") - query = torch.randn(num_tokens, + if rotary_dim is None: + rotary_dim = head_size + rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) + rope = rope.to(dtype).cuda() + + positions = torch.randint(0, + max_position, (batch_size, seq_len), + device="cuda") + query = torch.randn(batch_size, + seq_len, num_heads * head_size, dtype=dtype, device="cuda") - key = torch.randn(num_tokens, - num_heads * head_size, - dtype=dtype, - device="cuda") - - # Create the rotary embedding. - inv_freq = 1.0 / (base**( - torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)) - t = torch.arange(max_position).float() - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cos_sin_cache = torch.cat((cos, sin), dim=-1) - cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda") - - # Run the kernel. The kernel is in-place, so we need to clone the inputs. - out_query = query.clone() - out_key = key.clone() - ops.rotary_embedding( - positions, - out_query, - out_key, - head_size, - cos_sin_cache, - is_neox_style, - ) - - # Run the reference implementation. - ref_rotary_embedding = RefRotaryEmbedding( - dim=rotary_dim, - is_neox_style=is_neox_style, - max_position_embeddings=max_position, - base=base, - ).to(dtype=dtype, device="cuda") - ref_query, ref_key = ref_rotary_embedding( - positions, - query.view(num_tokens, num_heads, head_size), - key.view(num_tokens, num_heads, head_size), - ) - ref_query = ref_query.view(num_tokens, num_heads * head_size) - ref_key = ref_key.view(num_tokens, num_heads * head_size) + key = torch.randn_like(query) + # NOTE(woosuk): The reference implementation should be executed first + # because the custom kernel is in-place. + ref_query, ref_key = rope._forward(positions, query, key) + out_query, out_key = rope.forward(positions, query, key) # Compare the results. assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5) assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 5c0def82..1af120d1 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -1,8 +1,10 @@ """Custom activation functions.""" +import math from typing import Optional import torch import torch.nn as nn +import torch.nn.functional as F from vllm._C import ops from vllm.model_executor.layers.quantization import QuantizationConfig @@ -22,6 +24,11 @@ class SiluAndMul(nn.Module): return: (batch_size, seq_len, d) or (num_tokens, d) """ + def _forward(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + def forward(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) @@ -32,6 +39,12 @@ class SiluAndMul(nn.Module): class NewGELU(nn.Module): + def _forward(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + c = math.sqrt(2.0 / math.pi) + return 0.5 * x * (1.0 + torch.tanh(c * + (x + 0.044715 * torch.pow(x, 3.0)))) + def forward(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) ops.gelu_new(out, x) @@ -40,6 +53,11 @@ class NewGELU(nn.Module): class FastGELU(nn.Module): + def _forward(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * + (1.0 + 0.044715 * x * x))) + def forward(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) ops.gelu_fast(out, x) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 69fba087..cb3cee2b 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -23,6 +23,26 @@ class RMSNorm(nn.Module): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + def _forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + if residual is None: + return x + else: + return x, residual + def forward( self, x: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index b3a4d38b..91c093e3 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -30,6 +30,19 @@ import torch.nn as nn from vllm._C import ops +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + class RotaryEmbedding(nn.Module): """Original rotary positional embedding.""" @@ -81,6 +94,47 @@ class RotaryEmbedding(nn.Module): cache = torch.cat((cos, sin), dim=-1) return cache + def _forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + + query_rot = query[..., :self.rotary_dim] + key_rot = key[..., :self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] + + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + query = query.flatten(-2) + key = key.flatten(-2) + return query, key + def forward( self, positions: torch.Tensor,