[Distributed][PP] only create embedding & lm head when necessary (#6455)

original title: [Distributed][Model] Rank-based Component Creation for Pipeline Parallelism Memory Optimization
This commit is contained in:
Wushi Dong 2024-07-16 19:20:26 -07:00 committed by GitHub
parent ce37be7ba0
commit 1d094fd7c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -50,7 +50,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import is_hip from vllm.utils import is_hip
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
@ -257,17 +257,24 @@ class LlamaModel(nn.Module):
(lora_config.max_loras or 1)) if lora_config else 0 (lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( if get_pp_group().is_first_rank or (config.tie_word_embeddings
self.vocab_size, and get_pp_group().is_last_rank):
config.hidden_size, self.embed_tokens = VocabParallelEmbedding(
org_num_embeddings=config.vocab_size, self.vocab_size,
) config.hidden_size,
org_num_embeddings=config.vocab_size,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda: LlamaDecoderLayer(config=config, lambda: LlamaDecoderLayer(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config)) quant_config=quant_config))
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
@ -360,26 +367,30 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
cache_config, cache_config,
quant_config, quant_config,
lora_config=lora_config) lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size if get_pp_group().is_last_rank:
if lora_config: self.unpadded_vocab_size = config.vocab_size
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size if lora_config:
self.lm_head = ParallelLMHead( self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.unpadded_vocab_size, self.lm_head = ParallelLMHead(
config.hidden_size, self.unpadded_vocab_size,
org_num_embeddings=config.vocab_size, config.hidden_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE org_num_embeddings=config.vocab_size,
# We need bigger padding if using lora for kernel padding_size=DEFAULT_VOCAB_PADDING_SIZE
# compatibility # We need bigger padding if using lora for kernel
if not lora_config else lora_config.lora_vocab_padding_size, # compatibility
quant_config=quant_config, if not lora_config else lora_config.lora_vocab_padding_size,
) quant_config=quant_config,
if config.tie_word_embeddings: )
self.lm_head.weight = self.model.embed_tokens.weight if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale) config.vocab_size,
self.sampler = Sampler() logit_scale)
self.sampler = Sampler()
else:
self.lm_head = PPMissingLayer()
def forward( def forward(
self, self,