[Hardware] [Intel GPU] refactor xpu worker/executor (#7686)

This commit is contained in:
Kunshang Ji 2024-08-21 00:54:10 +08:00 committed by GitHub
parent aae6927be0
commit c42590f97a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 28 deletions

View File

@ -1,16 +1,16 @@
from typing import List, Optional from typing import List, Optional, Tuple, Union
import torch import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, PromptAdapterConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig) PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import make_async from vllm.utils import make_async
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
@ -30,6 +30,7 @@ class XPUExecutor(GPUExecutor):
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
prompt_adapter_config: Optional[PromptAdapterConfig], prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
observability_config: Optional[ObservabilityConfig],
) -> None: ) -> None:
assert device_config.device_type == "xpu" assert device_config.device_type == "xpu"
assert (not speculative_config assert (not speculative_config
@ -46,32 +47,23 @@ class XPUExecutor(GPUExecutor):
self.device_config = device_config self.device_config = device_config
self.prompt_adapter_config = prompt_adapter_config self.prompt_adapter_config = prompt_adapter_config
self.speculative_config = None self.speculative_config = None
self.observability_config = observability_config
# Instantiate the worker and load the model to GPU. # Instantiate the worker and load the model to GPU.
self._init_executor() self._init_executor()
def _create_worker(self, def _get_worker_module_and_class(self) -> Tuple[str, str]:
local_rank: int = 0, if self.speculative_config is not None:
rank: int = 0,
distributed_init_method: Optional[str] = None):
if self.speculative_config is None:
worker_module_name = "vllm.worker.xpu_worker"
worker_class_name = "XPUWorker"
else:
raise NotImplementedError( raise NotImplementedError(
"XPU does not support speculative decoding") "XPU does not support speculative decoding")
else:
wrapper = WorkerWrapperBase( worker_module_name = "vllm.worker.xpu_worker"
worker_module_name=worker_module_name, worker_class_name = "XPUWorker"
worker_class_name=worker_class_name, return (worker_module_name, worker_class_name)
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
def execute_model( def execute_model(
self, self, execute_model_req: ExecuteModelRequest
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.driver_worker.execute_model(execute_model_req) output = self.driver_worker.execute_model(execute_model_req)
return output return output

View File

@ -137,7 +137,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
device_config=self.device_config, device_config=self.device_config,
load_config=self.load_config, load_config=self.load_config,
lora_config=self.lora_config, lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
cache_config=self.cache_config, cache_config=self.cache_config,

View File

@ -9,8 +9,8 @@ import torch
import torch.distributed import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ObservabilityConfig,
PromptAdapterConfig, SchedulerConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig) SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
@ -50,6 +50,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
speculative_config: Optional[SpeculativeConfig] = None, speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
) -> None: ) -> None:
assert device_config.device_type == "xpu" assert device_config.device_type == "xpu"
assert is_xpu() assert is_xpu()
@ -67,8 +68,10 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
self.lora_config = lora_config self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: self.observability_config = observability_config
assert self.rank == 0, "The driver worker must have rank 0." if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
@ -183,7 +186,11 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
# dependency (libdrm and drm headers) on your system. # dependency (libdrm and drm headers) on your system.
ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE",
"sockets") "sockets")
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
str(parallel_config.world_size))
os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
os.environ["LOCAL_RANK"] = str(self.local_rank)
init_distributed_environment( init_distributed_environment(
world_size=parallel_config.world_size, world_size=parallel_config.world_size,
rank=rank, rank=rank,