[Bugfix] Fix dtype mismatch in PaliGemma (#6367)
This commit is contained in:
parent
aea19f0989
commit
024ad87cdc
@ -129,7 +129,7 @@ def run_test(
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("dtype", ["float", "half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
|
||||
@ -277,6 +277,7 @@ class GemmaModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
|
||||
@ -19,7 +19,7 @@ from vllm.model_executor.models.gemma import GemmaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import cached_get_tokenizer
|
||||
from vllm.sequence import SamplerOutput, SequenceData
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
||||
|
||||
from .interfaces import SupportsVision
|
||||
from .utils import merge_vision_embeddings
|
||||
@ -111,7 +111,7 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
orig_prompt = llm_inputs.get("prompt")
|
||||
orig_prompt_ids = llm_inputs.get("prompt_token_ids")
|
||||
|
||||
if image_token_str in orig_prompt:
|
||||
if orig_prompt is not None and image_token_str in orig_prompt:
|
||||
logger.warning(
|
||||
"The image token '%s' was detected in the prompt and "
|
||||
"will be removed. Please follow the proper prompt format"
|
||||
@ -214,7 +214,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
|
||||
pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
image_outputs = vision_tower(pixel_values, output_hidden_states=True)
|
||||
target_dtype = vision_tower.get_input_embeddings().weight.dtype
|
||||
image_outputs = vision_tower(pixel_values.to(dtype=target_dtype),
|
||||
output_hidden_states=True)
|
||||
|
||||
selected_image_features = image_outputs.last_hidden_state
|
||||
|
||||
@ -236,9 +238,12 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
return self.multi_modal_projector(image_features)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object) -> SamplerOutput:
|
||||
|
||||
parsed_image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
@ -263,6 +268,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
None,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
Loading…
Reference in New Issue
Block a user