Fix empty output when temp is too low (#2937)

Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Chang Su 2024-08-13 22:31:44 -07:00 committed by GitHub
parent 199adbb7cf
commit c134a46402
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 1 deletions

View File

@ -118,8 +118,9 @@ class Sampler(nn.Module):
sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties)
# Apply temperature scaling.
# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
if do_top_p_top_k:

View File

@ -13,6 +13,7 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
_SAMPLING_EPS = 1e-5
_MAX_TEMP = 1e-2
class SamplingType(IntEnum):
@ -145,6 +146,12 @@ class SamplingParams:
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.repetition_penalty = repetition_penalty
if 0 < temperature < _MAX_TEMP:
logger.warning(
"temperature %s is less than %s, which may cause numerical "
"errors nan or inf in tensors. We have maxed it out to %s.",
temperature, _MAX_TEMP, _MAX_TEMP)
temperature = max(temperature, _MAX_TEMP)
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k