Add support for LLaMA-2 (#505)
This commit is contained in:
parent
c487a221ee
commit
6fc2a38b11
@ -17,6 +17,7 @@ Easy, fast, and cheap LLM serving for everyone
|
|||||||
---
|
---
|
||||||
|
|
||||||
*Latest News* 🔥
|
*Latest News* 🔥
|
||||||
|
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
|
||||||
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
|
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
|
||||||
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
||||||
|
|
||||||
@ -46,7 +47,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit
|
|||||||
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
||||||
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
||||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
||||||
- LLaMA (`lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||||
|
|
||||||
|
|||||||
@ -7,11 +7,12 @@ template<typename scalar_t>
|
|||||||
__global__ void rotary_embedding_neox_kernel(
|
__global__ void rotary_embedding_neox_kernel(
|
||||||
const int64_t* __restrict__ positions, // [num_tokens]
|
const int64_t* __restrict__ positions, // [num_tokens]
|
||||||
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||||
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||||
const int rot_dim,
|
const int rot_dim,
|
||||||
const int stride,
|
const int stride,
|
||||||
const int num_heads,
|
const int num_heads,
|
||||||
|
const int num_kv_heads,
|
||||||
const int head_size) {
|
const int head_size) {
|
||||||
// Each thread block is responsible for one token.
|
// Each thread block is responsible for one token.
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
@ -19,8 +20,8 @@ __global__ void rotary_embedding_neox_kernel(
|
|||||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||||
|
|
||||||
const int embed_dim = rot_dim / 2;
|
const int embed_dim = rot_dim / 2;
|
||||||
const int n = num_heads * embed_dim;
|
const int nq = num_heads * embed_dim;
|
||||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int token_head = token_idx * stride + head_idx * head_size;
|
const int token_head = token_idx * stride + head_idx * head_size;
|
||||||
|
|
||||||
@ -39,10 +40,12 @@ __global__ void rotary_embedding_neox_kernel(
|
|||||||
query[out_x] = q_x * cos - q_y * sin;
|
query[out_x] = q_x * cos - q_y * sin;
|
||||||
query[out_y] = q_y * cos + q_x * sin;
|
query[out_y] = q_y * cos + q_x * sin;
|
||||||
|
|
||||||
const scalar_t k_x = key[token_head + x_index];
|
if (head_idx < num_kv_heads) {
|
||||||
const scalar_t k_y = key[token_head + y_index];
|
const scalar_t k_x = key[token_head + x_index];
|
||||||
key[out_x] = k_x * cos - k_y * sin;
|
const scalar_t k_y = key[token_head + y_index];
|
||||||
key[out_y] = k_y * cos + k_x * sin;
|
key[out_x] = k_x * cos - k_y * sin;
|
||||||
|
key[out_y] = k_y * cos + k_x * sin;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,13 +54,14 @@ __global__ void rotary_embedding_neox_kernel(
|
|||||||
void rotary_embedding_neox(
|
void rotary_embedding_neox(
|
||||||
torch::Tensor& positions, // [num_tokens]
|
torch::Tensor& positions, // [num_tokens]
|
||||||
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
||||||
torch::Tensor& key, // [num_tokens, num_heads * head_size]
|
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
|
||||||
int head_size,
|
int head_size,
|
||||||
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
|
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
|
||||||
{
|
{
|
||||||
int num_tokens = query.size(0);
|
int num_tokens = query.size(0);
|
||||||
int rot_dim = cos_sin_cache.size(1);
|
int rot_dim = cos_sin_cache.size(1);
|
||||||
int num_heads = query.size(1) / head_size;
|
int num_heads = query.size(1) / head_size;
|
||||||
|
int num_kv_heads = key.size(1) / head_size;
|
||||||
int stride = query.stride(0);
|
int stride = query.stride(0);
|
||||||
TORCH_CHECK(stride == key.stride(0));
|
TORCH_CHECK(stride == key.stride(0));
|
||||||
|
|
||||||
@ -78,6 +82,7 @@ void rotary_embedding_neox(
|
|||||||
rot_dim,
|
rot_dim,
|
||||||
stride,
|
stride,
|
||||||
num_heads,
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
head_size);
|
head_size);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@ -30,8 +30,8 @@ Alongside each architecture, we include some popular models that use it.
|
|||||||
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
|
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
|
||||||
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
|
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
|
||||||
* - :code:`LlamaForCausalLM`
|
* - :code:`LlamaForCausalLM`
|
||||||
- LLaMA, Vicuna, Alpaca, Koala, Guanaco
|
- LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
|
||||||
- :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc.
|
- :code:`meta-llama/Llama-2-13b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc.
|
||||||
* - :code:`MPTForCausalLM`
|
* - :code:`MPTForCausalLM`
|
||||||
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||||
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
||||||
|
|||||||
@ -4,7 +4,7 @@ ray >= 2.5.1
|
|||||||
sentencepiece # Required for LLaMA tokenizer.
|
sentencepiece # Required for LLaMA tokenizer.
|
||||||
numpy
|
numpy
|
||||||
torch >= 2.0.0
|
torch >= 2.0.0
|
||||||
transformers >= 4.28.0 # Required for LLaMA.
|
transformers >= 4.31.0 # Required for LLaMA-2.
|
||||||
xformers >= 0.0.19
|
xformers >= 0.0.19
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn
|
||||||
|
|||||||
@ -100,7 +100,12 @@ class ModelConfig:
|
|||||||
return 1
|
return 1
|
||||||
# For Falcon:
|
# For Falcon:
|
||||||
if getattr(self.hf_config, "n_head_kv", None) is not None:
|
if getattr(self.hf_config, "n_head_kv", None) is not None:
|
||||||
return self.hf_config.n_head_kv
|
return (self.hf_config.n_head_kv //
|
||||||
|
parallel_config.tensor_parallel_size)
|
||||||
|
# For LLaMA-2:
|
||||||
|
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
|
||||||
|
return (self.hf_config.num_key_value_heads //
|
||||||
|
parallel_config.tensor_parallel_size)
|
||||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||||
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
|
|||||||
@ -128,7 +128,8 @@ class PagedAttention(nn.Module):
|
|||||||
query: shape = [num_generation_tokens, num_heads, head_size]
|
query: shape = [num_generation_tokens, num_heads, head_size]
|
||||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||||
block_size, x]
|
block_size, x]
|
||||||
value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size]
|
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||||
|
block_size]
|
||||||
input_metadata: metadata for paged attention.
|
input_metadata: metadata for paged attention.
|
||||||
"""
|
"""
|
||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
@ -241,8 +242,9 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
max_position: int = 8192,
|
max_position: int = 8192,
|
||||||
base: int = 10000,
|
base: int = 10000,
|
||||||
|
num_kv_heads: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(num_heads, head_size, scale)
|
super().__init__(num_heads, head_size, scale, num_kv_heads)
|
||||||
|
|
||||||
# Create the cos and sin cache.
|
# Create the cos and sin cache.
|
||||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||||
@ -276,11 +278,12 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
Args:
|
Args:
|
||||||
positions: shape = [num_tokens]
|
positions: shape = [num_tokens]
|
||||||
query: shape = [num_tokens, num_heads * head_size]
|
query: shape = [num_tokens, num_heads * head_size]
|
||||||
key: shape = [num_tokens, num_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
key_cache: shape = [num_blocks, num_heads, head_size/x,
|
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||||
block_size, x]
|
block_size, x]
|
||||||
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
|
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||||
|
block_size]
|
||||||
input_metadata: metadata for paged attention.
|
input_metadata: metadata for paged attention.
|
||||||
cache_event: event to wait for the cache operations to finish.
|
cache_event: event to wait for the cache operations to finish.
|
||||||
|
|
||||||
|
|||||||
@ -84,21 +84,26 @@ class LlamaAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
tensor_model_parallel_world_size = (
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
get_tensor_model_parallel_world_size())
|
|
||||||
self.total_num_heads = num_heads
|
self.total_num_heads = num_heads
|
||||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
assert self.total_num_heads % tp_size == 0
|
||||||
self.num_heads = (self.total_num_heads //
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
tensor_model_parallel_world_size)
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||||
self.head_dim = hidden_size // self.total_num_heads
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
self.qkv_proj = ColumnParallelLinear(
|
self.qkv_proj = ColumnParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
3 * self.total_num_heads * self.head_dim,
|
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||||
|
self.head_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
gather_output=False,
|
gather_output=False,
|
||||||
perform_initialization=False,
|
perform_initialization=False,
|
||||||
@ -113,7 +118,8 @@ class LlamaAttention(nn.Module):
|
|||||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
rotary_dim=self.head_dim)
|
rotary_dim=self.head_dim,
|
||||||
|
num_kv_heads=self.num_kv_heads)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -124,7 +130,7 @@ class LlamaAttention(nn.Module):
|
|||||||
cache_event: Optional[torch.cuda.Event],
|
cache_event: Optional[torch.cuda.Event],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
k_cache, v_cache = kv_cache
|
k_cache, v_cache = kv_cache
|
||||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||||
input_metadata, cache_event)
|
input_metadata, cache_event)
|
||||||
@ -140,6 +146,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
self.self_attn = LlamaAttention(
|
self.self_attn = LlamaAttention(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=config.num_key_value_heads,
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(
|
self.mlp = LlamaMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@ -259,9 +266,19 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
use_np_cache: bool = False):
|
||||||
tensor_model_parallel_world_size = (
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
get_tensor_model_parallel_world_size())
|
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||||
|
kv_proj_shard_size = (self.config.hidden_size //
|
||||||
|
self.config.num_attention_heads *
|
||||||
|
self.config.num_key_value_heads // tp_size)
|
||||||
|
attention_weight_specs = [
|
||||||
|
# (weight_name, shard_size, offset)
|
||||||
|
("q_proj", q_proj_shard_size, 0),
|
||||||
|
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
||||||
|
("v_proj", kv_proj_shard_size,
|
||||||
|
q_proj_shard_size + kv_proj_shard_size),
|
||||||
|
]
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
@ -272,8 +289,7 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
if "embed_tokens" in name or "lm_head" in name:
|
if "embed_tokens" in name or "lm_head" in name:
|
||||||
param = state_dict[name]
|
param = state_dict[name]
|
||||||
# Consider padding in the vocab size.
|
# Consider padding in the vocab size.
|
||||||
padded_vocab_size = (param.shape[0] *
|
padded_vocab_size = (param.shape[0] * tp_size)
|
||||||
tensor_model_parallel_world_size)
|
|
||||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||||
extra_rows = torch.empty(num_extra_rows,
|
extra_rows = torch.empty(num_extra_rows,
|
||||||
loaded_weight.shape[1])
|
loaded_weight.shape[1])
|
||||||
@ -281,18 +297,17 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||||
|
|
||||||
is_attention_weight = False
|
is_attention_weight = False
|
||||||
for stride_id, att_weight_name in enumerate(
|
for weight_name, shard_size, offset in attention_weight_specs:
|
||||||
["q_proj", "k_proj", "v_proj"]):
|
if weight_name not in name:
|
||||||
if att_weight_name not in name:
|
|
||||||
continue
|
continue
|
||||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||||
shard_size = param.shape[0] // 3
|
|
||||||
loaded_weight = loaded_weight[
|
loaded_weight = loaded_weight[
|
||||||
shard_size * tensor_model_parallel_rank:shard_size *
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
(tensor_model_parallel_rank + 1)]
|
(tensor_model_parallel_rank + 1)]
|
||||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
param_slice = param.data[offset:offset + shard_size]
|
||||||
(stride_id + 1)]
|
|
||||||
assert param_slice.shape == loaded_weight.shape
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
|
||||||
param_slice.copy_(loaded_weight)
|
param_slice.copy_(loaded_weight)
|
||||||
is_attention_weight = True
|
is_attention_weight = True
|
||||||
break
|
break
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user