Fix input_metadata.selected_token_indices in worker prepare_inputs (#1546)
This commit is contained in:
parent
06458a0b42
commit
8efe23f150
44
tests/worker/test_worker.py
Normal file
44
tests/worker/test_worker.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# pylint: disable=protected-access
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
|
|
||||||
|
def test_worker_prepare_inputs_for_prompt():
|
||||||
|
worker = Worker(None, None, None)
|
||||||
|
worker.block_size = 16
|
||||||
|
batch_size = random.randint(1, 256)
|
||||||
|
prompt_lens = []
|
||||||
|
seq_group_metadata_list = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
# make sure all tokens fit into one block
|
||||||
|
prompt_len = i % (worker.block_size - 1) + 1
|
||||||
|
prompt_lens.append(prompt_len)
|
||||||
|
seq_data = list(range(prompt_len))
|
||||||
|
seq_group_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={0: SequenceData(seq_data)},
|
||||||
|
sampling_params=SamplingParams(temperature=0),
|
||||||
|
block_tables={0: [1]},
|
||||||
|
))
|
||||||
|
expected_selected_token_indices = []
|
||||||
|
selected_token_start_idx = 0
|
||||||
|
max_seq_len = max(prompt_lens)
|
||||||
|
for prompt_len in prompt_lens:
|
||||||
|
expected_selected_token_indices.append(selected_token_start_idx +
|
||||||
|
prompt_len - 1)
|
||||||
|
selected_token_start_idx += max_seq_len
|
||||||
|
input_tokens, input_positions, input_metadata = worker._prepare_inputs(
|
||||||
|
seq_group_metadata_list)
|
||||||
|
assert input_tokens.shape == input_positions.shape == (batch_size,
|
||||||
|
max_seq_len)
|
||||||
|
torch.testing.assert_close(input_tokens, input_positions)
|
||||||
|
actual = input_metadata.selected_token_indices
|
||||||
|
expected = torch.tensor(expected_selected_token_indices,
|
||||||
|
device=actual.device,
|
||||||
|
dtype=actual.dtype)
|
||||||
|
torch.testing.assert_close(actual, expected)
|
||||||
@ -211,12 +211,14 @@ class Worker:
|
|||||||
context_lens: List[int] = []
|
context_lens: List[int] = []
|
||||||
generation_block_tables: List[List[int]] = []
|
generation_block_tables: List[List[int]] = []
|
||||||
max_seq_len = max(prompt_lens) if prompt_lens else 1
|
max_seq_len = max(prompt_lens) if prompt_lens else 1
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||||
if seq_group_metadata.is_prompt:
|
if seq_group_metadata.is_prompt:
|
||||||
# We need to do this in this loop as we need to know max_seq_len
|
# We need to do this in this loop as we need to know max_seq_len
|
||||||
assert len(
|
assert len(
|
||||||
seq_ids) == 1, "Prompt input should have only one seq."
|
seq_ids) == 1, "Prompt input should have only one seq."
|
||||||
sampling_params = seq_group_metadata.sampling_params
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
|
assert len(prompt_lens) == len(seq_group_metadata_list)
|
||||||
|
prompt_len = prompt_lens[i]
|
||||||
if sampling_params.prompt_logprobs is not None:
|
if sampling_params.prompt_logprobs is not None:
|
||||||
selected_token_indices.extend(
|
selected_token_indices.extend(
|
||||||
range(selected_token_start_idx,
|
range(selected_token_start_idx,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user