[Bugfix] Fix MQLLMEngine hanging (#9973)
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
This commit is contained in:
parent
6e056bcf04
commit
04cef2c6ab
@ -112,7 +112,11 @@ class MQLLMEngineClient(EngineClient):
|
|||||||
|
|
||||||
# Stream for each individual request.
|
# Stream for each individual request.
|
||||||
self.output_queues: Dict[str, asyncio.Queue] = {}
|
self.output_queues: Dict[str, asyncio.Queue] = {}
|
||||||
self.output_loop = asyncio.create_task(self.run_output_handler_loop())
|
|
||||||
|
# Loop to handle output of the LLMEngine periodically.
|
||||||
|
# Started after the MQLLMEngine is ready so that we can
|
||||||
|
# build the Client in an executor to enable clean shutdown.
|
||||||
|
self.output_loop: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
# Loop to check health of the LLMEngine periodically.
|
# Loop to check health of the LLMEngine periodically.
|
||||||
# Started after the MQLLMEngine is ready.
|
# Started after the MQLLMEngine is ready.
|
||||||
@ -247,6 +251,9 @@ class MQLLMEngineClient(EngineClient):
|
|||||||
async def setup(self):
|
async def setup(self):
|
||||||
"""Setup the client before it starts sending server requests."""
|
"""Setup the client before it starts sending server requests."""
|
||||||
|
|
||||||
|
# Start output_loop
|
||||||
|
self.output_loop = asyncio.create_task(self.run_output_handler_loop())
|
||||||
|
|
||||||
with self.get_data_socket() as socket:
|
with self.get_data_socket() as socket:
|
||||||
# Wait until server is ready.
|
# Wait until server is ready.
|
||||||
response = await self._wait_for_server_rpc(socket)
|
response = await self._wait_for_server_rpc(socket)
|
||||||
@ -265,6 +272,7 @@ class MQLLMEngineClient(EngineClient):
|
|||||||
# Cancel background tasks.
|
# Cancel background tasks.
|
||||||
if self.health_loop is not None:
|
if self.health_loop is not None:
|
||||||
self.health_loop.cancel()
|
self.health_loop.cancel()
|
||||||
|
if self.output_loop is not None:
|
||||||
self.output_loop.cancel()
|
self.output_loop.cancel()
|
||||||
|
|
||||||
def _set_errored(self, e: BaseException):
|
def _set_errored(self, e: BaseException):
|
||||||
|
|||||||
@ -349,16 +349,22 @@ class MQLLMEngine:
|
|||||||
self.engine.model_executor._run_workers("stop_profile")
|
self.engine.model_executor._run_workers("stop_profile")
|
||||||
|
|
||||||
|
|
||||||
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
|
def signal_handler(*_) -> None:
|
||||||
ipc_path: str):
|
|
||||||
|
|
||||||
def signal_handler(*_) -> None:
|
|
||||||
# Interrupt server on sigterm
|
|
||||||
raise KeyboardInterrupt("MQLLMEngine terminated")
|
raise KeyboardInterrupt("MQLLMEngine terminated")
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
|
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
|
||||||
|
ipc_path: str, engine_alive):
|
||||||
|
try:
|
||||||
engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
|
engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
|
||||||
usage_context=usage_context,
|
usage_context=usage_context,
|
||||||
ipc_path=ipc_path)
|
ipc_path=ipc_path)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
engine.start()
|
engine.start()
|
||||||
|
|
||||||
|
except BaseException as e:
|
||||||
|
logger.exception(e)
|
||||||
|
engine_alive.value = False
|
||||||
|
raise e
|
||||||
|
|||||||
@ -171,39 +171,44 @@ async def build_async_engine_client_from_engine_args(
|
|||||||
# so we need to spawn a new process
|
# so we need to spawn a new process
|
||||||
context = multiprocessing.get_context("spawn")
|
context = multiprocessing.get_context("spawn")
|
||||||
|
|
||||||
|
# The Process can raise an exception during startup, which may
|
||||||
|
# not actually result in an exitcode being reported. As a result
|
||||||
|
# we use a shared variable to communicate the information.
|
||||||
|
engine_alive = multiprocessing.Value('b', True, lock=False)
|
||||||
engine_process = context.Process(target=run_mp_engine,
|
engine_process = context.Process(target=run_mp_engine,
|
||||||
args=(engine_args,
|
args=(engine_args,
|
||||||
UsageContext.OPENAI_API_SERVER,
|
UsageContext.OPENAI_API_SERVER,
|
||||||
ipc_path))
|
ipc_path, engine_alive))
|
||||||
engine_process.start()
|
engine_process.start()
|
||||||
engine_pid = engine_process.pid
|
engine_pid = engine_process.pid
|
||||||
assert engine_pid is not None, "Engine process failed to start"
|
assert engine_pid is not None, "Engine process failed to start."
|
||||||
logger.info("Started engine process with PID %d", engine_pid)
|
logger.info("Started engine process with PID %d", engine_pid)
|
||||||
|
|
||||||
# Build RPCClient, which conforms to EngineClient Protocol.
|
# Build RPCClient, which conforms to EngineClient Protocol.
|
||||||
# NOTE: Actually, this is not true yet. We still need to support
|
|
||||||
# embedding models via RPC (see TODO above)
|
|
||||||
engine_config = engine_args.create_engine_config()
|
engine_config = engine_args.create_engine_config()
|
||||||
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config,
|
build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
|
||||||
engine_pid)
|
engine_pid)
|
||||||
|
mq_engine_client = await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, build_client)
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await mp_engine_client.setup()
|
await mq_engine_client.setup()
|
||||||
break
|
break
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
if not engine_process.is_alive():
|
if (not engine_process.is_alive()
|
||||||
|
or not engine_alive.value):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Engine process failed to start") from None
|
"Engine process failed to start. See stack "
|
||||||
|
"trace for the root cause.") from None
|
||||||
|
|
||||||
yield mp_engine_client # type: ignore[misc]
|
yield mq_engine_client # type: ignore[misc]
|
||||||
finally:
|
finally:
|
||||||
# Ensure rpc server process was terminated
|
# Ensure rpc server process was terminated
|
||||||
engine_process.terminate()
|
engine_process.terminate()
|
||||||
|
|
||||||
# Close all open connections to the backend
|
# Close all open connections to the backend
|
||||||
mp_engine_client.close()
|
mq_engine_client.close()
|
||||||
|
|
||||||
# Wait for engine process to join
|
# Wait for engine process to join
|
||||||
engine_process.join(4)
|
engine_process.join(4)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user