[VLM] Use SequenceData.from_token_counts to create dummy data (#8687)
This commit is contained in:
parent
71c60491f2
commit
5e85f4f82a
@ -125,7 +125,7 @@ class InputRegistry:
|
||||
# Avoid circular import
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
dummy_seq_data = SequenceData.from_counts({0: seq_len})
|
||||
dummy_seq_data = SequenceData.from_token_counts((0, seq_len))
|
||||
dummy_multi_modal_data = None
|
||||
|
||||
return dummy_seq_data, dummy_multi_modal_data
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
"""Minimal implementation of BlipVisionModel intended to be only used
|
||||
within a vision language model."""
|
||||
from array import array
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
@ -19,7 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
@ -53,6 +52,7 @@ def get_max_blip_image_tokens(
|
||||
def dummy_seq_data_for_blip(
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
|
||||
seq_len: int,
|
||||
num_images: int,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
@ -62,11 +62,10 @@ def dummy_seq_data_for_blip(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[image_token_id]) * image_feature_size
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - image_feature_size)
|
||||
return SequenceData(token_ids)
|
||||
return SequenceData.from_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
)
|
||||
|
||||
|
||||
def dummy_image_for_blip(
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from array import array
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
@ -18,8 +17,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.opt import OPTModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
|
||||
from .blip import (BlipVisionModel, dummy_image_for_blip,
|
||||
get_max_blip_image_tokens)
|
||||
@ -429,11 +427,10 @@ def dummy_seq_data_for_blip2(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[image_token_id]) * image_feature_size * num_images
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - image_feature_size * num_images)
|
||||
return SequenceData(token_ids)
|
||||
return SequenceData.from_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
)
|
||||
|
||||
|
||||
def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from array import array
|
||||
from functools import cached_property
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
|
||||
Tuple, TypedDict)
|
||||
@ -32,8 +31,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .interfaces import SupportsMultiModal
|
||||
@ -72,11 +70,10 @@ def dummy_seq_data_for_chameleon(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[image_token_id]) * image_feature_size * num_images
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - image_feature_size * num_images)
|
||||
return SequenceData(token_ids)
|
||||
return SequenceData.from_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
)
|
||||
|
||||
|
||||
def dummy_image_for_chameleon(
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
"""Minimal implementation of CLIPVisionModel intended to be only used
|
||||
within a vision language model."""
|
||||
from array import array
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -20,7 +19,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
@ -62,11 +61,10 @@ def dummy_seq_data_for_clip(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[image_token_id]) * image_feature_size * num_images
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - image_feature_size * num_images)
|
||||
return SequenceData(token_ids)
|
||||
return SequenceData.from_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
)
|
||||
|
||||
|
||||
def dummy_image_for_clip(
|
||||
|
||||
@ -23,7 +23,6 @@
|
||||
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
|
||||
import math
|
||||
import re
|
||||
from array import array
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
|
||||
TypedDict)
|
||||
@ -56,8 +55,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
|
||||
@ -259,8 +257,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
|
||||
|
||||
|
||||
def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len
|
||||
return SequenceData(token_ids)
|
||||
return SequenceData.from_token_counts((0, seq_len))
|
||||
|
||||
|
||||
def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int):
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from array import array
|
||||
from dataclasses import dataclass, fields
|
||||
from itertools import tee
|
||||
from typing import Iterable, List, Mapping, Optional, Tuple, Union
|
||||
@ -24,8 +23,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .utils import init_vllm_registered_model
|
||||
@ -63,13 +61,11 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
|
||||
image_feature_size = (size**2) // (patch_size**2)
|
||||
|
||||
num_image_tokens = image_feature_size * num_images
|
||||
seq_data = SequenceData.from_token_counts(
|
||||
(image_token_id, num_image_tokens),
|
||||
(0, seq_len - num_image_tokens),
|
||||
)
|
||||
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[image_token_id]) * num_image_tokens
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - num_image_tokens)
|
||||
|
||||
seq_data = SequenceData(token_ids)
|
||||
mm_data = {"image": num_images * [image]}
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
|
||||
import math
|
||||
import re
|
||||
from array import array
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
|
||||
Optional, Tuple, TypedDict, Union)
|
||||
@ -45,8 +44,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .utils import flatten_bn, is_pp_missing_parameter, make_layers
|
||||
@ -819,7 +817,7 @@ def dummy_data_for_qwen(
|
||||
# The presence of a visual config indicates this is a multimodal model.
|
||||
# If we don't have it, the model is considered an LLM for warmup purposes.
|
||||
if not hasattr(hf_config, "visual"):
|
||||
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len))
|
||||
seq_data = SequenceData.from_token_counts((0, seq_len))
|
||||
mm_data = None
|
||||
return seq_data, mm_data
|
||||
|
||||
@ -846,11 +844,13 @@ def dummy_data_for_qwen(
|
||||
if len(toks) < seq_len:
|
||||
toks += [0] * (seq_len - len(toks))
|
||||
|
||||
seq_data = SequenceData.from_seqs(toks)
|
||||
|
||||
# Build the input images; width/height doesn't actually matter here since
|
||||
# the data will get resized and the # of tokens per image is constant
|
||||
image = Image.new("RGB", (224, 224), color=0)
|
||||
mm_data = {"image": image if num_images == 1 else [image] * num_images}
|
||||
return SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), mm_data
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
|
||||
|
||||
@ -22,7 +22,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||
from array import array
|
||||
from functools import lru_cache, partial
|
||||
from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict,
|
||||
Union)
|
||||
@ -66,8 +65,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||
from vllm.multimodal.base import MultiModalData
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.transformers_utils.processor import get_processor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -681,15 +679,14 @@ def dummy_data_for_qwen2_vl(
|
||||
"--limit-mm-per-prompt.")
|
||||
|
||||
hf_config = ctx.get_hf_config(Qwen2VLConfig)
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[hf_config.vision_start_token_id])
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[hf_config.image_token_id]) * max_llm_image_tokens
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[hf_config.vision_end_token_id])
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - max_llm_image_tokens - 2)
|
||||
dummy_seqdata = SequenceData(token_ids)
|
||||
|
||||
dummy_seqdata = SequenceData.from_token_counts(
|
||||
(hf_config.vision_start_token_id, 1),
|
||||
(hf_config.image_token_id, max_llm_image_tokens),
|
||||
(hf_config.vision_end_token_id, 1),
|
||||
(0, seq_len - max_llm_image_tokens - 2),
|
||||
)
|
||||
|
||||
dummy_image = Image.new("RGB", (max_resized_width, max_resized_height),
|
||||
color=0)
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
within a vision language model."""
|
||||
|
||||
import math
|
||||
from array import array
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -24,7 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
@ -67,11 +66,10 @@ def dummy_seq_data_for_siglip(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[image_token_id]) * image_feature_size
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - image_feature_size)
|
||||
return SequenceData(token_ids)
|
||||
return SequenceData.from_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
)
|
||||
|
||||
|
||||
def dummy_image_for_siglip(
|
||||
|
||||
@ -77,15 +77,11 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
|
||||
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
|
||||
|
||||
|
||||
def dummy_data_for_ultravox(
|
||||
def dummy_seq_data_for_ultravox(
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
audio_count: int,
|
||||
):
|
||||
feature_extractor = whisper_feature_extractor(ctx)
|
||||
|
||||
audio_count = mm_counts["audio"]
|
||||
|
||||
audio_placeholder = array(
|
||||
VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx)
|
||||
@ -96,10 +92,28 @@ def dummy_data_for_ultravox(
|
||||
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - len(audio_token_ids))
|
||||
|
||||
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
|
||||
mm_dict = {"audio": [audio_and_sr] * audio_count}
|
||||
return SequenceData(audio_token_ids + other_token_ids)
|
||||
|
||||
return (SequenceData(audio_token_ids + other_token_ids), mm_dict)
|
||||
|
||||
def dummy_audio_for_ultravox(
|
||||
ctx: InputContext,
|
||||
audio_count: int,
|
||||
):
|
||||
feature_extractor = whisper_feature_extractor(ctx)
|
||||
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
|
||||
return {"audio": [audio_and_sr] * audio_count}
|
||||
|
||||
|
||||
def dummy_data_for_ultravox(
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
):
|
||||
audio_count = mm_counts["audio"]
|
||||
seq_data = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
|
||||
mm_dict = dummy_audio_for_ultravox(ctx, audio_count)
|
||||
|
||||
return (seq_data, mm_dict)
|
||||
|
||||
|
||||
def input_mapper_for_ultravox(ctx: InputContext, data: object):
|
||||
|
||||
@ -171,13 +171,13 @@ class SequenceData(msgspec.Struct,
|
||||
_mrope_position_delta: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData":
|
||||
if len(counts_by_token) == 0:
|
||||
def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData":
|
||||
if len(token_counts) == 0:
|
||||
return SequenceData.from_seqs([])
|
||||
|
||||
arrs = [
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
|
||||
for token_id, count in counts_by_token.items()
|
||||
for token_id, count in token_counts
|
||||
]
|
||||
|
||||
return SequenceData(reduce(array.__add__, arrs))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user