[Misc] use AutoTokenizer for benchmark serving when vLLM not installed (#5588)

This commit is contained in:
zhyncs 2024-06-18 00:40:35 +08:00 committed by GitHub
parent 890d8d960b
commit 1f12122b17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 2 deletions

View File

@ -4,10 +4,13 @@ import sys
import time import time
import traceback import traceback
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional from typing import List, Optional, Union
import aiohttp import aiohttp
import huggingface_hub.constants
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
@ -388,6 +391,30 @@ def remove_prefix(text: str, prefix: str) -> str:
return text return text
def get_model(pretrained_model_name_or_path: str):
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
from modelscope import snapshot_download
else:
from huggingface_hub import snapshot_download
model_path = snapshot_download(
model_id=pretrained_model_name_or_path,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
return model_path
def get_tokenizer(
pretrained_model_name_or_path: str, trust_remote_code: bool
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
if pretrained_model_name_or_path is not None and not os.path.exists(
pretrained_model_name_or_path):
pretrained_model_name_or_path = get_model(
pretrained_model_name_or_path)
return AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
trust_remote_code=trust_remote_code)
ASYNC_REQUEST_FUNCS = { ASYNC_REQUEST_FUNCS = {
"tgi": async_request_tgi, "tgi": async_request_tgi,
"vllm": async_request_openai_completions, "vllm": async_request_openai_completions,

View File

@ -39,7 +39,10 @@ from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
try:
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
except ImportError:
from backend_request_func import get_tokenizer
@dataclass @dataclass