[Frontend][Core] Move merge_async_iterators to utils (#4026)
This commit is contained in:
parent
1096717ae9
commit
7fd3949a0b
@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
||||
Optional, Tuple)
|
||||
@ -17,7 +16,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -50,41 +49,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
|
||||
return prompt_is_tokens, prompts
|
||||
|
||||
|
||||
def merge_async_iterators(*iterators):
|
||||
"""Merge multiple asynchronous iterators into a single iterator.
|
||||
|
||||
This method handle the case where some iterators finish before others.
|
||||
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||
iterator that yields the item.
|
||||
"""
|
||||
queue = asyncio.Queue()
|
||||
|
||||
finished = [False] * len(iterators)
|
||||
|
||||
async def producer(i, iterator):
|
||||
try:
|
||||
async for item in iterator:
|
||||
await queue.put((i, item))
|
||||
except Exception as e:
|
||||
await queue.put(e)
|
||||
finished[i] = True
|
||||
|
||||
_tasks = [
|
||||
asyncio.create_task(producer(i, iterator))
|
||||
for i, iterator in enumerate(iterators)
|
||||
]
|
||||
|
||||
async def consumer():
|
||||
while not all(finished) or not queue.empty():
|
||||
item = await queue.get()
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
await asyncio.gather(*_tasks)
|
||||
|
||||
return consumer()
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@ -9,8 +9,8 @@ import warnings
|
||||
from collections import OrderedDict, defaultdict
|
||||
from functools import lru_cache, partial
|
||||
from platform import uname
|
||||
from typing import (Any, Awaitable, Callable, Dict, Generic, Hashable, List,
|
||||
Optional, Tuple, TypeVar, Union)
|
||||
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
|
||||
Hashable, List, Optional, Tuple, TypeVar, Union)
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
@ -181,6 +181,42 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
|
||||
return _async_wrapper
|
||||
|
||||
|
||||
def merge_async_iterators(
|
||||
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
|
||||
"""Merge multiple asynchronous iterators into a single iterator.
|
||||
|
||||
This method handle the case where some iterators finish before others.
|
||||
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||
iterator that yields the item.
|
||||
"""
|
||||
queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
|
||||
|
||||
finished = [False] * len(iterators)
|
||||
|
||||
async def producer(i: int, iterator: AsyncIterator[T]):
|
||||
try:
|
||||
async for item in iterator:
|
||||
await queue.put((i, item))
|
||||
except Exception as e:
|
||||
await queue.put(e)
|
||||
finished[i] = True
|
||||
|
||||
_tasks = [
|
||||
asyncio.create_task(producer(i, iterator))
|
||||
for i, iterator in enumerate(iterators)
|
||||
]
|
||||
|
||||
async def consumer():
|
||||
while not all(finished) or not queue.empty():
|
||||
item = await queue.get()
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
await asyncio.gather(*_tasks)
|
||||
|
||||
return consumer()
|
||||
|
||||
|
||||
def get_ip() -> str:
|
||||
host_ip = os.environ.get("HOST_IP")
|
||||
if host_ip:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user