Fix sampler test (#1379)
This commit is contained in:
parent
e8ef4c0820
commit
d3a5bd9fb7
@ -1,9 +1,9 @@
|
||||
# pylint: disable=protected-access
|
||||
import pytest
|
||||
import random
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
@ -69,7 +69,7 @@ def test_sampler_all_greedy(seed: int):
|
||||
input_metadata=input_metadata)
|
||||
expected = torch.argmax(fake_logits, dim=-1)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output:
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == expected[i].item()
|
||||
|
||||
|
||||
@ -101,7 +101,7 @@ def test_sampler_all_random(seed: int):
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output:
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == i
|
||||
|
||||
|
||||
@ -181,5 +181,5 @@ def test_sampler_mixed(seed: int):
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
if seq_group_metadata_list[i].sampling_params.use_beam_search:
|
||||
continue
|
||||
for nth_output in sequence_output:
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token in expected_tokens
|
||||
|
||||
Loading…
Reference in New Issue
Block a user