From c134a4640214e760a8480b6d80b2038bf9ba3a8e Mon Sep 17 00:00:00 2001 From: Chang Su Date: Tue, 13 Aug 2024 22:31:44 -0700 Subject: [PATCH] Fix empty output when temp is too low (#2937) Co-authored-by: Cyrus Leung --- vllm/model_executor/layers/sampler.py | 3 ++- vllm/sampling_params.py | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index cc78a0ea..41abdf21 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -118,8 +118,9 @@ class Sampler(nn.Module): sampling_tensors.frequency_penalties, sampling_tensors.repetition_penalties) - # Apply temperature scaling. + # Use float32 to apply temperature scaling. # Use in-place division to avoid creating a new tensor. + logits = logits.to(torch.float) logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) if do_top_p_top_k: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 1c1e5f16..04250c68 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -13,6 +13,7 @@ from vllm.logger import init_logger logger = init_logger(__name__) _SAMPLING_EPS = 1e-5 +_MAX_TEMP = 1e-2 class SamplingType(IntEnum): @@ -145,6 +146,12 @@ class SamplingParams: self.presence_penalty = presence_penalty self.frequency_penalty = frequency_penalty self.repetition_penalty = repetition_penalty + if 0 < temperature < _MAX_TEMP: + logger.warning( + "temperature %s is less than %s, which may cause numerical " + "errors nan or inf in tensors. We have maxed it out to %s.", + temperature, _MAX_TEMP, _MAX_TEMP) + temperature = max(temperature, _MAX_TEMP) self.temperature = temperature self.top_p = top_p self.top_k = top_k