[Bugfix] Fix img_sizes Parsing in Phi3-Vision (#5888)

This commit is contained in:
Roger Wang 2024-06-27 01:29:24 -07:00 committed by GitHub
parent 96354d6a29
commit 2061f0b8a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -65,12 +65,6 @@ class Phi3ImageEmbeddingBase(nn.Module):
self.type_feature: str
self.img_processor: CLIPVisionModel
def set_img_features(self, img_features: torch.FloatTensor) -> None:
self.img_features = img_features
def set_img_sizes(self, img_sizes: torch.LongTensor) -> None:
self.img_sizes = img_sizes
def get_img_features(self,
img_embeds: torch.FloatTensor) -> torch.FloatTensor:
LAYER_IDX = self.layer_idx
@ -144,22 +138,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
self.layer_idx = config.img_processor.get('layer_idx', -2)
self.type_feature = config.img_processor.get('type_feature', 'patch')
def forward(self,
input_ids: torch.LongTensor,
def forward(self, input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
image_sizes=None) -> torch.FloatTensor:
image_sizes: torch.Tensor) -> torch.FloatTensor:
"""process and merge text embeddings with image embeddings."""
# (batch_size, max_num_crops, 3, height, width)
img_embeds = pixel_values
# (batch_size, 2)
img_sizes = image_sizes
if self.img_features is not None:
img_embeds = self.img_features.clone()
self.img_features = None
if self.img_sizes is not None:
img_sizes = self.img_sizes
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
@ -190,11 +179,8 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
output_imgs = []
output_len = []
if isinstance(img_sizes, torch.Tensor):
img_sizes.squeeze_(0)
for _bs in range(bs):
h, w = img_sizes
h, w = img_sizes[_bs]
h = h // 336
w = w // 336
B_ = h * w