[Model] Pipeline parallel support for JAIS (#7603)

This commit is contained in:
Besher Alkurdi 2024-08-17 21:11:09 +03:00 committed by GitHub
parent d95cc0a55c
commit e73f76eec6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 18 deletions

View File

@ -36,6 +36,7 @@ _PP_SUPPORTED_MODELS = [
"AquilaForCausalLM", "AquilaForCausalLM",
"DeepseekV2ForCausalLM", "DeepseekV2ForCausalLM",
"InternLMForCausalLM", "InternLMForCausalLM",
"JAISLMHeadModel",
"LlamaForCausalLM", "LlamaForCausalLM",
"LLaMAForCausalLM", "LLaMAForCausalLM",
"MistralForCausalLM", "MistralForCausalLM",

View File

@ -20,14 +20,14 @@
"""Inference-only Jais model compatible with HuggingFace weights.""" """Inference-only Jais model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
@ -43,6 +43,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs import JAISConfig from vllm.transformers_utils.configs import JAISConfig
from .utils import is_pp_missing_parameter, make_layers
class SwiGLUActivation(nn.Module): class SwiGLUActivation(nn.Module):
@ -216,6 +218,7 @@ class JAISModel(nn.Module):
config: JAISConfig, config: JAISConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -231,10 +234,15 @@ class JAISModel(nn.Module):
self.embeddings_scale = config.embeddings_scale self.embeddings_scale = config.embeddings_scale
else: else:
self.embeddings_scale = config.mup_embeddings_scale self.embeddings_scale = config.mup_embeddings_scale
self.h = nn.ModuleList([
JAISBlock(config, cache_config, quant_config) self.start_layer, self.end_layer, self.h = make_layers(
for _ in range(config.num_hidden_layers) config.num_hidden_layers,
]) lambda prefix: JAISBlock(config=config,
cache_config=cache_config,
quant_config=quant_config),
prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward( def forward(
@ -243,7 +251,9 @@ class JAISModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[IntermediateTensors, torch.Tensor]:
if get_pp_group().is_first_rank:
inputs_embeds = self.wte(input_ids) inputs_embeds = self.wte(input_ids)
if self.wpe is not None: if self.wpe is not None:
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
@ -252,10 +262,18 @@ class JAISModel(nn.Module):
hidden_states = inputs_embeds hidden_states = inputs_embeds
hidden_states *= torch.tensor(float(self.embeddings_scale), hidden_states *= torch.tensor(float(self.embeddings_scale),
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(len(self.h)): for i in range(self.start_layer, self.end_layer):
layer = self.h[i] layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
@ -290,9 +308,9 @@ class JAISLMHeadModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[IntermediateTensors, torch.Tensor]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
@ -304,6 +322,16 @@ class JAISLMHeadModel(nn.Module):
sampling_metadata) sampling_metadata)
return logits return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
@ -327,6 +355,10 @@ class JAISLMHeadModel(nn.Module):
continue continue
if not name.startswith("transformer."): if not name.startswith("transformer."):
name = "transformer." + name name = "transformer." + name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights. # Because of this, we need to transpose the weights.