[mypy] Misc. typing improvements (#7417)

This commit is contained in:
Cyrus Leung 2024-08-13 09:20:20 +08:00 committed by GitHub
parent 198d6a2898
commit 9ba85bc152
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 74 additions and 75 deletions

View File

@ -1,10 +1,12 @@
import contextlib
import functools
import gc
from typing import Callable, TypeVar
import pytest
import ray
import torch
from typing_extensions import ParamSpec
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
@ -22,12 +24,16 @@ def cleanup():
torch.cuda.empty_cache()
def retry_until_skip(n):
_P = ParamSpec("_P")
_R = TypeVar("_R")
def decorator_retry(func):
def retry_until_skip(n: int):
def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]:
@functools.wraps(func)
def wrapper_retry(*args, **kwargs):
def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R:
for i in range(n):
try:
return func(*args, **kwargs)
@ -35,7 +41,9 @@ def retry_until_skip(n):
gc.collect()
torch.cuda.empty_cache()
if i == n - 1:
pytest.skip("Skipping test after attempts..")
pytest.skip(f"Skipping test after {n} attempts.")
raise AssertionError("Code should not be reached")
return wrapper_retry

View File

@ -1,10 +1,8 @@
import asyncio
import os
import socket
import sys
from functools import partial
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
Tuple, TypeVar)
from typing import AsyncIterator, Tuple
import pytest
@ -13,26 +11,11 @@ from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
from .utils import error_on_warning
if sys.version_info < (3, 10):
if TYPE_CHECKING:
_AwaitableT = TypeVar("_AwaitableT", bound=Awaitable[Any])
_AwaitableT_co = TypeVar("_AwaitableT_co",
bound=Awaitable[Any],
covariant=True)
class _SupportsSynchronousAnext(Protocol[_AwaitableT_co]):
def __anext__(self) -> _AwaitableT_co:
...
def anext(i: "_SupportsSynchronousAnext[_AwaitableT]", /) -> "_AwaitableT":
return i.__anext__()
@pytest.mark.asyncio
async def test_merge_async_iterators():
async def mock_async_iterator(idx: int) -> AsyncIterator[str]:
async def mock_async_iterator(idx: int):
try:
while True:
yield f"item from iterator {idx}"
@ -41,8 +24,10 @@ async def test_merge_async_iterators():
print(f"iterator {idx} cancelled")
iterators = [mock_async_iterator(i) for i in range(3)]
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
*iterators, is_cancelled=partial(asyncio.sleep, 0, result=False))
merged_iterator = merge_async_iterators(*iterators,
is_cancelled=partial(asyncio.sleep,
0,
result=False))
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
async for idx, output in generator:
@ -56,7 +41,8 @@ async def test_merge_async_iterators():
for iterator in iterators:
try:
await asyncio.wait_for(anext(iterator), 1)
# Can use anext() in python >= 3.10
await asyncio.wait_for(iterator.__anext__(), 1)
except StopAsyncIteration:
# All iterators should be cancelled and print this message.
print("Iterator was cancelled normally")

View File

