Implement custom kernel for LLaMA rotary embedding (#14)

This commit is contained in:
Woosuk Kwon 2023-03-30 11:04:21 -07:00 committed by GitHub
parent 80a2f812f1
commit 88c0268a18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 318 additions and 69 deletions

View File

@ -6,13 +6,14 @@ import torch.nn as nn
from cacheflow import attention_ops
from cacheflow import cache_ops
from cacheflow import pos_encoding_ops
from cacheflow.models import InputMetadata
class OPTCacheFlowAttention(nn.Module):
class GPTCacheFlowAttention(nn.Module):
def __init__(self, scale: float) -> None:
super(OPTCacheFlowAttention, self).__init__()
super().__init__()
self.scale = float(scale)
self.flash_attn = FlashAttention(softmax_scale=self.scale)
@ -136,3 +137,71 @@ class OPTCacheFlowAttention(nn.Module):
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
return output.view(-1, num_heads * head_size)
class OPTCacheFlowAttention(GPTCacheFlowAttention):
"""OPT uses the same attention mechanism as GPT."""
def __init__(self, scale: float) -> None:
super().__init__(scale)
class LlamaCacheFlowAttention(GPTCacheFlowAttention):
"""Llama uses GPT-NeoX style rotary embedding."""
def __init__(
self,
scale: float,
head_size: int,
max_position: int = 8192,
base: int = 10000,
) -> None:
super().__init__(scale)
# Create the cos and sin cache.
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
# FIXME(woosuk): This assumes that we configure the default dtype when
# initializing the model. Make it more robust.
torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype)
# Embedding size: [max_position, head_size]
self.register_buffer('cos_sin_cache', cache, persistent=False)
def forward(
self,
positions: torch.LongTensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
out_query = torch.empty_like(query)
out_key = torch.empty_like(key)
pos_encoding_ops.rotary_embedding_neox(
out_query,
out_key,
positions,
query,
key,
self.cos_sin_cache,
)
return super().forward(
out_query,
out_key,
value,
key_cache,
value_cache,
input_metadata,
cache_event,
)

View File

@ -8,12 +8,10 @@ from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from transformers import LlamaConfig
from transformers import PreTrainedModel
from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
@ -41,48 +39,8 @@ class LlamaRMSNorm(nn.Module):
return self.weight * hidden_states
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000):
super().__init__()
self.max_position_embeddings = max_position_embeddings
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
self.register_buffer("inv_freq", inv_freq)
# Create cos and sin embeddings.
t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, self.inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=self.inv_freq.dtype)
sin = emb.sin().to(dtype=self.inv_freq.dtype)
self.register_buffer("cos_cached", cos, persistent=False)
self.register_buffer("sin_cached", sin, persistent=False)
def forward(
self,
positions: torch.LongTensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
cos = F.embedding(positions, self.cos_cached)
sin = F.embedding(positions, self.sin_cached)
return cos, sin
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
# TODO: Optimize.
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
@ -156,9 +114,7 @@ class LlamaAttention(nn.Module):
input_is_parallel=True,
perform_initialization=False,
)
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)
# FIXME(woosuk): Rename this.
self.attn = OPTCacheFlowAttention(scale=self.scaling)
self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim)
def forward(
self,
@ -171,19 +127,9 @@ class LlamaAttention(nn.Module):
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)
# Apply rotrary embedding.
# TODO: Optimize.
q = q.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
k = k.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
cos, sin = self.rotary_emb(positions)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q = q.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim)
k = k.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim)
key_cache, value_cache = kv_cache
k_cache, v_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output

View File

@ -165,8 +165,7 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
self.head_size = config.hidden_size // self.num_heads
self.ffn_size = config.intermediate_size
self.vocab_size = config.vocab_size
# FIXME
self.max_position = 2048
self.max_position = 8192
def _get_param_size(self) -> int:
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size

View File

