[Bugfix] Fix img_sizes Parsing in Phi3-Vision (#5888)
This commit is contained in:
parent
96354d6a29
commit
2061f0b8a7
@ -65,12 +65,6 @@ class Phi3ImageEmbeddingBase(nn.Module):
|
||||
self.type_feature: str
|
||||
self.img_processor: CLIPVisionModel
|
||||
|
||||
def set_img_features(self, img_features: torch.FloatTensor) -> None:
|
||||
self.img_features = img_features
|
||||
|
||||
def set_img_sizes(self, img_sizes: torch.LongTensor) -> None:
|
||||
self.img_sizes = img_sizes
|
||||
|
||||
def get_img_features(self,
|
||||
img_embeds: torch.FloatTensor) -> torch.FloatTensor:
|
||||
LAYER_IDX = self.layer_idx
|
||||
@ -144,22 +138,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
self.layer_idx = config.img_processor.get('layer_idx', -2)
|
||||
self.type_feature = config.img_processor.get('type_feature', 'patch')
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.LongTensor,
|
||||
def forward(self, input_ids: torch.LongTensor,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes=None) -> torch.FloatTensor:
|
||||
image_sizes: torch.Tensor) -> torch.FloatTensor:
|
||||
"""process and merge text embeddings with image embeddings."""
|
||||
|
||||
# (batch_size, max_num_crops, 3, height, width)
|
||||
img_embeds = pixel_values
|
||||
|
||||
# (batch_size, 2)
|
||||
img_sizes = image_sizes
|
||||
|
||||
if self.img_features is not None:
|
||||
img_embeds = self.img_features.clone()
|
||||
self.img_features = None
|
||||
|
||||
if self.img_sizes is not None:
|
||||
img_sizes = self.img_sizes
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
|
||||
@ -190,11 +179,8 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
output_imgs = []
|
||||
output_len = []
|
||||
|
||||
if isinstance(img_sizes, torch.Tensor):
|
||||
img_sizes.squeeze_(0)
|
||||
|
||||
for _bs in range(bs):
|
||||
h, w = img_sizes
|
||||
h, w = img_sizes[_bs]
|
||||
h = h // 336
|
||||
w = w // 336
|
||||
B_ = h * w
|
||||
|
||||
Loading…
Reference in New Issue
Block a user