@ -7,12 +7,13 @@ import time
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
import openai
import ray
import requests
from transformers import AutoTokenizer
from typing_extensions import ParamSpec
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
@ -360,13 +361,17 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
time.sleep(5)
def fork_new_process_for_each_test(f):
_P = ParamSpec("_P")
def fork_new_process_for_each_test(
f: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to fork a new process for each test function.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
"""
@functools.wraps(f)
def wrapper(*args, **kwargs):
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
os.setpgrp()

View File

@ -1,10 +1,10 @@
import functools
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type,
TypeVar)
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type
from torch import nn
from transformers import PretrainedConfig
from typing_extensions import TypeVar
from vllm.logger import init_logger
@ -17,7 +17,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
C = TypeVar("C", bound=PretrainedConfig)
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
@dataclass(frozen=True)
@ -44,7 +44,7 @@ class InputContext:
return multimodal_config
def get_hf_config(self, hf_config_type: Type[C]) -> C:
def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C:
"""
Get the HuggingFace configuration
(:class:`transformers.PretrainedConfig`) of the model,

View File

@ -165,7 +165,7 @@ def get_internvl_num_patches(image_size: int, patch_size: int,
def get_max_internvl_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config
use_thumbnail = hf_config.use_thumbnail
@ -187,7 +187,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config
image_size = vision_config.image_size
@ -260,7 +260,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
image_feature_size = get_max_internvl_image_tokens(ctx)
model_config = ctx.model_config
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)

View File

@ -34,7 +34,7 @@ import torch.types
from PIL import Image
from torch import nn
from torch.nn.init import trunc_normal_
from transformers.configuration_utils import PretrainedConfig
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
@ -404,7 +404,7 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
def get_max_minicpmv_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_config()
return getattr(hf_config, "query_num", 64)
@ -420,7 +420,7 @@ def dummy_image_for_minicpmv(hf_config: PretrainedConfig):
def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_config()
seq_data = dummy_seq_data_for_minicpmv(seq_len)
mm_data = dummy_image_for_minicpmv(hf_config)

View File

@ -341,7 +341,7 @@ def get_phi3v_image_feature_size(
def get_max_phi3v_image_tokens(ctx: InputContext):
return get_phi3v_image_feature_size(
ctx.get_hf_config(PretrainedConfig),
ctx.get_hf_config(),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
@ -391,7 +391,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_config()
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):

View File

@ -3,13 +3,12 @@ from typing import List, Optional, Tuple, TypeVar
import torch
from PIL import Image
from transformers import PreTrainedTokenizerBase
from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.image_processor import get_image_processor
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.utils import is_list_of
from .base import MultiModalInputs, MultiModalPlugin
@ -40,7 +39,7 @@ def repeat_and_pad_token(
def repeat_and_pad_image_tokens(
tokenizer: PreTrainedTokenizerBase,
tokenizer: AnyTokenizer,
prompt: Optional[str],
prompt_token_ids: List[int],
*,

View File

@ -4,9 +4,10 @@ pynvml. However, it should not initialize cuda context.
import os
from functools import lru_cache, wraps
from typing import List, Tuple
from typing import Callable, List, Tuple, TypeVar
import pynvml
from typing_extensions import ParamSpec
from vllm.logger import init_logger
@ -14,16 +15,19 @@ from .interface import Platform, PlatformEnum
logger = init_logger(__name__)
_P = ParamSpec("_P")
_R = TypeVar("_R")
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
def with_nvml_context(fn):
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def wrapper(*args, **kwargs):
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)

View File

@ -1,10 +1,9 @@
from typing import Dict, List, Optional, Tuple, Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing import Dict, List, Optional, Tuple
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from .tokenizer import AnyTokenizer
from .tokenizer_group import BaseTokenizerGroup
# Used eg. for marking rejected tokens in spec decoding.
INVALID_TOKEN_ID = -1
@ -16,8 +15,7 @@ class Detokenizer:
def __init__(self, tokenizer_group: BaseTokenizerGroup):
self.tokenizer_group = tokenizer_group
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
"""Returns the HF tokenizer to use for a given sequence."""
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
@ -174,7 +172,7 @@ def _replace_none_with_empty(tokens: List[Optional[str]]):
def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
tokenizer: AnyTokenizer,
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
@ -213,7 +211,7 @@ INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
def convert_prompt_ids_to_tokens(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
tokenizer: AnyTokenizer,
prompt_ids: List[int],
skip_special_tokens: bool = False,
) -> Tuple[List[str], int, int]:
@ -240,7 +238,7 @@ def convert_prompt_ids_to_tokens(
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
tokenizer: AnyTokenizer,
all_input_ids: List[int],
prev_tokens: Optional[List[str]],
prefix_offset: int,

View File

@ -12,10 +12,10 @@ from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
from vllm.utils import make_async
from .tokenizer_group import AnyTokenizer
logger = init_logger(__name__)
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
"""Get tokenizer with cached properties.
@ -141,7 +141,7 @@ def get_tokenizer(
def get_lora_tokenizer(lora_request: LoRARequest, *args,
**kwargs) -> Optional[PreTrainedTokenizer]:
**kwargs) -> Optional[AnyTokenizer]:
if lora_request is None:
return None
try:

View File

@ -8,8 +8,7 @@ from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
from .tokenizer_group import TokenizerGroup
if ray:
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
RayTokenizerGroupPool)
from .ray_tokenizer_group import RayTokenizerGroupPool
else:
RayTokenizerGroupPool = None # type: ignore

View File

@ -1,12 +1,9 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing import List, Optional
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
from vllm.transformers_utils.tokenizer import AnyTokenizer
class BaseTokenizerGroup(ABC):
@ -24,9 +21,10 @@ class BaseTokenizerGroup(ABC):
pass
@abstractmethod
def get_max_input_len(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
def get_max_input_len(
self,
lora_request: Optional[LoRARequest] = None,
) -> Optional[int]:
"""Get the maximum input length for the LoRA request."""
pass

View File

@ -13,8 +13,9 @@ from vllm.config import TokenizerPoolConfig
from vllm.executor.ray_utils import ray
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
from .base_tokenizer_group import BaseTokenizerGroup
from .tokenizer_group import TokenizerGroup
logger = init_logger(__name__)

View File

@ -2,12 +2,13 @@ from typing import List, Optional
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
get_lora_tokenizer,
get_lora_tokenizer_async,
get_tokenizer)
from vllm.utils import LRUCache
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
from .base_tokenizer_group import BaseTokenizerGroup
class TokenizerGroup(BaseTokenizerGroup):

View File

@ -1101,9 +1101,9 @@ def cuda_device_count_stateless() -> int:
#From: https://stackoverflow.com/a/4104188/2749989
def run_once(f):
def run_once(f: Callable[P, None]) -> Callable[P, None]:
def wrapper(*args, **kwargs) -> Any:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
if not wrapper.has_run: # type: ignore[attr-defined]
wrapper.has_run = True # type: ignore[attr-defined]
return f(*args, **kwargs)