[bugfix] fix chatglm dummy_data_for_glmv (#9955)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
d6459b4516
commit
74b529ceee
@ -14,8 +14,8 @@ from torch.nn import LayerNorm
|
|||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||||
token_inputs)
|
InputContext, token_inputs)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -31,8 +31,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
|
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
|
||||||
MultiModalInputs)
|
|
||||||
from vllm.multimodal.base import MultiModalData
|
from vllm.multimodal.base import MultiModalData
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||||
@ -117,16 +116,15 @@ def get_max_glmv_image_tokens(ctx: InputContext):
|
|||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
|
||||||
def dummy_data_for_glmv(
|
def dummy_data_for_glmv(ctx: InputContext, seq_len: int,
|
||||||
ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
|
mm_counts: Mapping[str, int]) -> DummyData:
|
||||||
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
|
|
||||||
hf_config = ctx.get_hf_config(ChatGLMConfig)
|
hf_config = ctx.get_hf_config(ChatGLMConfig)
|
||||||
vision_config = getattr(hf_config, 'vision_config', None)
|
vision_config = getattr(hf_config, 'vision_config', None)
|
||||||
|
|
||||||
if vision_config is None:
|
if vision_config is None:
|
||||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
|
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
|
||||||
seq_data = SequenceData(token_ids)
|
seq_data = SequenceData(token_ids)
|
||||||
return seq_data, None
|
return DummyData(seq_data, None)
|
||||||
elif isinstance(vision_config, dict):
|
elif isinstance(vision_config, dict):
|
||||||
image_size = vision_config["image_size"]
|
image_size = vision_config["image_size"]
|
||||||
image_placeholder_length = calculate_image_placeholder(vision_config)
|
image_placeholder_length = calculate_image_placeholder(vision_config)
|
||||||
@ -141,7 +139,7 @@ def dummy_data_for_glmv(
|
|||||||
"image": Image.new("RGB", (image_size, image_size), color=0)
|
"image": Image.new("RGB", (image_size, image_size), color=0)
|
||||||
}
|
}
|
||||||
|
|
||||||
return seq_data, mm_data
|
return DummyData(seq_data, mm_data)
|
||||||
|
|
||||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user