[TPU] Remove multi-modal args in TPU backend (#6504)
This commit is contained in:
parent
5fa6e9876e
commit
e09ce759aa
@ -1,5 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
from typing import List, Mapping, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -12,8 +12,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
|
|
||||||
MultiModalInputs)
|
|
||||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||||
SamplerOutput, SequenceGroupMetadata,
|
SamplerOutput, SequenceGroupMetadata,
|
||||||
SequenceOutput)
|
SequenceOutput)
|
||||||
@ -68,10 +66,6 @@ class TPUModelRunner:
|
|||||||
False,
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Multi-modal data support
|
|
||||||
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
|
|
||||||
.create_input_mapper(self.model_config)
|
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
self.device = self.device_config.device
|
self.device = self.device_config.device
|
||||||
|
|
||||||
@ -154,7 +148,7 @@ class TPUModelRunner:
|
|||||||
# Dummy run.
|
# Dummy run.
|
||||||
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
||||||
self.model(token_ids, position_ids, kv_caches, attn_metadata,
|
self.model(token_ids, position_ids, kv_caches, attn_metadata,
|
||||||
input_lens, None, t, p, num_samples)
|
input_lens, t, p, num_samples)
|
||||||
|
|
||||||
def warmup_model(
|
def warmup_model(
|
||||||
self,
|
self,
|
||||||
@ -199,14 +193,12 @@ class TPUModelRunner:
|
|||||||
def _prepare_prompt(
|
def _prepare_prompt(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
|
||||||
Mapping[str, BatchedTensors]]:
|
|
||||||
assert len(seq_group_metadata_list) > 0
|
assert len(seq_group_metadata_list) > 0
|
||||||
input_tokens: List[List[int]] = []
|
input_tokens: List[List[int]] = []
|
||||||
input_positions: List[List[int]] = []
|
input_positions: List[List[int]] = []
|
||||||
prompt_lens: List[int] = []
|
prompt_lens: List[int] = []
|
||||||
slot_mapping: List[List[int]] = []
|
slot_mapping: List[List[int]] = []
|
||||||
multi_modal_inputs_list: List[MultiModalInputs] = []
|
|
||||||
|
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
assert seq_group_metadata.is_prompt
|
assert seq_group_metadata.is_prompt
|
||||||
@ -232,11 +224,6 @@ class TPUModelRunner:
|
|||||||
slot = block_number * self.block_size + block_offset
|
slot = block_number * self.block_size + block_offset
|
||||||
slot_mapping[-1].append(slot)
|
slot_mapping[-1].append(slot)
|
||||||
|
|
||||||
mm_data = seq_group_metadata.multi_modal_data
|
|
||||||
if mm_data:
|
|
||||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
|
||||||
multi_modal_inputs_list.append(mm_kwargs)
|
|
||||||
|
|
||||||
assert len(prompt_lens) > 0
|
assert len(prompt_lens) > 0
|
||||||
num_prefills = len(prompt_lens)
|
num_prefills = len(prompt_lens)
|
||||||
num_prefill_tokens = sum(prompt_lens)
|
num_prefill_tokens = sum(prompt_lens)
|
||||||
@ -274,24 +261,17 @@ class TPUModelRunner:
|
|||||||
block_tables=None,
|
block_tables=None,
|
||||||
context_lens=None,
|
context_lens=None,
|
||||||
)
|
)
|
||||||
|
return input_tokens, input_positions, attn_metadata, prompt_lens
|
||||||
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
return (input_tokens, input_positions, attn_metadata, prompt_lens,
|
|
||||||
multi_modal_kwargs)
|
|
||||||
|
|
||||||
def _prepare_decode(
|
def _prepare_decode(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
|
||||||
Mapping[str, BatchedTensors]]:
|
|
||||||
assert len(seq_group_metadata_list) > 0
|
assert len(seq_group_metadata_list) > 0
|
||||||
input_tokens: List[List[int]] = []
|
input_tokens: List[List[int]] = []
|
||||||
input_positions: List[List[int]] = []
|
input_positions: List[List[int]] = []
|
||||||
slot_mapping: List[List[int]] = []
|
slot_mapping: List[List[int]] = []
|
||||||
context_lens: List[int] = []
|
context_lens: List[int] = []
|
||||||
multi_modal_inputs_list: List[MultiModalInputs] = []
|
|
||||||
|
|
||||||
batch_idx = 0
|
batch_idx = 0
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
@ -317,11 +297,6 @@ class TPUModelRunner:
|
|||||||
slot = block_number * self.block_size + block_offset
|
slot = block_number * self.block_size + block_offset
|
||||||
slot_mapping.append([slot])
|
slot_mapping.append([slot])
|
||||||
|
|
||||||
mm_data = seq_group_metadata.multi_modal_data
|
|
||||||
if mm_data:
|
|
||||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
|
||||||
multi_modal_inputs_list.append(mm_kwargs)
|
|
||||||
|
|
||||||
batch_size = _get_padded_batch_size(batch_idx)
|
batch_size = _get_padded_batch_size(batch_idx)
|
||||||
num_paddings = batch_size - batch_idx
|
num_paddings = batch_size - batch_idx
|
||||||
input_tokens = input_tokens + [[0]] * num_paddings
|
input_tokens = input_tokens + [[0]] * num_paddings
|
||||||
@ -355,12 +330,7 @@ class TPUModelRunner:
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
context_lens=context_lens,
|
context_lens=context_lens,
|
||||||
)
|
)
|
||||||
|
return input_tokens, input_positions, attn_metadata, input_lens
|
||||||
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
return (input_tokens, input_positions, attn_metadata, input_lens,
|
|
||||||
multi_modal_kwargs)
|
|
||||||
|
|
||||||
def _prepare_sample(
|
def _prepare_sample(
|
||||||
self,
|
self,
|
||||||
@ -513,7 +483,6 @@ class ModelWrapper(nn.Module):
|
|||||||
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
input_lens: torch.Tensor,
|
input_lens: torch.Tensor,
|
||||||
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]],
|
|
||||||
t: torch.Tensor,
|
t: torch.Tensor,
|
||||||
p: torch.Tensor,
|
p: torch.Tensor,
|
||||||
num_samples: int,
|
num_samples: int,
|
||||||
@ -527,8 +496,6 @@ class ModelWrapper(nn.Module):
|
|||||||
memory profiling at initialization.
|
memory profiling at initialization.
|
||||||
attn_metadata: The Pallas attention metadata.
|
attn_metadata: The Pallas attention metadata.
|
||||||
input_lens: The actual input lengths of shape [batch_size].
|
input_lens: The actual input lengths of shape [batch_size].
|
||||||
multi_modal_kwargs: Keyword arguments from multi-modal data to
|
|
||||||
pass to the model.
|
|
||||||
t: The sampling temperature of shape [batch_size].
|
t: The sampling temperature of shape [batch_size].
|
||||||
p: The top-p probability of shape [batch_size].
|
p: The top-p probability of shape [batch_size].
|
||||||
"""
|
"""
|
||||||
@ -573,7 +540,6 @@ class ModelWrapper(nn.Module):
|
|||||||
position_ids,
|
position_ids,
|
||||||
kv_caches,
|
kv_caches,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
**(multi_modal_kwargs or {}),
|
|
||||||
)
|
)
|
||||||
hidden_states = hidden_states.flatten(0, 1)
|
hidden_states = hidden_states.flatten(0, 1)
|
||||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user