[mypy] Misc. typing improvements (#7417)
This commit is contained in:
parent
198d6a2898
commit
9ba85bc152
@ -1,10 +1,12 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import gc
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.distributed import (destroy_distributed_environment,
|
||||
destroy_model_parallel)
|
||||
@ -22,12 +24,16 @@ def cleanup():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def retry_until_skip(n):
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
def decorator_retry(func):
|
||||
|
||||
def retry_until_skip(n: int):
|
||||
|
||||
def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper_retry(*args, **kwargs):
|
||||
def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
for i in range(n):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
@ -35,7 +41,9 @@ def retry_until_skip(n):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
if i == n - 1:
|
||||
pytest.skip("Skipping test after attempts..")
|
||||
pytest.skip(f"Skipping test after {n} attempts.")
|
||||
|
||||
raise AssertionError("Code should not be reached")
|
||||
|
||||
return wrapper_retry
|
||||
|
||||
|
||||
@ -1,10 +1,8 @@
|
||||
import asyncio
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from functools import partial
|
||||
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
|
||||
Tuple, TypeVar)
|
||||
from typing import AsyncIterator, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
@ -13,26 +11,11 @@ from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
|
||||
|
||||
from .utils import error_on_warning
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
if TYPE_CHECKING:
|
||||
_AwaitableT = TypeVar("_AwaitableT", bound=Awaitable[Any])
|
||||
_AwaitableT_co = TypeVar("_AwaitableT_co",
|
||||
bound=Awaitable[Any],
|
||||
covariant=True)
|
||||
|
||||
class _SupportsSynchronousAnext(Protocol[_AwaitableT_co]):
|
||||
|
||||
def __anext__(self) -> _AwaitableT_co:
|
||||
...
|
||||
|
||||
def anext(i: "_SupportsSynchronousAnext[_AwaitableT]", /) -> "_AwaitableT":
|
||||
return i.__anext__()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_async_iterators():
|
||||
|
||||
async def mock_async_iterator(idx: int) -> AsyncIterator[str]:
|
||||
async def mock_async_iterator(idx: int):
|
||||
try:
|
||||
while True:
|
||||
yield f"item from iterator {idx}"
|
||||
@ -41,8 +24,10 @@ async def test_merge_async_iterators():
|
||||
print(f"iterator {idx} cancelled")
|
||||
|
||||
iterators = [mock_async_iterator(i) for i in range(3)]
|
||||
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
|
||||
*iterators, is_cancelled=partial(asyncio.sleep, 0, result=False))
|
||||
merged_iterator = merge_async_iterators(*iterators,
|
||||
is_cancelled=partial(asyncio.sleep,
|
||||
0,
|
||||
result=False))
|
||||
|
||||
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
|
||||
async for idx, output in generator:
|
||||
@ -56,7 +41,8 @@ async def test_merge_async_iterators():
|
||||
|
||||
for iterator in iterators:
|
||||
try:
|
||||
await asyncio.wait_for(anext(iterator), 1)
|
||||
# Can use anext() in python >= 3.10
|
||||
await asyncio.wait_for(iterator.__anext__(), 1)
|
||||
except StopAsyncIteration:
|
||||
# All iterators should be cancelled and print this message.
|
||||
print("Iterator was cancelled normally")
|
||||
|
||||
@ -7,12 +7,13 @@ import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import openai
|
||||
import ray
|
||||
import requests
|
||||
from transformers import AutoTokenizer
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
@ -360,13 +361,17 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def fork_new_process_for_each_test(f):
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
def fork_new_process_for_each_test(
|
||||
f: Callable[_P, None]) -> Callable[_P, None]:
|
||||
"""Decorator to fork a new process for each test function.
|
||||
See https://github.com/vllm-project/vllm/issues/7053 for more details.
|
||||
"""
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
|
||||
# Make the process the leader of its own process group
|
||||
# to avoid sending SIGTERM to the parent process
|
||||
os.setpgrp()
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type,
|
||||
TypeVar)
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
C = TypeVar("C", bound=PretrainedConfig)
|
||||
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -44,7 +44,7 @@ class InputContext:
|
||||
|
||||
return multimodal_config
|
||||
|
||||
def get_hf_config(self, hf_config_type: Type[C]) -> C:
|
||||
def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C:
|
||||
"""
|
||||
Get the HuggingFace configuration
|
||||
(:class:`transformers.PretrainedConfig`) of the model,
|
||||
|
||||
@ -165,7 +165,7 @@ def get_internvl_num_patches(image_size: int, patch_size: int,
|
||||
|
||||
|
||||
def get_max_internvl_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||
hf_config = ctx.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
@ -187,7 +187,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
return llm_inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||
hf_config = ctx.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
image_size = vision_config.image_size
|
||||
@ -260,7 +260,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
|
||||
|
||||
image_feature_size = get_max_internvl_image_tokens(ctx)
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||
hf_config = ctx.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
|
||||
@ -34,7 +34,7 @@ import torch.types
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
@ -404,7 +404,7 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
|
||||
|
||||
|
||||
def get_max_minicpmv_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||
hf_config = ctx.get_hf_config()
|
||||
return getattr(hf_config, "query_num", 64)
|
||||
|
||||
|
||||
@ -420,7 +420,7 @@ def dummy_image_for_minicpmv(hf_config: PretrainedConfig):
|
||||
|
||||
|
||||
def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int):
|
||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||
hf_config = ctx.get_hf_config()
|
||||
|
||||
seq_data = dummy_seq_data_for_minicpmv(seq_len)
|
||||
mm_data = dummy_image_for_minicpmv(hf_config)
|
||||
|
||||
@ -341,7 +341,7 @@ def get_phi3v_image_feature_size(
|
||||
def get_max_phi3v_image_tokens(ctx: InputContext):
|
||||
|
||||
return get_phi3v_image_feature_size(
|
||||
ctx.get_hf_config(PretrainedConfig),
|
||||
ctx.get_hf_config(),
|
||||
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
)
|
||||
@ -391,7 +391,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
return llm_inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||
hf_config = ctx.get_hf_config()
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
|
||||
@ -3,13 +3,12 @@ from typing import List, Optional, Tuple, TypeVar
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.image_processor import get_image_processor
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .base import MultiModalInputs, MultiModalPlugin
|
||||
@ -40,7 +39,7 @@ def repeat_and_pad_token(
|
||||
|
||||
|
||||
def repeat_and_pad_image_tokens(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: List[int],
|
||||
*,
|
||||
|
||||
@ -4,9 +4,10 @@ pynvml. However, it should not initialize cuda context.
|
||||
|
||||
import os
|
||||
from functools import lru_cache, wraps
|
||||
from typing import List, Tuple
|
||||
from typing import Callable, List, Tuple, TypeVar
|
||||
|
||||
import pynvml
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@ -14,16 +15,19 @@ from .interface import Platform, PlatformEnum
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
# NVML utils
|
||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
# all the related functions work on real physical device ids.
|
||||
# the major benefit of using NVML is that it will not initialize CUDA
|
||||
|
||||
|
||||
def with_nvml_context(fn):
|
||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
|
||||
from .tokenizer import AnyTokenizer
|
||||
from .tokenizer_group import BaseTokenizerGroup
|
||||
|
||||
# Used eg. for marking rejected tokens in spec decoding.
|
||||
INVALID_TOKEN_ID = -1
|
||||
@ -16,8 +15,7 @@ class Detokenizer:
|
||||
def __init__(self, tokenizer_group: BaseTokenizerGroup):
|
||||
self.tokenizer_group = tokenizer_group
|
||||
|
||||
def get_tokenizer_for_seq(self,
|
||||
sequence: Sequence) -> "PreTrainedTokenizer":
|
||||
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
|
||||
"""Returns the HF tokenizer to use for a given sequence."""
|
||||
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
||||
|
||||
@ -174,7 +172,7 @@ def _replace_none_with_empty(tokens: List[Optional[str]]):
|
||||
|
||||
|
||||
def _convert_tokens_to_string_with_added_encoders(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
tokenizer: AnyTokenizer,
|
||||
output_tokens: List[str],
|
||||
skip_special_tokens: bool,
|
||||
spaces_between_special_tokens: bool,
|
||||
@ -213,7 +211,7 @@ INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
|
||||
def convert_prompt_ids_to_tokens(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_ids: List[int],
|
||||
skip_special_tokens: bool = False,
|
||||
) -> Tuple[List[str], int, int]:
|
||||
@ -240,7 +238,7 @@ def convert_prompt_ids_to_tokens(
|
||||
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
|
||||
# under Apache 2.0 license
|
||||
def detokenize_incrementally(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
tokenizer: AnyTokenizer,
|
||||
all_input_ids: List[int],
|
||||
prev_tokens: Optional[List[str]],
|
||||
prefix_offset: int,
|
||||
|
||||
@ -12,10 +12,10 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
|
||||
from vllm.utils import make_async
|
||||
|
||||
from .tokenizer_group import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||
|
||||
|
||||
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
||||
"""Get tokenizer with cached properties.
|
||||
@ -141,7 +141,7 @@ def get_tokenizer(
|
||||
|
||||
|
||||
def get_lora_tokenizer(lora_request: LoRARequest, *args,
|
||||
**kwargs) -> Optional[PreTrainedTokenizer]:
|
||||
**kwargs) -> Optional[AnyTokenizer]:
|
||||
if lora_request is None:
|
||||
return None
|
||||
try:
|
||||
|
||||
@ -8,8 +8,7 @@ from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
|
||||
from .tokenizer_group import TokenizerGroup
|
||||
|
||||
if ray:
|
||||
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
|
||||
RayTokenizerGroupPool)
|
||||
from .ray_tokenizer_group import RayTokenizerGroupPool
|
||||
else:
|
||||
RayTokenizerGroupPool = None # type: ignore
|
||||
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.config import TokenizerPoolConfig
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
class BaseTokenizerGroup(ABC):
|
||||
@ -24,9 +21,10 @@ class BaseTokenizerGroup(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_max_input_len(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> Optional[int]:
|
||||
def get_max_input_len(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> Optional[int]:
|
||||
"""Get the maximum input length for the LoRA request."""
|
||||
pass
|
||||
|
||||
|
||||
@ -13,8 +13,9 @@ from vllm.config import TokenizerPoolConfig
|
||||
from vllm.executor.ray_utils import ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
|
||||
from .base_tokenizer_group import BaseTokenizerGroup
|
||||
from .tokenizer_group import TokenizerGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -2,12 +2,13 @@ from typing import List, Optional
|
||||
|
||||
from vllm.config import TokenizerPoolConfig
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
get_lora_tokenizer,
|
||||
get_lora_tokenizer_async,
|
||||
get_tokenizer)
|
||||
from vllm.utils import LRUCache
|
||||
|
||||
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
|
||||
from .base_tokenizer_group import BaseTokenizerGroup
|
||||
|
||||
|
||||
class TokenizerGroup(BaseTokenizerGroup):
|
||||
|
||||
@ -1101,9 +1101,9 @@ def cuda_device_count_stateless() -> int:
|
||||
|
||||
|
||||
#From: https://stackoverflow.com/a/4104188/2749989
|
||||
def run_once(f):
|
||||
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||
if not wrapper.has_run: # type: ignore[attr-defined]
|
||||
wrapper.has_run = True # type: ignore[attr-defined]
|
||||
return f(*args, **kwargs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user