[mypy] Misc. typing improvements (#7417)
This commit is contained in:
parent
198d6a2898
commit
9ba85bc152
@ -1,10 +1,12 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import gc
|
import gc
|
||||||
|
from typing import Callable, TypeVar
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from vllm.distributed import (destroy_distributed_environment,
|
from vllm.distributed import (destroy_distributed_environment,
|
||||||
destroy_model_parallel)
|
destroy_model_parallel)
|
||||||
@ -22,12 +24,16 @@ def cleanup():
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def retry_until_skip(n):
|
_P = ParamSpec("_P")
|
||||||
|
_R = TypeVar("_R")
|
||||||
|
|
||||||
def decorator_retry(func):
|
|
||||||
|
def retry_until_skip(n: int):
|
||||||
|
|
||||||
|
def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||||
|
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapper_retry(*args, **kwargs):
|
def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
try:
|
try:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
@ -35,7 +41,9 @@ def retry_until_skip(n):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
if i == n - 1:
|
if i == n - 1:
|
||||||
pytest.skip("Skipping test after attempts..")
|
pytest.skip(f"Skipping test after {n} attempts.")
|
||||||
|
|
||||||
|
raise AssertionError("Code should not be reached")
|
||||||
|
|
||||||
return wrapper_retry
|
return wrapper_retry
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import sys
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
|
from typing import AsyncIterator, Tuple
|
||||||
Tuple, TypeVar)
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -13,26 +11,11 @@ from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
|
|||||||
|
|
||||||
from .utils import error_on_warning
|
from .utils import error_on_warning
|
||||||
|
|
||||||
if sys.version_info < (3, 10):
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
_AwaitableT = TypeVar("_AwaitableT", bound=Awaitable[Any])
|
|
||||||
_AwaitableT_co = TypeVar("_AwaitableT_co",
|
|
||||||
bound=Awaitable[Any],
|
|
||||||
covariant=True)
|
|
||||||
|
|
||||||
class _SupportsSynchronousAnext(Protocol[_AwaitableT_co]):
|
|
||||||
|
|
||||||
def __anext__(self) -> _AwaitableT_co:
|
|
||||||
...
|
|
||||||
|
|
||||||
def anext(i: "_SupportsSynchronousAnext[_AwaitableT]", /) -> "_AwaitableT":
|
|
||||||
return i.__anext__()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_merge_async_iterators():
|
async def test_merge_async_iterators():
|
||||||
|
|
||||||
async def mock_async_iterator(idx: int) -> AsyncIterator[str]:
|
async def mock_async_iterator(idx: int):
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
yield f"item from iterator {idx}"
|
yield f"item from iterator {idx}"
|
||||||
@ -41,8 +24,10 @@ async def test_merge_async_iterators():
|
|||||||
print(f"iterator {idx} cancelled")
|
print(f"iterator {idx} cancelled")
|
||||||
|
|
||||||
iterators = [mock_async_iterator(i) for i in range(3)]
|
iterators = [mock_async_iterator(i) for i in range(3)]
|
||||||
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
|
merged_iterator = merge_async_iterators(*iterators,
|
||||||
*iterators, is_cancelled=partial(asyncio.sleep, 0, result=False))
|
is_cancelled=partial(asyncio.sleep,
|
||||||
|
0,
|
||||||
|
result=False))
|
||||||
|
|
||||||
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
|
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
|
||||||
async for idx, output in generator:
|
async for idx, output in generator:
|
||||||
@ -56,7 +41,8 @@ async def test_merge_async_iterators():
|
|||||||
|
|
||||||
for iterator in iterators:
|
for iterator in iterators:
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(anext(iterator), 1)
|
# Can use anext() in python >= 3.10
|
||||||
|
await asyncio.wait_for(iterator.__anext__(), 1)
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
# All iterators should be cancelled and print this message.
|
# All iterators should be cancelled and print this message.
|
||||||
print("Iterator was cancelled normally")
|
print("Iterator was cancelled normally")
|
||||||
|
|||||||
@ -7,12 +7,13 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import ray
|
import ray
|
||||||
import requests
|
import requests
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
@ -360,13 +361,17 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
|
|||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
def fork_new_process_for_each_test(f):
|
_P = ParamSpec("_P")
|
||||||
|
|
||||||
|
|
||||||
|
def fork_new_process_for_each_test(
|
||||||
|
f: Callable[_P, None]) -> Callable[_P, None]:
|
||||||
"""Decorator to fork a new process for each test function.
|
"""Decorator to fork a new process for each test function.
|
||||||
See https://github.com/vllm-project/vllm/issues/7053 for more details.
|
See https://github.com/vllm-project/vllm/issues/7053 for more details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
|
||||||
# Make the process the leader of its own process group
|
# Make the process the leader of its own process group
|
||||||
# to avoid sending SIGTERM to the parent process
|
# to avoid sending SIGTERM to the parent process
|
||||||
os.setpgrp()
|
os.setpgrp()
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
import functools
|
import functools
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type,
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type
|
||||||
TypeVar)
|
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
C = TypeVar("C", bound=PretrainedConfig)
|
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -44,7 +44,7 @@ class InputContext:
|
|||||||
|
|
||||||
return multimodal_config
|
return multimodal_config
|
||||||
|
|
||||||
def get_hf_config(self, hf_config_type: Type[C]) -> C:
|
def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C:
|
||||||
"""
|
"""
|
||||||
Get the HuggingFace configuration
|
Get the HuggingFace configuration
|
||||||
(:class:`transformers.PretrainedConfig`) of the model,
|
(:class:`transformers.PretrainedConfig`) of the model,
|
||||||
|
|||||||
@ -165,7 +165,7 @@ def get_internvl_num_patches(image_size: int, patch_size: int,
|
|||||||
|
|
||||||
|
|
||||||
def get_max_internvl_image_tokens(ctx: InputContext):
|
def get_max_internvl_image_tokens(ctx: InputContext):
|
||||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
hf_config = ctx.get_hf_config()
|
||||||
vision_config = hf_config.vision_config
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
use_thumbnail = hf_config.use_thumbnail
|
use_thumbnail = hf_config.use_thumbnail
|
||||||
@ -187,7 +187,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
return llm_inputs
|
return llm_inputs
|
||||||
|
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
hf_config = ctx.get_hf_config()
|
||||||
vision_config = hf_config.vision_config
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
image_size = vision_config.image_size
|
image_size = vision_config.image_size
|
||||||
@ -260,7 +260,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
|
|||||||
|
|
||||||
image_feature_size = get_max_internvl_image_tokens(ctx)
|
image_feature_size = get_max_internvl_image_tokens(ctx)
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
hf_config = ctx.get_hf_config()
|
||||||
vision_config = hf_config.vision_config
|
vision_config = hf_config.vision_config
|
||||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
|
|||||||
@ -34,7 +34,7 @@ import torch.types
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.init import trunc_normal_
|
from torch.nn.init import trunc_normal_
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
@ -404,7 +404,7 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
|
|||||||
|
|
||||||
|
|
||||||
def get_max_minicpmv_image_tokens(ctx: InputContext):
|
def get_max_minicpmv_image_tokens(ctx: InputContext):
|
||||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
hf_config = ctx.get_hf_config()
|
||||||
return getattr(hf_config, "query_num", 64)
|
return getattr(hf_config, "query_num", 64)
|
||||||
|
|
||||||
|
|
||||||
@ -420,7 +420,7 @@ def dummy_image_for_minicpmv(hf_config: PretrainedConfig):
|
|||||||
|
|
||||||
|
|
||||||
def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int):
|
def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int):
|
||||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
hf_config = ctx.get_hf_config()
|
||||||
|
|
||||||
seq_data = dummy_seq_data_for_minicpmv(seq_len)
|
seq_data = dummy_seq_data_for_minicpmv(seq_len)
|
||||||
mm_data = dummy_image_for_minicpmv(hf_config)
|
mm_data = dummy_image_for_minicpmv(hf_config)
|
||||||
|
|||||||
@ -341,7 +341,7 @@ def get_phi3v_image_feature_size(
|
|||||||
def get_max_phi3v_image_tokens(ctx: InputContext):
|
def get_max_phi3v_image_tokens(ctx: InputContext):
|
||||||
|
|
||||||
return get_phi3v_image_feature_size(
|
return get_phi3v_image_feature_size(
|
||||||
ctx.get_hf_config(PretrainedConfig),
|
ctx.get_hf_config(),
|
||||||
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||||
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||||
)
|
)
|
||||||
@ -391,7 +391,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
return llm_inputs
|
return llm_inputs
|
||||||
|
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
hf_config = ctx.get_hf_config()
|
||||||
|
|
||||||
image_data = multi_modal_data["image"]
|
image_data = multi_modal_data["image"]
|
||||||
if isinstance(image_data, Image.Image):
|
if isinstance(image_data, Image.Image):
|
||||||
|
|||||||
@ -3,13 +3,12 @@ from typing import List, Optional, Tuple, TypeVar
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import PreTrainedTokenizerBase
|
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.inputs.registry import InputContext
|
from vllm.inputs.registry import InputContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.image_processor import get_image_processor
|
from vllm.transformers_utils.image_processor import get_image_processor
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .base import MultiModalInputs, MultiModalPlugin
|
from .base import MultiModalInputs, MultiModalPlugin
|
||||||
@ -40,7 +39,7 @@ def repeat_and_pad_token(
|
|||||||
|
|
||||||
|
|
||||||
def repeat_and_pad_image_tokens(
|
def repeat_and_pad_image_tokens(
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: AnyTokenizer,
|
||||||
prompt: Optional[str],
|
prompt: Optional[str],
|
||||||
prompt_token_ids: List[int],
|
prompt_token_ids: List[int],
|
||||||
*,
|
*,
|
||||||
|
|||||||
@ -4,9 +4,10 @@ pynvml. However, it should not initialize cuda context.
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from functools import lru_cache, wraps
|
from functools import lru_cache, wraps
|
||||||
from typing import List, Tuple
|
from typing import Callable, List, Tuple, TypeVar
|
||||||
|
|
||||||
import pynvml
|
import pynvml
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
@ -14,16 +15,19 @@ from .interface import Platform, PlatformEnum
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_P = ParamSpec("_P")
|
||||||
|
_R = TypeVar("_R")
|
||||||
|
|
||||||
# NVML utils
|
# NVML utils
|
||||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||||
# all the related functions work on real physical device ids.
|
# all the related functions work on real physical device ids.
|
||||||
# the major benefit of using NVML is that it will not initialize CUDA
|
# the major benefit of using NVML is that it will not initialize CUDA
|
||||||
|
|
||||||
|
|
||||||
def with_nvml_context(fn):
|
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||||
|
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||||
pynvml.nvmlInit()
|
pynvml.nvmlInit()
|
||||||
try:
|
try:
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
||||||
|
|
||||||
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
||||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
|
||||||
BaseTokenizerGroup)
|
from .tokenizer import AnyTokenizer
|
||||||
|
from .tokenizer_group import BaseTokenizerGroup
|
||||||
|
|
||||||
# Used eg. for marking rejected tokens in spec decoding.
|
# Used eg. for marking rejected tokens in spec decoding.
|
||||||
INVALID_TOKEN_ID = -1
|
INVALID_TOKEN_ID = -1
|
||||||
@ -16,8 +15,7 @@ class Detokenizer:
|
|||||||
def __init__(self, tokenizer_group: BaseTokenizerGroup):
|
def __init__(self, tokenizer_group: BaseTokenizerGroup):
|
||||||
self.tokenizer_group = tokenizer_group
|
self.tokenizer_group = tokenizer_group
|
||||||
|
|
||||||
def get_tokenizer_for_seq(self,
|
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
|
||||||
sequence: Sequence) -> "PreTrainedTokenizer":
|
|
||||||
"""Returns the HF tokenizer to use for a given sequence."""
|
"""Returns the HF tokenizer to use for a given sequence."""
|
||||||
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
||||||
|
|
||||||
@ -174,7 +172,7 @@ def _replace_none_with_empty(tokens: List[Optional[str]]):
|
|||||||
|
|
||||||
|
|
||||||
def _convert_tokens_to_string_with_added_encoders(
|
def _convert_tokens_to_string_with_added_encoders(
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: AnyTokenizer,
|
||||||
output_tokens: List[str],
|
output_tokens: List[str],
|
||||||
skip_special_tokens: bool,
|
skip_special_tokens: bool,
|
||||||
spaces_between_special_tokens: bool,
|
spaces_between_special_tokens: bool,
|
||||||
@ -213,7 +211,7 @@ INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
|||||||
|
|
||||||
|
|
||||||
def convert_prompt_ids_to_tokens(
|
def convert_prompt_ids_to_tokens(
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: AnyTokenizer,
|
||||||
prompt_ids: List[int],
|
prompt_ids: List[int],
|
||||||
skip_special_tokens: bool = False,
|
skip_special_tokens: bool = False,
|
||||||
) -> Tuple[List[str], int, int]:
|
) -> Tuple[List[str], int, int]:
|
||||||
@ -240,7 +238,7 @@ def convert_prompt_ids_to_tokens(
|
|||||||
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
|
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
|
||||||
# under Apache 2.0 license
|
# under Apache 2.0 license
|
||||||
def detokenize_incrementally(
|
def detokenize_incrementally(
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: AnyTokenizer,
|
||||||
all_input_ids: List[int],
|
all_input_ids: List[int],
|
||||||
prev_tokens: Optional[List[str]],
|
prev_tokens: Optional[List[str]],
|
||||||
prefix_offset: int,
|
prefix_offset: int,
|
||||||
|
|||||||
@ -12,10 +12,10 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
|
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
|
||||||
from vllm.utils import make_async
|
from vllm.utils import make_async
|
||||||
|
|
||||||
from .tokenizer_group import AnyTokenizer
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||||
|
|
||||||
|
|
||||||
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
||||||
"""Get tokenizer with cached properties.
|
"""Get tokenizer with cached properties.
|
||||||
@ -141,7 +141,7 @@ def get_tokenizer(
|
|||||||
|
|
||||||
|
|
||||||
def get_lora_tokenizer(lora_request: LoRARequest, *args,
|
def get_lora_tokenizer(lora_request: LoRARequest, *args,
|
||||||
**kwargs) -> Optional[PreTrainedTokenizer]:
|
**kwargs) -> Optional[AnyTokenizer]:
|
||||||
if lora_request is None:
|
if lora_request is None:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -8,8 +8,7 @@ from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
|
|||||||
from .tokenizer_group import TokenizerGroup
|
from .tokenizer_group import TokenizerGroup
|
||||||
|
|
||||||
if ray:
|
if ray:
|
||||||
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
|
from .ray_tokenizer_group import RayTokenizerGroupPool
|
||||||
RayTokenizerGroupPool)
|
|
||||||
else:
|
else:
|
||||||
RayTokenizerGroupPool = None # type: ignore
|
RayTokenizerGroupPool = None # type: ignore
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,9 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
||||||
|
|
||||||
from vllm.config import TokenizerPoolConfig
|
from vllm.config import TokenizerPoolConfig
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTokenizerGroup(ABC):
|
class BaseTokenizerGroup(ABC):
|
||||||
@ -24,9 +21,10 @@ class BaseTokenizerGroup(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_max_input_len(self,
|
def get_max_input_len(
|
||||||
lora_request: Optional[LoRARequest] = None
|
self,
|
||||||
) -> Optional[int]:
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
) -> Optional[int]:
|
||||||
"""Get the maximum input length for the LoRA request."""
|
"""Get the maximum input length for the LoRA request."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -13,8 +13,9 @@ from vllm.config import TokenizerPoolConfig
|
|||||||
from vllm.executor.ray_utils import ray
|
from vllm.executor.ray_utils import ray
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
|
from .base_tokenizer_group import BaseTokenizerGroup
|
||||||
from .tokenizer_group import TokenizerGroup
|
from .tokenizer_group import TokenizerGroup
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|||||||
@ -2,12 +2,13 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from vllm.config import TokenizerPoolConfig
|
from vllm.config import TokenizerPoolConfig
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
|
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||||
|
get_lora_tokenizer,
|
||||||
get_lora_tokenizer_async,
|
get_lora_tokenizer_async,
|
||||||
get_tokenizer)
|
get_tokenizer)
|
||||||
from vllm.utils import LRUCache
|
from vllm.utils import LRUCache
|
||||||
|
|
||||||
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
|
from .base_tokenizer_group import BaseTokenizerGroup
|
||||||
|
|
||||||
|
|
||||||
class TokenizerGroup(BaseTokenizerGroup):
|
class TokenizerGroup(BaseTokenizerGroup):
|
||||||
|
|||||||
@ -1101,9 +1101,9 @@ def cuda_device_count_stateless() -> int:
|
|||||||
|
|
||||||
|
|
||||||
#From: https://stackoverflow.com/a/4104188/2749989
|
#From: https://stackoverflow.com/a/4104188/2749989
|
||||||
def run_once(f):
|
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||||
|
|
||||||
def wrapper(*args, **kwargs) -> Any:
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||||
if not wrapper.has_run: # type: ignore[attr-defined]
|
if not wrapper.has_run: # type: ignore[attr-defined]
|
||||||
wrapper.has_run = True # type: ignore[attr-defined]
|
wrapper.has_run = True # type: ignore[attr-defined]
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user