Fix sampler test (#1379)
This commit is contained in:
parent
e8ef4c0820
commit
d3a5bd9fb7
@ -1,9 +1,9 @@
|
|||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
import pytest
|
|
||||||
import random
|
import random
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
@ -69,7 +69,7 @@ def test_sampler_all_greedy(seed: int):
|
|||||||
input_metadata=input_metadata)
|
input_metadata=input_metadata)
|
||||||
expected = torch.argmax(fake_logits, dim=-1)
|
expected = torch.argmax(fake_logits, dim=-1)
|
||||||
for i, sequence_output in enumerate(sampler_output):
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
for nth_output in sequence_output:
|
for nth_output in sequence_output.samples:
|
||||||
assert nth_output.output_token == expected[i].item()
|
assert nth_output.output_token == expected[i].item()
|
||||||
|
|
||||||
|
|
||||||
@ -101,7 +101,7 @@ def test_sampler_all_random(seed: int):
|
|||||||
hidden_states=input_tensor,
|
hidden_states=input_tensor,
|
||||||
input_metadata=input_metadata)
|
input_metadata=input_metadata)
|
||||||
for i, sequence_output in enumerate(sampler_output):
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
for nth_output in sequence_output:
|
for nth_output in sequence_output.samples:
|
||||||
assert nth_output.output_token == i
|
assert nth_output.output_token == i
|
||||||
|
|
||||||
|
|
||||||
@ -181,5 +181,5 @@ def test_sampler_mixed(seed: int):
|
|||||||
for i, sequence_output in enumerate(sampler_output):
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
if seq_group_metadata_list[i].sampling_params.use_beam_search:
|
if seq_group_metadata_list[i].sampling_params.use_beam_search:
|
||||||
continue
|
continue
|
||||||
for nth_output in sequence_output:
|
for nth_output in sequence_output.samples:
|
||||||
assert nth_output.output_token in expected_tokens
|
assert nth_output.output_token in expected_tokens
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user