[Misc] Add vision language model support to CPU backend (#3968)

This commit is contained in:
Isotr0py 2024-04-22 15:44:16 +08:00 committed by GitHub
parent 747b1a7147
commit 296cdf8ac7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 53 additions and 32 deletions

View File

@ -45,6 +45,7 @@ class CPUExecutor(ExecutorBase):
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True, is_driver_worker=True,
) )

View File

@ -5,7 +5,7 @@ from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
@ -29,6 +29,7 @@ class CPUModelRunner:
device_config: DeviceConfig, device_config: DeviceConfig,
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
*args, *args,
@ -38,6 +39,7 @@ class CPUModelRunner:
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.lora_config = lora_config self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.load_config = load_config self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
@ -59,13 +61,14 @@ class CPUModelRunner:
self.block_size: int # Set after initial profiling. self.block_size: int # Set after initial profiling.
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(model_config=self.model_config, self.model = get_model(
load_config=self.load_config, model_config=self.model_config,
device_config=self.device_config, load_config=self.load_config,
vision_language_config=None, device_config=self.device_config,
lora_config=self.lora_config, vision_language_config=self.vision_language_config,
parallel_config=self.parallel_config, lora_config=self.lora_config,
scheduler_config=self.scheduler_config) parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
def _prepare_prompt( def _prepare_prompt(
self, self,
@ -76,6 +79,7 @@ class CPUModelRunner:
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
prompt_lens: List[int] = [] prompt_lens: List[int] = []
multi_modal_input_list: List[torch.Tensor] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
@ -96,6 +100,10 @@ class CPUModelRunner:
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prompt_len))) input_positions.extend(list(range(computed_len, prompt_len)))
if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append(
seq_group_metadata.multi_modal_data.data)
# Compute the slot mapping. # Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
@ -118,6 +126,15 @@ class CPUModelRunner:
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping.append(slot)
if multi_modal_input_list:
assert self.vision_language_config, (
"Multi-modal inputs are only supported by "
"vision language models.")
multi_modal_input = torch.cat(multi_modal_input_list,
dim=0).to(self.device)
else:
multi_modal_input = None
num_prompt_tokens = len(input_tokens) num_prompt_tokens = len(input_tokens)
input_tokens = torch.tensor(input_tokens, input_tokens = torch.tensor(input_tokens,
@ -144,12 +161,8 @@ class CPUModelRunner:
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
) )
return ( return (input_tokens, input_positions, attn_metadata, prompt_lens,
input_tokens, multi_modal_input)
input_positions,
attn_metadata,
prompt_lens,
)
def _prepare_decode( def _prepare_decode(
self, self,
@ -336,14 +349,16 @@ class CPUModelRunner:
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
SamplingMetadata]: SamplingMetadata]:
multi_modal_input = None
if self.is_driver_worker: if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors. # Prepare input tensors.
if is_prompt: if is_prompt:
(input_tokens, input_positions, attn_metadata, (input_tokens, input_positions, attn_metadata, prompt_lens,
prompt_lens) = self._prepare_prompt(seq_group_metadata_list) multi_modal_input
) = self._prepare_prompt(seq_group_metadata_list)
else: else:
(input_tokens, input_positions, (input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list) attn_metadata) = self._prepare_decode(seq_group_metadata_list)
@ -376,12 +391,8 @@ class CPUModelRunner:
perform_sampling=False, perform_sampling=False,
) )
return ( return (input_tokens, input_positions, attn_metadata,
input_tokens, sampling_metadata, multi_modal_input)
input_positions,
attn_metadata,
sampling_metadata,
)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
@ -389,7 +400,8 @@ class CPUModelRunner:
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata (input_tokens, input_positions, attn_metadata, sampling_metadata,
multi_modal_input
) = self.prepare_input_tensors(seq_group_metadata_list) ) = self.prepare_input_tensors(seq_group_metadata_list)
model_executable = self.model model_executable = self.model
@ -399,6 +411,8 @@ class CPUModelRunner:
"kv_caches": kv_caches, "kv_caches": kv_caches,
"attn_metadata": attn_metadata, "attn_metadata": attn_metadata,
} }
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs) hidden_states = model_executable(**execute_model_kwargs)

View File

@ -6,7 +6,8 @@ import torch.distributed
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig) ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict, from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized, ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
@ -122,6 +123,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
@ -135,21 +137,25 @@ class CPUWorker(LoraNotSupportedWorkerBase):
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
if self.model_config.trust_remote_code: if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing # note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules from vllm.utils import init_cached_hf_modules
init_cached_hf_modules() init_cached_hf_modules()
self.model_runner = CPUModelRunner(model_config, self.model_runner = CPUModelRunner(
parallel_config, model_config,
scheduler_config, parallel_config,
device_config, scheduler_config,
load_config=self.load_config, device_config,
lora_config=self.lora_config, load_config=self.load_config,
kv_cache_dtype=kv_cache_dtype, lora_config=self.lora_config,
is_driver_worker=is_driver_worker) vision_language_config=self.vision_language_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
self.cache_engine: CPUCacheEngine self.cache_engine: CPUCacheEngine