[Hardware][CPU] Enable mrope and support Qwen2-VL on CPU backend (#8770)
This commit is contained in:
parent
e3dd0692fa
commit
c23953675f
@ -67,6 +67,7 @@ from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.transformers_utils.processor import get_processor
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory)
|
||||
@ -281,6 +282,21 @@ class Qwen2VisionAttention(nn.Module):
|
||||
context_layer = rearrange(output,
|
||||
"(b s) ... -> b s ...",
|
||||
b=batch_size)
|
||||
elif is_cpu():
|
||||
seq_length = q.size(1)
|
||||
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]]
|
||||
attention_mask = torch.zeros([1, seq_length, seq_length],
|
||||
device=q.device,
|
||||
dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
|
||||
cu_seqlens[i - 1]:cu_seqlens[i]] = True
|
||||
output = F.scaled_dot_product_attention(q,
|
||||
k,
|
||||
v,
|
||||
attention_mask,
|
||||
dropout_p=0.0)
|
||||
context_layer = rearrange(output, "b h s d -> b s h d ")
|
||||
else:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
@ -12,11 +12,13 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalInputs)
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.sequence import (IntermediateTensors, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
@ -145,6 +147,38 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
query_lens=seq_lens,
|
||||
)
|
||||
|
||||
def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data,
|
||||
computed_len: int):
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
|
||||
# special processing for mrope position deltas.
|
||||
mrope_positions = None
|
||||
if self.runner.model_is_mrope:
|
||||
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
|
||||
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
|
||||
assert image_grid_thw is not None or video_grid_thw is not None, (
|
||||
"mrope embedding type requires multi-modal input mapper "
|
||||
"returns 'image_grid_thw' or 'video_grid_thw'.")
|
||||
|
||||
hf_config = self.runner.model_config.hf_config
|
||||
token_ids = seq_data.get_token_ids()
|
||||
|
||||
mrope_positions, mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions(
|
||||
token_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
video_token_id=hf_config.video_token_id,
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
vision_end_token_id=hf_config.vision_end_token_id,
|
||||
spatial_merge_size=hf_config.vision_config.
|
||||
spatial_merge_size,
|
||||
context_len=computed_len,
|
||||
)
|
||||
seq_data.mrope_position_delta = mrope_position_delta
|
||||
return mm_kwargs, mrope_positions
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
@ -153,6 +187,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
|
||||
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
multi_modal_inputs_list: List[MultiModalInputs] = []
|
||||
@ -171,14 +207,20 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
seq_lens.append(seq_len) # Prompt token num
|
||||
input_tokens.extend(prompt_tokens) # Token ids
|
||||
|
||||
mrope_positions = None
|
||||
if (mm_data := seq_group_metadata.multi_modal_data):
|
||||
mm_kwargs, mrope_positions = self._compute_multi_modal_input(
|
||||
seq_data, mm_data, computed_len)
|
||||
multi_modal_inputs_list.append(mm_kwargs)
|
||||
|
||||
# Token position ids
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
input_positions.extend(list(range(computed_len, seq_len)))
|
||||
|
||||
if (mm_data := seq_group_metadata.multi_modal_data):
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
multi_modal_inputs_list.append(mm_kwargs)
|
||||
if mrope_positions:
|
||||
for idx in range(3):
|
||||
input_mrope_positions[idx].extend(mrope_positions[idx])
|
||||
else:
|
||||
input_positions.extend(list(range(computed_len, seq_len)))
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
@ -202,12 +244,18 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
if any(input_mrope_positions):
|
||||
input_positions = None # type: ignore
|
||||
else:
|
||||
input_mrope_positions = None # type: ignore
|
||||
|
||||
num_prompt_tokens = len(input_tokens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
input_positions = torch.tensor(input_positions,
|
||||
input_positions = torch.tensor(input_positions
|
||||
or input_mrope_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
@ -238,6 +286,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
block_tables: List[List[int]] = []
|
||||
@ -255,7 +304,17 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append(position)
|
||||
if seq_data.mrope_position_delta is not None:
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
next_pos = MRotaryEmbedding.get_next_input_positions(
|
||||
seq_data.mrope_position_delta,
|
||||
context_len,
|
||||
seq_len,
|
||||
)
|
||||
for idx in range(3):
|
||||
input_mrope_positions[idx].extend(next_pos[idx])
|
||||
else:
|
||||
input_positions.append(position)
|
||||
|
||||
seq_len = seq_len if self.sliding_window is None else min(
|
||||
seq_len, self.sliding_window)
|
||||
@ -273,12 +332,18 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
block_tables.append(block_table)
|
||||
|
||||
if any(input_mrope_positions):
|
||||
input_positions = None # type: ignore
|
||||
else:
|
||||
input_mrope_positions = None # type: ignore
|
||||
|
||||
max_decode_seq_len = max(seq_lens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = torch.tensor(input_positions,
|
||||
input_positions = torch.tensor(input_positions
|
||||
or input_mrope_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
@ -373,6 +438,15 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])
|
||||
|
||||
@property
|
||||
def model_is_mrope(self) -> bool:
|
||||
"""Detect if the model has "mrope" rope_scaling type.
|
||||
mrope requires keep "rope_deltas" between prompt and decoding phases."""
|
||||
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
|
||||
if rope_scaling is None:
|
||||
return False
|
||||
return rope_scaling.get("type", None) == "mrope"
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user