From 7013a80170146ca4f7fd9539e4d7fd852d4718b6 Mon Sep 17 00:00:00 2001 From: Dan Lord Date: Mon, 30 Oct 2023 16:52:56 -0700 Subject: [PATCH] Add support for `spaces_between_special_tokens` --- vllm/engine/llm_engine.py | 6 +++--- vllm/entrypoints/openai/api_server.py | 4 ++++ vllm/entrypoints/openai/protocol.py | 2 ++ vllm/sampling_params.py | 8 +++++++- vllm/transformers_utils/tokenizer.py | 15 ++++++++++++--- 5 files changed, 28 insertions(+), 7 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f0d868d3..c3752b11 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -632,8 +632,7 @@ class LLMEngine: f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") self.last_logging_time = now - def _decode_sequence(self, seq: Sequence, - sampling_params: SamplingParams) -> None: + def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: """Decodes the new token for a sequence.""" (new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally( @@ -642,7 +641,8 @@ class LLMEngine: prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, read_offset=seq.read_offset, - skip_special_tokens=sampling_params.skip_special_tokens, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms.spaces_between_special_tokens, ) if seq.tokens is None: seq.tokens = new_tokens diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 80d6f271..a0adf4d0 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -212,6 +212,7 @@ async def create_chat_completion(request: ChatCompletionRequest, request_id = f"cmpl-{random_uuid()}" created_time = int(time.monotonic()) try: + spaces_between_special_tokens = request.spaces_between_special_tokens sampling_params = SamplingParams( n=request.n, presence_penalty=request.presence_penalty, @@ -226,6 +227,7 @@ async def create_chat_completion(request: ChatCompletionRequest, ignore_eos=request.ignore_eos, use_beam_search=request.use_beam_search, skip_special_tokens=request.skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) @@ -413,6 +415,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): created_time = int(time.monotonic()) try: + spaces_between_special_tokens = request.spaces_between_special_tokens sampling_params = SamplingParams( n=request.n, best_of=request.best_of, @@ -428,6 +431,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): logprobs=request.logprobs, use_beam_search=request.use_beam_search, skip_special_tokens=request.skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 12b7453d..7700c5dd 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -72,6 +72,7 @@ class ChatCompletionRequest(BaseModel): use_beam_search: Optional[bool] = False stop_token_ids: Optional[List[int]] = Field(default_factory=list) skip_special_tokens: Optional[bool] = True + spaces_between_special_tokens: Optional[bool] = True class CompletionRequest(BaseModel): @@ -98,6 +99,7 @@ class CompletionRequest(BaseModel): use_beam_search: Optional[bool] = False stop_token_ids: Optional[List[int]] = Field(default_factory=list) skip_special_tokens: Optional[bool] = True + spaces_between_special_tokens: Optional[bool] = True class LogProbs(BaseModel): diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 7ddefcc9..00a9135a 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -71,6 +71,8 @@ class SamplingParams: `logprobs+1` elements in the response. prompt_logprobs: Number of log probabilities to return per prompt token. skip_special_tokens: Whether to skip special tokens in the output. + spaces_between_special_tokens: Whether to add spaces between special + tokens in the output. Defaults to True. """ def __init__( @@ -93,6 +95,7 @@ class SamplingParams: logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None, skip_special_tokens: bool = True, + spaces_between_special_tokens: bool = True, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -120,6 +123,7 @@ class SamplingParams: self.logprobs = logprobs self.prompt_logprobs = prompt_logprobs self.skip_special_tokens = skip_special_tokens + self.spaces_between_special_tokens = spaces_between_special_tokens self._verify_args() if self.use_beam_search: @@ -222,4 +226,6 @@ class SamplingParams: f"max_tokens={self.max_tokens}, " f"logprobs={self.logprobs}, " f"prompt_logprobs={self.prompt_logprobs}, " - f"skip_special_tokens={self.skip_special_tokens})") + f"skip_special_tokens={self.skip_special_tokens}, " + "spaces_between_special_tokens=" + f"{self.spaces_between_special_tokens})") diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 424c102c..5b048148 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -73,6 +73,7 @@ def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], output_tokens: List[str], skip_special_tokens: bool, + spaces_between_special_tokens: bool, ) -> str: # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 @@ -96,7 +97,10 @@ def _convert_tokens_to_string_with_added_encoders( if current_sub_text: sub_text = tokenizer.convert_tokens_to_string(current_sub_text) sub_texts.append(sub_text) - return " ".join(sub_texts) + if spaces_between_special_tokens: + return " ".join(sub_texts) + else: + return "".join(sub_texts) # Based on @@ -109,6 +113,7 @@ def detokenize_incrementally( prefix_offset: int = 0, read_offset: int = 0, skip_special_tokens: bool = False, + spaces_between_special_tokens: bool = True, ) -> Tuple[List[str], str, int, int]: new_token_id = all_input_ids[-1] # This is the first iteration for this sequence @@ -143,11 +148,15 @@ def detokenize_incrementally( prefix_text = _convert_tokens_to_string_with_added_encoders( tokenizer, output_tokens[prefix_offset:read_offset], - skip_special_tokens=skip_special_tokens) + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) new_text = _convert_tokens_to_string_with_added_encoders( tokenizer, output_tokens[prefix_offset:], - skip_special_tokens=skip_special_tokens) + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) if len(new_text) > len(prefix_text) and not new_text.endswith("�"): # utf-8 char at the end means it's a potential unfinished byte sequence