Fix empty output when temp is too low (#2937)
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
199adbb7cf
commit
c134a46402
@ -118,8 +118,9 @@ class Sampler(nn.Module):
|
||||
sampling_tensors.frequency_penalties,
|
||||
sampling_tensors.repetition_penalties)
|
||||
|
||||
# Apply temperature scaling.
|
||||
# Use float32 to apply temperature scaling.
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
logits = logits.to(torch.float)
|
||||
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
|
||||
|
||||
if do_top_p_top_k:
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
_MAX_TEMP = 1e-2
|
||||
|
||||
|
||||
class SamplingType(IntEnum):
|
||||
@ -145,6 +146,12 @@ class SamplingParams:
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.repetition_penalty = repetition_penalty
|
||||
if 0 < temperature < _MAX_TEMP:
|
||||
logger.warning(
|
||||
"temperature %s is less than %s, which may cause numerical "
|
||||
"errors nan or inf in tensors. We have maxed it out to %s.",
|
||||
temperature, _MAX_TEMP, _MAX_TEMP)
|
||||
temperature = max(temperature, _MAX_TEMP)
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
|
||||
Loading…
Reference in New Issue
Block a user