[VLM] Use SequenceData.from_token_counts to create dummy data (#8687)

This commit is contained in:
Cyrus Leung 2024-09-21 14:28:56 +08:00 committed by GitHub
parent 71c60491f2
commit 5e85f4f82a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 74 additions and 81 deletions

View File

@ -125,7 +125,7 @@ class InputRegistry:
# Avoid circular import
from vllm.sequence import SequenceData
dummy_seq_data = SequenceData.from_counts({0: seq_len})
dummy_seq_data = SequenceData.from_token_counts((0, seq_len))
dummy_multi_modal_data = None
return dummy_seq_data, dummy_multi_modal_data

View File

@ -1,6 +1,5 @@
"""Minimal implementation of BlipVisionModel intended to be only used
within a vision language model."""
from array import array
from typing import Optional, Union
import torch
@ -19,7 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from vllm.sequence import SequenceData
try:
from xformers import ops as xops
@ -53,6 +52,7 @@ def get_max_blip_image_tokens(
def dummy_seq_data_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
@ -62,11 +62,10 @@ def dummy_seq_data_for_blip(
else:
image_feature_size = image_feature_size_override
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size)
return SequenceData(token_ids)
return SequenceData.from_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
def dummy_image_for_blip(

View File

@ -1,4 +1,3 @@
from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
@ -18,8 +17,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.opt import OPTModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData
from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens)
@ -429,11 +427,10 @@ def dummy_seq_data_for_blip2(
else:
image_feature_size = image_feature_size_override
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
return SequenceData.from_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
def dummy_data_for_blip2(ctx: InputContext, seq_len: int,

View File

@ -1,4 +1,3 @@
from array import array
from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict)
@ -32,8 +31,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal
@ -72,11 +70,10 @@ def dummy_seq_data_for_chameleon(
else:
image_feature_size = image_feature_size_override
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
return SequenceData.from_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
def dummy_image_for_chameleon(

View File

@ -1,6 +1,5 @@
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from array import array
from typing import Iterable, List, Optional, Tuple, Union
import torch
@ -20,7 +19,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from vllm.sequence import SequenceData
try:
from xformers import ops as xops
@ -62,11 +61,10 @@ def dummy_seq_data_for_clip(
else:
image_feature_size = image_feature_size_override
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
return SequenceData.from_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
def dummy_image_for_clip(

View File

@ -23,7 +23,6 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math
import re
from array import array
from functools import partial
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
TypedDict)
@ -56,8 +55,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData
from .idefics2_vision_model import Idefics2VisionTransformer
@ -259,8 +257,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len
return SequenceData(token_ids)
return SequenceData.from_token_counts((0, seq_len))
def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int):

View File

@ -1,4 +1,3 @@
from array import array
from dataclasses import dataclass, fields
from itertools import tee
from typing import Iterable, List, Mapping, Optional, Tuple, Union
@ -24,8 +23,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData
from .interfaces import SupportsMultiModal
from .utils import init_vllm_registered_model
@ -63,13 +61,11 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
image_feature_size = (size**2) // (patch_size**2)
num_image_tokens = image_feature_size * num_images
seq_data = SequenceData.from_token_counts(
(image_token_id, num_image_tokens),
(0, seq_len - num_image_tokens),
)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * num_image_tokens
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - num_image_tokens)
seq_data = SequenceData(token_ids)
mm_data = {"image": num_images * [image]}
return seq_data, mm_data

View File

@ -7,7 +7,6 @@
import math
import re
from array import array
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Tuple, TypedDict, Union)
@ -45,8 +44,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import is_list_of
from .utils import flatten_bn, is_pp_missing_parameter, make_layers
@ -819,7 +817,7 @@ def dummy_data_for_qwen(
# The presence of a visual config indicates this is a multimodal model.
# If we don't have it, the model is considered an LLM for warmup purposes.
if not hasattr(hf_config, "visual"):
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len))
seq_data = SequenceData.from_token_counts((0, seq_len))
mm_data = None
return seq_data, mm_data
@ -846,11 +844,13 @@ def dummy_data_for_qwen(
if len(toks) < seq_len:
toks += [0] * (seq_len - len(toks))
seq_data = SequenceData.from_seqs(toks)
# Build the input images; width/height doesn't actually matter here since
# the data will get resized and the # of tokens per image is constant
image = Image.new("RGB", (224, 224), color=0)
mm_data = {"image": image if num_images == 1 else [image] * num_images}
return SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), mm_data
return seq_data, mm_data
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)

View File

@ -22,7 +22,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from array import array
from functools import lru_cache, partial
from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict,
Union)
@ -66,8 +65,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
from vllm.multimodal.base import MultiModalData
from vllm.multimodal.image import cached_get_image_processor
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import get_processor
logger = init_logger(__name__)
@ -681,15 +679,14 @@ def dummy_data_for_qwen2_vl(
"--limit-mm-per-prompt.")
hf_config = ctx.get_hf_config(Qwen2VLConfig)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[hf_config.vision_start_token_id])
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[hf_config.image_token_id]) * max_llm_image_tokens
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[hf_config.vision_end_token_id])
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - max_llm_image_tokens - 2)
dummy_seqdata = SequenceData(token_ids)
dummy_seqdata = SequenceData.from_token_counts(
(hf_config.vision_start_token_id, 1),
(hf_config.image_token_id, max_llm_image_tokens),
(hf_config.vision_end_token_id, 1),
(0, seq_len - max_llm_image_tokens - 2),
)
dummy_image = Image.new("RGB", (max_resized_width, max_resized_height),
color=0)

View File

@ -2,7 +2,6 @@
within a vision language model."""
import math
from array import array
from typing import Iterable, List, Optional, Tuple, Union
import torch
@ -24,7 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from vllm.sequence import SequenceData
try:
from xformers import ops as xops
@ -67,11 +66,10 @@ def dummy_seq_data_for_siglip(
else:
image_feature_size = image_feature_size_override
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size)
return SequenceData(token_ids)
return SequenceData.from_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
def dummy_image_for_siglip(

View File

@ -77,15 +77,11 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
def dummy_data_for_ultravox(
def dummy_seq_data_for_ultravox(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
audio_count: int,
):
feature_extractor = whisper_feature_extractor(ctx)
audio_count = mm_counts["audio"]
audio_placeholder = array(
VLLM_TOKEN_ID_ARRAY_TYPE,
[_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx)
@ -96,10 +92,28 @@ def dummy_data_for_ultravox(
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - len(audio_token_ids))
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
mm_dict = {"audio": [audio_and_sr] * audio_count}
return SequenceData(audio_token_ids + other_token_ids)
return (SequenceData(audio_token_ids + other_token_ids), mm_dict)
def dummy_audio_for_ultravox(
ctx: InputContext,
audio_count: int,
):
feature_extractor = whisper_feature_extractor(ctx)
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
return {"audio": [audio_and_sr] * audio_count}
def dummy_data_for_ultravox(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
):
audio_count = mm_counts["audio"]
seq_data = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
mm_dict = dummy_audio_for_ultravox(ctx, audio_count)
return (seq_data, mm_dict)
def input_mapper_for_ultravox(ctx: InputContext, data: object):

View File

@ -171,13 +171,13 @@ class SequenceData(msgspec.Struct,
_mrope_position_delta: Optional[int] = None
@staticmethod
def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData":
if len(counts_by_token) == 0:
def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData":
if len(token_counts) == 0:
return SequenceData.from_seqs([])
arrs = [
array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
for token_id, count in counts_by_token.items()
for token_id, count in token_counts
]
return SequenceData(reduce(array.__add__, arrs))