@ -51,7 +51,7 @@ class OPTAttention(nn.Module):
assert num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim**-0.5
self.scaling = self.head_dim ** -0.5
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
self.k_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
@ -66,7 +66,6 @@ class OPTAttention(nn.Module):
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
input_is_parallel=True,
perform_initialization=False)
self.attn = OPTCacheFlowAttention(scale=self.scaling)
def forward(

View File

@ -12,7 +12,7 @@ from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_pa
class Sampler(nn.Module):
def __init__(self) -> None:
super(Sampler, self).__init__()
super().__init__()
def forward(
self,

View File

@ -122,13 +122,13 @@ void reshape_and_cache(
torch::Tensor& value_cache,
torch::Tensor& slot_mapping) {
int num_tokens = key.size(0);
int head_num = key.size(1);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
dim3 grid(num_tokens);
dim3 block(std::min(head_num * head_size, 512));
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
key.scalar_type(),
@ -140,7 +140,7 @@ void reshape_and_cache(
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int>(),
head_num,
num_heads,
head_size,
block_size,
x);

16
csrc/pos_encoding.cpp Normal file
View File

@ -0,0 +1,16 @@
#include <torch/extension.h>
void rotary_embedding_neox(
torch::Tensor& out_query,
torch::Tensor& out_key,
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
torch::Tensor& cos_sin_cache);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"rotary_embedding_neox",
&rotary_embedding_neox,
"Apply GPT-NeoX style rotary embedding to query and key");
}

View File

@ -0,0 +1,83 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
namespace cacheflow {
template<typename scalar_t>
__global__ void rotary_embedding_neox_kernel(
scalar_t* __restrict__ out_query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ out_key, // [num_tokens, num_heads, head_size]
const int64_t* __restrict__ positions, // [num_tokens]
const scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2]
const int num_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * head_size;
const int embed_dim = head_size / 2;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int idx = token_idx * n + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int token_head = token_idx * n + head_idx * head_size;
const bool is_first_half = head_offset < embed_dim;
const int rot_offset = head_offset % embed_dim;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);
const scalar_t q_x = __ldg(query + token_head + x_index);
const scalar_t q_y = __ldg(query + token_head + y_index);
const scalar_t q_cos = is_first_half ? q_x : q_y;
const scalar_t q_sin = is_first_half ? -q_y : q_x;
out_query[idx] = q_cos * cos + q_sin * sin;
const scalar_t k_x = __ldg(key + token_head + x_index);
const scalar_t k_y = __ldg(key + token_head + y_index);
const scalar_t k_cos = is_first_half ? k_x : k_y;
const scalar_t k_sin = is_first_half ? -k_y : k_x;
out_key[idx] = k_cos * cos + k_sin * sin;
}
}
} // namespace cacheflow
void rotary_embedding_neox(
torch::Tensor& out_query, // [num_tokens, num_heads * head_size]
torch::Tensor& out_key, // [num_tokens, num_heads * head_size]
torch::Tensor& positions, // [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_heads * head_size]
torch::Tensor& cos_sin_cache) // [max_position, head_size]
{
int num_tokens = query.size(0);
int head_size = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size;
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
query.scalar_type(),
"rotary_embedding_neox",
[&] {
cacheflow::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
out_query.data_ptr<scalar_t>(),
out_key.data_ptr<scalar_t>(),
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
num_heads,
head_size);
});
}

View File

@ -23,6 +23,14 @@ attention_extension = cpp_extension.CUDAExtension(
)
ext_modules.append(attention_extension)
# Positional encodings.
positional_encoding_extension = cpp_extension.CUDAExtension(
name='cacheflow.pos_encoding_ops',
sources=['csrc/pos_encoding.cpp', 'csrc/pos_encoding_kernels.cu'],
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
)
ext_modules.append(positional_encoding_extension)
setuptools.setup(
name='cacheflow',
ext_modules=ext_modules,

View File

@ -0,0 +1,129 @@
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from cacheflow import pos_encoding_ops
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class RefRotaryEmbeddingNeox(nn.Module):
"""Reference implementation of the GPT-NeoX style rotary embedding."""
def __init__(
self,
dim: int,
max_position_embeddings: int = 2048,
base: int = 10000,
) -> None:
super().__init__()
self.max_position_embeddings = max_position_embeddings
# Create cos and sin embeddings.
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=inv_freq.dtype)
sin = emb.sin().to(dtype=inv_freq.dtype)
self.register_buffer("cos_cached", cos, persistent=False)
self.register_buffer("sin_cached", sin, persistent=False)
def forward(
self,
positions: torch.LongTensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads, head_size]
key: torch.Tensor, # [num_tokens, num_heads, head_size]
) -> Tuple[torch.Tensor, torch.Tensor]:
cos = F.embedding(positions, self.cos_cached)
sin = F.embedding(positions, self.sin_cached)
query = query.transpose(0, 1)
key = key.transpose(0, 1)
query, key = apply_rotary_pos_emb(query, key, cos, sin)
query = query.transpose(0, 1).contiguous()
key = key.transpose(0, 1).contiguous()
# Output query/key shape: [num_tokens, num_tokens, head_size]
return query, key
@torch.inference_mode()
def test_rotary_embedding_neox(
num_tokens: int,
num_heads: int,
head_size: int,
max_position: int,
dtype: torch.dtype,
base: int = 10000,
) -> None:
positions = torch.randint(0, max_position, (num_tokens,), device='cuda')
query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
# Create the rotary embedding.
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
cos = freqs.cos()
sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1)
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
# Run the kernel.
out_query = torch.empty_like(query)
out_key = torch.empty_like(key)
pos_encoding_ops.rotary_embedding_neox(
out_query,
out_key,
positions,
query,
key,
cos_sin_cache,
)
# Run the reference implementation.
ref_rotary_embedding = RefRotaryEmbeddingNeox(
dim=head_size,
max_position_embeddings=max_position,
base=base,
).to(dtype=dtype, device='cuda')
ref_query, ref_key = ref_rotary_embedding(
positions,
query.view(num_tokens, num_heads, head_size),
key.view(num_tokens, num_heads, head_size),
)
ref_query = ref_query.view(num_tokens, num_heads * head_size)
ref_key = ref_key.view(num_tokens, num_heads * head_size)
# Compare the results.
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
if __name__ == '__main__':
for dtype in [torch.half, torch.float]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Running tests for head_size={head_size} and dtype={dtype}')
test_rotary_embedding_neox(
num_tokens=2145,
num_heads=5,
head_size=head_size,
max_position=8192,
dtype=dtype,
)