[Hardware][CPU] Enable mrope and support Qwen2-VL on CPU backend (#8770)

This commit is contained in:
Isotr0py 2024-09-25 14:16:11 +08:00 committed by GitHub
parent e3dd0692fa
commit c23953675f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 99 additions and 9 deletions

View File

@ -67,6 +67,7 @@ from vllm.multimodal.image import cached_get_image_processor
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import get_processor
from vllm.utils import is_cpu
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory)
@ -281,6 +282,21 @@ class Qwen2VisionAttention(nn.Module):
context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
elif is_cpu():
seq_length = q.size(1)
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]]
attention_mask = torch.zeros([1, seq_length, seq_length],
device=q.device,
dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
cu_seqlens[i - 1]:cu_seqlens[i]] = True
output = F.scaled_dot_product_attention(q,
k,
v,
attention_mask,
dropout_p=0.0)
context_layer = rearrange(output, "b h s d -> b s h d ")
else:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask

View File

@ -12,11 +12,13 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
@ -145,6 +147,38 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
query_lens=seq_lens,
)
def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data,
computed_len: int):
mm_kwargs = self.multi_modal_input_mapper(mm_data)
# special processing for mrope position deltas.
mrope_positions = None
if self.runner.model_is_mrope:
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
assert image_grid_thw is not None or video_grid_thw is not None, (
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw'.")
hf_config = self.runner.model_config.hf_config
token_ids = seq_data.get_token_ids()
mrope_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
context_len=computed_len,
)
seq_data.mrope_position_delta = mrope_position_delta
return mm_kwargs, mrope_positions
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
@ -153,6 +187,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
slot_mapping: List[int] = []
seq_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
@ -171,14 +207,20 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
seq_lens.append(seq_len) # Prompt token num
input_tokens.extend(prompt_tokens) # Token ids
mrope_positions = None
if (mm_data := seq_group_metadata.multi_modal_data):
mm_kwargs, mrope_positions = self._compute_multi_modal_input(
seq_data, mm_data, computed_len)
multi_modal_inputs_list.append(mm_kwargs)
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, seq_len)))
if (mm_data := seq_group_metadata.multi_modal_data):
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
if mrope_positions:
for idx in range(3):
input_mrope_positions[idx].extend(mrope_positions[idx])
else:
input_positions.extend(list(range(computed_len, seq_len)))
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
@ -202,12 +244,18 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if any(input_mrope_positions):
input_positions = None # type: ignore
else:
input_mrope_positions = None # type: ignore
num_prompt_tokens = len(input_tokens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device) # type: ignore
input_positions = torch.tensor(input_positions,
input_positions = torch.tensor(input_positions
or input_mrope_positions,
dtype=torch.long,
device=self.device) # type: ignore
slot_mapping = torch.tensor(slot_mapping,
@ -238,6 +286,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
slot_mapping: List[int] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
@ -255,7 +304,17 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append(position)
if seq_data.mrope_position_delta is not None:
context_len = seq_data.get_num_computed_tokens()
next_pos = MRotaryEmbedding.get_next_input_positions(
seq_data.mrope_position_delta,
context_len,
seq_len,
)
for idx in range(3):
input_mrope_positions[idx].extend(next_pos[idx])
else:
input_positions.append(position)
seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
@ -273,12 +332,18 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
if any(input_mrope_positions):
input_positions = None # type: ignore
else:
input_mrope_positions = None # type: ignore
max_decode_seq_len = max(seq_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
input_positions = torch.tensor(input_positions
or input_mrope_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
@ -373,6 +438,15 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])
@property
def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"
def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,