From e09ce759aa8d4fa41304df59b7e888fe12724d58 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 17 Jul 2024 04:02:53 -0700 Subject: [PATCH] [TPU] Remove multi-modal args in TPU backend (#6504) --- vllm/worker/tpu_model_runner.py | 46 +++++---------------------------- 1 file changed, 6 insertions(+), 40 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 6c1149ee..bbf0db31 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,5 +1,5 @@ import time -from typing import List, Mapping, Optional, Tuple +from typing import List, Optional, Tuple import numpy as np import torch @@ -12,8 +12,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, - MultiModalInputs) from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, SamplerOutput, SequenceGroupMetadata, SequenceOutput) @@ -68,10 +66,6 @@ class TPUModelRunner: False, ) - # Multi-modal data support - self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ - .create_input_mapper(self.model_config) - def load_model(self) -> None: self.device = self.device_config.device @@ -154,7 +148,7 @@ class TPUModelRunner: # Dummy run. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 self.model(token_ids, position_ids, kv_caches, attn_metadata, - input_lens, None, t, p, num_samples) + input_lens, t, p, num_samples) def warmup_model( self, @@ -199,14 +193,12 @@ class TPUModelRunner: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor, - Mapping[str, BatchedTensors]]: + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] prompt_lens: List[int] = [] slot_mapping: List[List[int]] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -232,11 +224,6 @@ class TPUModelRunner: slot = block_number * self.block_size + block_offset slot_mapping[-1].append(slot) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - mm_kwargs = self.multi_modal_input_mapper(mm_data) - multi_modal_inputs_list.append(mm_kwargs) - assert len(prompt_lens) > 0 num_prefills = len(prompt_lens) num_prefill_tokens = sum(prompt_lens) @@ -274,24 +261,17 @@ class TPUModelRunner: block_tables=None, context_lens=None, ) - - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, - device=self.device) - - return (input_tokens, input_positions, attn_metadata, prompt_lens, - multi_modal_kwargs) + return input_tokens, input_positions, attn_metadata, prompt_lens def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor, - Mapping[str, BatchedTensors]]: + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] context_lens: List[int] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] batch_idx = 0 for seq_group_metadata in seq_group_metadata_list: @@ -317,11 +297,6 @@ class TPUModelRunner: slot = block_number * self.block_size + block_offset slot_mapping.append([slot]) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - mm_kwargs = self.multi_modal_input_mapper(mm_data) - multi_modal_inputs_list.append(mm_kwargs) - batch_size = _get_padded_batch_size(batch_idx) num_paddings = batch_size - batch_idx input_tokens = input_tokens + [[0]] * num_paddings @@ -355,12 +330,7 @@ class TPUModelRunner: block_tables=block_tables, context_lens=context_lens, ) - - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, - device=self.device) - - return (input_tokens, input_positions, attn_metadata, input_lens, - multi_modal_kwargs) + return input_tokens, input_positions, attn_metadata, input_lens def _prepare_sample( self, @@ -513,7 +483,6 @@ class ModelWrapper(nn.Module): kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], attn_metadata: AttentionMetadata, input_lens: torch.Tensor, - multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]], t: torch.Tensor, p: torch.Tensor, num_samples: int, @@ -527,8 +496,6 @@ class ModelWrapper(nn.Module): memory profiling at initialization. attn_metadata: The Pallas attention metadata. input_lens: The actual input lengths of shape [batch_size]. - multi_modal_kwargs: Keyword arguments from multi-modal data to - pass to the model. t: The sampling temperature of shape [batch_size]. p: The top-p probability of shape [batch_size]. """ @@ -573,7 +540,6 @@ class ModelWrapper(nn.Module): position_ids, kv_caches, attn_metadata, - **(multi_modal_kwargs or {}), ) hidden_states = hidden_states.flatten(0, 1) logits = self.model.compute_logits(hidden_states, sampling_metadata)