[bugfix] fix chatglm dummy_data_for_glmv (#9955)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-02 08:03:33 -07:00 committed by GitHub
parent d6459b4516
commit 74b529ceee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,8 +14,8 @@ from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -31,8 +31,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
MultiModalInputs)
from vllm.multimodal.base import MultiModalData from vllm.multimodal.base import MultiModalData
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
@ -117,16 +116,15 @@ def get_max_glmv_image_tokens(ctx: InputContext):
raise NotImplementedError(msg) raise NotImplementedError(msg)
def dummy_data_for_glmv( def dummy_data_for_glmv(ctx: InputContext, seq_len: int,
ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] mm_counts: Mapping[str, int]) -> DummyData:
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
hf_config = ctx.get_hf_config(ChatGLMConfig) hf_config = ctx.get_hf_config(ChatGLMConfig)
vision_config = getattr(hf_config, 'vision_config', None) vision_config = getattr(hf_config, 'vision_config', None)
if vision_config is None: if vision_config is None:
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len) token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
seq_data = SequenceData(token_ids) seq_data = SequenceData(token_ids)
return seq_data, None return DummyData(seq_data, None)
elif isinstance(vision_config, dict): elif isinstance(vision_config, dict):
image_size = vision_config["image_size"] image_size = vision_config["image_size"]
image_placeholder_length = calculate_image_placeholder(vision_config) image_placeholder_length = calculate_image_placeholder(vision_config)
@ -141,7 +139,7 @@ def dummy_data_for_glmv(
"image": Image.new("RGB", (image_size, image_size), color=0) "image": Image.new("RGB", (image_size, image_size), color=0)
} }
return seq_data, mm_data return DummyData(seq_data, mm_data)
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)