diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7fb1af15..0ad46cbe 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -234,15 +234,22 @@ class ChatCompletionRequest(OpenAIBaseModel): logits_processors = None if self.logit_bias: + logit_bias: Dict[int, float] = {} + try: + for token_id, bias in self.logit_bias.items(): + # Convert token_id to integer before we add to LLMEngine + # Clamp the bias between -100 and 100 per OpenAI API spec + logit_bias[int(token_id)] = min(100, max(-100, bias)) + except ValueError as exc: + raise ValueError(f"Found token_id `{token_id}` in logit_bias " + f"but token_id must be an integer or string " + f"representing an integer") from exc def logit_bias_logits_processor( token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: - assert self.logit_bias is not None - for token_id, bias in self.logit_bias.items(): - # Clamp the bias between -100 and 100 per OpenAI API spec - bias = min(100, max(-100, bias)) - logits[int(token_id)] += bias + for token_id, bias in logit_bias.items(): + logits[token_id] += bias return logits logits_processors = [logit_bias_logits_processor] @@ -419,15 +426,22 @@ class CompletionRequest(OpenAIBaseModel): logits_processors = None if self.logit_bias: + logit_bias: Dict[int, float] = {} + try: + for token_id, bias in self.logit_bias.items(): + # Convert token_id to integer + # Clamp the bias between -100 and 100 per OpenAI API spec + logit_bias[int(token_id)] = min(100, max(-100, bias)) + except ValueError as exc: + raise ValueError(f"Found token_id `{token_id}` in logit_bias " + f"but token_id must be an integer or string " + f"representing an integer") from exc def logit_bias_logits_processor( token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: - assert self.logit_bias is not None - for token_id, bias in self.logit_bias.items(): - # Clamp the bias between -100 and 100 per OpenAI API spec - bias = min(100, max(-100, bias)) - logits[int(token_id)] += bias + for token_id, bias in logit_bias.items(): + logits[token_id] += bias return logits logits_processors = [logit_bias_logits_processor]