[Core] Comment out unused code in sampler (#7023)
This commit is contained in:
parent
660dea1235
commit
db35186391
@ -13,6 +13,8 @@ from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
_SEED_0_REPLACEMENT = 3403598558
|
||||
# Some triton sampler related code is guarded before it is ready.
|
||||
_USE_TRITON_SAMPLER = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -347,11 +349,13 @@ class SamplingTensors:
|
||||
repetition_penalties: List[float] = []
|
||||
sampling_seeds: List[int] = []
|
||||
sample_indices: List[int] = []
|
||||
prompt_best_of: List[int] = []
|
||||
do_penalties = False
|
||||
do_top_p_top_k = False
|
||||
do_min_p = False
|
||||
|
||||
if _USE_TRITON_SAMPLER:
|
||||
prompt_best_of: List[int] = []
|
||||
|
||||
# We need one base seed per Triton slice.
|
||||
seeds_to_generate = (extra_seeds_to_generate +
|
||||
get_num_triton_sampler_splits(vocab_size))
|
||||
@ -366,9 +370,6 @@ class SamplingTensors:
|
||||
r = sampling_params.repetition_penalty
|
||||
top_p = sampling_params.top_p
|
||||
min_p = sampling_params.min_p
|
||||
seed = sampling_params.seed
|
||||
|
||||
is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
|
||||
|
||||
# k should not be greater than the vocab size.
|
||||
top_k = min(sampling_params.top_k, vocab_size)
|
||||
@ -389,8 +390,7 @@ class SamplingTensors:
|
||||
do_penalties = True
|
||||
|
||||
is_prompt = seq_group.is_prompt
|
||||
if (seq_group.is_prompt
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
if (is_prompt and sampling_params.prompt_logprobs is not None):
|
||||
# For tokens in the prompt that we only need to get
|
||||
# their logprobs
|
||||
query_len = seq_group.query_len
|
||||
@ -415,11 +415,15 @@ class SamplingTensors:
|
||||
frequency_penalties += [f] * len(seq_ids)
|
||||
repetition_penalties += [r] * len(seq_ids)
|
||||
|
||||
if _USE_TRITON_SAMPLER:
|
||||
if is_prompt:
|
||||
prompt_best_of.append(sampling_params.best_of)
|
||||
query_len = seq_group.query_len
|
||||
assert query_len is not None
|
||||
|
||||
seed = sampling_params.seed
|
||||
is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group.seq_data[seq_id]
|
||||
extra_entropy = extra_entropy or ()
|
||||
@ -549,7 +553,7 @@ class SamplingTensors:
|
||||
device="cpu",
|
||||
dtype=torch.long,
|
||||
pin_memory=pin_memory,
|
||||
).T.contiguous()
|
||||
).t().contiguous()
|
||||
|
||||
# Because the memory is pinned, we can do non-blocking
|
||||
# transfer to device.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user