[Bugfix] Fix MQLLMEngine hanging (#9973)

Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw 2024-11-04 16:01:43 -05:00 committed by GitHub
parent 6e056bcf04
commit 04cef2c6ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 42 additions and 23 deletions

View File

@ -112,7 +112,11 @@ class MQLLMEngineClient(EngineClient):
# Stream for each individual request. # Stream for each individual request.
self.output_queues: Dict[str, asyncio.Queue] = {} self.output_queues: Dict[str, asyncio.Queue] = {}
self.output_loop = asyncio.create_task(self.run_output_handler_loop())
# Loop to handle output of the LLMEngine periodically.
# Started after the MQLLMEngine is ready so that we can
# build the Client in an executor to enable clean shutdown.
self.output_loop: Optional[asyncio.Task] = None
# Loop to check health of the LLMEngine periodically. # Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready. # Started after the MQLLMEngine is ready.
@ -247,6 +251,9 @@ class MQLLMEngineClient(EngineClient):
async def setup(self): async def setup(self):
"""Setup the client before it starts sending server requests.""" """Setup the client before it starts sending server requests."""
# Start output_loop
self.output_loop = asyncio.create_task(self.run_output_handler_loop())
with self.get_data_socket() as socket: with self.get_data_socket() as socket:
# Wait until server is ready. # Wait until server is ready.
response = await self._wait_for_server_rpc(socket) response = await self._wait_for_server_rpc(socket)
@ -265,7 +272,8 @@ class MQLLMEngineClient(EngineClient):
# Cancel background tasks. # Cancel background tasks.
if self.health_loop is not None: if self.health_loop is not None:
self.health_loop.cancel() self.health_loop.cancel()
self.output_loop.cancel() if self.output_loop is not None:
self.output_loop.cancel()
def _set_errored(self, e: BaseException): def _set_errored(self, e: BaseException):
logger.exception(repr(e)) logger.exception(repr(e))

View File

@ -349,16 +349,22 @@ class MQLLMEngine:
self.engine.model_executor._run_workers("stop_profile") self.engine.model_executor._run_workers("stop_profile")
def signal_handler(*_) -> None:
raise KeyboardInterrupt("MQLLMEngine terminated")
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str): ipc_path: str, engine_alive):
try:
engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
usage_context=usage_context,
ipc_path=ipc_path)
def signal_handler(*_) -> None: signal.signal(signal.SIGTERM, signal_handler)
# Interrupt server on sigterm
raise KeyboardInterrupt("MQLLMEngine terminated")
signal.signal(signal.SIGTERM, signal_handler) engine.start()
engine = MQLLMEngine.from_engine_args(engine_args=engine_args, except BaseException as e:
usage_context=usage_context, logger.exception(e)
ipc_path=ipc_path) engine_alive.value = False
engine.start() raise e

View File

@ -171,39 +171,44 @@ async def build_async_engine_client_from_engine_args(
# so we need to spawn a new process # so we need to spawn a new process
context = multiprocessing.get_context("spawn") context = multiprocessing.get_context("spawn")
# The Process can raise an exception during startup, which may
# not actually result in an exitcode being reported. As a result
# we use a shared variable to communicate the information.
engine_alive = multiprocessing.Value('b', True, lock=False)
engine_process = context.Process(target=run_mp_engine, engine_process = context.Process(target=run_mp_engine,
args=(engine_args, args=(engine_args,
UsageContext.OPENAI_API_SERVER, UsageContext.OPENAI_API_SERVER,
ipc_path)) ipc_path, engine_alive))
engine_process.start() engine_process.start()
engine_pid = engine_process.pid engine_pid = engine_process.pid
assert engine_pid is not None, "Engine process failed to start" assert engine_pid is not None, "Engine process failed to start."
logger.info("Started engine process with PID %d", engine_pid) logger.info("Started engine process with PID %d", engine_pid)
# Build RPCClient, which conforms to EngineClient Protocol. # Build RPCClient, which conforms to EngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
engine_config = engine_args.create_engine_config() engine_config = engine_args.create_engine_config()
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config, build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
engine_pid) engine_pid)
mq_engine_client = await asyncio.get_running_loop().run_in_executor(
None, build_client)
try: try:
while True: while True:
try: try:
await mp_engine_client.setup() await mq_engine_client.setup()
break break
except TimeoutError: except TimeoutError:
if not engine_process.is_alive(): if (not engine_process.is_alive()
or not engine_alive.value):
raise RuntimeError( raise RuntimeError(
"Engine process failed to start") from None "Engine process failed to start. See stack "
"trace for the root cause.") from None
yield mp_engine_client # type: ignore[misc] yield mq_engine_client # type: ignore[misc]
finally: finally:
# Ensure rpc server process was terminated # Ensure rpc server process was terminated
engine_process.terminate() engine_process.terminate()
# Close all open connections to the backend # Close all open connections to the backend
mp_engine_client.close() mq_engine_client.close()
# Wait for engine process to join # Wait for engine process to join
engine_process.join(4) engine_process.join(4)