From d3a5bd9fb7d778c2f2f74bcf8d5343f185f69b61 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 16 Oct 2023 12:57:26 -0700 Subject: [PATCH] Fix sampler test (#1379) --- tests/samplers/test_sampler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 74c819ef..c4d33711 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,9 +1,9 @@ # pylint: disable=protected-access -import pytest import random from typing import Tuple from unittest.mock import patch +import pytest import torch from vllm.model_executor.layers.sampler import Sampler @@ -69,7 +69,7 @@ def test_sampler_all_greedy(seed: int): input_metadata=input_metadata) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): - for nth_output in sequence_output: + for nth_output in sequence_output.samples: assert nth_output.output_token == expected[i].item() @@ -101,7 +101,7 @@ def test_sampler_all_random(seed: int): hidden_states=input_tensor, input_metadata=input_metadata) for i, sequence_output in enumerate(sampler_output): - for nth_output in sequence_output: + for nth_output in sequence_output.samples: assert nth_output.output_token == i @@ -181,5 +181,5 @@ def test_sampler_mixed(seed: int): for i, sequence_output in enumerate(sampler_output): if seq_group_metadata_list[i].sampling_params.use_beam_search: continue - for nth_output in sequence_output: + for nth_output in sequence_output.samples: assert nth_output.output_token in expected_tokens