Add Falcon support (new) (#592)
This commit is contained in:
parent
20044cab7a
commit
1b0bd0fe8a
@ -44,6 +44,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit
|
|||||||
|
|
||||||
- Baichuan-7B (`baichuan-inc/Baichuan-7B`)
|
- Baichuan-7B (`baichuan-inc/Baichuan-7B`)
|
||||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
||||||
|
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
||||||
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
|
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
|
||||||
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
||||||
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
||||||
|
|||||||
@ -10,7 +10,8 @@ __global__ void rotary_embedding_neox_kernel(
|
|||||||
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||||
const int rot_dim,
|
const int rot_dim,
|
||||||
const int stride,
|
const int query_stride,
|
||||||
|
const int key_stride,
|
||||||
const int num_heads,
|
const int num_heads,
|
||||||
const int num_kv_heads,
|
const int num_kv_heads,
|
||||||
const int head_size) {
|
const int head_size) {
|
||||||
@ -23,14 +24,14 @@ __global__ void rotary_embedding_neox_kernel(
|
|||||||
const int nq = num_heads * embed_dim;
|
const int nq = num_heads * embed_dim;
|
||||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int token_head = token_idx * stride + head_idx * head_size;
|
const int token_head = token_idx * query_stride + head_idx * head_size;
|
||||||
|
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
const int x_index = rot_offset;
|
const int x_index = rot_offset;
|
||||||
const int y_index = embed_dim + rot_offset;
|
const int y_index = embed_dim + rot_offset;
|
||||||
|
|
||||||
const int out_x = token_idx * stride + head_idx * head_size + x_index;
|
const int out_x = token_idx * query_stride + head_idx * head_size + x_index;
|
||||||
const int out_y = token_idx * stride + head_idx * head_size + y_index;
|
const int out_y = token_idx * query_stride + head_idx * head_size + y_index;
|
||||||
|
|
||||||
const scalar_t cos = __ldg(cache_ptr + x_index);
|
const scalar_t cos = __ldg(cache_ptr + x_index);
|
||||||
const scalar_t sin = __ldg(cache_ptr + y_index);
|
const scalar_t sin = __ldg(cache_ptr + y_index);
|
||||||
@ -39,14 +40,28 @@ __global__ void rotary_embedding_neox_kernel(
|
|||||||
const scalar_t q_y = query[token_head + y_index];
|
const scalar_t q_y = query[token_head + y_index];
|
||||||
query[out_x] = q_x * cos - q_y * sin;
|
query[out_x] = q_x * cos - q_y * sin;
|
||||||
query[out_y] = q_y * cos + q_x * sin;
|
query[out_y] = q_y * cos + q_x * sin;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int nk = num_kv_heads * embed_dim;
|
||||||
|
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||||
|
const int head_idx = i / embed_dim;
|
||||||
|
const int token_head = token_idx * key_stride + head_idx * head_size;
|
||||||
|
|
||||||
|
const int rot_offset = i % embed_dim;
|
||||||
|
const int x_index = rot_offset;
|
||||||
|
const int y_index = embed_dim + rot_offset;
|
||||||
|
|
||||||
|
const int out_x = token_idx * key_stride + head_idx * head_size + x_index;
|
||||||
|
const int out_y = token_idx * key_stride + head_idx * head_size + y_index;
|
||||||
|
|
||||||
|
const scalar_t cos = __ldg(cache_ptr + x_index);
|
||||||
|
const scalar_t sin = __ldg(cache_ptr + y_index);
|
||||||
|
|
||||||
if (head_idx < num_kv_heads) {
|
|
||||||
const scalar_t k_x = key[token_head + x_index];
|
const scalar_t k_x = key[token_head + x_index];
|
||||||
const scalar_t k_y = key[token_head + y_index];
|
const scalar_t k_y = key[token_head + y_index];
|
||||||
key[out_x] = k_x * cos - k_y * sin;
|
key[out_x] = k_x * cos - k_y * sin;
|
||||||
key[out_y] = k_y * cos + k_x * sin;
|
key[out_y] = k_y * cos + k_x * sin;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
@ -62,8 +77,8 @@ void rotary_embedding_neox(
|
|||||||
int rot_dim = cos_sin_cache.size(1);
|
int rot_dim = cos_sin_cache.size(1);
|
||||||
int num_heads = query.size(1) / head_size;
|
int num_heads = query.size(1) / head_size;
|
||||||
int num_kv_heads = key.size(1) / head_size;
|
int num_kv_heads = key.size(1) / head_size;
|
||||||
int stride = query.stride(0);
|
int query_stride = query.stride(0);
|
||||||
TORCH_CHECK(stride == key.stride(0));
|
int key_stride = key.stride(0);
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||||
@ -80,7 +95,8 @@ void rotary_embedding_neox(
|
|||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(),
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
rot_dim,
|
rot_dim,
|
||||||
stride,
|
query_stride,
|
||||||
|
key_stride,
|
||||||
num_heads,
|
num_heads,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
head_size);
|
head_size);
|
||||||
|
|||||||
@ -20,6 +20,9 @@ Alongside each architecture, we include some popular models that use it.
|
|||||||
* - :code:`BloomForCausalLM`
|
* - :code:`BloomForCausalLM`
|
||||||
- BLOOM, BLOOMZ, BLOOMChat
|
- BLOOM, BLOOMZ, BLOOMChat
|
||||||
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
|
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
|
||||||
|
* - :code:`FalconForCausalLM`
|
||||||
|
- Falcon
|
||||||
|
- :code:`tiiuae/falcon-7b``, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
|
||||||
* - :code:`GPT2LMHeadModel`
|
* - :code:`GPT2LMHeadModel`
|
||||||
- GPT-2
|
- GPT-2
|
||||||
- :code:`gpt2`, :code:`gpt2-xl`, etc.
|
- :code:`gpt2`, :code:`gpt2-xl`, etc.
|
||||||
|
|||||||
@ -10,7 +10,8 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
# Test the following prompts.
|
# Test the following prompts.
|
||||||
test_prompts = [
|
test_prompts = [
|
||||||
("A robot may not injure a human being", SamplingParams()),
|
("A robot may not injure a human being",
|
||||||
|
SamplingParams(temperature=0.0)),
|
||||||
("To be or not to be,",
|
("To be or not to be,",
|
||||||
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
|
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
|
||||||
("What is the meaning of life?",
|
("What is the meaning of life?",
|
||||||
|
|||||||
@ -94,8 +94,13 @@ class ModelConfig:
|
|||||||
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||||
|
|
||||||
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
|
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||||
# For GPTBigCode:
|
# For GPTBigCode & Falcon:
|
||||||
if getattr(self.hf_config, "multi_query", False):
|
# Note: for falcon, when new_decoder_architecture is True, the
|
||||||
|
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||||
|
# KV heads.
|
||||||
|
if (getattr(self.hf_config, "multi_query", False) and
|
||||||
|
(self.hf_config.model_type == "falcon" and
|
||||||
|
not getattr(self.hf_config, "new_decoder_architecture", False))):
|
||||||
# Multi-query attention, only one KV head.
|
# Multi-query attention, only one KV head.
|
||||||
return 1
|
return 1
|
||||||
# For Falcon:
|
# For Falcon:
|
||||||
|
|||||||
@ -314,14 +314,13 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
class PagedAttentionWithALiBi(PagedAttention):
|
class PagedAttentionWithALiBi(PagedAttention):
|
||||||
"""PagedAttention with ALiBi attention bias."""
|
"""PagedAttention with ALiBi attention bias."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
slopes: List[float],
|
slopes: List[float],
|
||||||
) -> None:
|
num_kv_heads: Optional[int] = None) -> None:
|
||||||
super().__init__(num_heads, head_size, scale)
|
super().__init__(num_heads, head_size, scale, num_kv_heads)
|
||||||
assert len(slopes) == num_heads
|
assert len(slopes) == num_heads
|
||||||
|
|
||||||
slopes = torch.tensor(slopes, dtype=torch.float32)
|
slopes = torch.tensor(slopes, dtype=torch.float32)
|
||||||
@ -334,6 +333,11 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
# Generates ALiBi mask for each prompt.
|
# Generates ALiBi mask for each prompt.
|
||||||
for prompt_len in input_metadata.prompt_lens:
|
for prompt_len in input_metadata.prompt_lens:
|
||||||
bias = torch.arange(prompt_len)
|
bias = torch.arange(prompt_len)
|
||||||
|
# Note(zhuohan): HF uses
|
||||||
|
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||||
|
# here. We find that both biases give the same results, but
|
||||||
|
# the bias below more accurately follows the original ALiBi
|
||||||
|
# paper.
|
||||||
bias = bias[None, :] - bias[:, None]
|
bias = bias[None, :] - bias[:, None]
|
||||||
bias = bias.to(self.alibi_slopes.device)
|
bias = bias.to(self.alibi_slopes.device)
|
||||||
|
|
||||||
@ -363,10 +367,17 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
Args:
|
Args:
|
||||||
output: shape = [num_prompt_tokens, num_heads, head_size]
|
output: shape = [num_prompt_tokens, num_heads, head_size]
|
||||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||||
key: shape = [num_prompt_tokens, num_heads, head_size]
|
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||||
value: shape = [num_prompt_tokens, num_heads, head_size]
|
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||||
input_metadata: metadata for paged attention.
|
input_metadata: metadata for paged attention.
|
||||||
"""
|
"""
|
||||||
|
if self.num_kv_heads != self.num_heads:
|
||||||
|
# Project the key and value tensors to the desired number of heads.
|
||||||
|
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
|
||||||
|
value = torch.repeat_interleave(value,
|
||||||
|
self.num_queries_per_kv,
|
||||||
|
dim=1)
|
||||||
|
|
||||||
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
||||||
# lengths with custom attention bias, we process each prompt one by
|
# lengths with custom attention bias, we process each prompt one by
|
||||||
# one. This is inefficient, especially when we have many short prompts.
|
# one. This is inefficient, especially when we have many short prompts.
|
||||||
@ -400,9 +411,10 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
Args:
|
Args:
|
||||||
output: shape = [num_generation_tokens, num_heads, head_size]
|
output: shape = [num_generation_tokens, num_heads, head_size]
|
||||||
query: shape = [num_generation_tokens, num_heads, head_size]
|
query: shape = [num_generation_tokens, num_heads, head_size]
|
||||||
key_cache: shape = [num_blocks, num_heads, head_size/x,
|
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||||
block_size, x]
|
block_size, x]
|
||||||
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
|
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||||
|
block_size]
|
||||||
input_metadata: metadata for paged attention.
|
input_metadata: metadata for paged attention.
|
||||||
"""
|
"""
|
||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
|
|||||||
@ -14,6 +14,7 @@ _MODEL_REGISTRY = {
|
|||||||
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
|
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
|
||||||
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
|
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
|
||||||
"BloomForCausalLM": BloomForCausalLM,
|
"BloomForCausalLM": BloomForCausalLM,
|
||||||
|
"FalconForCausalLM": FalconForCausalLM,
|
||||||
"GPT2LMHeadModel": GPT2LMHeadModel,
|
"GPT2LMHeadModel": GPT2LMHeadModel,
|
||||||
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
|
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
|
||||||
"GPTJForCausalLM": GPTJForCausalLM,
|
"GPTJForCausalLM": GPTJForCausalLM,
|
||||||
@ -22,6 +23,7 @@ _MODEL_REGISTRY = {
|
|||||||
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
|
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
|
||||||
"MPTForCausalLM": MPTForCausalLM,
|
"MPTForCausalLM": MPTForCausalLM,
|
||||||
"OPTForCausalLM": OPTForCausalLM,
|
"OPTForCausalLM": OPTForCausalLM,
|
||||||
|
"RWForCausalLM": FalconForCausalLM,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from vllm.model_executor.models.baichuan import BaiChuanForCausalLM, BaichuanForCausalLM
|
from vllm.model_executor.models.baichuan import BaiChuanForCausalLM, BaichuanForCausalLM
|
||||||
from vllm.model_executor.models.bloom import BloomForCausalLM
|
from vllm.model_executor.models.bloom import BloomForCausalLM
|
||||||
|
from vllm.model_executor.models.falcon import FalconForCausalLM
|
||||||
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
|
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
|
||||||
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
|
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
|
||||||
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
|
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
|
||||||
@ -12,6 +13,7 @@ __all__ = [
|
|||||||
"BaiChuanForCausalLM",
|
"BaiChuanForCausalLM",
|
||||||
"BaichuanForCausalLM",
|
"BaichuanForCausalLM",
|
||||||
"BloomForCausalLM",
|
"BloomForCausalLM",
|
||||||
|
"FalconForCausalLM",
|
||||||
"GPT2LMHeadModel",
|
"GPT2LMHeadModel",
|
||||||
"GPTBigCodeForCausalLM",
|
"GPTBigCodeForCausalLM",
|
||||||
"GPTJForCausalLM",
|
"GPTJForCausalLM",
|
||||||
|
|||||||
496
vllm/model_executor/models/falcon.py
Normal file
496
vllm/model_executor/models/falcon.py
Normal file
@ -0,0 +1,496 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights
|
||||||
|
# reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch Falcon model."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import LayerNorm
|
||||||
|
from transformers import FalconConfig as HF_FalconConfig
|
||||||
|
|
||||||
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.attention import (PagedAttention,
|
||||||
|
PagedAttentionWithALiBi,
|
||||||
|
PagedAttentionWithRoPE)
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||||
|
load_tensor_parallel_weights)
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear,
|
||||||
|
reduce_from_tensor_model_parallel_region)
|
||||||
|
from vllm.sequence import SequenceOutputs
|
||||||
|
from vllm.transformers_utils.configs import RWConfig
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
FalconConfig = Union[HF_FalconConfig, RWConfig]
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during
|
||||||
|
# training, this means that there's one additional quantization to bfloat16
|
||||||
|
# between the operations. In order not to degrade the quality of our HF-port,
|
||||||
|
# we keep these characteristics in the final model.
|
||||||
|
class FalconLinear(nn.Linear):
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = x @ self.weight.T
|
||||||
|
if self.bias is None:
|
||||||
|
return hidden_states
|
||||||
|
return hidden_states + self.bias
|
||||||
|
|
||||||
|
|
||||||
|
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||||
|
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||||
|
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||||
|
dtype=torch.float32)
|
||||||
|
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||||
|
slopes = torch.pow(base, powers)
|
||||||
|
|
||||||
|
if closest_power_of_2 != total_num_heads:
|
||||||
|
extra_base = torch.tensor(
|
||||||
|
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||||
|
dtype=torch.float32)
|
||||||
|
num_remaining_heads = min(closest_power_of_2,
|
||||||
|
total_num_heads - closest_power_of_2)
|
||||||
|
extra_powers = torch.arange(1,
|
||||||
|
1 + 2 * num_remaining_heads,
|
||||||
|
2,
|
||||||
|
dtype=torch.int32)
|
||||||
|
slopes = torch.cat(
|
||||||
|
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||||
|
|
||||||
|
return slopes
|
||||||
|
|
||||||
|
|
||||||
|
class FalconAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: FalconConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
self.total_num_heads = config.num_attention_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.head_dim = self.hidden_size // self.total_num_heads
|
||||||
|
assert self.head_dim * self.total_num_heads == self.hidden_size
|
||||||
|
|
||||||
|
self.new_decoder_architecture = config.new_decoder_architecture
|
||||||
|
self.multi_query = config.multi_query
|
||||||
|
|
||||||
|
if self.new_decoder_architecture:
|
||||||
|
self.total_num_kv_heads = config.num_kv_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||||
|
self.query_key_value = ColumnParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||||
|
self.head_dim,
|
||||||
|
bias=config.bias,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
skip_bias_add=True,
|
||||||
|
)
|
||||||
|
elif self.multi_query:
|
||||||
|
self.total_num_kv_heads = 1
|
||||||
|
self.num_kv_heads = 1
|
||||||
|
self.query = ColumnParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
bias=config.bias,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
skip_bias_add=True,
|
||||||
|
)
|
||||||
|
self.key_value = FalconLinear(self.hidden_size,
|
||||||
|
2 * self.head_dim,
|
||||||
|
bias=config.bias)
|
||||||
|
else:
|
||||||
|
self.total_num_kv_heads = self.total_num_heads
|
||||||
|
self.num_kv_heads = self.num_heads
|
||||||
|
self.query_key_value = ColumnParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||||
|
self.head_dim,
|
||||||
|
bias=config.bias,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
skip_bias_add=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
|
||||||
|
# Layer-wise attention scaling
|
||||||
|
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||||
|
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||||
|
or config.parallel_attn)
|
||||||
|
self.dense = RowParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=config.bias,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
skip_bias_add=True,
|
||||||
|
reduce_results=self.reduce_row_parallel_results)
|
||||||
|
|
||||||
|
self.use_rotary = config.rotary
|
||||||
|
self.use_alibi = config.alibi
|
||||||
|
assert not (self.use_rotary and self.use_alibi), (
|
||||||
|
"Rotary and alibi are mutually exclusive.")
|
||||||
|
|
||||||
|
if self.use_rotary:
|
||||||
|
# TODO(zhuohan): Pass in correct `max_position``
|
||||||
|
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.inv_norm_factor,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
num_kv_heads=self.num_kv_heads)
|
||||||
|
elif self.use_alibi:
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
head_start = tp_rank * self.num_heads
|
||||||
|
head_end = (tp_rank + 1) * self.num_heads
|
||||||
|
alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
|
||||||
|
self.inv_norm_factor)
|
||||||
|
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||||
|
self.attn = PagedAttentionWithALiBi(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.inv_norm_factor,
|
||||||
|
alibi_slopes,
|
||||||
|
num_kv_heads=self.num_kv_heads)
|
||||||
|
else:
|
||||||
|
self.attn = PagedAttention(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
scale=self.inv_norm_factor,
|
||||||
|
num_kv_heads=self.num_kv_heads)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if not self.new_decoder_architecture and self.multi_query:
|
||||||
|
q, bias = self.query(hidden_states)
|
||||||
|
if bias is not None:
|
||||||
|
q += bias
|
||||||
|
kv = self.key_value(hidden_states)
|
||||||
|
k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
|
||||||
|
else:
|
||||||
|
qkv, bias = self.query_key_value(hidden_states)
|
||||||
|
if bias is not None:
|
||||||
|
qkv += bias
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
||||||
|
dim=-1)
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
if self.use_rotary:
|
||||||
|
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||||
|
input_metadata, cache_event)
|
||||||
|
else:
|
||||||
|
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||||
|
cache_event)
|
||||||
|
attn_output, bias = self.dense(attn_output)
|
||||||
|
return attn_output, bias
|
||||||
|
|
||||||
|
|
||||||
|
class FalconMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: FalconConfig):
|
||||||
|
super().__init__()
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
|
||||||
|
4 * hidden_size,
|
||||||
|
bias=config.bias,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
skip_bias_add=True)
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||||
|
or config.parallel_attn)
|
||||||
|
self.dense_4h_to_h = RowParallelLinear(
|
||||||
|
4 * hidden_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=config.bias,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
skip_bias_add=True,
|
||||||
|
reduce_results=self.reduce_row_parallel_results)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
|
||||||
|
x, bias = self.dense_h_to_4h(x)
|
||||||
|
if bias is not None:
|
||||||
|
x += bias
|
||||||
|
x = self.act(x)
|
||||||
|
x, bias = self.dense_4h_to_h(x)
|
||||||
|
return x, bias
|
||||||
|
|
||||||
|
|
||||||
|
class FalconDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: FalconConfig):
|
||||||
|
super().__init__()
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.self_attention = FalconAttention(config)
|
||||||
|
self.mlp = FalconMLP(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
if config.new_decoder_architecture:
|
||||||
|
# The layer norm before self-attention
|
||||||
|
self.ln_attn = LayerNorm(hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon)
|
||||||
|
# The layer norm before the MLP
|
||||||
|
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
else:
|
||||||
|
self.input_layernorm = LayerNorm(hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon)
|
||||||
|
if not config.parallel_attn:
|
||||||
|
self.post_attention_layernorm = LayerNorm(
|
||||||
|
hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||||
|
or config.parallel_attn)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
if self.config.new_decoder_architecture:
|
||||||
|
attention_layernorm_out = self.ln_attn(hidden_states)
|
||||||
|
mlp_layernorm_out = self.ln_mlp(hidden_states)
|
||||||
|
else:
|
||||||
|
attention_layernorm_out = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self attention.
|
||||||
|
attention_output, attention_bias = self.self_attention(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=attention_layernorm_out,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
cache_event=cache_event,
|
||||||
|
)
|
||||||
|
if self.reduce_row_parallel_results and attention_bias is not None:
|
||||||
|
attention_output += attention_bias
|
||||||
|
|
||||||
|
if not self.config.new_decoder_architecture:
|
||||||
|
if self.config.parallel_attn:
|
||||||
|
mlp_layernorm_out = attention_layernorm_out
|
||||||
|
else:
|
||||||
|
residual += attention_output
|
||||||
|
mlp_layernorm_out = self.post_attention_layernorm(residual)
|
||||||
|
|
||||||
|
# MLP.
|
||||||
|
mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
|
||||||
|
if self.reduce_row_parallel_results and mlp_bias is not None:
|
||||||
|
mlp_output += mlp_bias
|
||||||
|
|
||||||
|
if not self.reduce_row_parallel_results:
|
||||||
|
# When MLP and Attention layers are parallel, we can use
|
||||||
|
# only one all-reduce operator to reduce the results from
|
||||||
|
# both MLP and Attention layers.
|
||||||
|
mlp_output += attention_output
|
||||||
|
mlp_output = reduce_from_tensor_model_parallel_region(mlp_output)
|
||||||
|
if attention_bias is not None:
|
||||||
|
mlp_output += attention_bias
|
||||||
|
if mlp_bias is not None:
|
||||||
|
mlp_output += mlp_bias
|
||||||
|
|
||||||
|
output = mlp_output + residual
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class FalconModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: FalconConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.use_alibi = config.alibi
|
||||||
|
|
||||||
|
# Embedding + LN Embedding
|
||||||
|
self.word_embeddings = VocabParallelEmbedding(
|
||||||
|
config.vocab_size, self.embed_dim, perform_initialization=False)
|
||||||
|
|
||||||
|
# Transformer blocks
|
||||||
|
self.h = nn.ModuleList([
|
||||||
|
FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Final Layer Norm
|
||||||
|
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.word_embeddings(input_ids)
|
||||||
|
for i in range(len(self.h)):
|
||||||
|
if cache_events is None:
|
||||||
|
cache_event = None
|
||||||
|
else:
|
||||||
|
cache_event = cache_events[i]
|
||||||
|
layer = self.h[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
input_metadata,
|
||||||
|
cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FalconForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: FalconConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.transformer = FalconModel(config)
|
||||||
|
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||||
|
config.vocab_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> Dict[int, SequenceOutputs]:
|
||||||
|
hidden_states = self.transformer(
|
||||||
|
input_ids,
|
||||||
|
positions,
|
||||||
|
kv_caches,
|
||||||
|
input_metadata,
|
||||||
|
cache_events,
|
||||||
|
)
|
||||||
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
|
input_metadata)
|
||||||
|
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
_column_parallel_weights = [
|
||||||
|
"word_embeddings.weight", "lm_head.weight", "dense_h_to_4h.weight",
|
||||||
|
"dense_h_to_4h.bias"
|
||||||
|
]
|
||||||
|
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
|
||||||
|
|
||||||
|
def load_weights(self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
use_np_cache: bool = False):
|
||||||
|
tp_size = (get_tensor_model_parallel_world_size())
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
|
hidden_size = self.config.hidden_size
|
||||||
|
total_num_heads = self.config.num_attention_heads
|
||||||
|
num_heads = total_num_heads // tp_size
|
||||||
|
head_size = hidden_size // total_num_heads
|
||||||
|
head_start = tp_rank * num_heads
|
||||||
|
head_end = (tp_rank + 1) * num_heads
|
||||||
|
if self.config.new_decoder_architecture:
|
||||||
|
total_num_kv_heads = self.config.num_kv_heads
|
||||||
|
num_kv_heads = total_num_kv_heads // tp_size
|
||||||
|
separated_q_kv = False
|
||||||
|
kv_head_start = tp_rank * num_kv_heads
|
||||||
|
kv_head_end = (tp_rank + 1) * num_kv_heads
|
||||||
|
elif self.config.multi_query:
|
||||||
|
total_num_kv_heads = 1
|
||||||
|
num_kv_heads = 1
|
||||||
|
separated_q_kv = True
|
||||||
|
kv_head_start = 0
|
||||||
|
kv_head_end = 1
|
||||||
|
else:
|
||||||
|
total_num_kv_heads = total_num_heads
|
||||||
|
num_kv_heads = total_num_kv_heads // tp_size
|
||||||
|
separated_q_kv = False
|
||||||
|
kv_head_start = tp_rank * num_kv_heads
|
||||||
|
kv_head_end = (tp_rank + 1) * num_kv_heads
|
||||||
|
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, use_np_cache):
|
||||||
|
if "query_key_value" in name:
|
||||||
|
loaded_weight_size = loaded_weight.size()
|
||||||
|
loaded_weight = loaded_weight.view(
|
||||||
|
total_num_kv_heads, num_query_heads_per_kv_head + 2,
|
||||||
|
head_size, *loaded_weight_size[1:])
|
||||||
|
|
||||||
|
wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:])
|
||||||
|
wk = loaded_weight[:, [-2]].reshape(-1,
|
||||||
|
*loaded_weight_size[1:])
|
||||||
|
wv = loaded_weight[:, [-1]].reshape(-1,
|
||||||
|
*loaded_weight_size[1:])
|
||||||
|
|
||||||
|
wq = wq[head_size * head_start:head_size * head_end]
|
||||||
|
wk = wk[head_size * kv_head_start:head_size * kv_head_end]
|
||||||
|
wv = wv[head_size * kv_head_start:head_size * kv_head_end]
|
||||||
|
|
||||||
|
if separated_q_kv:
|
||||||
|
loaded_weight_q = wq
|
||||||
|
loaded_weight_kv = torch.cat([wk, wv], dim=0)
|
||||||
|
q_weight_name = name.replace("query_key_value", "query")
|
||||||
|
kv_weight_name = name.replace("query_key_value",
|
||||||
|
"key_value")
|
||||||
|
load_tensor_parallel_weights(state_dict[q_weight_name],
|
||||||
|
loaded_weight_q,
|
||||||
|
q_weight_name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights,
|
||||||
|
tp_rank)
|
||||||
|
load_tensor_parallel_weights(state_dict[kv_weight_name],
|
||||||
|
loaded_weight_kv,
|
||||||
|
kv_weight_name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights,
|
||||||
|
tp_rank)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||||
|
|
||||||
|
param = state_dict[name]
|
||||||
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights, tp_rank)
|
||||||
@ -44,7 +44,6 @@ _PIPELINE_GLOBAL_RANKS = None
|
|||||||
# rank when broadcasting weights from src to all other data parallel ranks
|
# rank when broadcasting weights from src to all other data parallel ranks
|
||||||
_DATA_PARALLEL_GLOBAL_RANKS = None
|
_DATA_PARALLEL_GLOBAL_RANKS = None
|
||||||
|
|
||||||
_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None
|
|
||||||
|
|
||||||
def initialize_model_parallel(
|
def initialize_model_parallel(
|
||||||
tensor_model_parallel_size: int = 1,
|
tensor_model_parallel_size: int = 1,
|
||||||
@ -196,20 +195,6 @@ def initialize_model_parallel(
|
|||||||
if rank in ranks:
|
if rank in ranks:
|
||||||
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
|
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
|
||||||
|
|
||||||
def initialize_all_reduce_launcher(
|
|
||||||
max_num_tokens: int,
|
|
||||||
hidden_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
disable_graph: bool = False,
|
|
||||||
) -> None:
|
|
||||||
global _ALL_REDUCE_LAUNCHER
|
|
||||||
_ALL_REDUCE_LAUNCHER = GraphAllReduce(
|
|
||||||
max_num_tokens=max_num_tokens,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
dtype=dtype,
|
|
||||||
disable_graph=disable_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
def model_parallel_is_initialized():
|
def model_parallel_is_initialized():
|
||||||
"""Check if model and data parallel groups are initialized."""
|
"""Check if model and data parallel groups are initialized."""
|
||||||
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
|
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
|
||||||
@ -458,6 +443,7 @@ def get_pipeline_model_parallel_last_rank():
|
|||||||
last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
||||||
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
|
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_next_rank():
|
def get_pipeline_model_parallel_next_rank():
|
||||||
"""Return the global rank that follows the caller in the pipeline"""
|
"""Return the global rank that follows the caller in the pipeline"""
|
||||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
||||||
@ -485,10 +471,6 @@ def get_data_parallel_rank():
|
|||||||
"""Return my rank for the data parallel group."""
|
"""Return my rank for the data parallel group."""
|
||||||
return torch.distributed.get_rank(group=get_data_parallel_group())
|
return torch.distributed.get_rank(group=get_data_parallel_group())
|
||||||
|
|
||||||
def get_all_reduce_launcher() -> 'GraphAllReduce':
|
|
||||||
assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized'
|
|
||||||
return _ALL_REDUCE_LAUNCHER
|
|
||||||
|
|
||||||
def destroy_model_parallel():
|
def destroy_model_parallel():
|
||||||
"""Set the groups to none."""
|
"""Set the groups to none."""
|
||||||
global _MODEL_PARALLEL_GROUP
|
global _MODEL_PARALLEL_GROUP
|
||||||
@ -515,56 +497,3 @@ def destroy_model_parallel():
|
|||||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
||||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||||
|
|
||||||
|
|
||||||
class GraphAllReduce:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_num_tokens: int,
|
|
||||||
hidden_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
disable_graph: bool = False,
|
|
||||||
) -> None:
|
|
||||||
self.max_num_tokens = max_num_tokens
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.disable_graph = disable_graph
|
|
||||||
|
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
|
||||||
if tp_world_size == 1:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.group = get_tensor_model_parallel_group()
|
|
||||||
self.buffer = torch.empty(
|
|
||||||
size=(max_num_tokens, hidden_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda',
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build graphs for different number of tokens.
|
|
||||||
if not self.disable_graph:
|
|
||||||
self.graphs = {}
|
|
||||||
for num_tokens in range(8, max_num_tokens + 1, 8):
|
|
||||||
self.graphs[num_tokens] = self._build_graph(num_tokens)
|
|
||||||
|
|
||||||
def _build_graph(self, num_tokens: int) -> torch.cuda.CUDAGraph:
|
|
||||||
# Warm up.
|
|
||||||
torch.distributed.all_reduce(self.buffer[:num_tokens], group=self.group)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Build graph.
|
|
||||||
graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.cuda.graph(graph):
|
|
||||||
torch.distributed.all_reduce(
|
|
||||||
self.buffer[:num_tokens], group=self.group)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
return graph
|
|
||||||
|
|
||||||
def launch(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
# NOTE: x must be a slice of self.buffer.
|
|
||||||
num_tokens = x.shape[0]
|
|
||||||
if self.disable_graph:
|
|
||||||
torch.distributed.all_reduce(x, group=self.group)
|
|
||||||
else:
|
|
||||||
self.graphs[num_tokens].replay()
|
|
||||||
return x
|
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from .mappings import (
|
|||||||
copy_to_tensor_model_parallel_region,
|
copy_to_tensor_model_parallel_region,
|
||||||
gather_from_tensor_model_parallel_region,
|
gather_from_tensor_model_parallel_region,
|
||||||
gather_from_sequence_parallel_region,
|
gather_from_sequence_parallel_region,
|
||||||
|
reduce_from_tensor_model_parallel_region,
|
||||||
scatter_to_tensor_model_parallel_region,
|
scatter_to_tensor_model_parallel_region,
|
||||||
scatter_to_sequence_parallel_region,
|
scatter_to_sequence_parallel_region,
|
||||||
)
|
)
|
||||||
@ -38,7 +39,7 @@ __all__ = [
|
|||||||
"copy_to_tensor_model_parallel_region",
|
"copy_to_tensor_model_parallel_region",
|
||||||
"gather_from_tensor_model_parallel_region",
|
"gather_from_tensor_model_parallel_region",
|
||||||
"gather_from_sequence_parallel_region",
|
"gather_from_sequence_parallel_region",
|
||||||
# "reduce_from_tensor_model_parallel_region",
|
"reduce_from_tensor_model_parallel_region",
|
||||||
"scatter_to_tensor_model_parallel_region",
|
"scatter_to_tensor_model_parallel_region",
|
||||||
"scatter_to_sequence_parallel_region",
|
"scatter_to_sequence_parallel_region",
|
||||||
# random.py
|
# random.py
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from torch.nn.parameter import Parameter
|
|||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
get_all_reduce_launcher,
|
|
||||||
)
|
)
|
||||||
from .mappings import (
|
from .mappings import (
|
||||||
copy_to_tensor_model_parallel_region,
|
copy_to_tensor_model_parallel_region,
|
||||||
@ -248,8 +247,8 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
self.gather_output = gather_output
|
self.gather_output = gather_output
|
||||||
# Divide the weight matrix along the last dimension.
|
# Divide the weight matrix along the last dimension.
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
self.world_size = get_tensor_model_parallel_world_size()
|
||||||
self.output_size_per_partition = divide(output_size, world_size)
|
self.output_size_per_partition = divide(output_size, self.world_size)
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
@ -350,6 +349,7 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
params_dtype:
|
params_dtype:
|
||||||
use_cpu_initialization:
|
use_cpu_initialization:
|
||||||
perform_initialization:
|
perform_initialization:
|
||||||
|
reduce_results:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_size, output_size, *,
|
def __init__(self, input_size, output_size, *,
|
||||||
@ -360,6 +360,7 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
params_dtype=None,
|
params_dtype=None,
|
||||||
use_cpu_initialization=False,
|
use_cpu_initialization=False,
|
||||||
perform_initialization=True,
|
perform_initialization=True,
|
||||||
|
reduce_results=True,
|
||||||
):
|
):
|
||||||
super(RowParallelLinear, self).__init__()
|
super(RowParallelLinear, self).__init__()
|
||||||
|
|
||||||
@ -367,14 +368,19 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
self.input_is_parallel = input_is_parallel
|
self.input_is_parallel = input_is_parallel
|
||||||
|
self.reduce_results = reduce_results
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
# Divide the weight matrix along the last dimension.
|
# Divide the weight matrix along the last dimension.
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
self.world_size = get_tensor_model_parallel_world_size()
|
||||||
self.input_size_per_partition = divide(input_size, world_size)
|
self.input_size_per_partition = divide(input_size, self.world_size)
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
|
|
||||||
|
if not reduce_results and (bias and not skip_bias_add):
|
||||||
|
raise ValueError("When not reduce the results, adding bias to the "
|
||||||
|
"results can lead to incorrect results")
|
||||||
|
|
||||||
# Parameters.
|
# Parameters.
|
||||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||||
# we allocate the transpose.
|
# we allocate the transpose.
|
||||||
@ -427,17 +433,12 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
else:
|
else:
|
||||||
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
||||||
if get_tensor_model_parallel_world_size() == 1:
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
output_ = F.linear(input_parallel, self.weight)
|
output_parallel = F.linear(input_parallel, self.weight)
|
||||||
|
if self.reduce_results and self.world_size > 1:
|
||||||
|
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||||
else:
|
else:
|
||||||
# Matrix multiply.
|
output_ = output_parallel
|
||||||
all_reduce_launcher = get_all_reduce_launcher()
|
|
||||||
num_tokens = input_parallel.shape[0]
|
|
||||||
output_buffer = all_reduce_launcher.buffer[:num_tokens]
|
|
||||||
torch.matmul(input_parallel, self.weight_t, out=output_buffer)
|
|
||||||
# All-reduce across all the partitions.
|
|
||||||
output_ = all_reduce_launcher.launch(output_buffer)
|
|
||||||
|
|
||||||
if not self.skip_bias_add:
|
if not self.skip_bias_add:
|
||||||
output = output_ + self.bias if self.bias is not None else output_
|
output = output_ + self.bias if self.bias is not None else output_
|
||||||
|
|||||||
@ -5,6 +5,8 @@ from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
|
|||||||
_CONFIG_REGISTRY = {
|
_CONFIG_REGISTRY = {
|
||||||
"mpt": MPTConfig,
|
"mpt": MPTConfig,
|
||||||
"baichuan": BaiChuanConfig,
|
"baichuan": BaiChuanConfig,
|
||||||
|
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
|
||||||
|
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,12 @@
|
|||||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||||
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
||||||
|
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
|
||||||
|
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||||
|
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||||
|
from vllm.transformers_utils.configs.falcon import RWConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MPTConfig",
|
"MPTConfig",
|
||||||
"BaiChuanConfig",
|
"BaiChuanConfig",
|
||||||
|
"RWConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
87
vllm/transformers_utils/configs/falcon.py
Normal file
87
vllm/transformers_utils/configs/falcon.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
# Adapted from
|
||||||
|
# https://huggingface.co/tiiuae/falcon-7b/blob/main/configuration_RW.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Falcon configuration"""
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class RWConfig(PretrainedConfig):
|
||||||
|
model_type = "falcon"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
attribute_map = {
|
||||||
|
"num_hidden_layers": "n_layer",
|
||||||
|
"num_attention_heads": "n_head",
|
||||||
|
"num_kv_heads": "n_head_kv",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=250880,
|
||||||
|
hidden_size=64,
|
||||||
|
n_layer=2,
|
||||||
|
n_head=8,
|
||||||
|
layer_norm_epsilon=1e-5,
|
||||||
|
initializer_range=0.02,
|
||||||
|
use_cache=True,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
hidden_dropout=0.0,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
multi_query=True,
|
||||||
|
n_head_kv=None,
|
||||||
|
alibi=False,
|
||||||
|
bias=False,
|
||||||
|
parallel_attn=False,
|
||||||
|
new_decoder_architecture=False,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
# Backward compatibility with n_embed kwarg
|
||||||
|
n_embed = kwargs.pop("n_embed", None)
|
||||||
|
self.hidden_size = hidden_size if n_embed is None else n_embed
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.n_head = n_head
|
||||||
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.hidden_dropout = hidden_dropout
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.multi_query = multi_query
|
||||||
|
self.n_head_kv = 1 if n_head_kv is None else n_head_kv
|
||||||
|
self.alibi = alibi
|
||||||
|
self.bias = bias
|
||||||
|
self.parallel_attn = parallel_attn
|
||||||
|
self.new_decoder_architecture = new_decoder_architecture
|
||||||
|
|
||||||
|
if self.hidden_size == 8192:
|
||||||
|
# Hack for falcon-40b
|
||||||
|
self.new_decoder_architecture = True
|
||||||
|
|
||||||
|
super().__init__(bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.hidden_size // self.n_head
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rotary(self):
|
||||||
|
return not self.alibi
|
||||||
@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
|||||||
SchedulerConfig)
|
SchedulerConfig)
|
||||||
from vllm.model_executor import get_model, InputMetadata, set_random_seed
|
from vllm.model_executor import get_model, InputMetadata, set_random_seed
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
initialize_model_parallel, initialize_all_reduce_launcher)
|
initialize_model_parallel)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import SequenceData, SequenceGroupMetadata, SequenceOutputs
|
from vllm.sequence import SequenceData, SequenceGroupMetadata, SequenceOutputs
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
@ -65,11 +65,6 @@ class Worker:
|
|||||||
# Initialize the model.
|
# Initialize the model.
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
self.model = get_model(self.model_config)
|
self.model = get_model(self.model_config)
|
||||||
initialize_all_reduce_launcher(
|
|
||||||
self.scheduler_config.max_num_batched_tokens,
|
|
||||||
self.model_config.get_hidden_size(),
|
|
||||||
self.model_config.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def profile_num_available_blocks(
|
def profile_num_available_blocks(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user