Inclusion of InternVLChatModel In PP_SUPPORTED_MODELS(Pipeline Parallelism) (#7860)

This commit is contained in:
manikandan.tm@zucisystems.com 2024-09-05 17:03:37 +05:30 committed by GitHub
parent 288a938872
commit 8685ba1a1e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 90 additions and 35 deletions

View File

@ -18,23 +18,26 @@ logger = init_logger("test_pipeline_parallel")
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, " @pytest.mark.parametrize(
("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, "
"MODEL_NAME, DIST_BACKEND"), "MODEL_NAME, DIST_BACKEND"),
[ [
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
]) (2, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"),
],
)
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
DIST_BACKEND): TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND):
if VLLM_MULTI_NODE and DIST_BACKEND == "mp": if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for " pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend") "multiprocessing distributed backend")
@ -71,6 +74,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
if EAGER_MODE: if EAGER_MODE:
pp_args.append("--enforce-eager") pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager") tp_args.append("--enforce-eager")
if TRUST_REMOTE_CODE:
pp_args.append("--trust-remote-code")
tp_args.append("--trust-remote-code")
pp_env = None pp_env = None
if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2 if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2
and CHUNKED_PREFILL): and CHUNKED_PREFILL):

View File

@ -178,6 +178,11 @@ def compare_two_settings(model: str,
env2: The second set of environment variables to pass to the API server. env2: The second set of environment variables to pass to the API server.
""" """
trust_remote_code = "--trust-remote-code"
if trust_remote_code in arg1 or trust_remote_code in arg2:
tokenizer = AutoTokenizer.from_pretrained(model,
trust_remote_code=True)
else:
tokenizer = AutoTokenizer.from_pretrained(model) tokenizer = AutoTokenizer.from_pretrained(model)
prompt = "Hello, my name is" prompt = "Hello, my name is"

View File

@ -35,18 +35,20 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096
_PP_SUPPORTED_MODELS = [ _PP_SUPPORTED_MODELS = [
"AquilaModel",
"AquilaForCausalLM", "AquilaForCausalLM",
"AquilaModel",
"DeepseekV2ForCausalLM", "DeepseekV2ForCausalLM",
"GPT2LMHeadModel",
"InternLM2ForCausalLM",
"InternLMForCausalLM", "InternLMForCausalLM",
"InternVLChatModel",
"JAISLMHeadModel", "JAISLMHeadModel",
"LlamaForCausalLM", "LlamaForCausalLM",
"LLaMAForCausalLM", "LLaMAForCausalLM",
"MistralForCausalLM", "MistralForCausalLM",
"Phi3ForCausalLM",
"GPT2LMHeadModel",
"MixtralForCausalLM", "MixtralForCausalLM",
"NemotronForCausalLM", "NemotronForCausalLM",
"Phi3ForCausalLM",
"Qwen2ForCausalLM", "Qwen2ForCausalLM",
"Qwen2MoeForCausalLM", "Qwen2MoeForCausalLM",
"QWenLMHeadModel", "QWenLMHeadModel",

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@ -8,7 +8,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather) tensor_model_parallel_all_gather)
@ -28,6 +28,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class InternLM2MLP(nn.Module): class InternLM2MLP(nn.Module):
@ -234,6 +237,7 @@ class InternLM2Model(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
@ -243,11 +247,15 @@ class InternLM2Model(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
InternLMDecoderLayer(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: InternLMDecoderLayer(config, cache_config,
]) quant_config),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.tok_embeddings(input_ids) return self.tok_embeddings(input_ids)
@ -260,21 +268,31 @@ class InternLM2Model(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: IntermediateTensors = None, intermediate_tensors: IntermediateTensors = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.tok_embeddings(input_ids) hidden_states = self.tok_embeddings(input_ids)
residual = None residual = None
for i in range(len(self.layers)): else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
@ -298,6 +316,8 @@ class InternLM2ForCausalLM(nn.Module):
self.output.weight = self.model.tok_embeddings.weight self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
@ -308,7 +328,7 @@ class InternLM2ForCausalLM(nn.Module):
intermediate_tensors: IntermediateTensors, intermediate_tensors: IntermediateTensors,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
@ -345,6 +365,8 @@ class InternLM2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
@ -353,6 +375,8 @@ class InternLM2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -341,6 +341,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
nn.Linear(llm_hidden_size, llm_hidden_size)) nn.Linear(llm_hidden_size, llm_hidden_size))
self.img_context_token_id = None self.img_context_token_id = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def pixel_shuffle(self, x, scale_factor=0.5): def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size() n, w, h, c = x.size()
@ -461,7 +463,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states

View File

@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors from vllm.multimodal.base import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
@ -279,3 +280,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
if name.startswith(missing_layer_name): if name.startswith(missing_layer_name):
return True return True
return False return False
def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
def make_empty_intermediate_tensors(
batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
key: torch.zeros((batch_size, hidden_size),
dtype=dtype,
device=device)
for key in keys
})
return make_empty_intermediate_tensors