[Misc] Add vision language model support to CPU backend (#3968)

This commit is contained in:
Isotr0py 2024-04-22 15:44:16 +08:00 committed by GitHub
parent 747b1a7147
commit 296cdf8ac7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 53 additions and 32 deletions

View File

@ -45,6 +45,7 @@ class CPUExecutor(ExecutorBase):
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)

View File

@ -5,7 +5,7 @@ from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
@ -29,6 +29,7 @@ class CPUModelRunner:
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
*args,
@ -38,6 +39,7 @@ class CPUModelRunner:
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
@ -59,10 +61,11 @@ class CPUModelRunner:
self.block_size: int # Set after initial profiling.
def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
vision_language_config=None,
vision_language_config=self.vision_language_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
@ -76,6 +79,7 @@ class CPUModelRunner:
input_positions: List[int] = []
slot_mapping: List[int] = []
prompt_lens: List[int] = []
multi_modal_input_list: List[torch.Tensor] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
@ -96,6 +100,10 @@ class CPUModelRunner:
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prompt_len)))
if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append(
seq_group_metadata.multi_modal_data.data)
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
@ -118,6 +126,15 @@ class CPUModelRunner:
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if multi_modal_input_list:
assert self.vision_language_config, (
"Multi-modal inputs are only supported by "
"vision language models.")
multi_modal_input = torch.cat(multi_modal_input_list,
dim=0).to(self.device)
else:
multi_modal_input = None
num_prompt_tokens = len(input_tokens)
input_tokens = torch.tensor(input_tokens,
@ -144,12 +161,8 @@ class CPUModelRunner:
slot_mapping=slot_mapping,
kv_cache_dtype=self.kv_cache_dtype,
)
return (
input_tokens,
input_positions,
attn_metadata,
prompt_lens,
)
return (input_tokens, input_positions, attn_metadata, prompt_lens,
multi_modal_input)
def _prepare_decode(
self,
@ -336,14 +349,16 @@ class CPUModelRunner:
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
SamplingMetadata]:
multi_modal_input = None
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata,
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, prompt_lens,
multi_modal_input
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
@ -376,12 +391,8 @@ class CPUModelRunner:
perform_sampling=False,
)
return (
input_tokens,
input_positions,
attn_metadata,
sampling_metadata,
)
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, multi_modal_input)
@torch.inference_mode()
def execute_model(
@ -389,7 +400,8 @@ class CPUModelRunner:
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata
(input_tokens, input_positions, attn_metadata, sampling_metadata,
multi_modal_input
) = self.prepare_input_tensors(seq_group_metadata_list)
model_executable = self.model
@ -399,6 +411,8 @@ class CPUModelRunner:
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
}
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs)

View File

@ -6,7 +6,8 @@ import torch.distributed
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig)
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
@ -122,6 +123,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
) -> None:
@ -135,19 +137,23 @@ class CPUWorker(LoraNotSupportedWorkerBase):
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = CPUModelRunner(model_config,
self.model_runner = CPUModelRunner(
model_config,
parallel_config,
scheduler_config,
device_config,
load_config=self.load_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by