From 4141608c6a636952242b86e50d8f90ca674b7425 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Tue, 15 Oct 2024 02:23:33 +0800 Subject: [PATCH] [Hardware][intel GPU] add async output process for xpu (#8897) --- vllm/config.py | 4 ++-- vllm/worker/xpu_model_runner.py | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index b0761ae0..7a3248f4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -361,9 +361,9 @@ class ModelConfig: # Reminder: Please update docs/source/serving/compatibility_matrix.rst # If the feature combo become valid - if device_config.device_type not in ("cuda", "tpu"): + if device_config.device_type not in ("cuda", "tpu", "xpu"): logger.warning( - "Async output processing is only supported for CUDA or TPU. " + "Async output processing is only supported for CUDA, TPU, XPU. " "Disabling it for other platforms.") self.use_async_output_proc = False return diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 20dceee8..5ff4626c 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -2,8 +2,8 @@ import dataclasses import time import weakref from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, - TypeVar) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Type, TypeVar) import torch import torch.nn as nn @@ -57,6 +57,7 @@ class ModelInputForXPU(ModelRunnerInputBase): virtual_engine: Optional[int] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None + async_callback: Optional[Callable] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -582,6 +583,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): if not self.is_driver_worker: return [] + if model_input.async_callback is not None: + model_input.async_callback() + # Sample the next token. output: SamplerOutput = self.model.sample( logits=logits,