Fix sampler test (#1379)

This commit is contained in:
Woosuk Kwon 2023-10-16 12:57:26 -07:00 committed by GitHub
parent e8ef4c0820
commit d3a5bd9fb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,9 +1,9 @@
# pylint: disable=protected-access
import pytest
import random
from typing import Tuple
from unittest.mock import patch
import pytest
import torch
from vllm.model_executor.layers.sampler import Sampler
@ -69,7 +69,7 @@ def test_sampler_all_greedy(seed: int):
input_metadata=input_metadata)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
for nth_output in sequence_output.samples:
assert nth_output.output_token == expected[i].item()
@ -101,7 +101,7 @@ def test_sampler_all_random(seed: int):
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
for nth_output in sequence_output.samples:
assert nth_output.output_token == i
@ -181,5 +181,5 @@ def test_sampler_mixed(seed: int):
for i, sequence_output in enumerate(sampler_output):
if seq_group_metadata_list[i].sampling_params.use_beam_search:
continue
for nth_output in sequence_output:
for nth_output in sequence_output.samples:
assert nth_output.output_token in expected_tokens