diff --git a/README.md b/README.md index f67e6b6c..30746f0c 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit - Baichuan-7B (`baichuan-inc/Baichuan-7B`) - BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) +- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) - GPT-2 (`gpt2`, `gpt2-xl`, etc.) - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) - GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.) diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 0c89ab08..98939fc7 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -10,7 +10,8 @@ __global__ void rotary_embedding_neox_kernel( scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const int rot_dim, - const int stride, + const int query_stride, + const int key_stride, const int num_heads, const int num_kv_heads, const int head_size) { @@ -23,14 +24,14 @@ __global__ void rotary_embedding_neox_kernel( const int nq = num_heads * embed_dim; for (int i = threadIdx.x; i < nq; i += blockDim.x) { const int head_idx = i / embed_dim; - const int token_head = token_idx * stride + head_idx * head_size; + const int token_head = token_idx * query_stride + head_idx * head_size; const int rot_offset = i % embed_dim; const int x_index = rot_offset; const int y_index = embed_dim + rot_offset; - const int out_x = token_idx * stride + head_idx * head_size + x_index; - const int out_y = token_idx * stride + head_idx * head_size + y_index; + const int out_x = token_idx * query_stride + head_idx * head_size + x_index; + const int out_y = token_idx * query_stride + head_idx * head_size + y_index; const scalar_t cos = __ldg(cache_ptr + x_index); const scalar_t sin = __ldg(cache_ptr + y_index); @@ -39,13 +40,27 @@ __global__ void rotary_embedding_neox_kernel( const scalar_t q_y = query[token_head + y_index]; query[out_x] = q_x * cos - q_y * sin; query[out_y] = q_y * cos + q_x * sin; + } - if (head_idx < num_kv_heads) { - const scalar_t k_x = key[token_head + x_index]; - const scalar_t k_y = key[token_head + y_index]; - key[out_x] = k_x * cos - k_y * sin; - key[out_y] = k_y * cos + k_x * sin; - } + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int token_head = token_idx * key_stride + head_idx * head_size; + + const int rot_offset = i % embed_dim; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; + + const int out_x = token_idx * key_stride + head_idx * head_size + x_index; + const int out_y = token_idx * key_stride + head_idx * head_size + y_index; + + const scalar_t cos = __ldg(cache_ptr + x_index); + const scalar_t sin = __ldg(cache_ptr + y_index); + + const scalar_t k_x = key[token_head + x_index]; + const scalar_t k_y = key[token_head + y_index]; + key[out_x] = k_x * cos - k_y * sin; + key[out_y] = k_y * cos + k_x * sin; } } @@ -62,8 +77,8 @@ void rotary_embedding_neox( int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(1) / head_size; int num_kv_heads = key.size(1) / head_size; - int stride = query.stride(0); - TORCH_CHECK(stride == key.stride(0)); + int query_stride = query.stride(0); + int key_stride = key.stride(0); dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -80,7 +95,8 @@ void rotary_embedding_neox( key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, - stride, + query_stride, + key_stride, num_heads, num_kv_heads, head_size); diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 4c9a47e7..46f7c119 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -20,6 +20,9 @@ Alongside each architecture, we include some popular models that use it. * - :code:`BloomForCausalLM` - BLOOM, BLOOMZ, BLOOMChat - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. + * - :code:`FalconForCausalLM` + - Falcon + - :code:`tiiuae/falcon-7b``, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc. * - :code:`GPT2LMHeadModel` - GPT-2 - :code:`gpt2`, :code:`gpt2-xl`, etc. diff --git a/examples/llm_engine_example.py b/examples/llm_engine_example.py index 76fa117c..185d3762 100644 --- a/examples/llm_engine_example.py +++ b/examples/llm_engine_example.py @@ -10,7 +10,8 @@ def main(args: argparse.Namespace): # Test the following prompts. test_prompts = [ - ("A robot may not injure a human being", SamplingParams()), + ("A robot may not injure a human being", + SamplingParams(temperature=0.0)), ("To be or not to be,", SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)), ("What is the meaning of life?", diff --git a/vllm/config.py b/vllm/config.py index ae089a6b..bd3dd6a2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -94,8 +94,13 @@ class ModelConfig: return self.hf_config.hidden_size // self.hf_config.num_attention_heads def get_num_heads(self, parallel_config: "ParallelConfig") -> int: - # For GPTBigCode: - if getattr(self.hf_config, "multi_query", False): + # For GPTBigCode & Falcon: + # Note: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + if (getattr(self.hf_config, "multi_query", False) and + (self.hf_config.model_type == "falcon" and + not getattr(self.hf_config, "new_decoder_architecture", False))): # Multi-query attention, only one KV head. return 1 # For Falcon: diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index bd25ee7a..e726a407 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -314,14 +314,13 @@ class PagedAttentionWithRoPE(PagedAttention): class PagedAttentionWithALiBi(PagedAttention): """PagedAttention with ALiBi attention bias.""" - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - slopes: List[float], - ) -> None: - super().__init__(num_heads, head_size, scale) + def __init__(self, + num_heads: int, + head_size: int, + scale: float, + slopes: List[float], + num_kv_heads: Optional[int] = None) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads) assert len(slopes) == num_heads slopes = torch.tensor(slopes, dtype=torch.float32) @@ -334,6 +333,11 @@ class PagedAttentionWithALiBi(PagedAttention): # Generates ALiBi mask for each prompt. for prompt_len in input_metadata.prompt_lens: bias = torch.arange(prompt_len) + # Note(zhuohan): HF uses + # `bias = bias[None, :].repeat(prompt_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. bias = bias[None, :] - bias[:, None] bias = bias.to(self.alibi_slopes.device) @@ -363,10 +367,17 @@ class PagedAttentionWithALiBi(PagedAttention): Args: output: shape = [num_prompt_tokens, num_heads, head_size] query: shape = [num_prompt_tokens, num_heads, head_size] - key: shape = [num_prompt_tokens, num_heads, head_size] - value: shape = [num_prompt_tokens, num_heads, head_size] + key: shape = [num_prompt_tokens, num_kv_heads, head_size] + value: shape = [num_prompt_tokens, num_kv_heads, head_size] input_metadata: metadata for paged attention. """ + if self.num_kv_heads != self.num_heads: + # Project the key and value tensors to the desired number of heads. + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1) + value = torch.repeat_interleave(value, + self.num_queries_per_kv, + dim=1) + # FIXME(woosuk): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. @@ -400,9 +411,10 @@ class PagedAttentionWithALiBi(PagedAttention): Args: output: shape = [num_generation_tokens, num_heads, head_size] query: shape = [num_generation_tokens, num_heads, head_size] - key_cache: shape = [num_blocks, num_heads, head_size/x, + key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] - value_cache: shape = [num_blocks, num_heads, head_size, block_size] + value_cache: shape = [num_blocks, num_kv_heads, head_size, + block_size] input_metadata: metadata for paged attention. """ block_size = value_cache.shape[3] diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index c93d2db8..a1bcd159 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -14,6 +14,7 @@ _MODEL_REGISTRY = { "BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b "BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b "BloomForCausalLM": BloomForCausalLM, + "FalconForCausalLM": FalconForCausalLM, "GPT2LMHeadModel": GPT2LMHeadModel, "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, "GPTJForCausalLM": GPTJForCausalLM, @@ -22,6 +23,7 @@ _MODEL_REGISTRY = { "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* "MPTForCausalLM": MPTForCausalLM, "OPTForCausalLM": OPTForCausalLM, + "RWForCausalLM": FalconForCausalLM, } diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index d3259a05..787cb478 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,5 +1,6 @@ from vllm.model_executor.models.baichuan import BaiChuanForCausalLM, BaichuanForCausalLM from vllm.model_executor.models.bloom import BloomForCausalLM +from vllm.model_executor.models.falcon import FalconForCausalLM from vllm.model_executor.models.gpt2 import GPT2LMHeadModel from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM from vllm.model_executor.models.gpt_j import GPTJForCausalLM @@ -12,6 +13,7 @@ __all__ = [ "BaiChuanForCausalLM", "BaichuanForCausalLM", "BloomForCausalLM", + "FalconForCausalLM", "GPT2LMHeadModel", "GPTBigCodeForCausalLM", "GPTJForCausalLM", diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py new file mode 100644 index 00000000..7730b231 --- /dev/null +++ b/vllm/model_executor/models/falcon.py @@ -0,0 +1,496 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py +# Copyright 2023 The vLLM team. +# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Falcon model.""" + +import math +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import LayerNorm +from transformers import FalconConfig as HF_FalconConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.attention import (PagedAttention, + PagedAttentionWithALiBi, + PagedAttentionWithRoPE) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.weight_utils import (hf_model_weights_iterator, + load_tensor_parallel_weights) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.tensor_parallel import ( + VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear, + reduce_from_tensor_model_parallel_region) +from vllm.sequence import SequenceOutputs +from vllm.transformers_utils.configs import RWConfig + +KVCache = Tuple[torch.Tensor, torch.Tensor] +FalconConfig = Union[HF_FalconConfig, RWConfig] + + +# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during +# training, this means that there's one additional quantization to bfloat16 +# between the operations. In order not to degrade the quality of our HF-port, +# we keep these characteristics in the final model. +class FalconLinear(nn.Linear): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + hidden_states = x @ self.weight.T + if self.bias is None: + return hidden_states + return hidden_states + self.bias + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(1, + 1 + 2 * num_remaining_heads, + 2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + + return slopes + + +class FalconAttention(nn.Module): + + def __init__(self, config: FalconConfig): + super().__init__() + + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + self.new_decoder_architecture = config.new_decoder_architecture + self.multi_query = config.multi_query + + if self.new_decoder_architecture: + self.total_num_kv_heads = config.num_kv_heads + assert self.total_num_heads % tp_size == 0 + self.num_kv_heads = self.total_num_kv_heads // tp_size + self.query_key_value = ColumnParallelLinear( + self.hidden_size, + (self.total_num_heads + 2 * self.total_num_kv_heads) * + self.head_dim, + bias=config.bias, + gather_output=False, + perform_initialization=False, + skip_bias_add=True, + ) + elif self.multi_query: + self.total_num_kv_heads = 1 + self.num_kv_heads = 1 + self.query = ColumnParallelLinear( + self.hidden_size, + self.total_num_heads * self.head_dim, + bias=config.bias, + gather_output=False, + perform_initialization=False, + skip_bias_add=True, + ) + self.key_value = FalconLinear(self.hidden_size, + 2 * self.head_dim, + bias=config.bias) + else: + self.total_num_kv_heads = self.total_num_heads + self.num_kv_heads = self.num_heads + self.query_key_value = ColumnParallelLinear( + self.hidden_size, + (self.total_num_heads + 2 * self.total_num_kv_heads) * + self.head_dim, + bias=config.bias, + gather_output=False, + perform_initialization=False, + skip_bias_add=True, + ) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.reduce_row_parallel_results = not (config.new_decoder_architecture + or config.parallel_attn) + self.dense = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.bias, + input_is_parallel=True, + perform_initialization=False, + skip_bias_add=True, + reduce_results=self.reduce_row_parallel_results) + + self.use_rotary = config.rotary + self.use_alibi = config.alibi + assert not (self.use_rotary and self.use_alibi), ( + "Rotary and alibi are mutually exclusive.") + + if self.use_rotary: + # TODO(zhuohan): Pass in correct `max_position`` + self.attn = PagedAttentionWithRoPE(self.num_heads, + self.head_dim, + self.inv_norm_factor, + rotary_dim=self.head_dim, + num_kv_heads=self.num_kv_heads) + elif self.use_alibi: + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = (_get_alibi_slopes(self.total_num_heads) * + self.inv_norm_factor) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + self.attn = PagedAttentionWithALiBi(self.num_heads, + self.head_dim, + self.inv_norm_factor, + alibi_slopes, + num_kv_heads=self.num_kv_heads) + else: + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scale=self.inv_norm_factor, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + if not self.new_decoder_architecture and self.multi_query: + q, bias = self.query(hidden_states) + if bias is not None: + q += bias + kv = self.key_value(hidden_states) + k, v = kv.split([self.kv_size, self.kv_size], dim=-1) + else: + qkv, bias = self.query_key_value(hidden_states) + if bias is not None: + qkv += bias + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + k_cache, v_cache = kv_cache + if self.use_rotary: + attn_output = self.attn(positions, q, k, v, k_cache, v_cache, + input_metadata, cache_event) + else: + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, + cache_event) + attn_output, bias = self.dense(attn_output) + return attn_output, bias + + +class FalconMLP(nn.Module): + + def __init__(self, config: FalconConfig): + super().__init__() + hidden_size = config.hidden_size + + self.dense_h_to_4h = ColumnParallelLinear(hidden_size, + 4 * hidden_size, + bias=config.bias, + gather_output=False, + perform_initialization=False, + skip_bias_add=True) + self.act = nn.GELU() + self.reduce_row_parallel_results = not (config.new_decoder_architecture + or config.parallel_attn) + self.dense_4h_to_h = RowParallelLinear( + 4 * hidden_size, + hidden_size, + bias=config.bias, + input_is_parallel=True, + perform_initialization=False, + skip_bias_add=True, + reduce_results=self.reduce_row_parallel_results) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # NOTE(zhuohan): Following huggingface, we do not fuse bias add here. + x, bias = self.dense_h_to_4h(x) + if bias is not None: + x += bias + x = self.act(x) + x, bias = self.dense_4h_to_h(x) + return x, bias + + +class FalconDecoderLayer(nn.Module): + + def __init__(self, config: FalconConfig): + super().__init__() + hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.self_attention = FalconAttention(config) + self.mlp = FalconMLP(config) + self.config = config + + if config.new_decoder_architecture: + # The layer norm before self-attention + self.ln_attn = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + # The layer norm before the MLP + self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + else: + self.input_layernorm = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + if not config.parallel_attn: + self.post_attention_layernorm = LayerNorm( + hidden_size, eps=config.layer_norm_epsilon) + + self.reduce_row_parallel_results = not (config.new_decoder_architecture + or config.parallel_attn) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ): + residual = hidden_states + + if self.config.new_decoder_architecture: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) + else: + attention_layernorm_out = self.input_layernorm(hidden_states) + + # Self attention. + attention_output, attention_bias = self.self_attention( + positions=positions, + hidden_states=attention_layernorm_out, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + if self.reduce_row_parallel_results and attention_bias is not None: + attention_output += attention_bias + + if not self.config.new_decoder_architecture: + if self.config.parallel_attn: + mlp_layernorm_out = attention_layernorm_out + else: + residual += attention_output + mlp_layernorm_out = self.post_attention_layernorm(residual) + + # MLP. + mlp_output, mlp_bias = self.mlp(mlp_layernorm_out) + if self.reduce_row_parallel_results and mlp_bias is not None: + mlp_output += mlp_bias + + if not self.reduce_row_parallel_results: + # When MLP and Attention layers are parallel, we can use + # only one all-reduce operator to reduce the results from + # both MLP and Attention layers. + mlp_output += attention_output + mlp_output = reduce_from_tensor_model_parallel_region(mlp_output) + if attention_bias is not None: + mlp_output += attention_bias + if mlp_bias is not None: + mlp_output += mlp_bias + + output = mlp_output + residual + + return output + + +class FalconModel(nn.Module): + + def __init__(self, config: FalconConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.use_alibi = config.alibi + + # Embedding + LN Embedding + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, self.embed_dim, perform_initialization=False) + + # Transformer blocks + self.h = nn.ModuleList([ + FalconDecoderLayer(config) for _ in range(config.num_hidden_layers) + ]) + + # Final Layer Norm + self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + hidden_states = self.word_embeddings(input_ids) + for i in range(len(self.h)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.h[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + cache_event, + ) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class FalconForCausalLM(nn.Module): + + def __init__(self, config: FalconConfig): + super().__init__() + self.config = config + self.transformer = FalconModel(config) + self.lm_head = ColumnParallelLinear(config.hidden_size, + config.vocab_size, + bias=False, + gather_output=False, + perform_initialization=False) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.transformer( + input_ids, + positions, + kv_caches, + input_metadata, + cache_events, + ) + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + input_metadata) + + return next_tokens + + _column_parallel_weights = [ + "word_embeddings.weight", "lm_head.weight", "dense_h_to_4h.weight", + "dense_h_to_4h.bias" + ] + _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"] + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False): + tp_size = (get_tensor_model_parallel_world_size()) + tp_rank = get_tensor_model_parallel_rank() + + hidden_size = self.config.hidden_size + total_num_heads = self.config.num_attention_heads + num_heads = total_num_heads // tp_size + head_size = hidden_size // total_num_heads + head_start = tp_rank * num_heads + head_end = (tp_rank + 1) * num_heads + if self.config.new_decoder_architecture: + total_num_kv_heads = self.config.num_kv_heads + num_kv_heads = total_num_kv_heads // tp_size + separated_q_kv = False + kv_head_start = tp_rank * num_kv_heads + kv_head_end = (tp_rank + 1) * num_kv_heads + elif self.config.multi_query: + total_num_kv_heads = 1 + num_kv_heads = 1 + separated_q_kv = True + kv_head_start = 0 + kv_head_end = 1 + else: + total_num_kv_heads = total_num_heads + num_kv_heads = total_num_kv_heads // tp_size + separated_q_kv = False + kv_head_start = tp_rank * num_kv_heads + kv_head_end = (tp_rank + 1) * num_kv_heads + num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, use_np_cache): + if "query_key_value" in name: + loaded_weight_size = loaded_weight.size() + loaded_weight = loaded_weight.view( + total_num_kv_heads, num_query_heads_per_kv_head + 2, + head_size, *loaded_weight_size[1:]) + + wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:]) + wk = loaded_weight[:, [-2]].reshape(-1, + *loaded_weight_size[1:]) + wv = loaded_weight[:, [-1]].reshape(-1, + *loaded_weight_size[1:]) + + wq = wq[head_size * head_start:head_size * head_end] + wk = wk[head_size * kv_head_start:head_size * kv_head_end] + wv = wv[head_size * kv_head_start:head_size * kv_head_end] + + if separated_q_kv: + loaded_weight_q = wq + loaded_weight_kv = torch.cat([wk, wv], dim=0) + q_weight_name = name.replace("query_key_value", "query") + kv_weight_name = name.replace("query_key_value", + "key_value") + load_tensor_parallel_weights(state_dict[q_weight_name], + loaded_weight_q, + q_weight_name, + self._column_parallel_weights, + self._row_parallel_weights, + tp_rank) + load_tensor_parallel_weights(state_dict[kv_weight_name], + loaded_weight_kv, + kv_weight_name, + self._column_parallel_weights, + self._row_parallel_weights, + tp_rank) + continue + else: + loaded_weight = torch.cat([wq, wk, wv], dim=0) + + param = state_dict[name] + load_tensor_parallel_weights(param, loaded_weight, name, + self._column_parallel_weights, + self._row_parallel_weights, tp_rank) diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index c8988ac0..e5a43258 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -44,7 +44,6 @@ _PIPELINE_GLOBAL_RANKS = None # rank when broadcasting weights from src to all other data parallel ranks _DATA_PARALLEL_GLOBAL_RANKS = None -_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None def initialize_model_parallel( tensor_model_parallel_size: int = 1, @@ -196,20 +195,6 @@ def initialize_model_parallel( if rank in ranks: _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks -def initialize_all_reduce_launcher( - max_num_tokens: int, - hidden_size: int, - dtype: torch.dtype, - disable_graph: bool = False, -) -> None: - global _ALL_REDUCE_LAUNCHER - _ALL_REDUCE_LAUNCHER = GraphAllReduce( - max_num_tokens=max_num_tokens, - hidden_size=hidden_size, - dtype=dtype, - disable_graph=disable_graph, - ) - def model_parallel_is_initialized(): """Check if model and data parallel groups are initialized.""" if _TENSOR_MODEL_PARALLEL_GROUP is None or \ @@ -458,6 +443,7 @@ def get_pipeline_model_parallel_last_rank(): last_rank_local = get_pipeline_model_parallel_world_size() - 1 return _PIPELINE_GLOBAL_RANKS[last_rank_local] + def get_pipeline_model_parallel_next_rank(): """Return the global rank that follows the caller in the pipeline""" assert _PIPELINE_GLOBAL_RANKS is not None, \ @@ -485,10 +471,6 @@ def get_data_parallel_rank(): """Return my rank for the data parallel group.""" return torch.distributed.get_rank(group=get_data_parallel_group()) -def get_all_reduce_launcher() -> 'GraphAllReduce': - assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized' - return _ALL_REDUCE_LAUNCHER - def destroy_model_parallel(): """Set the groups to none.""" global _MODEL_PARALLEL_GROUP @@ -515,56 +497,3 @@ def destroy_model_parallel(): _MPU_TENSOR_MODEL_PARALLEL_RANK = None global _MPU_PIPELINE_MODEL_PARALLEL_RANK _MPU_PIPELINE_MODEL_PARALLEL_RANK = None - - -class GraphAllReduce: - - def __init__( - self, - max_num_tokens: int, - hidden_size: int, - dtype: torch.dtype, - disable_graph: bool = False, - ) -> None: - self.max_num_tokens = max_num_tokens - self.hidden_size = hidden_size - self.disable_graph = disable_graph - - tp_world_size = get_tensor_model_parallel_world_size() - if tp_world_size == 1: - return - - self.group = get_tensor_model_parallel_group() - self.buffer = torch.empty( - size=(max_num_tokens, hidden_size), - dtype=dtype, - device='cuda', - ) - - # Build graphs for different number of tokens. - if not self.disable_graph: - self.graphs = {} - for num_tokens in range(8, max_num_tokens + 1, 8): - self.graphs[num_tokens] = self._build_graph(num_tokens) - - def _build_graph(self, num_tokens: int) -> torch.cuda.CUDAGraph: - # Warm up. - torch.distributed.all_reduce(self.buffer[:num_tokens], group=self.group) - torch.cuda.synchronize() - - # Build graph. - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): - torch.distributed.all_reduce( - self.buffer[:num_tokens], group=self.group) - torch.cuda.synchronize() - return graph - - def launch(self, x: torch.Tensor) -> torch.Tensor: - # NOTE: x must be a slice of self.buffer. - num_tokens = x.shape[0] - if self.disable_graph: - torch.distributed.all_reduce(x, group=self.group) - else: - self.graphs[num_tokens].replay() - return x diff --git a/vllm/model_executor/parallel_utils/tensor_parallel/__init__.py b/vllm/model_executor/parallel_utils/tensor_parallel/__init__.py index da0ce2a1..d17f12f3 100644 --- a/vllm/model_executor/parallel_utils/tensor_parallel/__init__.py +++ b/vllm/model_executor/parallel_utils/tensor_parallel/__init__.py @@ -12,6 +12,7 @@ from .mappings import ( copy_to_tensor_model_parallel_region, gather_from_tensor_model_parallel_region, gather_from_sequence_parallel_region, + reduce_from_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region, scatter_to_sequence_parallel_region, ) @@ -38,7 +39,7 @@ __all__ = [ "copy_to_tensor_model_parallel_region", "gather_from_tensor_model_parallel_region", "gather_from_sequence_parallel_region", -# "reduce_from_tensor_model_parallel_region", + "reduce_from_tensor_model_parallel_region", "scatter_to_tensor_model_parallel_region", "scatter_to_sequence_parallel_region", # random.py diff --git a/vllm/model_executor/parallel_utils/tensor_parallel/layers.py b/vllm/model_executor/parallel_utils/tensor_parallel/layers.py index 8c2343ab..0b4d32b6 100644 --- a/vllm/model_executor/parallel_utils/tensor_parallel/layers.py +++ b/vllm/model_executor/parallel_utils/tensor_parallel/layers.py @@ -14,7 +14,6 @@ from torch.nn.parameter import Parameter from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_all_reduce_launcher, ) from .mappings import ( copy_to_tensor_model_parallel_region, @@ -248,8 +247,8 @@ class ColumnParallelLinear(torch.nn.Module): self.output_size = output_size self.gather_output = gather_output # Divide the weight matrix along the last dimension. - world_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, world_size) + self.world_size = get_tensor_model_parallel_world_size() + self.output_size_per_partition = divide(output_size, self.world_size) self.skip_bias_add = skip_bias_add if params_dtype is None: @@ -350,6 +349,7 @@ class RowParallelLinear(torch.nn.Module): params_dtype: use_cpu_initialization: perform_initialization: + reduce_results: """ def __init__(self, input_size, output_size, *, @@ -360,6 +360,7 @@ class RowParallelLinear(torch.nn.Module): params_dtype=None, use_cpu_initialization=False, perform_initialization=True, + reduce_results=True, ): super(RowParallelLinear, self).__init__() @@ -367,14 +368,19 @@ class RowParallelLinear(torch.nn.Module): self.input_size = input_size self.output_size = output_size self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results if params_dtype is None: params_dtype = torch.get_default_dtype() # Divide the weight matrix along the last dimension. - world_size = get_tensor_model_parallel_world_size() - self.input_size_per_partition = divide(input_size, world_size) + self.world_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, self.world_size) self.skip_bias_add = skip_bias_add + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") + # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. @@ -427,17 +433,12 @@ class RowParallelLinear(torch.nn.Module): input_parallel = input_ else: input_parallel = scatter_to_tensor_model_parallel_region(input_) - if get_tensor_model_parallel_world_size() == 1: - # Matrix multiply. - output_ = F.linear(input_parallel, self.weight) + # Matrix multiply. + output_parallel = F.linear(input_parallel, self.weight) + if self.reduce_results and self.world_size > 1: + output_ = reduce_from_tensor_model_parallel_region(output_parallel) else: - # Matrix multiply. - all_reduce_launcher = get_all_reduce_launcher() - num_tokens = input_parallel.shape[0] - output_buffer = all_reduce_launcher.buffer[:num_tokens] - torch.matmul(input_parallel, self.weight_t, out=output_buffer) - # All-reduce across all the partitions. - output_ = all_reduce_launcher.launch(output_buffer) + output_ = output_parallel if not self.skip_bias_add: output = output_ + self.bias if self.bias is not None else output_ diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 71e67118..b7b3da63 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -5,6 +5,8 @@ from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import _CONFIG_REGISTRY = { "mpt": MPTConfig, "baichuan": BaiChuanConfig, + "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) + "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 5f0ba4eb..b98c797c 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,7 +1,12 @@ from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.baichuan import BaiChuanConfig +# RWConfig is for the original tiiuae/falcon-40b(-instruct) and +# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the +# `FalconConfig` class from the official HuggingFace transformers library. +from vllm.transformers_utils.configs.falcon import RWConfig __all__ = [ "MPTConfig", "BaiChuanConfig", + "RWConfig", ] diff --git a/vllm/transformers_utils/configs/falcon.py b/vllm/transformers_utils/configs/falcon.py new file mode 100644 index 00000000..c82cc606 --- /dev/null +++ b/vllm/transformers_utils/configs/falcon.py @@ -0,0 +1,87 @@ +# Adapted from +# https://huggingface.co/tiiuae/falcon-7b/blob/main/configuration_RW.py +# Copyright 2023 The vLLM team. +# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Falcon configuration""" +from transformers.configuration_utils import PretrainedConfig + + +class RWConfig(PretrainedConfig): + model_type = "falcon" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_hidden_layers": "n_layer", + "num_attention_heads": "n_head", + "num_kv_heads": "n_head_kv", + } + + def __init__( + self, + vocab_size=250880, + hidden_size=64, + n_layer=2, + n_head=8, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + hidden_dropout=0.0, + attention_dropout=0.0, + multi_query=True, + n_head_kv=None, + alibi=False, + bias=False, + parallel_attn=False, + new_decoder_architecture=False, + **kwargs, + ) -> None: + self.vocab_size = vocab_size + # Backward compatibility with n_embed kwarg + n_embed = kwargs.pop("n_embed", None) + self.hidden_size = hidden_size if n_embed is None else n_embed + self.n_layer = n_layer + self.n_head = n_head + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.multi_query = multi_query + self.n_head_kv = 1 if n_head_kv is None else n_head_kv + self.alibi = alibi + self.bias = bias + self.parallel_attn = parallel_attn + self.new_decoder_architecture = new_decoder_architecture + + if self.hidden_size == 8192: + # Hack for falcon-40b + self.new_decoder_architecture = True + + super().__init__(bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs) + + @property + def head_dim(self): + return self.hidden_size // self.n_head + + @property + def rotary(self): + return not self.alibi diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 30de4f60..bee7d441 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.model_executor import get_model, InputMetadata, set_random_seed from vllm.model_executor.parallel_utils.parallel_state import ( - initialize_model_parallel, initialize_all_reduce_launcher) + initialize_model_parallel) from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceData, SequenceGroupMetadata, SequenceOutputs from vllm.worker.cache_engine import CacheEngine @@ -65,11 +65,6 @@ class Worker: # Initialize the model. set_random_seed(self.model_config.seed) self.model = get_model(self.model_config) - initialize_all_reduce_launcher( - self.scheduler_config.max_num_batched_tokens, - self.model_config.get_hidden_size(), - self.model_config.dtype, - ) @torch.inference_mode() def profile_num_available_blocks(