[Mypy] Part 3 fix typing for nested directories for most of directory (#4161)
This commit is contained in:
parent
34128a697e
commit
0ae11f78ab
29
.github/workflows/mypy.yaml
vendored
29
.github/workflows/mypy.yaml
vendored
@ -32,19 +32,20 @@ jobs:
|
|||||||
pip install types-setuptools
|
pip install types-setuptools
|
||||||
- name: Mypy
|
- name: Mypy
|
||||||
run: |
|
run: |
|
||||||
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/attention --config-file pyproject.toml
|
||||||
|
# TODO(sang): Fix nested dir
|
||||||
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/distributed --config-file pyproject.toml
|
||||||
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/entrypoints --config-file pyproject.toml
|
||||||
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/executor --config-file pyproject.toml
|
||||||
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/usage --config-file pyproject.toml
|
||||||
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/*.py --config-file pyproject.toml
|
||||||
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/transformers_utils --config-file pyproject.toml
|
||||||
|
mypy vllm/engine --config-file pyproject.toml
|
||||||
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/worker --config-file pyproject.toml
|
||||||
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/spec_decode --config-file pyproject.toml
|
||||||
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
|
# TODO(sang): Fix nested dir
|
||||||
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/model_executor/*.py --config-file pyproject.toml
|
||||||
# TODO(sang): Follow up
|
# TODO(sang): Fix nested dir
|
||||||
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
|
# mypy vllm/lora/*.py --config-file pyproject.toml
|
||||||
|
|
||||||
|
|||||||
26
format.sh
26
format.sh
@ -94,21 +94,19 @@ echo 'vLLM yapf: Done'
|
|||||||
|
|
||||||
# Run mypy
|
# Run mypy
|
||||||
echo 'vLLM mypy:'
|
echo 'vLLM mypy:'
|
||||||
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/attention --config-file pyproject.toml
|
||||||
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/distributed --config-file pyproject.toml
|
||||||
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/entrypoints --config-file pyproject.toml
|
||||||
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/executor --config-file pyproject.toml
|
||||||
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/usage --config-file pyproject.toml
|
||||||
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/*.py --config-file pyproject.toml
|
||||||
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/transformers_utils --config-file pyproject.toml
|
||||||
|
mypy vllm/engine --config-file pyproject.toml
|
||||||
# TODO(sang): Follow up
|
mypy vllm/worker --config-file pyproject.toml
|
||||||
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/spec_decode --config-file pyproject.toml
|
||||||
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/model_executor/*.py --config-file pyproject.toml
|
||||||
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
|
# mypy vllm/lora/*.py --config-file pyproject.toml
|
||||||
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
|
||||||
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
|
|
||||||
|
|
||||||
|
|
||||||
CODESPELL_EXCLUDES=(
|
CODESPELL_EXCLUDES=(
|
||||||
|
|||||||
@ -47,14 +47,16 @@ python_version = "3.8"
|
|||||||
|
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
check_untyped_defs = true
|
check_untyped_defs = true
|
||||||
|
follow_imports = "skip"
|
||||||
|
|
||||||
files = "vllm"
|
files = "vllm"
|
||||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||||
exclude = [
|
exclude = [
|
||||||
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
|
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
|
||||||
|
# Ignore triton kernels in ops.
|
||||||
|
'vllm/attention/ops/.*\.py$'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
[tool.codespell]
|
[tool.codespell]
|
||||||
ignore-words-list = "dout, te, indicies"
|
ignore-words-list = "dout, te, indicies"
|
||||||
skip = "./tests/prompts,./benchmarks/sonnet.txt"
|
skip = "./tests/prompts,./benchmarks/sonnet.txt"
|
||||||
|
|||||||
@ -116,7 +116,7 @@ class AttentionImpl(ABC):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
|
attn_metadata: AttentionMetadata,
|
||||||
kv_scale: float,
|
kv_scale: float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -248,6 +248,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
# Prompt run.
|
# Prompt run.
|
||||||
|
assert prefill_meta.prompt_lens is not None
|
||||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
||||||
# triton attention
|
# triton attention
|
||||||
# When block_tables are not filled, it means q and k are the
|
# When block_tables are not filled, it means q and k are the
|
||||||
|
|||||||
@ -106,7 +106,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: TorchSDPAMetadata,
|
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||||
kv_scale: float,
|
kv_scale: float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with torch SDPA and PagedAttention.
|
"""Forward pass with torch SDPA and PagedAttention.
|
||||||
@ -136,6 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
|||||||
kv_scale)
|
kv_scale)
|
||||||
|
|
||||||
if attn_metadata.is_prompt:
|
if attn_metadata.is_prompt:
|
||||||
|
assert attn_metadata.prompt_lens is not None
|
||||||
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
|
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
|
||||||
if self.num_kv_heads != self.num_heads:
|
if self.num_kv_heads != self.num_heads:
|
||||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||||
|
|||||||
@ -288,6 +288,7 @@ class XFormersImpl(AttentionImpl):
|
|||||||
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
"""
|
"""
|
||||||
|
assert attn_metadata.prompt_lens is not None
|
||||||
original_query = query
|
original_query = query
|
||||||
if self.num_kv_heads != self.num_heads:
|
if self.num_kv_heads != self.num_heads:
|
||||||
# GQA/MQA requires the shape [B, M, G, H, K].
|
# GQA/MQA requires the shape [B, M, G, H, K].
|
||||||
|
|||||||
@ -104,6 +104,7 @@ class BlockTable:
|
|||||||
token_ids (List[int]): The sequence of token IDs to be appended.
|
token_ids (List[int]): The sequence of token IDs to be appended.
|
||||||
"""
|
"""
|
||||||
assert self._is_allocated
|
assert self._is_allocated
|
||||||
|
assert self._blocks is not None
|
||||||
|
|
||||||
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
|
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
|
||||||
num_lookahead_slots)
|
num_lookahead_slots)
|
||||||
|
|||||||
@ -99,7 +99,7 @@ class CopyOnWriteTracker:
|
|||||||
refcounter: RefCounter,
|
refcounter: RefCounter,
|
||||||
allocator: BlockAllocator,
|
allocator: BlockAllocator,
|
||||||
):
|
):
|
||||||
self._copy_on_writes = defaultdict(list)
|
self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
|
||||||
self._refcounter = refcounter
|
self._refcounter = refcounter
|
||||||
self._allocator = allocator
|
self._allocator = allocator
|
||||||
|
|
||||||
@ -138,6 +138,8 @@ class CopyOnWriteTracker:
|
|||||||
prev_block=block.prev_block).block_id
|
prev_block=block.prev_block).block_id
|
||||||
|
|
||||||
# Track src/dst copy.
|
# Track src/dst copy.
|
||||||
|
assert src_block_id is not None
|
||||||
|
assert block_id is not None
|
||||||
self._copy_on_writes[src_block_id].append(block_id)
|
self._copy_on_writes[src_block_id].append(block_id)
|
||||||
|
|
||||||
return block_id
|
return block_id
|
||||||
@ -180,6 +182,6 @@ def get_all_blocks_recursively(last_block: Block) -> List[Block]:
|
|||||||
recurse(block.prev_block, lst)
|
recurse(block.prev_block, lst)
|
||||||
lst.append(block)
|
lst.append(block)
|
||||||
|
|
||||||
all_blocks = []
|
all_blocks: List[Block] = []
|
||||||
recurse(last_block, all_blocks)
|
recurse(last_block, all_blocks)
|
||||||
return all_blocks
|
return all_blocks
|
||||||
|
|||||||
@ -52,8 +52,7 @@ class Block(ABC):
|
|||||||
class BlockAllocator(ABC):
|
class BlockAllocator(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def allocate_mutable(self, prev_block: Optional[Block],
|
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
|
||||||
device: Device) -> Block:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -98,8 +97,7 @@ class BlockAllocator(ABC):
|
|||||||
class DeviceAwareBlockAllocator(BlockAllocator):
|
class DeviceAwareBlockAllocator(BlockAllocator):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def allocate_mutable(self, prev_block: Optional[Block],
|
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
|
||||||
device: Device) -> Block:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -18,7 +18,7 @@ except ImportError:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_CA_HANDLE = None
|
_CA_HANDLE: Optional["CustomAllreduce"] = None
|
||||||
_IS_CAPTURING = False
|
_IS_CAPTURING = False
|
||||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||||
|
|
||||||
@ -51,7 +51,7 @@ def init_custom_ar() -> None:
|
|||||||
"Cannot test GPU P2P because not all GPUs are visible to the "
|
"Cannot test GPU P2P because not all GPUs are visible to the "
|
||||||
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
|
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
|
||||||
" is set.")
|
" is set.")
|
||||||
return False
|
return
|
||||||
# test nvlink first, this will filter out most of the cases
|
# test nvlink first, this will filter out most of the cases
|
||||||
# where custom allreduce is not supported
|
# where custom allreduce is not supported
|
||||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||||
@ -117,7 +117,7 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
|
|||||||
ca_handle = get_handle()
|
ca_handle = get_handle()
|
||||||
# when custom allreduce is disabled, this will be None
|
# when custom allreduce is disabled, this will be None
|
||||||
if ca_handle is None:
|
if ca_handle is None:
|
||||||
return
|
return None
|
||||||
if is_capturing():
|
if is_capturing():
|
||||||
if torch.cuda.is_current_stream_capturing():
|
if torch.cuda.is_current_stream_capturing():
|
||||||
if ca_handle.should_custom_ar(input):
|
if ca_handle.should_custom_ar(input):
|
||||||
@ -135,6 +135,8 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
|
|||||||
if ca_handle.should_custom_ar(input):
|
if ca_handle.should_custom_ar(input):
|
||||||
return ca_handle.all_reduce_unreg(input)
|
return ca_handle.all_reduce_unreg(input)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _nvml():
|
def _nvml():
|
||||||
@ -224,14 +226,14 @@ class CustomAllreduce:
|
|||||||
return self._gather_ipc_meta(shard_data)
|
return self._gather_ipc_meta(shard_data)
|
||||||
|
|
||||||
def _gather_ipc_meta(self, shard_data):
|
def _gather_ipc_meta(self, shard_data):
|
||||||
all_data = [None] * self.world_size
|
all_data: List[Optional[Any]] = [None] * self.world_size
|
||||||
dist.all_gather_object(all_data, shard_data)
|
dist.all_gather_object(all_data, shard_data)
|
||||||
|
|
||||||
handles = []
|
handles = []
|
||||||
offsets = []
|
offsets = []
|
||||||
for i in range(len(all_data)):
|
for i in range(len(all_data)):
|
||||||
handles.append(all_data[i][0])
|
handles.append(all_data[i][0]) # type: ignore
|
||||||
offsets.append(all_data[i][1])
|
offsets.append(all_data[i][1]) # type: ignore
|
||||||
return handles, offsets
|
return handles, offsets
|
||||||
|
|
||||||
def register_buffer(self, inp: torch.Tensor):
|
def register_buffer(self, inp: torch.Tensor):
|
||||||
|
|||||||
@ -107,9 +107,10 @@ _c_ncclCommInitRank.argtypes = [
|
|||||||
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
|
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
|
||||||
]
|
]
|
||||||
|
|
||||||
|
ncclDataType_t = ctypes.c_int
|
||||||
|
|
||||||
# enums
|
|
||||||
class ncclDataType_t(ctypes.c_int):
|
class ncclDataTypeEnum:
|
||||||
ncclInt8 = 0
|
ncclInt8 = 0
|
||||||
ncclChar = 0
|
ncclChar = 0
|
||||||
ncclUint8 = 1
|
ncclUint8 = 1
|
||||||
@ -128,7 +129,7 @@ class ncclDataType_t(ctypes.c_int):
|
|||||||
ncclNumTypes = 10
|
ncclNumTypes = 10
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
|
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||||
if dtype == torch.int8:
|
if dtype == torch.int8:
|
||||||
return cls.ncclInt8
|
return cls.ncclInt8
|
||||||
if dtype == torch.uint8:
|
if dtype == torch.uint8:
|
||||||
@ -148,7 +149,10 @@ class ncclDataType_t(ctypes.c_int):
|
|||||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||||
|
|
||||||
|
|
||||||
class ncclRedOp_t(ctypes.c_int):
|
ncclRedOp_t = ctypes.c_int
|
||||||
|
|
||||||
|
|
||||||
|
class ncclRedOpTypeEnum:
|
||||||
ncclSum = 0
|
ncclSum = 0
|
||||||
ncclProd = 1
|
ncclProd = 1
|
||||||
ncclMax = 2
|
ncclMax = 2
|
||||||
@ -157,7 +161,7 @@ class ncclRedOp_t(ctypes.c_int):
|
|||||||
ncclNumOps = 5
|
ncclNumOps = 5
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
|
def from_torch(cls, op: ReduceOp) -> int:
|
||||||
if op == ReduceOp.SUM:
|
if op == ReduceOp.SUM:
|
||||||
return cls.ncclSum
|
return cls.ncclSum
|
||||||
if op == ReduceOp.PRODUCT:
|
if op == ReduceOp.PRODUCT:
|
||||||
@ -180,8 +184,8 @@ class ncclRedOp_t(ctypes.c_int):
|
|||||||
_c_ncclAllReduce = nccl.ncclAllReduce
|
_c_ncclAllReduce = nccl.ncclAllReduce
|
||||||
_c_ncclAllReduce.restype = ctypes.c_int
|
_c_ncclAllReduce.restype = ctypes.c_int
|
||||||
_c_ncclAllReduce.argtypes = [
|
_c_ncclAllReduce.argtypes = [
|
||||||
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
|
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
|
||||||
ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
|
ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
|
||||||
]
|
]
|
||||||
|
|
||||||
# equivalent to c declaration:
|
# equivalent to c declaration:
|
||||||
@ -251,8 +255,8 @@ class NCCLCommunicator:
|
|||||||
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
|
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
|
||||||
ctypes.c_void_p(tensor.data_ptr()),
|
ctypes.c_void_p(tensor.data_ptr()),
|
||||||
tensor.numel(),
|
tensor.numel(),
|
||||||
ncclDataType_t.from_torch(tensor.dtype),
|
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||||
ncclRedOp_t.from_torch(op), self.comm,
|
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||||
ctypes.c_void_p(stream.cuda_stream))
|
ctypes.c_void_p(stream.cuda_stream))
|
||||||
assert result == 0
|
assert result == 0
|
||||||
|
|
||||||
|
|||||||
@ -30,6 +30,7 @@ def is_initialized() -> bool:
|
|||||||
def set_pynccl_stream(stream: torch.cuda.Stream):
|
def set_pynccl_stream(stream: torch.cuda.Stream):
|
||||||
"""Set the cuda stream for communication"""
|
"""Set the cuda stream for communication"""
|
||||||
try:
|
try:
|
||||||
|
assert comm is not None
|
||||||
comm.stream = stream
|
comm.stream = stream
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
@ -52,6 +53,7 @@ def init_process_group(world_size: int,
|
|||||||
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
||||||
"""All-reduces the input tensor across the process group."""
|
"""All-reduces the input tensor across the process group."""
|
||||||
assert input_.is_cuda, f"{input_} should be a cuda tensor"
|
assert input_.is_cuda, f"{input_} should be a cuda tensor"
|
||||||
|
assert comm is not None
|
||||||
comm.all_reduce(input_, op)
|
comm.all_reduce(input_, op)
|
||||||
|
|
||||||
|
|
||||||
@ -62,8 +64,9 @@ def destroy_process_group() -> None:
|
|||||||
|
|
||||||
def get_world_size() -> int:
|
def get_world_size() -> int:
|
||||||
"""Returns the world size."""
|
"""Returns the world size."""
|
||||||
|
assert comm is not None
|
||||||
return comm.world_size
|
return comm.world_size
|
||||||
|
|
||||||
|
|
||||||
def get_nccl_backend():
|
def get_nccl_backend() -> Optional["NCCLCommunicator"]:
|
||||||
return comm
|
return comm
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, Iterable, List
|
from typing import Callable, List
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
@ -8,6 +8,7 @@ from vllm.core.scheduler import Scheduler
|
|||||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||||
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
|
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
|
||||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
|
from vllm.utils import Counter
|
||||||
|
|
||||||
|
|
||||||
class SequenceGroupOutputProcessor(ABC):
|
class SequenceGroupOutputProcessor(ABC):
|
||||||
@ -27,7 +28,7 @@ class SequenceGroupOutputProcessor(ABC):
|
|||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
detokenizer: Detokenizer,
|
detokenizer: Detokenizer,
|
||||||
scheduler: Scheduler,
|
scheduler: Scheduler,
|
||||||
seq_counter: Iterable[int],
|
seq_counter: Counter,
|
||||||
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
|
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
|
||||||
stop_checker: "StopChecker",
|
stop_checker: "StopChecker",
|
||||||
):
|
):
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Callable, Iterable, List
|
from typing import Callable, List
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
@ -11,6 +11,7 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
|
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
|
||||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
|
from vllm.utils import Counter
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -33,7 +34,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
self,
|
self,
|
||||||
detokenizer: Detokenizer,
|
detokenizer: Detokenizer,
|
||||||
scheduler: Scheduler,
|
scheduler: Scheduler,
|
||||||
seq_counter: Iterable[int],
|
seq_counter: Counter,
|
||||||
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
|
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
|
||||||
stop_checker: StopChecker,
|
stop_checker: StopChecker,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Iterable, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
from vllm.config import SchedulerConfig
|
from vllm.config import SchedulerConfig
|
||||||
from vllm.core.scheduler import Scheduler
|
from vllm.core.scheduler import Scheduler
|
||||||
@ -10,6 +10,7 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
|
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
|
||||||
SequenceOutput, SequenceStatus)
|
SequenceOutput, SequenceStatus)
|
||||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
|
from vllm.utils import Counter
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -33,7 +34,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
detokenizer: Detokenizer,
|
detokenizer: Detokenizer,
|
||||||
scheduler: Scheduler,
|
scheduler: Scheduler,
|
||||||
seq_counter: Iterable[int],
|
seq_counter: Counter,
|
||||||
stop_checker: StopChecker,
|
stop_checker: StopChecker,
|
||||||
):
|
):
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
@ -69,7 +70,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
samples = outputs.samples
|
samples = outputs.samples
|
||||||
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||||
existing_finished_seqs = seq_group.get_finished_seqs()
|
existing_finished_seqs = seq_group.get_finished_seqs()
|
||||||
parent_child_dict = {
|
parent_child_dict: Dict[int, List[SequenceOutput]] = {
|
||||||
parent_seq.seq_id: []
|
parent_seq.seq_id: []
|
||||||
for parent_seq in parent_seqs
|
for parent_seq in parent_seqs
|
||||||
}
|
}
|
||||||
@ -92,7 +93,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
continue
|
continue
|
||||||
# Fork the parent sequence if there are multiple child samples.
|
# Fork the parent sequence if there are multiple child samples.
|
||||||
for child_sample in child_samples[:-1]:
|
for child_sample in child_samples[:-1]:
|
||||||
new_child_seq_id = next(self.seq_counter)
|
new_child_seq_id: int = next(self.seq_counter)
|
||||||
child = parent.fork(new_child_seq_id)
|
child = parent.fork(new_child_seq_id)
|
||||||
child.append_token_id(child_sample.output_token,
|
child.append_token_id(child_sample.output_token,
|
||||||
child_sample.logprobs)
|
child_sample.logprobs)
|
||||||
|
|||||||
@ -8,7 +8,9 @@ def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
|
|||||||
"""Helper method which transforms a 2d list organized by
|
"""Helper method which transforms a 2d list organized by
|
||||||
[step][sequence group] into [sequence group][step].
|
[step][sequence group] into [sequence group][step].
|
||||||
"""
|
"""
|
||||||
output_by_sequence_group = [[] for _ in range(num_seq_groups)]
|
output_by_sequence_group: List[List[SamplerOutput]] = [
|
||||||
|
[] for _ in range(num_seq_groups)
|
||||||
|
]
|
||||||
for step in sampler_outputs:
|
for step in sampler_outputs:
|
||||||
for i, sequence_group_output in enumerate(step):
|
for i, sequence_group_output in enumerate(step):
|
||||||
output_by_sequence_group[i].append(sequence_group_output)
|
output_by_sequence_group[i].append(sequence_group_output)
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
|||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
CompletionRequest, ErrorResponse)
|
CompletionRequest, ErrorResponse)
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
@ -26,8 +27,8 @@ from vllm.usage.usage_lib import UsageContext
|
|||||||
|
|
||||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||||
|
|
||||||
openai_serving_chat: OpenAIServingChat = None
|
openai_serving_chat: OpenAIServingChat
|
||||||
openai_serving_completion: OpenAIServingCompletion = None
|
openai_serving_completion: OpenAIServingCompletion
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -95,6 +96,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
return StreamingResponse(content=generator,
|
return StreamingResponse(content=generator,
|
||||||
media_type="text/event-stream")
|
media_type="text/event-stream")
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(generator, ChatCompletionResponse)
|
||||||
return JSONResponse(content=generator.model_dump())
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,8 @@ import time
|
|||||||
from typing import Dict, List, Literal, Optional, Union
|
from typing import Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field, conint, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
@ -30,7 +31,7 @@ class ModelPermission(BaseModel):
|
|||||||
allow_fine_tuning: bool = False
|
allow_fine_tuning: bool = False
|
||||||
organization: str = "*"
|
organization: str = "*"
|
||||||
group: Optional[str] = None
|
group: Optional[str] = None
|
||||||
is_blocking: str = False
|
is_blocking: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(BaseModel):
|
||||||
@ -56,7 +57,7 @@ class UsageInfo(BaseModel):
|
|||||||
|
|
||||||
class ResponseFormat(BaseModel):
|
class ResponseFormat(BaseModel):
|
||||||
# type must be "json_object" or "text"
|
# type must be "json_object" or "text"
|
||||||
type: str = Literal["text", "json_object"]
|
type: Literal["text", "json_object"]
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
@ -152,6 +153,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
def logit_bias_logits_processor(
|
def logit_bias_logits_processor(
|
||||||
token_ids: List[int],
|
token_ids: List[int],
|
||||||
logits: torch.Tensor) -> torch.Tensor:
|
logits: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert self.logit_bias is not None
|
||||||
for token_id, bias in self.logit_bias.items():
|
for token_id, bias in self.logit_bias.items():
|
||||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||||
bias = min(100, max(-100, bias))
|
bias = min(100, max(-100, bias))
|
||||||
@ -213,7 +215,7 @@ class CompletionRequest(BaseModel):
|
|||||||
logit_bias: Optional[Dict[str, float]] = None
|
logit_bias: Optional[Dict[str, float]] = None
|
||||||
logprobs: Optional[int] = None
|
logprobs: Optional[int] = None
|
||||||
max_tokens: Optional[int] = 16
|
max_tokens: Optional[int] = 16
|
||||||
n: Optional[int] = 1
|
n: int = 1
|
||||||
presence_penalty: Optional[float] = 0.0
|
presence_penalty: Optional[float] = 0.0
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
@ -235,7 +237,7 @@ class CompletionRequest(BaseModel):
|
|||||||
min_tokens: Optional[int] = 0
|
min_tokens: Optional[int] = 0
|
||||||
skip_special_tokens: Optional[bool] = True
|
skip_special_tokens: Optional[bool] = True
|
||||||
spaces_between_special_tokens: Optional[bool] = True
|
spaces_between_special_tokens: Optional[bool] = True
|
||||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||||
# doc: end-completion-sampling-params
|
# doc: end-completion-sampling-params
|
||||||
|
|
||||||
# doc: begin-completion-extra-params
|
# doc: begin-completion-extra-params
|
||||||
@ -289,6 +291,7 @@ class CompletionRequest(BaseModel):
|
|||||||
def logit_bias_logits_processor(
|
def logit_bias_logits_processor(
|
||||||
token_ids: List[int],
|
token_ids: List[int],
|
||||||
logits: torch.Tensor) -> torch.Tensor:
|
logits: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert self.logit_bias is not None
|
||||||
for token_id, bias in self.logit_bias.items():
|
for token_id, bias in self.logit_bias.items():
|
||||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||||
bias = min(100, max(-100, bias))
|
bias = min(100, max(-100, bias))
|
||||||
|
|||||||
@ -115,12 +115,12 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
first_iteration = True
|
first_iteration = True
|
||||||
|
|
||||||
# Send response for each token for each request.n (index)
|
# Send response for each token for each request.n (index)
|
||||||
|
assert request.n is not None
|
||||||
previous_texts = [""] * request.n
|
previous_texts = [""] * request.n
|
||||||
previous_num_tokens = [0] * request.n
|
previous_num_tokens = [0] * request.n
|
||||||
finish_reason_sent = [False] * request.n
|
finish_reason_sent = [False] * request.n
|
||||||
try:
|
try:
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
res: RequestOutput
|
|
||||||
# We need to do it here, because if there are exceptions in
|
# We need to do it here, because if there are exceptions in
|
||||||
# the result_generator, it needs to be sent as the FIRST
|
# the result_generator, it needs to be sent as the FIRST
|
||||||
# response (by the try...catch).
|
# response (by the try...catch).
|
||||||
|
|||||||
@ -185,6 +185,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
num_prompts: int,
|
num_prompts: int,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
|
assert request.n is not None
|
||||||
previous_texts = [""] * request.n * num_prompts
|
previous_texts = [""] * request.n * num_prompts
|
||||||
previous_num_tokens = [0] * request.n * num_prompts
|
previous_num_tokens = [0] * request.n * num_prompts
|
||||||
has_echoed = [False] * request.n * num_prompts
|
has_echoed = [False] * request.n * num_prompts
|
||||||
@ -202,6 +203,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
# TODO(simon): optimize the performance by avoiding full
|
# TODO(simon): optimize the performance by avoiding full
|
||||||
# text O(n^2) sending.
|
# text O(n^2) sending.
|
||||||
|
|
||||||
|
assert request.max_tokens is not None
|
||||||
if request.echo and request.max_tokens == 0:
|
if request.echo and request.max_tokens == 0:
|
||||||
# only return the prompt
|
# only return the prompt
|
||||||
delta_text = res.prompt
|
delta_text = res.prompt
|
||||||
@ -279,7 +281,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
created_time: int,
|
created_time: int,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
choices = []
|
choices: List[CompletionResponseChoice] = []
|
||||||
num_prompt_tokens = 0
|
num_prompt_tokens = 0
|
||||||
num_generated_tokens = 0
|
num_generated_tokens = 0
|
||||||
for final_res in final_res_batch:
|
for final_res in final_res_batch:
|
||||||
@ -289,6 +291,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
prompt_text = final_res.prompt
|
prompt_text = final_res.prompt
|
||||||
|
|
||||||
for output in final_res.outputs:
|
for output in final_res.outputs:
|
||||||
|
assert request.max_tokens is not None
|
||||||
if request.echo and request.max_tokens == 0:
|
if request.echo and request.max_tokens == 0:
|
||||||
token_ids = prompt_token_ids
|
token_ids = prompt_token_ids
|
||||||
top_logprobs = prompt_logprobs
|
top_logprobs = prompt_logprobs
|
||||||
|
|||||||
@ -4,7 +4,9 @@ from dataclasses import dataclass
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from pydantic import conint
|
from pydantic import Field
|
||||||
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
@ -45,7 +47,8 @@ class OpenAIServing:
|
|||||||
]
|
]
|
||||||
|
|
||||||
self.max_model_len = 0
|
self.max_model_len = 0
|
||||||
self.tokenizer = None
|
# Lazy initialized
|
||||||
|
self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
event_loop = asyncio.get_running_loop()
|
event_loop = asyncio.get_running_loop()
|
||||||
@ -92,7 +95,7 @@ class OpenAIServing:
|
|||||||
def _create_logprobs(
|
def _create_logprobs(
|
||||||
self,
|
self,
|
||||||
token_ids: List[int],
|
token_ids: List[int],
|
||||||
top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None,
|
top_logprobs: List[Optional[Dict[int, Logprob]]],
|
||||||
num_output_top_logprobs: Optional[int] = None,
|
num_output_top_logprobs: Optional[int] = None,
|
||||||
initial_text_offset: int = 0,
|
initial_text_offset: int = 0,
|
||||||
) -> LogProbs:
|
) -> LogProbs:
|
||||||
@ -108,6 +111,7 @@ class OpenAIServing:
|
|||||||
token = self.tokenizer.decode(token_id)
|
token = self.tokenizer.decode(token_id)
|
||||||
logprobs.tokens.append(token)
|
logprobs.tokens.append(token)
|
||||||
logprobs.token_logprobs.append(None)
|
logprobs.token_logprobs.append(None)
|
||||||
|
assert logprobs.top_logprobs is not None
|
||||||
logprobs.top_logprobs.append(None)
|
logprobs.top_logprobs.append(None)
|
||||||
else:
|
else:
|
||||||
token_logprob = step_top_logprobs[token_id].logprob
|
token_logprob = step_top_logprobs[token_id].logprob
|
||||||
@ -116,6 +120,7 @@ class OpenAIServing:
|
|||||||
logprobs.token_logprobs.append(token_logprob)
|
logprobs.token_logprobs.append(token_logprob)
|
||||||
|
|
||||||
if num_output_top_logprobs:
|
if num_output_top_logprobs:
|
||||||
|
assert logprobs.top_logprobs is not None
|
||||||
logprobs.top_logprobs.append({
|
logprobs.top_logprobs.append({
|
||||||
# Convert float("-inf") to the
|
# Convert float("-inf") to the
|
||||||
# JSON-serializable float that OpenAI uses
|
# JSON-serializable float that OpenAI uses
|
||||||
@ -155,9 +160,9 @@ class OpenAIServing:
|
|||||||
|
|
||||||
async def _check_model(self, request) -> Optional[ErrorResponse]:
|
async def _check_model(self, request) -> Optional[ErrorResponse]:
|
||||||
if request.model in self.served_model_names:
|
if request.model in self.served_model_names:
|
||||||
return
|
return None
|
||||||
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
||||||
return
|
return None
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
message=f"The model `{request.model}` does not exist.",
|
message=f"The model `{request.model}` does not exist.",
|
||||||
err_type="NotFoundError",
|
err_type="NotFoundError",
|
||||||
@ -165,7 +170,7 @@ class OpenAIServing:
|
|||||||
|
|
||||||
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
|
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
|
||||||
if request.model in self.served_model_names:
|
if request.model in self.served_model_names:
|
||||||
return
|
return None
|
||||||
for lora in self.lora_requests:
|
for lora in self.lora_requests:
|
||||||
if request.model == lora.lora_name:
|
if request.model == lora.lora_name:
|
||||||
return lora
|
return lora
|
||||||
@ -177,7 +182,7 @@ class OpenAIServing:
|
|||||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
prompt_ids: Optional[List[int]] = None,
|
prompt_ids: Optional[List[int]] = None,
|
||||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||||
) -> Tuple[List[int], str]:
|
) -> Tuple[List[int], str]:
|
||||||
if not (prompt or prompt_ids):
|
if not (prompt or prompt_ids):
|
||||||
raise ValueError("Either prompt or prompt_ids should be provided.")
|
raise ValueError("Either prompt or prompt_ids should be provided.")
|
||||||
|
|||||||
@ -33,7 +33,7 @@ class LoRALayerWeights:
|
|||||||
def optimize(self) -> "LoRALayerWeights":
|
def optimize(self) -> "LoRALayerWeights":
|
||||||
"""Optimize the LoRA by merging the scaling into lora_b."""
|
"""Optimize the LoRA by merging the scaling into lora_b."""
|
||||||
if self.scaling == 1:
|
if self.scaling == 1:
|
||||||
return
|
return self
|
||||||
self.lora_b *= self.scaling
|
self.lora_b *= self.scaling
|
||||||
self.scaling = 1
|
self.scaling = 1
|
||||||
return self
|
return self
|
||||||
|
|||||||
@ -29,8 +29,8 @@ def _multi_split_sample(
|
|||||||
sampled_tokens_size: Tuple[int, int],
|
sampled_tokens_size: Tuple[int, int],
|
||||||
sampled_logprobs_size: Tuple[int, int],
|
sampled_logprobs_size: Tuple[int, int],
|
||||||
sample_indices: torch.Tensor,
|
sample_indices: torch.Tensor,
|
||||||
|
logprobs: torch.Tensor,
|
||||||
*,
|
*,
|
||||||
logprobs: Optional[torch.Tensor] = None,
|
|
||||||
modify_greedy_probs: bool = False,
|
modify_greedy_probs: bool = False,
|
||||||
save_logprobs: bool = False,
|
save_logprobs: bool = False,
|
||||||
):
|
):
|
||||||
@ -167,6 +167,7 @@ def sample(
|
|||||||
sampled_logprobs_size = (0, 0)
|
sampled_logprobs_size = (0, 0)
|
||||||
logprobs = probs
|
logprobs = probs
|
||||||
|
|
||||||
|
assert logprobs is not None
|
||||||
if _save_modified_probs:
|
if _save_modified_probs:
|
||||||
sampled_modified_probs_size = sampled_tokens_size
|
sampled_modified_probs_size = sampled_tokens_size
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -108,7 +108,8 @@ class RotaryEmbedding(nn.Module):
|
|||||||
query_pass = query[..., self.rotary_dim:]
|
query_pass = query[..., self.rotary_dim:]
|
||||||
key_pass = key[..., self.rotary_dim:]
|
key_pass = key[..., self.rotary_dim:]
|
||||||
|
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
|
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
|
||||||
|
positions.device)
|
||||||
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
||||||
if offsets is not None else positions]
|
if offsets is not None else positions]
|
||||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
|
|||||||
@ -222,13 +222,15 @@ class JAISConfig(PretrainedConfig):
|
|||||||
f"got {alibi_scaling_type}")
|
f"got {alibi_scaling_type}")
|
||||||
if (alibi_scaling_factor is not None
|
if (alibi_scaling_factor is not None
|
||||||
and not isinstance(alibi_scaling_factor, float)
|
and not isinstance(alibi_scaling_factor, float)
|
||||||
or alibi_scaling_factor <= 1.0):
|
or (alibi_scaling_factor is not None
|
||||||
|
and alibi_scaling_factor <= 1.0)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`alibi_scaling`'s factor field must be a float > 1.0,"
|
f"`alibi_scaling`'s factor field must be a float > 1.0,"
|
||||||
f"got {alibi_scaling_factor}")
|
f"got {alibi_scaling_factor}")
|
||||||
if (alibi_dynamic_scaling is not None
|
if (alibi_dynamic_scaling is not None
|
||||||
and not isinstance(alibi_dynamic_scaling, int)
|
and not isinstance(alibi_dynamic_scaling, int)
|
||||||
or alibi_dynamic_scaling <= 1):
|
or (alibi_dynamic_scaling is not None
|
||||||
|
and alibi_dynamic_scaling <= 1)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`alibi_scaling`'s `train_seq_len` field must be an"
|
f"`alibi_scaling`'s `train_seq_len` field must be an"
|
||||||
f"integer > 1, got {alibi_dynamic_scaling}")
|
f"integer > 1, got {alibi_dynamic_scaling}")
|
||||||
|
|||||||
@ -11,7 +11,7 @@ if ray:
|
|||||||
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
|
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
|
||||||
RayTokenizerGroupPool)
|
RayTokenizerGroupPool)
|
||||||
else:
|
else:
|
||||||
RayTokenizerGroupPool = None
|
RayTokenizerGroupPool = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
||||||
|
|||||||
@ -89,6 +89,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|||||||
This is blocking.
|
This is blocking.
|
||||||
"""
|
"""
|
||||||
self._ensure_queue_initialized()
|
self._ensure_queue_initialized()
|
||||||
|
assert self._idle_actors is not None
|
||||||
|
|
||||||
if self._idle_actors.empty():
|
if self._idle_actors.empty():
|
||||||
raise RuntimeError("No idle actors available.")
|
raise RuntimeError("No idle actors available.")
|
||||||
@ -120,6 +121,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|||||||
This is non-blocking.
|
This is non-blocking.
|
||||||
"""
|
"""
|
||||||
self._ensure_queue_initialized()
|
self._ensure_queue_initialized()
|
||||||
|
assert self._idle_actors is not None
|
||||||
|
|
||||||
actor = await self._idle_actors.get()
|
actor = await self._idle_actors.get()
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -114,9 +114,9 @@ class BaichuanTokenizer(PreTrainedTokenizer):
|
|||||||
token = self.sp_model.IdToPiece(index)
|
token = self.sp_model.IdToPiece(index)
|
||||||
return token
|
return token
|
||||||
|
|
||||||
def convert_tokens_to_string(self, tokens):
|
def convert_tokens_to_string(self, tokens: List[str]):
|
||||||
"""Converts a sequence of tokens (string) in a single string."""
|
"""Converts a sequence of tokens (string) in a single string."""
|
||||||
current_sub_tokens = []
|
current_sub_tokens: List[str] = []
|
||||||
out_string = ""
|
out_string = ""
|
||||||
prev_is_special = False
|
prev_is_special = False
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user