Start background task in AsyncLLMEngine.generate (#988)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
4b5bcf8906
commit
080438477f
@ -40,8 +40,7 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args,
|
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args)
|
||||||
start_engine_loop=False)
|
|
||||||
vllm.entrypoints.api_server.engine = engine
|
vllm.entrypoints.api_server.engine = engine
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app,
|
app,
|
||||||
|
|||||||
@ -230,6 +230,8 @@ class AsyncLLMEngine:
|
|||||||
async frontend will be executed in a separate process as the
|
async frontend will be executed in a separate process as the
|
||||||
model workers.
|
model workers.
|
||||||
log_requests: Whether to log the requests.
|
log_requests: Whether to log the requests.
|
||||||
|
start_engine_loop: If True, the background task to run the engine
|
||||||
|
will be automatically started in the generate call.
|
||||||
*args, *kwargs: Arguments for LLMEngine.
|
*args, *kwargs: Arguments for LLMEngine.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -240,7 +242,7 @@ class AsyncLLMEngine:
|
|||||||
engine_use_ray: bool,
|
engine_use_ray: bool,
|
||||||
*args,
|
*args,
|
||||||
log_requests: bool = True,
|
log_requests: bool = True,
|
||||||
start_engine_loop: bool = False,
|
start_engine_loop: bool = True,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
self.worker_use_ray = worker_use_ray
|
self.worker_use_ray = worker_use_ray
|
||||||
self.engine_use_ray = engine_use_ray
|
self.engine_use_ray = engine_use_ray
|
||||||
@ -249,8 +251,7 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
self.request_tracker: RequestTracker = RequestTracker()
|
self.request_tracker: RequestTracker = RequestTracker()
|
||||||
self.background_loop = None
|
self.background_loop = None
|
||||||
if start_engine_loop:
|
self.start_engine_loop = start_engine_loop
|
||||||
self.start_background_loop()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
@ -330,11 +331,14 @@ class AsyncLLMEngine:
|
|||||||
f"prompt token ids: {prompt_token_ids}.")
|
f"prompt token ids: {prompt_token_ids}.")
|
||||||
|
|
||||||
if not self.is_running:
|
if not self.is_running:
|
||||||
raise AsyncEngineDeadError(
|
if self.start_engine_loop:
|
||||||
"Background loop is not running. If it was running, "
|
self.start_background_loop()
|
||||||
"inspect the output to find the stacktrace of the "
|
else:
|
||||||
"error that caused the background loop to stop "
|
raise AsyncEngineDeadError(
|
||||||
"(AsyncEngineDeadError).")
|
"Background loop is not running. If it was running, "
|
||||||
|
"inspect the output to find the stacktrace of the "
|
||||||
|
"error that caused the background loop to stop "
|
||||||
|
"(AsyncEngineDeadError).")
|
||||||
|
|
||||||
stream = self.request_tracker.add_request(
|
stream = self.request_tracker.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
@ -426,7 +430,7 @@ class AsyncLLMEngine:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_engine_args(cls,
|
def from_engine_args(cls,
|
||||||
engine_args: AsyncEngineArgs,
|
engine_args: AsyncEngineArgs,
|
||||||
start_engine_loop: bool = False) -> "AsyncLLMEngine":
|
start_engine_loop: bool = True) -> "AsyncLLMEngine":
|
||||||
"""Creates an async LLM engine from the engine arguments."""
|
"""Creates an async LLM engine from the engine arguments."""
|
||||||
# Create the engine configs.
|
# Create the engine configs.
|
||||||
engine_configs = engine_args.create_engine_configs()
|
engine_configs = engine_args.create_engine_configs()
|
||||||
|
|||||||
@ -32,9 +32,6 @@ async def generate(request: Request) -> Response:
|
|||||||
sampling_params = SamplingParams(**request_dict)
|
sampling_params = SamplingParams(**request_dict)
|
||||||
request_id = random_uuid()
|
request_id = random_uuid()
|
||||||
|
|
||||||
if not engine.is_running:
|
|
||||||
engine.start_background_loop()
|
|
||||||
|
|
||||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||||
|
|
||||||
# Streaming case
|
# Streaming case
|
||||||
@ -80,8 +77,7 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
start_engine_loop=False)
|
|
||||||
|
|
||||||
uvicorn.run(app,
|
uvicorn.run(app,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
|
|||||||
@ -192,9 +192,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
"""
|
"""
|
||||||
logger.info(f"Received chat completion request: {request}")
|
logger.info(f"Received chat completion request: {request}")
|
||||||
|
|
||||||
if not engine.is_running:
|
|
||||||
engine.start_background_loop()
|
|
||||||
|
|
||||||
error_check_ret = await check_model(request)
|
error_check_ret = await check_model(request)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
@ -367,9 +364,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
"""
|
"""
|
||||||
logger.info(f"Received completion request: {request}")
|
logger.info(f"Received completion request: {request}")
|
||||||
|
|
||||||
if not engine.is_running:
|
|
||||||
engine.start_background_loop()
|
|
||||||
|
|
||||||
error_check_ret = await check_model(request)
|
error_check_ret = await check_model(request)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
@ -627,8 +621,7 @@ if __name__ == "__main__":
|
|||||||
served_model = args.model
|
served_model = args.model
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
start_engine_loop=False)
|
|
||||||
engine_model_config = asyncio.run(engine.get_model_config())
|
engine_model_config = asyncio.run(engine.get_model_config())
|
||||||
max_model_len = engine_model_config.get_max_model_len()
|
max_model_len = engine_model_config.get_max_model_len()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user