[Hardware][intel GPU] add async output process for xpu (#8897)

This commit is contained in:
Kunshang Ji 2024-10-15 02:23:33 +08:00 committed by GitHub
parent dfe43a2071
commit 4141608c6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 4 deletions

View File

@ -361,9 +361,9 @@ class ModelConfig:
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if device_config.device_type not in ("cuda", "tpu"):
if device_config.device_type not in ("cuda", "tpu", "xpu"):
logger.warning(
"Async output processing is only supported for CUDA or TPU. "
"Async output processing is only supported for CUDA, TPU, XPU. "
"Disabling it for other platforms.")
self.use_async_output_proc = False
return

View File

@ -2,8 +2,8 @@ import dataclasses
import time
import weakref
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar)
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, TypeVar)
import torch
import torch.nn as nn
@ -57,6 +57,7 @@ class ModelInputForXPU(ModelRunnerInputBase):
virtual_engine: Optional[int] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
async_callback: Optional[Callable] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
@ -582,6 +583,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
if not self.is_driver_worker:
return []
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,