From bfdcfa6a053c693800551bd1bd71acabbe1941e8 Mon Sep 17 00:00:00 2001 From: Seonghyeon Date: Thu, 29 Feb 2024 17:51:48 +0900 Subject: [PATCH] Support starcoder2 architecture (#3089) --- README.md | 1 + tests/models/test_models.py | 1 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/starcoder2.py | 310 ++++++++++++++++++ vllm/transformers_utils/config.py | 10 + vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/starcoder2.py | 127 +++++++ 7 files changed, 452 insertions(+) create mode 100644 vllm/model_executor/models/starcoder2.py create mode 100644 vllm/transformers_utils/configs/starcoder2.py diff --git a/README.md b/README.md index f771788d..064faa55 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) - Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.) - StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.) +- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.) - Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.) Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): diff --git a/tests/models/test_models.py b/tests/models/test_models.py index e44452e9..fb567e83 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -19,6 +19,7 @@ MODELS = [ "microsoft/phi-2", "stabilityai/stablelm-3b-4e1t", "allenai/OLMo-1B", + "bigcode/starcoder2-3b", ] diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index e4f3a785..75c2ae1e 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -45,6 +45,7 @@ _MODELS = { "RWForCausalLM": ("falcon", "FalconForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), + "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), } # Models not supported by ROCm. diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py new file mode 100644 index 00000000..1eda07b7 --- /dev/null +++ b/vllm/model_executor/models/starcoder2.py @@ -0,0 +1,310 @@ +# coding=utf-8 +# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Starcoder2 model.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) +from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_world_size +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +try: + from transformers import Starcoder2Config +except ImportError: + # fallback to PretrainedConfig + # NOTE: Please install transformers from source or use transformers>=4.39.0 + from transformers import PretrainedConfig as Starcoder2Config + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class Starcoder2Attention(nn.Module): + + def __init__(self, + config: Starcoder2Config, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.config = config + + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = config.rope_theta + self.max_position_embeddings = config.max_position_embeddings + self.use_bias = config.use_bias + self.sliding_window = config.sliding_window + + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=self.use_bias, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=self.use_bias, + linear_method=linear_method, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Starcoder2MLP(nn.Module): + + def __init__(self, + config: Starcoder2Config, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.c_fc = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=config.use_bias, + linear_method=linear_method, + ) + self.c_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=config.use_bias, + linear_method=linear_method, + ) + self.act = get_act_fn(config.hidden_act, + intermediate_size=config.intermediate_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class Starcoder2DecoderLayer(nn.Module): + + def __init__(self, + config: Starcoder2Config, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Starcoder2Attention(config, + linear_method=linear_method) + self.mlp = Starcoder2MLP(config, linear_method=linear_method) + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.norm_epsilon) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.norm_epsilon) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Starcoder2Model(nn.Module): + + def __init__(self, + config: Starcoder2Config, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # TODO: consider padding_idx (currently removed) + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.layers = nn.ModuleList([ + Starcoder2DecoderLayer(config, linear_method=linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer(positions, hidden_states, kv_caches[i], + input_metadata) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class Starcoder2ForCausalLM(nn.Module): + + def __init__(self, + config: Starcoder2Config, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.config = config + self.model = Starcoder2Model(config, linear_method=linear_method) + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + if config.tie_word_embeddings: + self.lm_head_weight = self.model.embed_tokens.weight + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + self.lm_head_weight = self.lm_head.weight + self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head_weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 6b0413f4..5e1f0439 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -9,6 +9,7 @@ _CONFIG_REGISTRY = { "mpt": MPTConfig, "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) + "starcoder2": Starcoder2Config, } @@ -16,6 +17,15 @@ def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None, code_revision: Optional[str] = None) -> PretrainedConfig: + # FIXME(woosuk): This is a temporary fix for StarCoder2. + # Remove this when the model is supported by HuggingFace transformers. + if "bigcode" in model and "starcoder2" in model: + config_class = _CONFIG_REGISTRY["starcoder2"] + config = config_class.from_pretrained(model, + revision=revision, + code_revision=code_revision) + return config + try: config = AutoConfig.from_pretrained( model, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index ef955f75..4966526f 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -4,9 +4,11 @@ from vllm.transformers_utils.configs.mpt import MPTConfig # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig +from vllm.transformers_utils.configs.starcoder2 import Starcoder2Config __all__ = [ "ChatGLMConfig", "MPTConfig", "RWConfig", + "Starcoder2Config", ] diff --git a/vllm/transformers_utils/configs/starcoder2.py b/vllm/transformers_utils/configs/starcoder2.py new file mode 100644 index 00000000..4c3b6b8d --- /dev/null +++ b/vllm/transformers_utils/configs/starcoder2.py @@ -0,0 +1,127 @@ +from transformers import PretrainedConfig + + +class Starcoder2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Starcoder2Model`]. It is used to instantiate a + Starcoder2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the [bigcode/starcoder2-7b_16k](https://huggingface.co/bigcode/starcoder2-7b_16k) model. + + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 49152): + Vocabulary size of the Starcoder2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Starcoder2Model`] + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 12288): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 30): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 24): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. Starcoder2's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + norm_epsilon (`float`, *optional*, defaults to 1e-05): + Epsilon value for the layer norm + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + bos_token_id (`int`, *optional*, defaults to 50256): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 50256): + The id of the "end-of-sequence" token. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `None` (no sliding window). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + residual_dropout (`float`, *optional*, defaults to 0.0): + Residual connection dropout value. + embedding_dropout (`float`, *optional*, defaults to 0.0): + Embedding dropout. + use_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias term on linear layers of the model. + + + ```python + >>> from transformers import Starcoder2Model, Starcoder2Config + + >>> # Initializing a Starcoder2 7B style configuration + >>> configuration = Starcoder2Config() + + >>> # Initializing a model from the Starcoder2 7B style configuration + >>> model = Starcoder2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "starcoder2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=49152, + hidden_size=3072, + intermediate_size=12288, + num_hidden_layers=30, + num_attention_heads=24, + num_key_value_heads=2, + hidden_act="gelu_pytorch_tanh", + max_position_embeddings=4096, + initializer_range=0.018042, + norm_epsilon=1e-5, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + rope_theta=10000.0, + sliding_window=None, + attention_dropout=0.0, + residual_dropout=0.0, + embedding_dropout=0.0, + use_bias=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.use_bias = use_bias + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.norm_epsilon = norm_epsilon + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.residual_dropout = residual_dropout + self.embedding_dropout = embedding_dropout + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + if self.architectures is None: + self.architectures = ['Starcoder2ForCausalLM']