Fix sampler test (#1379)

This commit is contained in:
Woosuk Kwon 2023-10-16 12:57:26 -07:00 committed by GitHub
parent e8ef4c0820
commit d3a5bd9fb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,9 +1,9 @@
# pylint: disable=protected-access # pylint: disable=protected-access
import pytest
import random import random
from typing import Tuple from typing import Tuple
from unittest.mock import patch from unittest.mock import patch
import pytest
import torch import torch
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
@ -69,7 +69,7 @@ def test_sampler_all_greedy(seed: int):
input_metadata=input_metadata) input_metadata=input_metadata)
expected = torch.argmax(fake_logits, dim=-1) expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output): for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output: for nth_output in sequence_output.samples:
assert nth_output.output_token == expected[i].item() assert nth_output.output_token == expected[i].item()
@ -101,7 +101,7 @@ def test_sampler_all_random(seed: int):
hidden_states=input_tensor, hidden_states=input_tensor,
input_metadata=input_metadata) input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output): for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output: for nth_output in sequence_output.samples:
assert nth_output.output_token == i assert nth_output.output_token == i
@ -181,5 +181,5 @@ def test_sampler_mixed(seed: int):
for i, sequence_output in enumerate(sampler_output): for i, sequence_output in enumerate(sampler_output):
if seq_group_metadata_list[i].sampling_params.use_beam_search: if seq_group_metadata_list[i].sampling_params.use_beam_search:
continue continue
for nth_output in sequence_output: for nth_output in sequence_output.samples:
assert nth_output.output_token in expected_tokens assert nth_output.output_token in expected_tokens