[TPU] Remove multi-modal args in TPU backend (#6504)

This commit is contained in:
Woosuk Kwon 2024-07-17 04:02:53 -07:00 committed by GitHub
parent 5fa6e9876e
commit e09ce759aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,5 @@
import time
from typing import List, Mapping, Optional, Tuple
from typing import List, Optional, Tuple
import numpy as np
import torch
@ -12,8 +12,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata,
SequenceOutput)
@ -68,10 +66,6 @@ class TPUModelRunner:
False,
)
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
def load_model(self) -> None:
self.device = self.device_config.device
@ -154,7 +148,7 @@ class TPUModelRunner:
# Dummy run.
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
self.model(token_ids, position_ids, kv_caches, attn_metadata,
input_lens, None, t, p, num_samples)
input_lens, t, p, num_samples)
def warmup_model(
self,
@ -199,14 +193,12 @@ class TPUModelRunner:
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
Mapping[str, BatchedTensors]]:
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
prompt_lens: List[int] = []
slot_mapping: List[List[int]] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
@ -232,11 +224,6 @@ class TPUModelRunner:
slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot)
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
assert len(prompt_lens) > 0
num_prefills = len(prompt_lens)
num_prefill_tokens = sum(prompt_lens)
@ -274,24 +261,17 @@ class TPUModelRunner:
block_tables=None,
context_lens=None,
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, attn_metadata, prompt_lens,
multi_modal_kwargs)
return input_tokens, input_positions, attn_metadata, prompt_lens
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
Mapping[str, BatchedTensors]]:
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
context_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
batch_idx = 0
for seq_group_metadata in seq_group_metadata_list:
@ -317,11 +297,6 @@ class TPUModelRunner:
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
batch_size = _get_padded_batch_size(batch_idx)
num_paddings = batch_size - batch_idx
input_tokens = input_tokens + [[0]] * num_paddings
@ -355,12 +330,7 @@ class TPUModelRunner:
block_tables=block_tables,
context_lens=context_lens,
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, attn_metadata, input_lens,
multi_modal_kwargs)
return input_tokens, input_positions, attn_metadata, input_lens
def _prepare_sample(
self,
@ -513,7 +483,6 @@ class ModelWrapper(nn.Module):
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
attn_metadata: AttentionMetadata,
input_lens: torch.Tensor,
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]],
t: torch.Tensor,
p: torch.Tensor,
num_samples: int,
@ -527,8 +496,6 @@ class ModelWrapper(nn.Module):
memory profiling at initialization.
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size].
multi_modal_kwargs: Keyword arguments from multi-modal data to
pass to the model.
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
"""
@ -573,7 +540,6 @@ class ModelWrapper(nn.Module):
position_ids,
kv_caches,
attn_metadata,
**(multi_modal_kwargs or {}),
)
hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata)