vllm/vllm/executor/hpu_executor.py
Konrad Zawora a02a50e6e5
[Hardware][Intel-Gaudi] Add Intel Gaudi (HPU) inference backend (#6143)
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Signed-off-by: Bob Zhu <bob.zhu@intel.com>
Signed-off-by: zehao-intel <zehao.huang@intel.com>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: Sanju C Sudhakaran <scsudhakaran@habana.ai>
Co-authored-by: Michal Adamczyk <madamczyk@habana.ai>
Co-authored-by: Marceli Fylcek <mfylcek@habana.ai>
Co-authored-by: Himangshu Lahkar <49579433+hlahkar@users.noreply.github.com>
Co-authored-by: Vivek Goel <vgoel@habana.ai>
Co-authored-by: yuwenzho <yuwen.zhou@intel.com>
Co-authored-by: Dominika Olszewska <dolszewska@habana.ai>
Co-authored-by: barak goldberg <149692267+bgoldberg-habana@users.noreply.github.com>
Co-authored-by: Michal Szutenberg <37601244+szutenberg@users.noreply.github.com>
Co-authored-by: Jan Kaniecki <jkaniecki@habana.ai>
Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyniewicz-habana@users.noreply.github.com>
Co-authored-by: Krzysztof Wisniewski <kwisniewski@habana.ai>
Co-authored-by: Dudi Lester <160421192+dudilester@users.noreply.github.com>
Co-authored-by: Ilia Taraban <tarabanil@gmail.com>
Co-authored-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Michał Kuligowski <mkuligowski@habana.ai>
Co-authored-by: Jakub Maksymczuk <jmaksymczuk@habana.ai>
Co-authored-by: Tomasz Zielinski <85164140+tzielinski-habana@users.noreply.github.com>
Co-authored-by: Sun Choi <schoi@habana.ai>
Co-authored-by: Iryna Boiko <iboiko@habana.ai>
Co-authored-by: Bob Zhu <41610754+czhu15@users.noreply.github.com>
Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com>
Co-authored-by: Zehao Huang <zehao.huang@intel.com>
Co-authored-by: Andrzej Kotłowski <Andrzej.Kotlowski@intel.com>
Co-authored-by: Yan Tomsinsky <73292515+Yantom1@users.noreply.github.com>
Co-authored-by: Nir David <ndavid@habana.ai>
Co-authored-by: Yu-Zhou <yu.zhou@intel.com>
Co-authored-by: Ruheena Suhani Shaik <rsshaik@habana.ai>
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
Co-authored-by: Marcin Swiniarski <mswiniarski@habana.ai>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Jacek Czaja <jacek.czaja@intel.com>
Co-authored-by: Jacek Czaja <jczaja@habana.ai>
Co-authored-by: Yuan <yuan.zhou@outlook.com>
2024-11-06 01:09:10 -08:00

206 lines
8.9 KiB
Python

###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
import contextlib
import os
from typing import Any, Dict, List, Optional, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
class HPUExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
"""Initialize the worker and load the model."""
self._init_worker()
def _get_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
"""Return worker init args for a given rank."""
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=rank == 0,
)
def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.hpu_worker",
worker_class_name="HPUWorker",
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
def _init_worker(self):
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# HPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
from vllm_hpu_extension.profiler import HabanaMemoryProfiler
with HabanaMemoryProfiler() as cache_init_m:
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
msg = f"init_cache_engine took {cache_init_m.get_summary_string()}"
logger.info(msg)
def finish_measurements(self):
self.driver_worker.finish_measurements()
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS - will log cpu fallbacks per engine step, only when there was any # noqa:E501
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL - will log cpu fallbacks per engine step, always, even if there were none # noqa:E501
log_graph_compilation_all = os.environ.get(
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL', '0') != '0'
log_graph_compilation = os.environ.get(
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION',
'0') != '0' or log_graph_compilation_all
log_cpu_fallbacks_all = os.environ.get(
'VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL', '0') != '0'
log_cpu_fallbacks = os.environ.get('VLLM_HPU_LOG_STEP_CPU_FALLBACKS',
'0') != '0' or log_cpu_fallbacks_all
if log_graph_compilation or log_cpu_fallbacks:
from habana_frameworks.torch.hpu.metrics import metric_localcontext
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
is_prompt = any([
seq_group_metadata.is_prompt
for seq_group_metadata in seq_group_metadata_list
])
max_context_len = max([
max([
len(v.prompt_token_ids) + len(v.output_token_ids)
for v in seq_group_metadata.seq_data.values()
]) for seq_group_metadata in seq_group_metadata_list
]) # whoa, that's some spicy stuff right here
max_num_blocks = (
(max_context_len - 1) // self.cache_config.block_size) + 1
input_stats = (f'is_prompt: {is_prompt}, '
f'num_seqs: {len(seq_group_metadata_list)}, '
f'max_context_len: {max_context_len}, '
f'max_num_blocks {max_num_blocks}')
gc_ctx = metric_localcontext(
"graph_compilation"
) if log_graph_compilation else contextlib.nullcontext()
cpu_fallback_ctx = metric_localcontext(
"cpu_fallback"
) if log_cpu_fallbacks else contextlib.nullcontext()
with gc_ctx as gc_local_metric, \
cpu_fallback_ctx as cpu_fallback_local_metric:
output = self.driver_worker.execute_model(execute_model_req)
if (log_graph_compilation and gc_local_metric.stats()[0][1] > 0
) or log_graph_compilation_all:
msg = ("VLLM_HPU_STEP_GRAPH_COMPILATION: "
f"{gc_local_metric.stats()}, {input_stats}")
logger.warning(msg)
if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1] >
0) or log_cpu_fallbacks_all:
msg = ("VLLM_HPU_STEP_CPU_FALLBACK: "
f"{cpu_fallback_local_metric.stats()}, {input_stats}")
logger.warning(msg)
return output
output = self.driver_worker.execute_model(execute_model_req)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self.driver_worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def list_prompt_adapters(self) -> Set[int]:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def check_health(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return
def start_profile(self) -> None:
self.driver_worker.start_profile()
def stop_profile(self) -> None:
self.driver_worker.stop_profile()
def shutdown(self) -> None:
self.driver_worker.shutdown_inc()
class HPUExecutorAsync(HPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )
return output