[Misc] Add vision language model support to CPU backend (#3968)
This commit is contained in:
parent
747b1a7147
commit
296cdf8ac7
@ -45,6 +45,7 @@ class CPUExecutor(ExecutorBase):
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
@ -5,7 +5,7 @@ from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
@ -29,6 +29,7 @@ class CPUModelRunner:
|
||||
device_config: DeviceConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
*args,
|
||||
@ -38,6 +39,7 @@ class CPUModelRunner:
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lora_config = lora_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.load_config = load_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
@ -59,13 +61,14 @@ class CPUModelRunner:
|
||||
self.block_size: int # Set after initial profiling.
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=self.device_config,
|
||||
vision_language_config=None,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config)
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=self.device_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config)
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
@ -76,6 +79,7 @@ class CPUModelRunner:
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
prompt_lens: List[int] = []
|
||||
multi_modal_input_list: List[torch.Tensor] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
@ -96,6 +100,10 @@ class CPUModelRunner:
|
||||
# is always the first token in the sequence.
|
||||
input_positions.extend(list(range(computed_len, prompt_len)))
|
||||
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
multi_modal_input_list.append(
|
||||
seq_group_metadata.multi_modal_data.data)
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
||||
@ -118,6 +126,15 @@ class CPUModelRunner:
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
if multi_modal_input_list:
|
||||
assert self.vision_language_config, (
|
||||
"Multi-modal inputs are only supported by "
|
||||
"vision language models.")
|
||||
multi_modal_input = torch.cat(multi_modal_input_list,
|
||||
dim=0).to(self.device)
|
||||
else:
|
||||
multi_modal_input = None
|
||||
|
||||
num_prompt_tokens = len(input_tokens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
@ -144,12 +161,8 @@ class CPUModelRunner:
|
||||
slot_mapping=slot_mapping,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
return (
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
prompt_lens,
|
||||
)
|
||||
return (input_tokens, input_positions, attn_metadata, prompt_lens,
|
||||
multi_modal_input)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
@ -336,14 +349,16 @@ class CPUModelRunner:
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
|
||||
SamplingMetadata]:
|
||||
multi_modal_input = None
|
||||
if self.is_driver_worker:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, attn_metadata,
|
||||
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
|
||||
(input_tokens, input_positions, attn_metadata, prompt_lens,
|
||||
multi_modal_input
|
||||
) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
|
||||
@ -376,12 +391,8 @@ class CPUModelRunner:
|
||||
perform_sampling=False,
|
||||
)
|
||||
|
||||
return (
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
sampling_metadata,
|
||||
)
|
||||
return (input_tokens, input_positions, attn_metadata,
|
||||
sampling_metadata, multi_modal_input)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
@ -389,7 +400,8 @@ class CPUModelRunner:
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> Optional[SamplerOutput]:
|
||||
(input_tokens, input_positions, attn_metadata, sampling_metadata
|
||||
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
||||
multi_modal_input
|
||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
model_executable = self.model
|
||||
@ -399,6 +411,8 @@ class CPUModelRunner:
|
||||
"kv_caches": kv_caches,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
if self.vision_language_config:
|
||||
execute_model_kwargs.update({"image_input": multi_modal_input})
|
||||
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
|
||||
@ -6,7 +6,8 @@ import torch.distributed
|
||||
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig)
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
@ -122,6 +123,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
) -> None:
|
||||
@ -135,21 +137,25 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.lora_config = lora_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if self.is_driver_worker:
|
||||
assert self.rank == 0, "The driver worker must have rank 0."
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
self.model_runner = CPUModelRunner(model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker)
|
||||
self.model_runner = CPUModelRunner(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine: CPUCacheEngine
|
||||
|
||||
Loading…
Reference in New Issue
Block a user