diff --git a/csrc/pos_encoding.cpp b/csrc/pos_encoding.cpp index 565d134c..eee0cf0d 100644 --- a/csrc/pos_encoding.cpp +++ b/csrc/pos_encoding.cpp @@ -1,15 +1,16 @@ #include -void rotary_embedding_neox( +void rotary_embedding( torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int head_size, - torch::Tensor& cos_sin_cache); + torch::Tensor& cos_sin_cache, + bool is_neox); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( - "rotary_embedding_neox", - &rotary_embedding_neox, - "Apply GPT-NeoX style rotary embedding to query and key"); + "rotary_embedding", + &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); } diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index ced26ecb..b4351ee0 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -5,8 +5,38 @@ namespace vllm { -template -__global__ void rotary_embedding_neox_kernel( +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ arr, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int embed_dim) +{ + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = __ldg(cos_ptr + x_index); + sin = __ldg(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = __ldg(cos_ptr + x_index / 2); + sin = __ldg(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +__global__ void rotary_embedding_kernel( const int64_t* __restrict__ positions, // [num_tokens] scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size] @@ -23,58 +53,37 @@ __global__ void rotary_embedding_neox_kernel( const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + const int nq = num_heads * embed_dim; for (int i = threadIdx.x; i < nq; i += blockDim.x) { const int head_idx = i / embed_dim; const int token_head = token_idx * query_stride + head_idx * head_size; - const int rot_offset = i % embed_dim; - const int x_index = rot_offset; - const int y_index = embed_dim + rot_offset; - - const int out_x = token_idx * query_stride + head_idx * head_size + x_index; - const int out_y = token_idx * query_stride + head_idx * head_size + y_index; - - const scalar_t cos = __ldg(cache_ptr + x_index); - const scalar_t sin = __ldg(cache_ptr + y_index); - - const scalar_t q_x = query[token_head + x_index]; - const scalar_t q_y = query[token_head + y_index]; - query[out_x] = q_x * cos - q_y * sin; - query[out_y] = q_y * cos + q_x * sin; + apply_rotary_embedding(query + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); } const int nk = num_kv_heads * embed_dim; for (int i = threadIdx.x; i < nk; i += blockDim.x) { const int head_idx = i / embed_dim; const int token_head = token_idx * key_stride + head_idx * head_size; - const int rot_offset = i % embed_dim; - const int x_index = rot_offset; - const int y_index = embed_dim + rot_offset; - - const int out_x = token_idx * key_stride + head_idx * head_size + x_index; - const int out_y = token_idx * key_stride + head_idx * head_size + y_index; - - const scalar_t cos = __ldg(cache_ptr + x_index); - const scalar_t sin = __ldg(cache_ptr + y_index); - - const scalar_t k_x = key[token_head + x_index]; - const scalar_t k_y = key[token_head + y_index]; - key[out_x] = k_x * cos - k_y * sin; - key[out_y] = k_y * cos + k_x * sin; + apply_rotary_embedding(key + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); } } } // namespace vllm -void rotary_embedding_neox( +void rotary_embedding( torch::Tensor& positions, // [num_tokens] torch::Tensor& query, // [num_tokens, num_heads * head_size] torch::Tensor& key, // [num_tokens, num_kv_heads * head_size] int head_size, - torch::Tensor& cos_sin_cache) // [max_position, rot_dim] -{ + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox) { int num_tokens = query.size(0); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(1) / head_size; @@ -87,18 +96,32 @@ void rotary_embedding_neox( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( query.scalar_type(), - "rotary_embedding_neox", + "rotary_embedding", [&] { - vllm::rotary_embedding_neox_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); + if (is_neox) { + vllm::rotary_embedding_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } else { + vllm::rotary_embedding_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } }); } diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index d830b268..1e591295 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -7,49 +7,64 @@ import torch.nn.functional as F from vllm import pos_encoding_ops +IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] HEAD_SIZES = [64, 80, 96, 112, 128, 256] ROTARY_DIMS = [None, 32] # None means rotary dim == head size NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing -NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing +NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing SEEDS = [0] -def rotate_half(x: torch.Tensor) -> torch.Tensor: +def rotate_neox(x: torch.Tensor) -> torch.Tensor: x1 = x[..., :x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb( +def rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def apply_rope( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, + is_neox_style: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + rotate_fn = rotate_neox if is_neox_style else rotate_gptj + q_embed = (q * cos) + (rotate_fn(q) * sin) + k_embed = (k * cos) + (rotate_fn(k) * sin) return q_embed, k_embed -class RefRotaryEmbeddingNeox(nn.Module): - """Reference implementation of the GPT-NeoX style rotary embedding.""" +class RefRotaryEmbedding(nn.Module): + """Reference implementation of rotary embedding.""" def __init__( self, dim: int, - max_position_embeddings: int = 2048, + is_neox_style: bool, + max_position_embeddings: int = 8192, base: int = 10000, ) -> None: super().__init__() self.rotary_dim = dim + self.is_neox_style = is_neox_style self.max_position_embeddings = max_position_embeddings # Create cos and sin embeddings. inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim)) t = torch.arange(max_position_embeddings).float() freqs = torch.einsum("i,j->ij", t, inv_freq.float()) - emb = torch.cat((freqs, freqs), dim=-1) + if is_neox_style: + emb = torch.cat((freqs, freqs), dim=-1) + else: + emb = torch.repeat_interleave(freqs, 2, -1) cos = emb.cos().to(dtype=inv_freq.dtype) sin = emb.sin().to(dtype=inv_freq.dtype) self.register_buffer("cos_cached", cos, persistent=False) @@ -61,7 +76,6 @@ class RefRotaryEmbeddingNeox(nn.Module): query: torch.Tensor, # [num_tokens, num_heads, head_size] key: torch.Tensor, # [num_tokens, num_heads, head_size] ) -> Tuple[torch.Tensor, torch.Tensor]: - query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] key_rot = key[..., :self.rotary_dim] @@ -71,7 +85,9 @@ class RefRotaryEmbeddingNeox(nn.Module): key_rot = key_rot.transpose(0, 1) cos = F.embedding(positions, self.cos_cached) sin = F.embedding(positions, self.sin_cached) - query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin, + self.is_neox_style) query_rot = query_rot.transpose(0, 1).contiguous() key_rot = key_rot.transpose(0, 1).contiguous() @@ -82,6 +98,7 @@ class RefRotaryEmbeddingNeox(nn.Module): return query, key +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -89,7 +106,8 @@ class RefRotaryEmbeddingNeox(nn.Module): @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_rotary_embedding_neox( +def test_rotary_embedding( + is_neox_style: bool, num_tokens: int, num_heads: int, head_size: int, @@ -104,15 +122,15 @@ def test_rotary_embedding_neox( torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - positions = torch.randint(0, max_position, (num_tokens, ), device='cuda') + positions = torch.randint(0, max_position, (num_tokens, ), device="cuda") query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, - device='cuda') + device="cuda") key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, - device='cuda') + device="cuda") # Create the rotary embedding. inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) @@ -126,20 +144,22 @@ def test_rotary_embedding_neox( # Run the kernel. The kernel is in-place, so we need to clone the inputs. out_query = query.clone() out_key = key.clone() - pos_encoding_ops.rotary_embedding_neox( + pos_encoding_ops.rotary_embedding( positions, out_query, out_key, head_size, cos_sin_cache, + is_neox_style, ) # Run the reference implementation. - ref_rotary_embedding = RefRotaryEmbeddingNeox( + ref_rotary_embedding = RefRotaryEmbedding( dim=rotary_dim, + is_neox_style=is_neox_style, max_position_embeddings=max_position, base=base, - ).to(dtype=dtype, device='cuda') + ).to(dtype=dtype, device="cuda") ref_query, ref_key = ref_rotary_embedding( positions, query.view(num_tokens, num_heads, head_size), diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index c59208e2..29bfe328 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -242,7 +242,7 @@ class PagedAttention(nn.Module): class PagedAttentionWithRoPE(PagedAttention): - """PagedAttention with GPT-NeoX style rotary embedding.""" + """PagedAttention with rotary embedding.""" def __init__( self, @@ -253,8 +253,10 @@ class PagedAttentionWithRoPE(PagedAttention): max_position: int = 8192, base: int = 10000, num_kv_heads: Optional[int] = None, + is_neox_style: bool = True, ) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads) + self.is_neox_style = is_neox_style # Create the cos and sin cache. inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) @@ -303,12 +305,13 @@ class PagedAttentionWithRoPE(PagedAttention): # Apply rotary embedding to the query and key before passing them # to the attention op. - pos_encoding_ops.rotary_embedding_neox( + pos_encoding_ops.rotary_embedding( positions, query, key, self.head_size, self.cos_sin_cache, + self.is_neox_style, ) return super().forward( query, diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 35a1518e..0c9a7ef9 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -67,8 +67,11 @@ class GPTJAttention(nn.Module): scaling = self.head_size**-0.5 assert getattr(config, "rotary", True) assert config.rotary_dim % 2 == 0 - self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size, - scaling, config.rotary_dim) + self.attn = PagedAttentionWithRoPE(self.num_heads, + self.head_size, + scaling, + config.rotary_dim, + is_neox_style=False) self.warmup = False def forward(