vllm/vllm/transformers_utils/tokenizer_group/__init__.py

34 lines
1.2 KiB
Python

from typing import Optional
from vllm.config import TokenizerPoolConfig
from vllm.executor.ray_utils import ray
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
if ray:
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
RayTokenizerGroupPool)
else:
RayTokenizerGroupPool = None # type: ignore
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> BaseTokenizerGroup:
if tokenizer_pool_config is None:
return TokenizerGroup(**init_kwargs)
if tokenizer_pool_config.pool_type == "ray":
if RayTokenizerGroupPool is None:
raise ImportError(
"RayTokenizerGroupPool is not available. Please install "
"the ray package to use the Ray tokenizer group pool.")
return RayTokenizerGroupPool.from_config(tokenizer_pool_config,
**init_kwargs)
else:
raise ValueError(
f"Unknown pool type: {tokenizer_pool_config.pool_type}")
__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"]