[Frontend][Core] Move merge_async_iterators to utils (#4026)

This commit is contained in:
Cyrus Leung 2024-04-12 13:30:54 +08:00 committed by GitHub
parent 1096717ae9
commit 7fd3949a0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 39 deletions

View File

@ -1,4 +1,3 @@
import asyncio
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional, Tuple)
@ -17,7 +16,7 @@ from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput
from vllm.utils import random_uuid
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
@ -50,41 +49,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
return prompt_is_tokens, prompts
def merge_async_iterators(*iterators):
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue = asyncio.Queue()
finished = [False] * len(iterators)
async def producer(i, iterator):
try:
async for item in iterator:
await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True
_tasks = [
asyncio.create_task(producer(i, iterator))
for i, iterator in enumerate(iterators)
]
async def consumer():
while not all(finished) or not queue.empty():
item = await queue.get()
if isinstance(item, Exception):
raise item
yield item
await asyncio.gather(*_tasks)
return consumer()
class OpenAIServingCompletion(OpenAIServing):
def __init__(self,

View File

@ -9,8 +9,8 @@ import warnings
from collections import OrderedDict, defaultdict
from functools import lru_cache, partial
from platform import uname
from typing import (Any, Awaitable, Callable, Dict, Generic, Hashable, List,
Optional, Tuple, TypeVar, Union)
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, Tuple, TypeVar, Union)
import psutil
import torch
@ -181,6 +181,42 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
return _async_wrapper
def merge_async_iterators(
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
finished = [False] * len(iterators)
async def producer(i: int, iterator: AsyncIterator[T]):
try:
async for item in iterator:
await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True
_tasks = [
asyncio.create_task(producer(i, iterator))
for i, iterator in enumerate(iterators)
]
async def consumer():
while not all(finished) or not queue.empty():
item = await queue.get()
if isinstance(item, Exception):
raise item
yield item
await asyncio.gather(*_tasks)
return consumer()
def get_ip() -> str:
host_ip = os.environ.get("HOST_IP")
if host_ip: