[Core] Streamline stream termination in AsyncLLMEngine (#7336)

This commit is contained in:
Nick Hill 2024-08-09 00:06:36 -07:00 committed by GitHub
parent 57b7be0e1c
commit b4e9528f95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 21 deletions

View File

@ -47,8 +47,10 @@ async def test_request_tracker():
assert tracker.new_requests_event.is_set() assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests() await tracker.wait_for_new_requests()
new, aborted = tracker.get_new_and_aborted_requests() new, aborted = tracker.get_new_and_aborted_requests()
assert len(aborted) == 1 # aborted new requests will cancel each other out -
assert "4" in aborted # there's no need for them to propagate into the
# engine
assert not aborted
assert not new assert not new
assert stream_4.finished assert stream_4.finished

View File

@ -85,11 +85,14 @@ class AsyncStream:
return return
self._queue.put_nowait(item) self._queue.put_nowait(item)
def finish(self, cancelled: bool = False) -> None: def finish(
self,
exception: Optional[Union[BaseException, Type[BaseException]]] = None,
) -> None:
if not self._finished: if not self._finished:
self._finished = True self._finished = True
self._queue.put_nowait( self._queue.put_nowait(
asyncio.CancelledError if cancelled else STOP_ITERATION) exception if exception is not None else STOP_ITERATION)
@property @property
def finished(self) -> bool: def finished(self) -> bool:
@ -133,14 +136,12 @@ class RequestTracker:
"""Propagate an exception to request streams """Propagate an exception to request streams
(all if request_id is None).""" (all if request_id is None)."""
if request_id is not None: if request_id is not None:
self._request_streams[request_id].put(exc) self.abort_request(request_id, exception=exc)
self.abort_request(request_id)
else: else:
# NB: list() used here because self.abort_request pops the stream # NB: tuple() used here because self.abort_request pops the stream
# out of self._request_streams, so we can't iterate on it directly # out of self._request_streams, so we can't iterate on it directly
for rid, stream in list(self._request_streams.items()): for rid in tuple(self._request_streams.keys()):
stream.put(exc) self.abort_request(rid, exception=exc)
self.abort_request(rid)
def process_request_output(self, def process_request_output(self,
request_output: Union[RequestOutput, request_output: Union[RequestOutput,
@ -167,14 +168,13 @@ class RequestTracker:
def process_exception(self, def process_exception(self,
request_id: str, request_id: str,
exception: Exception, exception: BaseException,
*, *,
verbose: bool = False) -> None: verbose: bool = False) -> None:
"""Propagate an exception from the engine.""" """Propagate an exception from the engine."""
self._request_streams[request_id].put(exception)
if verbose: if verbose:
logger.info("Finished request %s.", request_id) logger.info("Finished request %s.", request_id)
self.abort_request(request_id) self.abort_request(request_id, exception=exception)
def add_request(self, def add_request(self,
request_id: str, request_id: str,
@ -203,7 +203,8 @@ class RequestTracker:
def abort_request(self, def abort_request(self,
request_id: str, request_id: str,
*, *,
cancelled: bool = False, exception: Optional[Union[BaseException,
Type[BaseException]]] = None,
verbose: bool = False) -> None: verbose: bool = False) -> None:
"""Abort a request during next background loop iteration.""" """Abort a request during next background loop iteration."""
if verbose: if verbose:
@ -213,7 +214,7 @@ class RequestTracker:
stream = self._request_streams.pop(request_id, None) stream = self._request_streams.pop(request_id, None)
if stream is not None: if stream is not None:
stream.finish(cancelled=cancelled) stream.finish(exception=exception)
def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]: def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
"""Get the new requests and finished requests to be """Get the new requests and finished requests to be
@ -227,12 +228,14 @@ class RequestTracker:
while not self._new_requests.empty(): while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait() stream, new_request = self._new_requests.get_nowait()
if stream.request_id in finished_requests: request_id = stream.request_id
if request_id in finished_requests:
# The request has already been aborted. # The request has already been aborted.
stream.finish(cancelled=True) stream.finish(asyncio.CancelledError)
continue finished_requests.discard(request_id)
self._request_streams[stream.request_id] = stream else:
new_requests.append(new_request) self._request_streams[request_id] = stream
new_requests.append(new_request)
return new_requests, finished_requests return new_requests, finished_requests
@ -1015,7 +1018,7 @@ class AsyncLLMEngine:
request_id: The unique id of the request. request_id: The unique id of the request.
""" """
self._request_tracker.abort_request(request_id, self._request_tracker.abort_request(request_id,
cancelled=True, exception=asyncio.CancelledError,
verbose=self.log_requests) verbose=self.log_requests)
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig: