[VLM][Core] Fix exceptions on ragged NestedTensors (#7974)

This commit is contained in:
Peter Salas 2024-08-28 20:24:31 -07:00 committed by GitHub
parent a7f65c2be9
commit 74d5543ec5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 11 deletions

View File

@ -81,3 +81,15 @@ def test_multimodal_input_batch_multiple_batchable_lists():
result, result,
{"image": torch.stack([torch.stack([a, b]), {"image": torch.stack([torch.stack([a, b]),
torch.stack([c, d])])}) torch.stack([c, d])])})
def test_multimodal_input_batch_mixed_stacking_depths():
a = torch.rand([1, 2, 3])
b = torch.rand([1, 3, 3])
c = torch.rand([1, 4, 3])
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]})
result = MultiModalInputs.batch([{"image": [a]}, {"image": [b, c]}])
assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]})

View File

@ -1,7 +1,6 @@
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
Union, overload) Union, overload)
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.func import functional_call from torch.func import functional_call
@ -96,12 +95,13 @@ def flatten_bn(
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor: def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
""" """
Recursively concatenates NestedTensors along any heterogeneously sized Recursively flattens and concatenates NestedTensors on all but the last
dimensions. dimension.
""" """
if isinstance(embeddings, torch.Tensor): if isinstance(embeddings, torch.Tensor):
return embeddings # Flatten all but the last dimension.
return embeddings.flatten(0, -2)
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings)) return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
@ -136,15 +136,13 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
assert isinstance(num_expected_tokens, int) assert isinstance(num_expected_tokens, int)
flattened = _flatten_embeddings(multimodal_embeddings) flattened = _flatten_embeddings(multimodal_embeddings)
*dims, embed_dim = flattened.shape if flattened.shape[0] != num_expected_tokens:
num_multimodal_embeddings = np.prod(dims)
if num_multimodal_embeddings != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings) expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError( raise ValueError(
f"Attempted to assign {expr} = {num_multimodal_embeddings} " f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders") f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = flattened.view(num_expected_tokens, embed_dim) inputs_embeds[mask] = flattened
return inputs_embeds return inputs_embeds

View File

@ -54,8 +54,8 @@ class MultiModalInputs(_MultiModalInputsBase):
return nested_tensors return nested_tensors
stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
if is_list_of(stacked, list): if not is_list_of(stacked, torch.Tensor, check="all"):
# Do not stack nested lists # Only tensors (not lists) can be stacked.
return stacked return stacked
tensors_ = cast(List[torch.Tensor], stacked) tensors_ = cast(List[torch.Tensor], stacked)