[Bugfix] Fix dtype mismatch in PaliGemma (#6367)
This commit is contained in:
parent
aea19f0989
commit
024ad87cdc
@ -129,7 +129,7 @@ def run_test(
|
|||||||
[0.25, 0.5, 1.0],
|
[0.25, 0.5, 1.0],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
@pytest.mark.parametrize("dtype", ["float", "half"])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||||
|
|||||||
@ -277,6 +277,7 @@ class GemmaModel(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from vllm.model_executor.models.gemma import GemmaModel
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.image import cached_get_tokenizer
|
from vllm.multimodal.image import cached_get_tokenizer
|
||||||
from vllm.sequence import SamplerOutput, SequenceData
|
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
||||||
|
|
||||||
from .interfaces import SupportsVision
|
from .interfaces import SupportsVision
|
||||||
from .utils import merge_vision_embeddings
|
from .utils import merge_vision_embeddings
|
||||||
@ -111,7 +111,7 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
orig_prompt = llm_inputs.get("prompt")
|
orig_prompt = llm_inputs.get("prompt")
|
||||||
orig_prompt_ids = llm_inputs.get("prompt_token_ids")
|
orig_prompt_ids = llm_inputs.get("prompt_token_ids")
|
||||||
|
|
||||||
if image_token_str in orig_prompt:
|
if orig_prompt is not None and image_token_str in orig_prompt:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The image token '%s' was detected in the prompt and "
|
"The image token '%s' was detected in the prompt and "
|
||||||
"will be removed. Please follow the proper prompt format"
|
"will be removed. Please follow the proper prompt format"
|
||||||
@ -214,7 +214,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
|
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
|
||||||
pixel_values: torch.Tensor) -> torch.Tensor:
|
pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
image_outputs = vision_tower(pixel_values, output_hidden_states=True)
|
target_dtype = vision_tower.get_input_embeddings().weight.dtype
|
||||||
|
image_outputs = vision_tower(pixel_values.to(dtype=target_dtype),
|
||||||
|
output_hidden_states=True)
|
||||||
|
|
||||||
selected_image_features = image_outputs.last_hidden_state
|
selected_image_features = image_outputs.last_hidden_state
|
||||||
|
|
||||||
@ -236,9 +238,12 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
|
|
||||||
return self.multi_modal_projector(image_features)
|
return self.multi_modal_projector(image_features)
|
||||||
|
|
||||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
def forward(self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
**kwargs: object) -> SamplerOutput:
|
**kwargs: object) -> SamplerOutput:
|
||||||
|
|
||||||
parsed_image_input = self._parse_and_validate_image_input(**kwargs)
|
parsed_image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
@ -263,6 +268,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
positions,
|
positions,
|
||||||
kv_caches,
|
kv_caches,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
|
None,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user