[Distributed][PP] only create embedding & lm head when necessary (#6455)
original title: [Distributed][Model] Rank-based Component Creation for Pipeline Parallelism Memory Optimization
This commit is contained in:
parent
ce37be7ba0
commit
1d094fd7c0
@ -50,7 +50,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
|
|||||||
from vllm.utils import is_hip
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA
|
from .interfaces import SupportsLoRA
|
||||||
from .utils import is_pp_missing_parameter, make_layers
|
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
@ -257,17 +257,24 @@ class LlamaModel(nn.Module):
|
|||||||
(lora_config.max_loras or 1)) if lora_config else 0
|
(lora_config.max_loras or 1)) if lora_config else 0
|
||||||
self.vocab_size = config.vocab_size + lora_vocab
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
self.org_vocab_size = config.vocab_size
|
self.org_vocab_size = config.vocab_size
|
||||||
|
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||||
|
and get_pp_group().is_last_rank):
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
self.vocab_size,
|
self.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
org_num_embeddings=config.vocab_size,
|
org_num_embeddings=config.vocab_size,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda: LlamaDecoderLayer(config=config,
|
lambda: LlamaDecoderLayer(config=config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config))
|
quant_config=quant_config))
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
else:
|
||||||
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.embed_tokens(input_ids)
|
return self.embed_tokens(input_ids)
|
||||||
@ -360,6 +367,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
cache_config,
|
cache_config,
|
||||||
quant_config,
|
quant_config,
|
||||||
lora_config=lora_config)
|
lora_config=lora_config)
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
self.unpadded_vocab_size = config.vocab_size
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
if lora_config:
|
if lora_config:
|
||||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
@ -378,8 +386,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
|
|
||||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
config.vocab_size, logit_scale)
|
config.vocab_size,
|
||||||
|
logit_scale)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user