[Intel GPU] Fix xpu decode input (#9145)
This commit is contained in:
parent
04c12f8157
commit
80b57f00d5
@ -15,6 +15,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
|||||||
from vllm.distributed import get_pp_group
|
from vllm.distributed import get_pp_group
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor import SamplingMetadataCache
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||||
@ -136,7 +137,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
|
|||||||
(input_tokens, input_positions,
|
(input_tokens, input_positions,
|
||||||
attn_metadata) = self._prepare_decode(
|
attn_metadata) = self._prepare_decode(
|
||||||
self.seq_group_metadata_list)
|
self.seq_group_metadata_list)
|
||||||
seq_lens = []
|
seq_lens = None
|
||||||
multi_modal_kwargs = None
|
multi_modal_kwargs = None
|
||||||
|
|
||||||
return self.model_input_cls(
|
return self.model_input_cls(
|
||||||
@ -390,6 +391,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
|||||||
# Lazy initialization.
|
# Lazy initialization.
|
||||||
self.model: nn.Module # Set after init_Model
|
self.model: nn.Module # Set after init_Model
|
||||||
|
|
||||||
|
self.sampling_metadata_cache: SamplingMetadataCache = \
|
||||||
|
SamplingMetadataCache() \
|
||||||
|
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
with DeviceMemoryProfiler() as m:
|
with DeviceMemoryProfiler() as m:
|
||||||
self.model = get_model(
|
self.model = get_model(
|
||||||
@ -524,12 +529,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
|||||||
seq_group_metadata_list, finished_requests_ids)
|
seq_group_metadata_list, finished_requests_ids)
|
||||||
# Sampling metadata is only required for the final pp group
|
# Sampling metadata is only required for the final pp group
|
||||||
generators = self.get_generators(finished_requests_ids)
|
generators = self.get_generators(finished_requests_ids)
|
||||||
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
model_input.seq_lens,
|
seq_group_metadata_list,
|
||||||
model_input.query_lens,
|
model_input.seq_lens,
|
||||||
self.device,
|
model_input.query_lens,
|
||||||
pin_memory=False,
|
self.device,
|
||||||
generators=generators)
|
pin_memory=False,
|
||||||
|
generators=generators,
|
||||||
|
cache=self.sampling_metadata_cache)
|
||||||
|
|
||||||
return dataclasses.replace(model_input,
|
return dataclasses.replace(model_input,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user