Implement custom kernel for LLaMA rotary embedding (#14)
This commit is contained in:
parent
80a2f812f1
commit
88c0268a18
@ -6,13 +6,14 @@ import torch.nn as nn
|
||||
|
||||
from cacheflow import attention_ops
|
||||
from cacheflow import cache_ops
|
||||
from cacheflow import pos_encoding_ops
|
||||
from cacheflow.models import InputMetadata
|
||||
|
||||
|
||||
class OPTCacheFlowAttention(nn.Module):
|
||||
class GPTCacheFlowAttention(nn.Module):
|
||||
|
||||
def __init__(self, scale: float) -> None:
|
||||
super(OPTCacheFlowAttention, self).__init__()
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
|
||||
self.flash_attn = FlashAttention(softmax_scale=self.scale)
|
||||
@ -136,3 +137,71 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
# Reshape the output tensor.
|
||||
# NOTE(woosuk): The output tensor may include paddings.
|
||||
return output.view(-1, num_heads * head_size)
|
||||
|
||||
|
||||
class OPTCacheFlowAttention(GPTCacheFlowAttention):
|
||||
"""OPT uses the same attention mechanism as GPT."""
|
||||
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__(scale)
|
||||
|
||||
|
||||
class LlamaCacheFlowAttention(GPTCacheFlowAttention):
|
||||
"""Llama uses GPT-NeoX style rotary embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scale: float,
|
||||
head_size: int,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
super().__init__(scale)
|
||||
|
||||
# Create the cos and sin cache.
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
|
||||
# FIXME(woosuk): This assumes that we configure the default dtype when
|
||||
# initializing the model. Make it more robust.
|
||||
torch_dtype = torch.get_default_dtype()
|
||||
cache = cache.to(torch_dtype)
|
||||
# Embedding size: [max_position, head_size]
|
||||
self.register_buffer('cos_sin_cache', cache, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.LongTensor, # [num_tokens]
|
||||
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||
# Apply rotary embedding to the query and key before passing them
|
||||
# to the attention op.
|
||||
out_query = torch.empty_like(query)
|
||||
out_key = torch.empty_like(key)
|
||||
pos_encoding_ops.rotary_embedding_neox(
|
||||
out_query,
|
||||
out_key,
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.cos_sin_cache,
|
||||
)
|
||||
return super().forward(
|
||||
out_query,
|
||||
out_key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
|
||||
@ -8,12 +8,10 @@ from typing import Dict, List, Optional, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import LlamaConfig
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.attention import OPTCacheFlowAttention
|
||||
from cacheflow.models.attention import LlamaCacheFlowAttention
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
@ -41,48 +39,8 @@ class LlamaRMSNorm(nn.Module):
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class LlamaRotaryEmbedding(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000):
|
||||
super().__init__()
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
# Create cos and sin embeddings.
|
||||
t = torch.arange(max_position_embeddings).float()
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq.float())
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos().to(dtype=self.inv_freq.dtype)
|
||||
sin = emb.sin().to(dtype=self.inv_freq.dtype)
|
||||
self.register_buffer("cos_cached", cos, persistent=False)
|
||||
self.register_buffer("sin_cached", sin, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.LongTensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = F.embedding(positions, self.cos_cached)
|
||||
sin = F.embedding(positions, self.sin_cached)
|
||||
return cos, sin
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin):
|
||||
# TODO: Optimize.
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@ -156,9 +114,7 @@ class LlamaAttention(nn.Module):
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)
|
||||
# FIXME(woosuk): Rename this.
|
||||
self.attn = OPTCacheFlowAttention(scale=self.scaling)
|
||||
self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -171,19 +127,9 @@ class LlamaAttention(nn.Module):
|
||||
q, _ = self.q_proj(hidden_states)
|
||||
k, _ = self.k_proj(hidden_states)
|
||||
v, _ = self.v_proj(hidden_states)
|
||||
|
||||
# Apply rotrary embedding.
|
||||
# TODO: Optimize.
|
||||
q = q.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
|
||||
k = k.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
|
||||
cos, sin = self.rotary_emb(positions)
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||
q = q.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim)
|
||||
k = k.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim)
|
||||
|
||||
key_cache, value_cache = kv_cache
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@ -165,8 +165,7 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
self.head_size = config.hidden_size // self.num_heads
|
||||
self.ffn_size = config.intermediate_size
|
||||
self.vocab_size = config.vocab_size
|
||||
# FIXME
|
||||
self.max_position = 2048
|
||||
self.max_position = 8192
|
||||
|
||||
def _get_param_size(self) -> int:
|
||||
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
|
||||
|
||||
@ -51,7 +51,7 @@ class OPTAttention(nn.Module):
|
||||
assert num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = embed_dim // total_num_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
|
||||
self.k_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
@ -66,7 +66,6 @@ class OPTAttention(nn.Module):
|
||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
|
||||
self.attn = OPTCacheFlowAttention(scale=self.scaling)
|
||||
|
||||
def forward(
|
||||
|
||||
@ -12,7 +12,7 @@ from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_pa
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super(Sampler, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -122,13 +122,13 @@ void reshape_and_cache(
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping) {
|
||||
int num_tokens = key.size(0);
|
||||
int head_num = key.size(1);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(head_num * head_size, 512));
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
key.scalar_type(),
|
||||
@ -140,7 +140,7 @@ void reshape_and_cache(
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int>(),
|
||||
head_num,
|
||||
num_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
x);
|
||||
|
||||
16
csrc/pos_encoding.cpp
Normal file
16
csrc/pos_encoding.cpp
Normal file
@ -0,0 +1,16 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void rotary_embedding_neox(
|
||||
torch::Tensor& out_query,
|
||||
torch::Tensor& out_key,
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
torch::Tensor& cos_sin_cache);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"rotary_embedding_neox",
|
||||
&rotary_embedding_neox,
|
||||
"Apply GPT-NeoX style rotary embedding to query and key");
|
||||
}
|
||||
83
csrc/pos_encoding_kernels.cu
Normal file
83
csrc/pos_encoding_kernels.cu
Normal file
@ -0,0 +1,83 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
namespace cacheflow {
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void rotary_embedding_neox_kernel(
|
||||
scalar_t* __restrict__ out_query, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ out_key, // [num_tokens, num_heads, head_size]
|
||||
const int64_t* __restrict__ positions, // [num_tokens]
|
||||
const scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2]
|
||||
const int num_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int64_t pos = positions[token_idx];
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * head_size;
|
||||
|
||||
const int embed_dim = head_size / 2;
|
||||
const int n = num_heads * head_size;
|
||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||
const int idx = token_idx * n + i;
|
||||
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int token_head = token_idx * n + head_idx * head_size;
|
||||
|
||||
const bool is_first_half = head_offset < embed_dim;
|
||||
const int rot_offset = head_offset % embed_dim;
|
||||
const int x_index = rot_offset;
|
||||
const int y_index = embed_dim + rot_offset;
|
||||
|
||||
const scalar_t cos = __ldg(cache_ptr + x_index);
|
||||
const scalar_t sin = __ldg(cache_ptr + y_index);
|
||||
|
||||
const scalar_t q_x = __ldg(query + token_head + x_index);
|
||||
const scalar_t q_y = __ldg(query + token_head + y_index);
|
||||
const scalar_t q_cos = is_first_half ? q_x : q_y;
|
||||
const scalar_t q_sin = is_first_half ? -q_y : q_x;
|
||||
out_query[idx] = q_cos * cos + q_sin * sin;
|
||||
|
||||
const scalar_t k_x = __ldg(key + token_head + x_index);
|
||||
const scalar_t k_y = __ldg(key + token_head + y_index);
|
||||
const scalar_t k_cos = is_first_half ? k_x : k_y;
|
||||
const scalar_t k_sin = is_first_half ? -k_y : k_x;
|
||||
out_key[idx] = k_cos * cos + k_sin * sin;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
|
||||
void rotary_embedding_neox(
|
||||
torch::Tensor& out_query, // [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& out_key, // [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& positions, // [num_tokens]
|
||||
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& cos_sin_cache) // [max_position, head_size]
|
||||
{
|
||||
int num_tokens = query.size(0);
|
||||
int head_size = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(1) / head_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
query.scalar_type(),
|
||||
"rotary_embedding_neox",
|
||||
[&] {
|
||||
cacheflow::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out_query.data_ptr<scalar_t>(),
|
||||
out_key.data_ptr<scalar_t>(),
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
num_heads,
|
||||
head_size);
|
||||
});
|
||||
}
|
||||
8
setup.py
8
setup.py
@ -23,6 +23,14 @@ attention_extension = cpp_extension.CUDAExtension(
|
||||
)
|
||||
ext_modules.append(attention_extension)
|
||||
|
||||
# Positional encodings.
|
||||
positional_encoding_extension = cpp_extension.CUDAExtension(
|
||||
name='cacheflow.pos_encoding_ops',
|
||||
sources=['csrc/pos_encoding.cpp', 'csrc/pos_encoding_kernels.cu'],
|
||||
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
|
||||
)
|
||||
ext_modules.append(positional_encoding_extension)
|
||||
|
||||
setuptools.setup(
|
||||
name='cacheflow',
|
||||
ext_modules=ext_modules,
|
||||
|
||||
129
tests/kernels/pos_encoding.py
Normal file
129
tests/kernels/pos_encoding.py
Normal file
@ -0,0 +1,129 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from cacheflow import pos_encoding_ops
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class RefRotaryEmbeddingNeox(nn.Module):
|
||||
"""Reference implementation of the GPT-NeoX style rotary embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# Create cos and sin embeddings.
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
|
||||
t = torch.arange(max_position_embeddings).float()
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos().to(dtype=inv_freq.dtype)
|
||||
sin = emb.sin().to(dtype=inv_freq.dtype)
|
||||
self.register_buffer("cos_cached", cos, persistent=False)
|
||||
self.register_buffer("sin_cached", sin, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.LongTensor, # [num_tokens]
|
||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = F.embedding(positions, self.cos_cached)
|
||||
sin = F.embedding(positions, self.sin_cached)
|
||||
query = query.transpose(0, 1)
|
||||
key = key.transpose(0, 1)
|
||||
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
||||
query = query.transpose(0, 1).contiguous()
|
||||
key = key.transpose(0, 1).contiguous()
|
||||
# Output query/key shape: [num_tokens, num_tokens, head_size]
|
||||
return query, key
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding_neox(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
max_position: int,
|
||||
dtype: torch.dtype,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
positions = torch.randint(0, max_position, (num_tokens,), device='cuda')
|
||||
query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
|
||||
key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
|
||||
|
||||
# Create the rotary embedding.
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
|
||||
|
||||
# Run the kernel.
|
||||
out_query = torch.empty_like(query)
|
||||
out_key = torch.empty_like(key)
|
||||
pos_encoding_ops.rotary_embedding_neox(
|
||||
out_query,
|
||||
out_key,
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
cos_sin_cache,
|
||||
)
|
||||
|
||||
# Run the reference implementation.
|
||||
ref_rotary_embedding = RefRotaryEmbeddingNeox(
|
||||
dim=head_size,
|
||||
max_position_embeddings=max_position,
|
||||
base=base,
|
||||
).to(dtype=dtype, device='cuda')
|
||||
ref_query, ref_key = ref_rotary_embedding(
|
||||
positions,
|
||||
query.view(num_tokens, num_heads, head_size),
|
||||
key.view(num_tokens, num_heads, head_size),
|
||||
)
|
||||
ref_query = ref_query.view(num_tokens, num_heads * head_size)
|
||||
ref_key = ref_key.view(num_tokens, num_heads * head_size)
|
||||
|
||||
# Compare the results.
|
||||
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
|
||||
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Running tests for head_size={head_size} and dtype={dtype}')
|
||||
test_rotary_embedding_neox(
|
||||
num_tokens=2145,
|
||||
num_heads=5,
|
||||
head_size=head_size,
|
||||
max_position=8192,
|
||||
dtype=dtype,
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user