[Core] Comment out unused code in sampler (#7023)

This commit is contained in:
Peng Guanwen 2024-08-02 15:58:26 +08:00 committed by GitHub
parent 660dea1235
commit db35186391
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,6 +13,8 @@ from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558 _SEED_0_REPLACEMENT = 3403598558
# Some triton sampler related code is guarded before it is ready.
_USE_TRITON_SAMPLER = False
@dataclass @dataclass
@ -347,11 +349,13 @@ class SamplingTensors:
repetition_penalties: List[float] = [] repetition_penalties: List[float] = []
sampling_seeds: List[int] = [] sampling_seeds: List[int] = []
sample_indices: List[int] = [] sample_indices: List[int] = []
prompt_best_of: List[int] = []
do_penalties = False do_penalties = False
do_top_p_top_k = False do_top_p_top_k = False
do_min_p = False do_min_p = False
if _USE_TRITON_SAMPLER:
prompt_best_of: List[int] = []
# We need one base seed per Triton slice. # We need one base seed per Triton slice.
seeds_to_generate = (extra_seeds_to_generate + seeds_to_generate = (extra_seeds_to_generate +
get_num_triton_sampler_splits(vocab_size)) get_num_triton_sampler_splits(vocab_size))
@ -366,9 +370,6 @@ class SamplingTensors:
r = sampling_params.repetition_penalty r = sampling_params.repetition_penalty
top_p = sampling_params.top_p top_p = sampling_params.top_p
min_p = sampling_params.min_p min_p = sampling_params.min_p
seed = sampling_params.seed
is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
# k should not be greater than the vocab size. # k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size) top_k = min(sampling_params.top_k, vocab_size)
@ -389,8 +390,7 @@ class SamplingTensors:
do_penalties = True do_penalties = True
is_prompt = seq_group.is_prompt is_prompt = seq_group.is_prompt
if (seq_group.is_prompt if (is_prompt and sampling_params.prompt_logprobs is not None):
and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get # For tokens in the prompt that we only need to get
# their logprobs # their logprobs
query_len = seq_group.query_len query_len = seq_group.query_len
@ -415,11 +415,15 @@ class SamplingTensors:
frequency_penalties += [f] * len(seq_ids) frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids) repetition_penalties += [r] * len(seq_ids)
if _USE_TRITON_SAMPLER:
if is_prompt: if is_prompt:
prompt_best_of.append(sampling_params.best_of) prompt_best_of.append(sampling_params.best_of)
query_len = seq_group.query_len query_len = seq_group.query_len
assert query_len is not None assert query_len is not None
seed = sampling_params.seed
is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id] seq_data = seq_group.seq_data[seq_id]
extra_entropy = extra_entropy or () extra_entropy = extra_entropy or ()
@ -549,7 +553,7 @@ class SamplingTensors:
device="cpu", device="cpu",
dtype=torch.long, dtype=torch.long,
pin_memory=pin_memory, pin_memory=pin_memory,
).T.contiguous() ).t().contiguous()
# Because the memory is pinned, we can do non-blocking # Because the memory is pinned, we can do non-blocking
# transfer to device. # transfer to device.