diff --git a/README.md b/README.md index ee21e711..b9de3886 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit - BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) - GPT-2 (`gpt2`, `gpt2-xl`, etc.) - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) +- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.) - GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) - LLaMA (`lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 5389584e..69ae8890 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -382,7 +382,7 @@ void single_query_cached_kv_attention_launcher( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (head_size) { // NOTE(woosuk): To reduce the compilation time, we omitted head sizes - // 32, 160, 192, 256. + // 32, 160, 192. // case 32: // LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); // break; @@ -407,9 +407,9 @@ void single_query_cached_kv_attention_launcher( // case 192: // LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); // break; - // case 256: - // LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); - // break; + case 256: + LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); + break; default: TORCH_CHECK(false, "Unsupported head size: ", head_size); break; diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index aebe8f76..c6ae8a2b 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -23,6 +23,9 @@ Alongside each architecture, we include some popular models that use it. * - :code:`GPTBigCodeForCausalLM` - StarCoder, SantaCoder, WizardCoder - :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc. + * - :code:`GPTJForCausalLM` + - GPT-J + - :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc. * - :code:`GPTNeoXForCausalLM` - GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM - :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc. diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 6b88cd10..bf3147bb 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -286,7 +286,7 @@ def test_single_query_cached_kv_attention() -> None: torch.cuda.manual_seed(TEST_SEED) for dtype in [torch.half, torch.bfloat16, torch.float]: for block_size in [8, 16, 32]: - for head_size in [64, 80, 96, 128]: + for head_size in [64, 80, 96, 112, 128, 256]: print(f'Testing single_query_cached_kv_attention with ' f'dtype={dtype}, block_size={block_size}, ' f'head_size={head_size}') @@ -304,7 +304,7 @@ def test_multi_query_kv_attention() -> None: torch.random.manual_seed(TEST_SEED) torch.cuda.manual_seed(TEST_SEED) for dtype in [torch.half, torch.bfloat16, torch.float]: - for head_size in [64, 80, 96, 128]: + for head_size in [64, 80, 96, 112, 128, 256]: print(f'Testing multi_query_kv_attention with dtype={dtype}, ' f'head_size={head_size}') run_multi_query_kv_attention( diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 40f979b1..f4550e82 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -12,7 +12,7 @@ from vllm import cache_ops from vllm import pos_encoding_ops from vllm.model_executor.input_metadata import InputMetadata -_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128] +_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] class PagedAttention(nn.Module): diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index cb8ccaed..b586c98b 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -38,12 +38,15 @@ class Sampler(nn.Module): embedding: torch.Tensor, hidden_states: torch.Tensor, input_metadata: InputMetadata, + embedding_bias: Optional[torch.Tensor] = None, ) -> Dict[int, SequenceOutputs]: # Get the hidden states that we use for sampling. hidden_states = _prune_hidden_states(hidden_states, input_metadata) # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias logits = gather_from_tensor_model_parallel_region(logits) # Remove paddings in vocab (if any). logits = logits[:, :self.vocab_size] diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 1d48baab..40e5f583 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -14,6 +14,7 @@ _MODEL_REGISTRY = { "BloomForCausalLM": BloomForCausalLM, "GPT2LMHeadModel": GPT2LMHeadModel, "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, + "GPTJForCausalLM": GPTJForCausalLM, "GPTNeoXForCausalLM": GPTNeoXForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 64a4e628..8717c568 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,6 +1,7 @@ from vllm.model_executor.models.bloom import BloomForCausalLM from vllm.model_executor.models.gpt2 import GPT2LMHeadModel from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM +from vllm.model_executor.models.gpt_j import GPTJForCausalLM from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.mpt import MPTForCausalLM @@ -10,6 +11,7 @@ __all__ = [ "BloomForCausalLM", "GPT2LMHeadModel", "GPTBigCodeForCausalLM", + "GPTJForCausalLM", "GPTNeoXForCausalLM", "LlamaForCausalLM", "MPTForCausalLM", diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py new file mode 100644 index 00000000..2f858d2d --- /dev/null +++ b/vllm/model_executor/models/gpt_j.py @@ -0,0 +1,251 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py +# Copyright 2023 The vLLM team. +# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-J model compatible with HuggingFace weights. + +The input of the model is flattened to a 1D tensor of tokens. The model uses +InputMetadata to extract the original 2D shape of the input. +""" +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +from transformers import GPTJConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.weight_utils import (hf_model_weights_iterator, + load_tensor_parallel_weights) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.tensor_parallel import ( + VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.sequence import SequenceOutputs + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class GPTJAttention(nn.Module): + + def __init__(self, config: GPTJConfig): + super().__init__() + self.total_num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.total_num_heads + + self.qkv_proj = ColumnParallelLinear(config.hidden_size, + 3 * config.hidden_size, + bias=False, + gather_output=False, + perform_initialization=False) + self.out_proj = RowParallelLinear(config.hidden_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + scaling = self.head_size**-0.5 + assert config.rotary + assert config.rotary_dim % 2 == 0 + self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size, + scaling, config.rotary_dim) + self.warmup = False + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + k_cache, v_cache = kv_cache + attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache, + input_metadata, cache_event) + attn_output, _ = self.out_proj(attn_output) + return attn_output + + +class GPTJMLP(nn.Module): + + def __init__(self, intermediate_size: int, config: GPTJConfig): + super().__init__() + hidden_size = config.n_embd + self.fc_in = ColumnParallelLinear(hidden_size, + intermediate_size, + gather_output=False, + perform_initialization=False) + self.fc_out = RowParallelLinear(intermediate_size, + hidden_size, + input_is_parallel=True, + perform_initialization=False) + self.act = get_act_fn(config.activation_function) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.fc_out(hidden_states) + return hidden_states + + +class GPTJBlock(nn.Module): + + def __init__(self, config: GPTJConfig): + super().__init__() + if config.n_inner is None: + inner_dim = 4 * config.n_embd + else: + inner_dim = config.n_inner + self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.attn = GPTJAttention(config) + self.mlp = GPTJMLP(inner_dim, config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + position_ids=position_ids, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + mlp_output = self.mlp(hidden_states) + hidden_states = attn_output + mlp_output + residual + return hidden_states + + +class GPTJModel(nn.Module): + + def __init__(self, config: GPTJConfig): + super().__init__() + self.config = config + self.embed_dim = config.n_embd + self.wte = VocabParallelEmbedding(config.vocab_size, + self.embed_dim, + perform_initialization=False) + self.h = nn.ModuleList( + [GPTJBlock(config) for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + hidden_states = self.wte(input_ids) + for i in range(len(self.h)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.h[i] + hidden_states = layer( + position_ids, + hidden_states, + kv_caches[i], + input_metadata, + cache_event, + ) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPTJForCausalLM(nn.Module): + + def __init__(self, config: GPTJConfig): + super().__init__() + self.config = config + assert not config.tie_word_embeddings + self.transformer = GPTJModel(config) + self.lm_head = ColumnParallelLinear(config.n_embd, + config.vocab_size, + gather_output=False, + perform_initialization=False) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata, cache_events) + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + input_metadata, self.lm_head.bias) + return next_tokens + + _column_parallel_weights = [ + "wte.weight", "fc_in.weight", "fc_in.bias", "lm_head.weight", + "lm_head.bias" + ] + _row_parallel_weights = ["out_proj.weight", "fc_out.weight"] + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False): + tp_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, use_np_cache): + if "attn.bias" in name or "attn.masked_bias" in name: + continue + + is_attention_weight = False + for stride_id, att_weight_name in enumerate( + ["q_proj", "k_proj", "v_proj"]): + if att_weight_name not in name: + continue + param = state_dict[name.replace(att_weight_name, "qkv_proj")] + shard_size = param.shape[1] + loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * + (tp_rank + 1)] + param_slice = param.data[shard_size * stride_id:shard_size * + (stride_id + 1)] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_attention_weight = True + break + if is_attention_weight: + continue + + param = state_dict[name] + load_tensor_parallel_weights(param, loaded_weight, name, + self._column_parallel_weights, + self._row_parallel_weights, tp_rank) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 09a0d7ce..5afcb7a0 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math from typing import Dict, List, Optional, Tuple