[Bugfix] Fix InternVL2 vision embeddings process with pipeline parallel (#8299)

This commit is contained in:
Isotr0py 2024-09-11 10:11:01 +08:00 committed by GitHub
parent e497b8aeff
commit 1230263e16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 3 deletions

View File

@ -32,7 +32,9 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"),
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "ray"),
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "ray"),
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "ray"),
],
)
@fork_new_process_for_each_test
@ -46,6 +48,8 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
# use half precision for speed and memory savings in CI environment
"--dtype",
"float16",
"--max-model-len",
"8192",
"--pipeline-parallel-size",
str(PP_SIZE),
"--tensor-parallel-size",
@ -62,7 +66,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
tp_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"float16",
"--max-model-len",
"8192",
"--tensor-parallel-size",
str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI.
"--distributed-executor-backend",

View File

@ -17,6 +17,7 @@ from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
@ -480,7 +481,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
**kwargs: object,
) -> SamplerOutput:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
if image_input is not None and get_pp_group().is_first_rank:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vision_embeddings = self._process_image_input(image_input)