[Bugfix] Fix img_sizes Parsing in Phi3-Vision (#5888)
This commit is contained in:
parent
96354d6a29
commit
2061f0b8a7
@ -65,12 +65,6 @@ class Phi3ImageEmbeddingBase(nn.Module):
|
|||||||
self.type_feature: str
|
self.type_feature: str
|
||||||
self.img_processor: CLIPVisionModel
|
self.img_processor: CLIPVisionModel
|
||||||
|
|
||||||
def set_img_features(self, img_features: torch.FloatTensor) -> None:
|
|
||||||
self.img_features = img_features
|
|
||||||
|
|
||||||
def set_img_sizes(self, img_sizes: torch.LongTensor) -> None:
|
|
||||||
self.img_sizes = img_sizes
|
|
||||||
|
|
||||||
def get_img_features(self,
|
def get_img_features(self,
|
||||||
img_embeds: torch.FloatTensor) -> torch.FloatTensor:
|
img_embeds: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
LAYER_IDX = self.layer_idx
|
LAYER_IDX = self.layer_idx
|
||||||
@ -144,22 +138,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
|||||||
self.layer_idx = config.img_processor.get('layer_idx', -2)
|
self.layer_idx = config.img_processor.get('layer_idx', -2)
|
||||||
self.type_feature = config.img_processor.get('type_feature', 'patch')
|
self.type_feature = config.img_processor.get('type_feature', 'patch')
|
||||||
|
|
||||||
def forward(self,
|
def forward(self, input_ids: torch.LongTensor,
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
image_sizes=None) -> torch.FloatTensor:
|
image_sizes: torch.Tensor) -> torch.FloatTensor:
|
||||||
"""process and merge text embeddings with image embeddings."""
|
"""process and merge text embeddings with image embeddings."""
|
||||||
|
|
||||||
|
# (batch_size, max_num_crops, 3, height, width)
|
||||||
img_embeds = pixel_values
|
img_embeds = pixel_values
|
||||||
|
|
||||||
|
# (batch_size, 2)
|
||||||
img_sizes = image_sizes
|
img_sizes = image_sizes
|
||||||
|
|
||||||
if self.img_features is not None:
|
|
||||||
img_embeds = self.img_features.clone()
|
|
||||||
self.img_features = None
|
|
||||||
|
|
||||||
if self.img_sizes is not None:
|
|
||||||
img_sizes = self.img_sizes
|
|
||||||
|
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
|
||||||
@ -190,11 +179,8 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
|||||||
output_imgs = []
|
output_imgs = []
|
||||||
output_len = []
|
output_len = []
|
||||||
|
|
||||||
if isinstance(img_sizes, torch.Tensor):
|
|
||||||
img_sizes.squeeze_(0)
|
|
||||||
|
|
||||||
for _bs in range(bs):
|
for _bs in range(bs):
|
||||||
h, w = img_sizes
|
h, w = img_sizes[_bs]
|
||||||
h = h // 336
|
h = h // 336
|
||||||
w = w // 336
|
w = w // 336
|
||||||
B_ = h * w
|
B_ = h * w
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user