Add support for spaces_between_special_tokens

This commit is contained in:
Dan Lord 2023-10-30 16:52:56 -07:00 committed by GitHub
parent 79a30912b8
commit 7013a80170
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 28 additions and 7 deletions

View File

@ -632,8 +632,7 @@ class LLMEngine:
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now
def _decode_sequence(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
"""Decodes the new token for a sequence."""
(new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally(
@ -642,7 +641,8 @@ class LLMEngine:
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=sampling_params.skip_special_tokens,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.spaces_between_special_tokens,
)
if seq.tokens is None:
seq.tokens = new_tokens

View File

@ -212,6 +212,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.monotonic())
try:
spaces_between_special_tokens = request.spaces_between_special_tokens
sampling_params = SamplingParams(
n=request.n,
presence_penalty=request.presence_penalty,
@ -226,6 +227,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
@ -413,6 +415,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
created_time = int(time.monotonic())
try:
spaces_between_special_tokens = request.spaces_between_special_tokens
sampling_params = SamplingParams(
n=request.n,
best_of=request.best_of,
@ -428,6 +431,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))

View File

@ -72,6 +72,7 @@ class ChatCompletionRequest(BaseModel):
use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
class CompletionRequest(BaseModel):
@ -98,6 +99,7 @@ class CompletionRequest(BaseModel):
use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
class LogProbs(BaseModel):

View File

@ -71,6 +71,8 @@ class SamplingParams:
`logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token.
skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
"""
def __init__(
@ -93,6 +95,7 @@ class SamplingParams:
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
@ -120,6 +123,7 @@ class SamplingParams:
self.logprobs = logprobs
self.prompt_logprobs = prompt_logprobs
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self._verify_args()
if self.use_beam_search:
@ -222,4 +226,6 @@ class SamplingParams:
f"max_tokens={self.max_tokens}, "
f"logprobs={self.logprobs}, "
f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens})")
f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens})")

View File

@ -73,6 +73,7 @@ def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> str:
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
@ -96,7 +97,10 @@ def _convert_tokens_to_string_with_added_encoders(
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
return " ".join(sub_texts)
if spaces_between_special_tokens:
return " ".join(sub_texts)
else:
return "".join(sub_texts)
# Based on
@ -109,6 +113,7 @@ def detokenize_incrementally(
prefix_offset: int = 0,
read_offset: int = 0,
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> Tuple[List[str], str, int, int]:
new_token_id = all_input_ids[-1]
# This is the first iteration for this sequence
@ -143,11 +148,15 @@ def detokenize_incrementally(
prefix_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens)
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
new_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:],
skip_special_tokens=skip_special_tokens)
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
# utf-8 char at the end means it's a potential unfinished byte sequence