[Intel GPU] Fix xpu decode input (#9145)

This commit is contained in:
Kunshang Ji 2024-10-08 11:51:14 +08:00 committed by GitHub
parent 04c12f8157
commit 80b57f00d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,6 +15,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
@ -136,7 +137,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
(input_tokens, input_positions, (input_tokens, input_positions,
attn_metadata) = self._prepare_decode( attn_metadata) = self._prepare_decode(
self.seq_group_metadata_list) self.seq_group_metadata_list)
seq_lens = [] seq_lens = None
multi_modal_kwargs = None multi_modal_kwargs = None
return self.model_input_cls( return self.model_input_cls(
@ -390,6 +391,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
# Lazy initialization. # Lazy initialization.
self.model: nn.Module # Set after init_Model self.model: nn.Module # Set after init_Model
self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None
def load_model(self) -> None: def load_model(self) -> None:
with DeviceMemoryProfiler() as m: with DeviceMemoryProfiler() as m:
self.model = get_model( self.model = get_model(
@ -524,12 +529,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
seq_group_metadata_list, finished_requests_ids) seq_group_metadata_list, finished_requests_ids)
# Sampling metadata is only required for the final pp group # Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids) generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, sampling_metadata = SamplingMetadata.prepare(
model_input.seq_lens, seq_group_metadata_list,
model_input.query_lens, model_input.seq_lens,
self.device, model_input.query_lens,
pin_memory=False, self.device,
generators=generators) pin_memory=False,
generators=generators,
cache=self.sampling_metadata_cache)
return dataclasses.replace(model_input, return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,