diff --git a/tests/async_engine/api_server_async_engine.py b/tests/async_engine/api_server_async_engine.py index e84916da..1be76fdc 100644 --- a/tests/async_engine/api_server_async_engine.py +++ b/tests/async_engine/api_server_async_engine.py @@ -40,8 +40,7 @@ if __name__ == "__main__": args = parser.parse_args() engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngineWithStats.from_engine_args(engine_args, - start_engine_loop=False) + engine = AsyncLLMEngineWithStats.from_engine_args(engine_args) vllm.entrypoints.api_server.engine = engine uvicorn.run( app, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index fc4b4b09..1a228c82 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -230,6 +230,8 @@ class AsyncLLMEngine: async frontend will be executed in a separate process as the model workers. log_requests: Whether to log the requests. + start_engine_loop: If True, the background task to run the engine + will be automatically started in the generate call. *args, *kwargs: Arguments for LLMEngine. """ @@ -240,7 +242,7 @@ class AsyncLLMEngine: engine_use_ray: bool, *args, log_requests: bool = True, - start_engine_loop: bool = False, + start_engine_loop: bool = True, **kwargs) -> None: self.worker_use_ray = worker_use_ray self.engine_use_ray = engine_use_ray @@ -249,8 +251,7 @@ class AsyncLLMEngine: self.request_tracker: RequestTracker = RequestTracker() self.background_loop = None - if start_engine_loop: - self.start_background_loop() + self.start_engine_loop = start_engine_loop @property def is_running(self) -> bool: @@ -330,11 +331,14 @@ class AsyncLLMEngine: f"prompt token ids: {prompt_token_ids}.") if not self.is_running: - raise AsyncEngineDeadError( - "Background loop is not running. If it was running, " - "inspect the output to find the stacktrace of the " - "error that caused the background loop to stop " - "(AsyncEngineDeadError).") + if self.start_engine_loop: + self.start_background_loop() + else: + raise AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") stream = self.request_tracker.add_request( request_id, @@ -426,7 +430,7 @@ class AsyncLLMEngine: @classmethod def from_engine_args(cls, engine_args: AsyncEngineArgs, - start_engine_loop: bool = False) -> "AsyncLLMEngine": + start_engine_loop: bool = True) -> "AsyncLLMEngine": """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. engine_configs = engine_args.create_engine_configs() diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index b73642b2..e9d3342a 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -32,9 +32,6 @@ async def generate(request: Request) -> Response: sampling_params = SamplingParams(**request_dict) request_id = random_uuid() - if not engine.is_running: - engine.start_background_loop() - results_generator = engine.generate(prompt, sampling_params, request_id) # Streaming case @@ -80,8 +77,7 @@ if __name__ == "__main__": args = parser.parse_args() engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngine.from_engine_args(engine_args, - start_engine_loop=False) + engine = AsyncLLMEngine.from_engine_args(engine_args) uvicorn.run(app, host=args.host, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 22170a05..0f1a09b7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -192,9 +192,6 @@ async def create_chat_completion(request: ChatCompletionRequest, """ logger.info(f"Received chat completion request: {request}") - if not engine.is_running: - engine.start_background_loop() - error_check_ret = await check_model(request) if error_check_ret is not None: return error_check_ret @@ -367,9 +364,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request): """ logger.info(f"Received completion request: {request}") - if not engine.is_running: - engine.start_background_loop() - error_check_ret = await check_model(request) if error_check_ret is not None: return error_check_ret @@ -627,8 +621,7 @@ if __name__ == "__main__": served_model = args.model engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngine.from_engine_args(engine_args, - start_engine_loop=False) + engine = AsyncLLMEngine.from_engine_args(engine_args) engine_model_config = asyncio.run(engine.get_model_config()) max_model_len = engine_model_config.get_max_model_len()