From 6fc2a38b110f9ba6037b31ee016f20df32426877 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 20 Jul 2023 11:38:27 -0700 Subject: [PATCH] Add support for LLaMA-2 (#505) --- README.md | 3 +- csrc/pos_encoding_kernels.cu | 21 ++++++---- docs/source/models/supported_models.rst | 4 +- requirements.txt | 2 +- vllm/config.py | 7 +++- vllm/model_executor/layers/attention.py | 15 ++++--- vllm/model_executor/models/llama.py | 53 ++++++++++++++++--------- 7 files changed, 67 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index b9de3886..d26e7acb 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ Easy, fast, and cheap LLM serving for everyone --- *Latest News* 🔥 +- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command! - [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds. - [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai). @@ -46,7 +47,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) - GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.) - GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) -- LLaMA (`lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) +- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index ab58ef2e..0c89ab08 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -7,11 +7,12 @@ template __global__ void rotary_embedding_neox_kernel( const int64_t* __restrict__ positions, // [num_tokens] scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const int rot_dim, const int stride, const int num_heads, + const int num_kv_heads, const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; @@ -19,8 +20,8 @@ __global__ void rotary_embedding_neox_kernel( const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; const int embed_dim = rot_dim / 2; - const int n = num_heads * embed_dim; - for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { const int head_idx = i / embed_dim; const int token_head = token_idx * stride + head_idx * head_size; @@ -39,10 +40,12 @@ __global__ void rotary_embedding_neox_kernel( query[out_x] = q_x * cos - q_y * sin; query[out_y] = q_y * cos + q_x * sin; - const scalar_t k_x = key[token_head + x_index]; - const scalar_t k_y = key[token_head + y_index]; - key[out_x] = k_x * cos - k_y * sin; - key[out_y] = k_y * cos + k_x * sin; + if (head_idx < num_kv_heads) { + const scalar_t k_x = key[token_head + x_index]; + const scalar_t k_y = key[token_head + y_index]; + key[out_x] = k_x * cos - k_y * sin; + key[out_y] = k_y * cos + k_x * sin; + } } } @@ -51,13 +54,14 @@ __global__ void rotary_embedding_neox_kernel( void rotary_embedding_neox( torch::Tensor& positions, // [num_tokens] torch::Tensor& query, // [num_tokens, num_heads * head_size] - torch::Tensor& key, // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [num_tokens, num_kv_heads * head_size] int head_size, torch::Tensor& cos_sin_cache) // [max_position, rot_dim] { int num_tokens = query.size(0); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(1) / head_size; + int num_kv_heads = key.size(1) / head_size; int stride = query.stride(0); TORCH_CHECK(stride == key.stride(0)); @@ -78,6 +82,7 @@ void rotary_embedding_neox( rot_dim, stride, num_heads, + num_kv_heads, head_size); }); } diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c6ae8a2b..ea9cbcab 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -30,8 +30,8 @@ Alongside each architecture, we include some popular models that use it. - GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM - :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc. * - :code:`LlamaForCausalLM` - - LLaMA, Vicuna, Alpaca, Koala, Guanaco - - :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc. + - LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco + - :code:`meta-llama/Llama-2-13b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc. * - :code:`MPTForCausalLM` - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc. diff --git a/requirements.txt b/requirements.txt index 2dcefd21..42dfbeeb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ ray >= 2.5.1 sentencepiece # Required for LLaMA tokenizer. numpy torch >= 2.0.0 -transformers >= 4.28.0 # Required for LLaMA. +transformers >= 4.31.0 # Required for LLaMA-2. xformers >= 0.0.19 fastapi uvicorn diff --git a/vllm/config.py b/vllm/config.py index a3430f23..38d9108f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -100,7 +100,12 @@ class ModelConfig: return 1 # For Falcon: if getattr(self.hf_config, "n_head_kv", None) is not None: - return self.hf_config.n_head_kv + return (self.hf_config.n_head_kv // + parallel_config.tensor_parallel_size) + # For LLaMA-2: + if getattr(self.hf_config, "num_key_value_heads", None) is not None: + return (self.hf_config.num_key_value_heads // + parallel_config.tensor_parallel_size) total_num_attention_heads = self.hf_config.num_attention_heads return total_num_attention_heads // parallel_config.tensor_parallel_size diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 375f9f59..bd25ee7a 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -128,7 +128,8 @@ class PagedAttention(nn.Module): query: shape = [num_generation_tokens, num_heads, head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] - value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] + value_cache: shape = [num_blocks, num_kv_heads, head_size, + block_size] input_metadata: metadata for paged attention. """ block_size = value_cache.shape[3] @@ -241,8 +242,9 @@ class PagedAttentionWithRoPE(PagedAttention): rotary_dim: int, max_position: int = 8192, base: int = 10000, + num_kv_heads: Optional[int] = None, ) -> None: - super().__init__(num_heads, head_size, scale) + super().__init__(num_heads, head_size, scale, num_kv_heads) # Create the cos and sin cache. inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) @@ -276,11 +278,12 @@ class PagedAttentionWithRoPE(PagedAttention): Args: positions: shape = [num_tokens] query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_heads * head_size] - value: shape = [num_tokens, num_heads * head_size] - key_cache: shape = [num_blocks, num_heads, head_size/x, + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] - value_cache: shape = [num_blocks, num_heads, head_size, block_size] + value_cache: shape = [num_blocks, num_kv_heads, head_size, + block_size] input_metadata: metadata for paged attention. cache_event: event to wait for the cache operations to finish. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 65800207..93ab499e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -84,21 +84,26 @@ class LlamaAttention(nn.Module): self, hidden_size: int, num_heads: int, + num_kv_heads: int, ): super().__init__() self.hidden_size = hidden_size - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads - assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + assert self.total_num_kv_heads % tp_size == 0 + self.num_kv_heads = self.total_num_kv_heads // tp_size self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.qkv_proj = ColumnParallelLinear( hidden_size, - 3 * self.total_num_heads * self.head_dim, + (self.total_num_heads + 2 * self.total_num_kv_heads) * + self.head_dim, bias=False, gather_output=False, perform_initialization=False, @@ -113,7 +118,8 @@ class LlamaAttention(nn.Module): self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim, self.scaling, - rotary_dim=self.head_dim) + rotary_dim=self.head_dim, + num_kv_heads=self.num_kv_heads) def forward( self, @@ -124,7 +130,7 @@ class LlamaAttention(nn.Module): cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) k_cache, v_cache = kv_cache attn_output = self.attn(positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) @@ -140,6 +146,7 @@ class LlamaDecoderLayer(nn.Module): self.self_attn = LlamaAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -259,9 +266,19 @@ class LlamaForCausalLM(nn.Module): model_name_or_path: str, cache_dir: Optional[str] = None, use_np_cache: bool = False): - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tp_size = get_tensor_model_parallel_world_size() tensor_model_parallel_rank = get_tensor_model_parallel_rank() + q_proj_shard_size = (self.config.hidden_size // tp_size) + kv_proj_shard_size = (self.config.hidden_size // + self.config.num_attention_heads * + self.config.num_key_value_heads // tp_size) + attention_weight_specs = [ + # (weight_name, shard_size, offset) + ("q_proj", q_proj_shard_size, 0), + ("k_proj", kv_proj_shard_size, q_proj_shard_size), + ("v_proj", kv_proj_shard_size, + q_proj_shard_size + kv_proj_shard_size), + ] state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( @@ -272,8 +289,7 @@ class LlamaForCausalLM(nn.Module): if "embed_tokens" in name or "lm_head" in name: param = state_dict[name] # Consider padding in the vocab size. - padded_vocab_size = (param.shape[0] * - tensor_model_parallel_world_size) + padded_vocab_size = (param.shape[0] * tp_size) num_extra_rows = padded_vocab_size - self.config.vocab_size extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1]) @@ -281,18 +297,17 @@ class LlamaForCausalLM(nn.Module): loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) is_attention_weight = False - for stride_id, att_weight_name in enumerate( - ["q_proj", "k_proj", "v_proj"]): - if att_weight_name not in name: + for weight_name, shard_size, offset in attention_weight_specs: + if weight_name not in name: continue - param = state_dict[name.replace(att_weight_name, "qkv_proj")] - shard_size = param.shape[0] // 3 + param = state_dict[name.replace(weight_name, "qkv_proj")] + loaded_weight = loaded_weight[ shard_size * tensor_model_parallel_rank:shard_size * (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] + param_slice = param.data[offset:offset + shard_size] assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) is_attention_weight = True break