From 9f0e69b65350fad1d4a9c71ef58d6ae70eb635e8 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 31 Jul 2024 10:49:48 +0800 Subject: [PATCH] [CI/Build] Fix mypy errors (#6968) --- vllm/_custom_ops.py | 4 ++-- vllm/multimodal/base.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e351d602..2c09ca2c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,6 +1,6 @@ import contextlib import functools -from typing import List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Union import torch @@ -336,7 +336,7 @@ def scaled_fp8_quant( """ # This code assumes batch_dim and num_tokens are flattened assert (input.ndim == 2) - shape = input.shape + shape: Union[Tuple[int, int], torch.Size] = input.shape if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index f13885ef..aefb5f43 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -53,9 +53,7 @@ class MultiModalInputs(_MultiModalInputsBase): """ @staticmethod - def _try_concat( - tensors: List[NestedTensors], - ) -> Union[GenericSequence[NestedTensors], NestedTensors]: + def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors: """ If each input tensor in the batch has the same shape, return a single batched tensor; otherwise, return a list of :class:`NestedTensors` with @@ -105,7 +103,7 @@ class MultiModalInputs(_MultiModalInputsBase): return { k: MultiModalInputs._try_concat(item_list) for k, item_list in item_lists.items() - } # type: ignore + } @staticmethod def as_kwargs(