diff --git a/vllm/config.py b/vllm/config.py index 95c0b95f..e03adb5f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -36,6 +36,7 @@ _PP_SUPPORTED_MODELS = [ "AquilaForCausalLM", "DeepseekV2ForCausalLM", "InternLMForCausalLM", + "JAISLMHeadModel", "LlamaForCausalLM", "LLaMAForCausalLM", "MistralForCausalLM", diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 8c606916..ec6bea92 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -20,14 +20,14 @@ """Inference-only Jais model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -43,6 +43,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs import JAISConfig +from .utils import is_pp_missing_parameter, make_layers + class SwiGLUActivation(nn.Module): @@ -216,6 +218,7 @@ class JAISModel(nn.Module): config: JAISConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -231,10 +234,15 @@ class JAISModel(nn.Module): self.embeddings_scale = config.embeddings_scale else: self.embeddings_scale = config.mup_embeddings_scale - self.h = nn.ModuleList([ - JAISBlock(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: JAISBlock(config=config, + cache_config=cache_config, + quant_config=quant_config), + prefix=f"{prefix}.h", + ) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -243,19 +251,29 @@ class JAISModel(nn.Module): position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - inputs_embeds = self.wte(input_ids) - if self.wpe is not None: - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[IntermediateTensors, torch.Tensor]: + if get_pp_group().is_first_rank: + inputs_embeds = self.wte(input_ids) + if self.wpe is not None: + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + else: + hidden_states = inputs_embeds + hidden_states *= torch.tensor(float(self.embeddings_scale), + dtype=hidden_states.dtype) else: - hidden_states = inputs_embeds - hidden_states *= torch.tensor(float(self.embeddings_scale), - dtype=hidden_states.dtype) + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] - for i in range(len(self.h)): + for i in range(self.start_layer, self.end_layer): layer = self.h[i] - hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) + hidden_states = layer(hidden_states, + kv_caches[i - self.start_layer], + attn_metadata) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -290,9 +308,9 @@ class JAISLMHeadModel(nn.Module): kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[IntermediateTensors, torch.Tensor]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -304,6 +322,16 @@ class JAISLMHeadModel(nn.Module): sampling_metadata) return logits + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def sample( self, logits: torch.Tensor, @@ -327,6 +355,10 @@ class JAISLMHeadModel(nn.Module): continue if not name.startswith("transformer."): name = "transformer." + name + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] # The HF's GPT-2 implementation uses Conv1D instead of Linear. # Because of this, we need to transpose the weights.