remove floats == 0 comparison (#285)
This commit is contained in:
parent
4338cc4750
commit
425040d4c1
@ -11,6 +11,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import (
|
|||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SequenceOutputs
|
||||||
|
|
||||||
|
_SAMPLING_EPS = 1e-5
|
||||||
|
|
||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
"""Samples the next tokens from the model's outputs.
|
"""Samples the next tokens from the model's outputs.
|
||||||
@ -74,7 +75,7 @@ class Sampler(nn.Module):
|
|||||||
# Apply top-p and top-k truncation.
|
# Apply top-p and top-k truncation.
|
||||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||||
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
||||||
if any(p < 1.0 for p in top_ps) or any(k != self.vocab_size for k in top_ks):
|
if any(p < 1.0 - _SAMPLING_EPS for p in top_ps) or any(k != self.vocab_size for k in top_ks):
|
||||||
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
|
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
|
||||||
|
|
||||||
# Sample the next tokens.
|
# Sample the next tokens.
|
||||||
@ -152,7 +153,7 @@ def _apply_penalties(
|
|||||||
continue
|
continue
|
||||||
p = presence_penalties[i]
|
p = presence_penalties[i]
|
||||||
f = frequency_penalties[i]
|
f = frequency_penalties[i]
|
||||||
if p == 0.0 and f == 0.0:
|
if p < _SAMPLING_EPS and f < _SAMPLING_EPS:
|
||||||
continue
|
continue
|
||||||
indices.append(i)
|
indices.append(i)
|
||||||
|
|
||||||
@ -190,7 +191,7 @@ def _get_temperatures(
|
|||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
temperature = sampling_params.temperature
|
temperature = sampling_params.temperature
|
||||||
if temperature == 0.0:
|
if temperature < _SAMPLING_EPS:
|
||||||
# NOTE: Zero temperature means deterministic sampling
|
# NOTE: Zero temperature means deterministic sampling
|
||||||
# (i.e., greedy sampling or beam search).
|
# (i.e., greedy sampling or beam search).
|
||||||
# Set the temperature to 1 to avoid division by zero.
|
# Set the temperature to 1 to avoid division by zero.
|
||||||
@ -286,7 +287,7 @@ def _sample_from_prompt(
|
|||||||
beam_width = sampling_params.best_of
|
beam_width = sampling_params.best_of
|
||||||
_, next_token_ids = torch.topk(prob, beam_width)
|
_, next_token_ids = torch.topk(prob, beam_width)
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
elif sampling_params.temperature == 0.0:
|
elif sampling_params.temperature < _SAMPLING_EPS:
|
||||||
# Greedy sampling.
|
# Greedy sampling.
|
||||||
assert sampling_params.best_of == 1
|
assert sampling_params.best_of == 1
|
||||||
next_token_id = torch.argmax(prob)
|
next_token_id = torch.argmax(prob)
|
||||||
@ -343,7 +344,7 @@ def _sample_from_generation_tokens(
|
|||||||
|
|
||||||
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
|
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
|
||||||
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
|
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
|
||||||
elif sampling_params.temperature == 0.0:
|
elif sampling_params.temperature < _SAMPLING_EPS:
|
||||||
# Greedy sampling.
|
# Greedy sampling.
|
||||||
assert len(seq_ids) == 1
|
assert len(seq_ids) == 1
|
||||||
next_token_id = torch.argmax(probs, dim=-1)
|
next_token_id = torch.argmax(probs, dim=-1)
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
"""Sampling parameters for text generation."""
|
"""Sampling parameters for text generation."""
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
_SAMPLING_EPS = 1e-5
|
||||||
|
|
||||||
class SamplingParams:
|
class SamplingParams:
|
||||||
"""Sampling parameters for text generation.
|
"""Sampling parameters for text generation.
|
||||||
@ -71,7 +72,7 @@ class SamplingParams:
|
|||||||
self._verify_args()
|
self._verify_args()
|
||||||
if self.use_beam_search:
|
if self.use_beam_search:
|
||||||
self._verity_beam_search()
|
self._verity_beam_search()
|
||||||
elif self.temperature == 0.0:
|
elif self.temperature < _SAMPLING_EPS:
|
||||||
# Zero temperature means greedy sampling.
|
# Zero temperature means greedy sampling.
|
||||||
self._verify_greedy_sampling()
|
self._verify_greedy_sampling()
|
||||||
|
|
||||||
@ -106,9 +107,9 @@ class SamplingParams:
|
|||||||
if self.best_of == 1:
|
if self.best_of == 1:
|
||||||
raise ValueError("best_of must be greater than 1 when using beam "
|
raise ValueError("best_of must be greater than 1 when using beam "
|
||||||
f"search. Got {self.best_of}.")
|
f"search. Got {self.best_of}.")
|
||||||
if self.temperature > 0.0:
|
if self.temperature > _SAMPLING_EPS:
|
||||||
raise ValueError("temperature must be 0 when using beam search.")
|
raise ValueError("temperature must be 0 when using beam search.")
|
||||||
if self.top_p < 1.0:
|
if self.top_p < 1.0 - _SAMPLING_EPS:
|
||||||
raise ValueError("top_p must be 1 when using beam search.")
|
raise ValueError("top_p must be 1 when using beam search.")
|
||||||
if self.top_k != -1:
|
if self.top_k != -1:
|
||||||
raise ValueError("top_k must be -1 when using beam search.")
|
raise ValueError("top_k must be -1 when using beam search.")
|
||||||
@ -117,7 +118,7 @@ class SamplingParams:
|
|||||||
if self.best_of > 1:
|
if self.best_of > 1:
|
||||||
raise ValueError("best_of must be 1 when using greedy sampling."
|
raise ValueError("best_of must be 1 when using greedy sampling."
|
||||||
f"Got {self.best_of}.")
|
f"Got {self.best_of}.")
|
||||||
if self.top_p < 1.0:
|
if self.top_p < 1.0 - _SAMPLING_EPS:
|
||||||
raise ValueError("top_p must be 1 when using greedy sampling.")
|
raise ValueError("top_p must be 1 when using greedy sampling.")
|
||||||
if self.top_k != -1:
|
if self.top_k != -1:
|
||||||
raise ValueError("top_k must be -1 when using greedy sampling.")
|
raise ValueError("top_k must be -1 when using greedy sampling.")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user