[Bugfix] Fix LLaVA-NeXT (#5380)
This commit is contained in:
parent
774d1035e4
commit
2c0d933594
@ -216,6 +216,30 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||||
|
strategy: str) -> torch.Tensor:
|
||||||
|
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
|
||||||
|
if strategy == "default":
|
||||||
|
return image_features[:, 1:]
|
||||||
|
elif strategy == "full":
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||||
|
|
||||||
|
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
|
||||||
|
pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
|
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
|
||||||
|
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
|
||||||
|
output_hidden_states=True)
|
||||||
|
|
||||||
|
image_features = image_outputs.hidden_states[
|
||||||
|
self.config.vision_feature_layer]
|
||||||
|
|
||||||
|
return self._select_image_features(
|
||||||
|
image_features,
|
||||||
|
strategy=self.config.vision_feature_select_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
|
def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
|
||||||
patch_embeddings: torch.Tensor, *,
|
patch_embeddings: torch.Tensor, *,
|
||||||
strategy: str) -> torch.Tensor:
|
strategy: str) -> torch.Tensor:
|
||||||
|
|||||||
@ -77,7 +77,7 @@ def get_full_image_text_prompt(image_prompt: str, text_prompt: str,
|
|||||||
"""Combine image and text prompts for vision language model depending on
|
"""Combine image and text prompts for vision language model depending on
|
||||||
the model architecture."""
|
the model architecture."""
|
||||||
|
|
||||||
if config.hf_config.model_type == "llava":
|
if config.hf_config.model_type in ("llava", "llava_next"):
|
||||||
full_prompt = f"{image_prompt}\n{text_prompt}"
|
full_prompt = f"{image_prompt}\n{text_prompt}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user