[Mypy] Part 3 fix typing for nested directories for most of directory (#4161)

This commit is contained in:
SangBin Cho 2024-04-23 13:32:44 +09:00 committed by GitHub
parent 34128a697e
commit 0ae11f78ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 126 additions and 88 deletions

View File

@ -32,19 +32,20 @@ jobs:
pip install types-setuptools pip install types-setuptools
- name: Mypy - name: Mypy
run: | run: |
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/attention --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml # TODO(sang): Fix nested dir
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/model_executor/*.py --config-file pyproject.toml
# TODO(sang): Follow up # TODO(sang): Fix nested dir
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml # mypy vllm/lora/*.py --config-file pyproject.toml

View File

@ -94,21 +94,19 @@ echo 'vLLM yapf: Done'
# Run mypy # Run mypy
echo 'vLLM mypy:' echo 'vLLM mypy:'
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/attention --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
# TODO(sang): Follow up mypy vllm/worker --config-file pyproject.toml
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/model_executor/*.py --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml # mypy vllm/lora/*.py --config-file pyproject.toml
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
CODESPELL_EXCLUDES=( CODESPELL_EXCLUDES=(

View File

@ -47,14 +47,16 @@ python_version = "3.8"
ignore_missing_imports = true ignore_missing_imports = true
check_untyped_defs = true check_untyped_defs = true
follow_imports = "skip"
files = "vllm" files = "vllm"
# TODO(woosuk): Include the code from Megatron and HuggingFace. # TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = [ exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/", "vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
# Ignore triton kernels in ops.
'vllm/attention/ops/.*\.py$'
] ]
[tool.codespell] [tool.codespell]
ignore-words-list = "dout, te, indicies" ignore-words-list = "dout, te, indicies"
skip = "./tests/prompts,./benchmarks/sonnet.txt" skip = "./tests/prompts,./benchmarks/sonnet.txt"

View File

@ -116,7 +116,7 @@ class AttentionImpl(ABC):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[AttentionMetadataPerStage], attn_metadata: AttentionMetadata,
kv_scale: float, kv_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError

View File

@ -248,6 +248,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
assert prefill_meta.prompt_lens is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0: if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# triton attention # triton attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the

View File

@ -106,7 +106,7 @@ class TorchSDPABackendImpl(AttentionImpl):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata, attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float, kv_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
@ -136,6 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl):
kv_scale) kv_scale)
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
assert attn_metadata.prompt_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1) key = key.repeat_interleave(self.num_queries_per_kv, dim=1)

View File

@ -288,6 +288,7 @@ class XFormersImpl(AttentionImpl):
value: shape = [num_prefill_tokens, num_kv_heads, head_size] value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
""" """
assert attn_metadata.prompt_lens is not None
original_query = query original_query = query
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K]. # GQA/MQA requires the shape [B, M, G, H, K].

View File

@ -104,6 +104,7 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be appended. token_ids (List[int]): The sequence of token IDs to be appended.
""" """
assert self._is_allocated assert self._is_allocated
assert self._blocks is not None
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots) num_lookahead_slots)

View File

@ -99,7 +99,7 @@ class CopyOnWriteTracker:
refcounter: RefCounter, refcounter: RefCounter,
allocator: BlockAllocator, allocator: BlockAllocator,
): ):
self._copy_on_writes = defaultdict(list) self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
self._refcounter = refcounter self._refcounter = refcounter
self._allocator = allocator self._allocator = allocator
@ -138,6 +138,8 @@ class CopyOnWriteTracker:
prev_block=block.prev_block).block_id prev_block=block.prev_block).block_id
# Track src/dst copy. # Track src/dst copy.
assert src_block_id is not None
assert block_id is not None
self._copy_on_writes[src_block_id].append(block_id) self._copy_on_writes[src_block_id].append(block_id)
return block_id return block_id
@ -180,6 +182,6 @@ def get_all_blocks_recursively(last_block: Block) -> List[Block]:
recurse(block.prev_block, lst) recurse(block.prev_block, lst)
lst.append(block) lst.append(block)
all_blocks = [] all_blocks: List[Block] = []
recurse(last_block, all_blocks) recurse(last_block, all_blocks)
return all_blocks return all_blocks

View File

@ -52,8 +52,7 @@ class Block(ABC):
class BlockAllocator(ABC): class BlockAllocator(ABC):
@abstractmethod @abstractmethod
def allocate_mutable(self, prev_block: Optional[Block], def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
device: Device) -> Block:
pass pass
@abstractmethod @abstractmethod
@ -98,8 +97,7 @@ class BlockAllocator(ABC):
class DeviceAwareBlockAllocator(BlockAllocator): class DeviceAwareBlockAllocator(BlockAllocator):
@abstractmethod @abstractmethod
def allocate_mutable(self, prev_block: Optional[Block], def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
device: Device) -> Block:
pass pass
@abstractmethod @abstractmethod

View File

@ -1,6 +1,6 @@
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional from typing import Any, List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -18,7 +18,7 @@ except ImportError:
logger = init_logger(__name__) logger = init_logger(__name__)
_CA_HANDLE = None _CA_HANDLE: Optional["CustomAllreduce"] = None
_IS_CAPTURING = False _IS_CAPTURING = False
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
@ -51,7 +51,7 @@ def init_custom_ar() -> None:
"Cannot test GPU P2P because not all GPUs are visible to the " "Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.") " is set.")
return False return
# test nvlink first, this will filter out most of the cases # test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported # where custom allreduce is not supported
if "CUDA_VISIBLE_DEVICES" in os.environ: if "CUDA_VISIBLE_DEVICES" in os.environ:
@ -117,7 +117,7 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
ca_handle = get_handle() ca_handle = get_handle()
# when custom allreduce is disabled, this will be None # when custom allreduce is disabled, this will be None
if ca_handle is None: if ca_handle is None:
return return None
if is_capturing(): if is_capturing():
if torch.cuda.is_current_stream_capturing(): if torch.cuda.is_current_stream_capturing():
if ca_handle.should_custom_ar(input): if ca_handle.should_custom_ar(input):
@ -135,6 +135,8 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
if ca_handle.should_custom_ar(input): if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_unreg(input) return ca_handle.all_reduce_unreg(input)
return None
@contextmanager @contextmanager
def _nvml(): def _nvml():
@ -224,14 +226,14 @@ class CustomAllreduce:
return self._gather_ipc_meta(shard_data) return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data): def _gather_ipc_meta(self, shard_data):
all_data = [None] * self.world_size all_data: List[Optional[Any]] = [None] * self.world_size
dist.all_gather_object(all_data, shard_data) dist.all_gather_object(all_data, shard_data)
handles = [] handles = []
offsets = [] offsets = []
for i in range(len(all_data)): for i in range(len(all_data)):
handles.append(all_data[i][0]) handles.append(all_data[i][0]) # type: ignore
offsets.append(all_data[i][1]) offsets.append(all_data[i][1]) # type: ignore
return handles, offsets return handles, offsets
def register_buffer(self, inp: torch.Tensor): def register_buffer(self, inp: torch.Tensor):

View File

@ -107,9 +107,10 @@ _c_ncclCommInitRank.argtypes = [
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
] ]
ncclDataType_t = ctypes.c_int
# enums
class ncclDataType_t(ctypes.c_int): class ncclDataTypeEnum:
ncclInt8 = 0 ncclInt8 = 0
ncclChar = 0 ncclChar = 0
ncclUint8 = 1 ncclUint8 = 1
@ -128,7 +129,7 @@ class ncclDataType_t(ctypes.c_int):
ncclNumTypes = 10 ncclNumTypes = 10
@classmethod @classmethod
def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t': def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8: if dtype == torch.int8:
return cls.ncclInt8 return cls.ncclInt8
if dtype == torch.uint8: if dtype == torch.uint8:
@ -148,7 +149,10 @@ class ncclDataType_t(ctypes.c_int):
raise ValueError(f"Unsupported dtype: {dtype}") raise ValueError(f"Unsupported dtype: {dtype}")
class ncclRedOp_t(ctypes.c_int): ncclRedOp_t = ctypes.c_int
class ncclRedOpTypeEnum:
ncclSum = 0 ncclSum = 0
ncclProd = 1 ncclProd = 1
ncclMax = 2 ncclMax = 2
@ -157,7 +161,7 @@ class ncclRedOp_t(ctypes.c_int):
ncclNumOps = 5 ncclNumOps = 5
@classmethod @classmethod
def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t': def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM: if op == ReduceOp.SUM:
return cls.ncclSum return cls.ncclSum
if op == ReduceOp.PRODUCT: if op == ReduceOp.PRODUCT:
@ -180,8 +184,8 @@ class ncclRedOp_t(ctypes.c_int):
_c_ncclAllReduce = nccl.ncclAllReduce _c_ncclAllReduce = nccl.ncclAllReduce
_c_ncclAllReduce.restype = ctypes.c_int _c_ncclAllReduce.restype = ctypes.c_int
_c_ncclAllReduce.argtypes = [ _c_ncclAllReduce.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
] ]
# equivalent to c declaration: # equivalent to c declaration:
@ -251,8 +255,8 @@ class NCCLCommunicator:
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()), ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(), tensor.numel(),
ncclDataType_t.from_torch(tensor.dtype), ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOp_t.from_torch(op), self.comm, ncclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream)) ctypes.c_void_p(stream.cuda_stream))
assert result == 0 assert result == 0

View File

@ -30,6 +30,7 @@ def is_initialized() -> bool:
def set_pynccl_stream(stream: torch.cuda.Stream): def set_pynccl_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication""" """Set the cuda stream for communication"""
try: try:
assert comm is not None
comm.stream = stream comm.stream = stream
yield yield
finally: finally:
@ -52,6 +53,7 @@ def init_process_group(world_size: int,
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group.""" """All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor" assert input_.is_cuda, f"{input_} should be a cuda tensor"
assert comm is not None
comm.all_reduce(input_, op) comm.all_reduce(input_, op)
@ -62,8 +64,9 @@ def destroy_process_group() -> None:
def get_world_size() -> int: def get_world_size() -> int:
"""Returns the world size.""" """Returns the world size."""
assert comm is not None
return comm.world_size return comm.world_size
def get_nccl_backend(): def get_nccl_backend() -> Optional["NCCLCommunicator"]:
return comm return comm

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Iterable, List from typing import Callable, List
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -8,6 +8,7 @@ from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
class SequenceGroupOutputProcessor(ABC): class SequenceGroupOutputProcessor(ABC):
@ -27,7 +28,7 @@ class SequenceGroupOutputProcessor(ABC):
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: Scheduler, scheduler: Scheduler,
seq_counter: Iterable[int], seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: "StopChecker", stop_checker: "StopChecker",
): ):

View File

@ -1,4 +1,4 @@
from typing import Callable, Iterable, List from typing import Callable, List
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -11,6 +11,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, Sequence, SequenceGroup, from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus) SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
logger = init_logger(__name__) logger = init_logger(__name__)
@ -33,7 +34,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
self, self,
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: Scheduler, scheduler: Scheduler,
seq_counter: Iterable[int], seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: StopChecker, stop_checker: StopChecker,
): ):

View File

@ -1,4 +1,4 @@
from typing import Iterable, List, Tuple, Union from typing import Dict, List, Tuple, Union
from vllm.config import SchedulerConfig from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
@ -10,6 +10,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus) SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
logger = init_logger(__name__) logger = init_logger(__name__)
@ -33,7 +34,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: Scheduler, scheduler: Scheduler,
seq_counter: Iterable[int], seq_counter: Counter,
stop_checker: StopChecker, stop_checker: StopChecker,
): ):
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
@ -69,7 +70,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
samples = outputs.samples samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs() existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = { parent_child_dict: Dict[int, List[SequenceOutput]] = {
parent_seq.seq_id: [] parent_seq.seq_id: []
for parent_seq in parent_seqs for parent_seq in parent_seqs
} }
@ -92,7 +93,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
continue continue
# Fork the parent sequence if there are multiple child samples. # Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]: for child_sample in child_samples[:-1]:
new_child_seq_id = next(self.seq_counter) new_child_seq_id: int = next(self.seq_counter)
child = parent.fork(new_child_seq_id) child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token, child.append_token_id(child_sample.output_token,
child_sample.logprobs) child_sample.logprobs)

View File

@ -8,7 +8,9 @@ def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
"""Helper method which transforms a 2d list organized by """Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step]. [step][sequence group] into [sequence group][step].
""" """
output_by_sequence_group = [[] for _ in range(num_seq_groups)] output_by_sequence_group: List[List[SamplerOutput]] = [
[] for _ in range(num_seq_groups)
]
for step in sampler_outputs: for step in sampler_outputs:
for i, sequence_group_output in enumerate(step): for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output) output_by_sequence_group[i].append(sequence_group_output)

View File

@ -18,6 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest, ErrorResponse) CompletionRequest, ErrorResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
@ -26,8 +27,8 @@ from vllm.usage.usage_lib import UsageContext
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat: OpenAIServingChat = None openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion = None openai_serving_completion: OpenAIServingCompletion
logger = init_logger(__name__) logger = init_logger(__name__)
@ -95,6 +96,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
return StreamingResponse(content=generator, return StreamingResponse(content=generator,
media_type="text/event-stream") media_type="text/event-stream")
else: else:
assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())

View File

@ -4,7 +4,8 @@ import time
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
import torch import torch
from pydantic import BaseModel, Field, conint, model_validator from pydantic import BaseModel, Field, model_validator
from typing_extensions import Annotated
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
@ -30,7 +31,7 @@ class ModelPermission(BaseModel):
allow_fine_tuning: bool = False allow_fine_tuning: bool = False
organization: str = "*" organization: str = "*"
group: Optional[str] = None group: Optional[str] = None
is_blocking: str = False is_blocking: bool = False
class ModelCard(BaseModel): class ModelCard(BaseModel):
@ -56,7 +57,7 @@ class UsageInfo(BaseModel):
class ResponseFormat(BaseModel): class ResponseFormat(BaseModel):
# type must be "json_object" or "text" # type must be "json_object" or "text"
type: str = Literal["text", "json_object"] type: Literal["text", "json_object"]
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
@ -152,6 +153,7 @@ class ChatCompletionRequest(BaseModel):
def logit_bias_logits_processor( def logit_bias_logits_processor(
token_ids: List[int], token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor: logits: torch.Tensor) -> torch.Tensor:
assert self.logit_bias is not None
for token_id, bias in self.logit_bias.items(): for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec # Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias)) bias = min(100, max(-100, bias))
@ -213,7 +215,7 @@ class CompletionRequest(BaseModel):
logit_bias: Optional[Dict[str, float]] = None logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None logprobs: Optional[int] = None
max_tokens: Optional[int] = 16 max_tokens: Optional[int] = 16
n: Optional[int] = 1 n: int = 1
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
seed: Optional[int] = None seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
@ -235,7 +237,7 @@ class CompletionRequest(BaseModel):
min_tokens: Optional[int] = 0 min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True
truncate_prompt_tokens: Optional[conint(ge=1)] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: end-completion-sampling-params # doc: end-completion-sampling-params
# doc: begin-completion-extra-params # doc: begin-completion-extra-params
@ -289,6 +291,7 @@ class CompletionRequest(BaseModel):
def logit_bias_logits_processor( def logit_bias_logits_processor(
token_ids: List[int], token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor: logits: torch.Tensor) -> torch.Tensor:
assert self.logit_bias is not None
for token_id, bias in self.logit_bias.items(): for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec # Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias)) bias = min(100, max(-100, bias))

View File

@ -115,12 +115,12 @@ class OpenAIServingChat(OpenAIServing):
first_iteration = True first_iteration = True
# Send response for each token for each request.n (index) # Send response for each token for each request.n (index)
assert request.n is not None
previous_texts = [""] * request.n previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n finish_reason_sent = [False] * request.n
try: try:
async for res in result_generator: async for res in result_generator:
res: RequestOutput
# We need to do it here, because if there are exceptions in # We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST # the result_generator, it needs to be sent as the FIRST
# response (by the try...catch). # response (by the try...catch).

View File

@ -185,6 +185,7 @@ class OpenAIServingCompletion(OpenAIServing):
model_name: str, model_name: str,
num_prompts: int, num_prompts: int,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
assert request.n is not None
previous_texts = [""] * request.n * num_prompts previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts has_echoed = [False] * request.n * num_prompts
@ -202,6 +203,7 @@ class OpenAIServingCompletion(OpenAIServing):
# TODO(simon): optimize the performance by avoiding full # TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending. # text O(n^2) sending.
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and request.max_tokens == 0:
# only return the prompt # only return the prompt
delta_text = res.prompt delta_text = res.prompt
@ -279,7 +281,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int, created_time: int,
model_name: str, model_name: str,
) -> CompletionResponse: ) -> CompletionResponse:
choices = [] choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
num_generated_tokens = 0 num_generated_tokens = 0
for final_res in final_res_batch: for final_res in final_res_batch:
@ -289,6 +291,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_text = final_res.prompt prompt_text = final_res.prompt
for output in final_res.outputs: for output in final_res.outputs:
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids token_ids = prompt_token_ids
top_logprobs = prompt_logprobs top_logprobs = prompt_logprobs

View File

@ -4,7 +4,9 @@ from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from pydantic import conint from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
@ -45,7 +47,8 @@ class OpenAIServing:
] ]
self.max_model_len = 0 self.max_model_len = 0
self.tokenizer = None # Lazy initialized
self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
try: try:
event_loop = asyncio.get_running_loop() event_loop = asyncio.get_running_loop()
@ -92,7 +95,7 @@ class OpenAIServing:
def _create_logprobs( def _create_logprobs(
self, self,
token_ids: List[int], token_ids: List[int],
top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None, top_logprobs: List[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: Optional[int] = None, num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0, initial_text_offset: int = 0,
) -> LogProbs: ) -> LogProbs:
@ -108,6 +111,7 @@ class OpenAIServing:
token = self.tokenizer.decode(token_id) token = self.tokenizer.decode(token_id)
logprobs.tokens.append(token) logprobs.tokens.append(token)
logprobs.token_logprobs.append(None) logprobs.token_logprobs.append(None)
assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append(None) logprobs.top_logprobs.append(None)
else: else:
token_logprob = step_top_logprobs[token_id].logprob token_logprob = step_top_logprobs[token_id].logprob
@ -116,6 +120,7 @@ class OpenAIServing:
logprobs.token_logprobs.append(token_logprob) logprobs.token_logprobs.append(token_logprob)
if num_output_top_logprobs: if num_output_top_logprobs:
assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append({ logprobs.top_logprobs.append({
# Convert float("-inf") to the # Convert float("-inf") to the
# JSON-serializable float that OpenAI uses # JSON-serializable float that OpenAI uses
@ -155,9 +160,9 @@ class OpenAIServing:
async def _check_model(self, request) -> Optional[ErrorResponse]: async def _check_model(self, request) -> Optional[ErrorResponse]:
if request.model in self.served_model_names: if request.model in self.served_model_names:
return return None
if request.model in [lora.lora_name for lora in self.lora_requests]: if request.model in [lora.lora_name for lora in self.lora_requests]:
return return None
return self.create_error_response( return self.create_error_response(
message=f"The model `{request.model}` does not exist.", message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError", err_type="NotFoundError",
@ -165,7 +170,7 @@ class OpenAIServing:
def _maybe_get_lora(self, request) -> Optional[LoRARequest]: def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
if request.model in self.served_model_names: if request.model in self.served_model_names:
return return None
for lora in self.lora_requests: for lora in self.lora_requests:
if request.model == lora.lora_name: if request.model == lora.lora_name:
return lora return lora
@ -177,7 +182,7 @@ class OpenAIServing:
request: Union[ChatCompletionRequest, CompletionRequest], request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None, prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None, prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
) -> Tuple[List[int], str]: ) -> Tuple[List[int], str]:
if not (prompt or prompt_ids): if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.") raise ValueError("Either prompt or prompt_ids should be provided.")

View File

@ -33,7 +33,7 @@ class LoRALayerWeights:
def optimize(self) -> "LoRALayerWeights": def optimize(self) -> "LoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b.""" """Optimize the LoRA by merging the scaling into lora_b."""
if self.scaling == 1: if self.scaling == 1:
return return self
self.lora_b *= self.scaling self.lora_b *= self.scaling
self.scaling = 1 self.scaling = 1
return self return self

View File

@ -29,8 +29,8 @@ def _multi_split_sample(
sampled_tokens_size: Tuple[int, int], sampled_tokens_size: Tuple[int, int],
sampled_logprobs_size: Tuple[int, int], sampled_logprobs_size: Tuple[int, int],
sample_indices: torch.Tensor, sample_indices: torch.Tensor,
logprobs: torch.Tensor,
*, *,
logprobs: Optional[torch.Tensor] = None,
modify_greedy_probs: bool = False, modify_greedy_probs: bool = False,
save_logprobs: bool = False, save_logprobs: bool = False,
): ):
@ -167,6 +167,7 @@ def sample(
sampled_logprobs_size = (0, 0) sampled_logprobs_size = (0, 0)
logprobs = probs logprobs = probs
assert logprobs is not None
if _save_modified_probs: if _save_modified_probs:
sampled_modified_probs_size = sampled_tokens_size sampled_modified_probs_size = sampled_tokens_size
else: else:

View File

@ -108,7 +108,8 @@ class RotaryEmbedding(nn.Module):
query_pass = query[..., self.rotary_dim:] query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:]
self.cos_sin_cache = self.cos_sin_cache.to(positions.device) self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets) cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions] if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1) cos, sin = cos_sin.chunk(2, dim=-1)

View File

@ -222,13 +222,15 @@ class JAISConfig(PretrainedConfig):
f"got {alibi_scaling_type}") f"got {alibi_scaling_type}")
if (alibi_scaling_factor is not None if (alibi_scaling_factor is not None
and not isinstance(alibi_scaling_factor, float) and not isinstance(alibi_scaling_factor, float)
or alibi_scaling_factor <= 1.0): or (alibi_scaling_factor is not None
and alibi_scaling_factor <= 1.0)):
raise ValueError( raise ValueError(
f"`alibi_scaling`'s factor field must be a float > 1.0," f"`alibi_scaling`'s factor field must be a float > 1.0,"
f"got {alibi_scaling_factor}") f"got {alibi_scaling_factor}")
if (alibi_dynamic_scaling is not None if (alibi_dynamic_scaling is not None
and not isinstance(alibi_dynamic_scaling, int) and not isinstance(alibi_dynamic_scaling, int)
or alibi_dynamic_scaling <= 1): or (alibi_dynamic_scaling is not None
and alibi_dynamic_scaling <= 1)):
raise ValueError( raise ValueError(
f"`alibi_scaling`'s `train_seq_len` field must be an" f"`alibi_scaling`'s `train_seq_len` field must be an"
f"integer > 1, got {alibi_dynamic_scaling}") f"integer > 1, got {alibi_dynamic_scaling}")

View File

@ -11,7 +11,7 @@ if ray:
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
RayTokenizerGroupPool) RayTokenizerGroupPool)
else: else:
RayTokenizerGroupPool = None RayTokenizerGroupPool = None # type: ignore
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],

View File

@ -89,6 +89,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
This is blocking. This is blocking.
""" """
self._ensure_queue_initialized() self._ensure_queue_initialized()
assert self._idle_actors is not None
if self._idle_actors.empty(): if self._idle_actors.empty():
raise RuntimeError("No idle actors available.") raise RuntimeError("No idle actors available.")
@ -120,6 +121,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
This is non-blocking. This is non-blocking.
""" """
self._ensure_queue_initialized() self._ensure_queue_initialized()
assert self._idle_actors is not None
actor = await self._idle_actors.get() actor = await self._idle_actors.get()
try: try:

View File

@ -114,9 +114,9 @@ class BaichuanTokenizer(PreTrainedTokenizer):
token = self.sp_model.IdToPiece(index) token = self.sp_model.IdToPiece(index)
return token return token
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens: List[str]):
"""Converts a sequence of tokens (string) in a single string.""" """Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = [] current_sub_tokens: List[str] = []
out_string = "" out_string = ""
prev_is_special = False prev_is_special = False
for i, token in enumerate(tokens): for i, token in enumerate(tokens):