Add support for spaces_between_special_tokens
This commit is contained in:
parent
79a30912b8
commit
7013a80170
@ -632,8 +632,7 @@ class LLMEngine:
|
|||||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||||
self.last_logging_time = now
|
self.last_logging_time = now
|
||||||
|
|
||||||
def _decode_sequence(self, seq: Sequence,
|
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
|
||||||
sampling_params: SamplingParams) -> None:
|
|
||||||
"""Decodes the new token for a sequence."""
|
"""Decodes the new token for a sequence."""
|
||||||
(new_tokens, new_output_text, prefix_offset,
|
(new_tokens, new_output_text, prefix_offset,
|
||||||
read_offset) = detokenize_incrementally(
|
read_offset) = detokenize_incrementally(
|
||||||
@ -642,7 +641,8 @@ class LLMEngine:
|
|||||||
prev_tokens=seq.tokens,
|
prev_tokens=seq.tokens,
|
||||||
prefix_offset=seq.prefix_offset,
|
prefix_offset=seq.prefix_offset,
|
||||||
read_offset=seq.read_offset,
|
read_offset=seq.read_offset,
|
||||||
skip_special_tokens=sampling_params.skip_special_tokens,
|
skip_special_tokens=prms.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=prms.spaces_between_special_tokens,
|
||||||
)
|
)
|
||||||
if seq.tokens is None:
|
if seq.tokens is None:
|
||||||
seq.tokens = new_tokens
|
seq.tokens = new_tokens
|
||||||
|
|||||||
@ -212,6 +212,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
request_id = f"cmpl-{random_uuid()}"
|
request_id = f"cmpl-{random_uuid()}"
|
||||||
created_time = int(time.monotonic())
|
created_time = int(time.monotonic())
|
||||||
try:
|
try:
|
||||||
|
spaces_between_special_tokens = request.spaces_between_special_tokens
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
n=request.n,
|
n=request.n,
|
||||||
presence_penalty=request.presence_penalty,
|
presence_penalty=request.presence_penalty,
|
||||||
@ -226,6 +227,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
ignore_eos=request.ignore_eos,
|
ignore_eos=request.ignore_eos,
|
||||||
use_beam_search=request.use_beam_search,
|
use_beam_search=request.use_beam_search,
|
||||||
skip_special_tokens=request.skip_special_tokens,
|
skip_special_tokens=request.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
@ -413,6 +415,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
|
|
||||||
created_time = int(time.monotonic())
|
created_time = int(time.monotonic())
|
||||||
try:
|
try:
|
||||||
|
spaces_between_special_tokens = request.spaces_between_special_tokens
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
n=request.n,
|
n=request.n,
|
||||||
best_of=request.best_of,
|
best_of=request.best_of,
|
||||||
@ -428,6 +431,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
logprobs=request.logprobs,
|
logprobs=request.logprobs,
|
||||||
use_beam_search=request.use_beam_search,
|
use_beam_search=request.use_beam_search,
|
||||||
skip_special_tokens=request.skip_special_tokens,
|
skip_special_tokens=request.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
|
|||||||
@ -72,6 +72,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
use_beam_search: Optional[bool] = False
|
use_beam_search: Optional[bool] = False
|
||||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
skip_special_tokens: Optional[bool] = True
|
skip_special_tokens: Optional[bool] = True
|
||||||
|
spaces_between_special_tokens: Optional[bool] = True
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
@ -98,6 +99,7 @@ class CompletionRequest(BaseModel):
|
|||||||
use_beam_search: Optional[bool] = False
|
use_beam_search: Optional[bool] = False
|
||||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
skip_special_tokens: Optional[bool] = True
|
skip_special_tokens: Optional[bool] = True
|
||||||
|
spaces_between_special_tokens: Optional[bool] = True
|
||||||
|
|
||||||
|
|
||||||
class LogProbs(BaseModel):
|
class LogProbs(BaseModel):
|
||||||
|
|||||||
@ -71,6 +71,8 @@ class SamplingParams:
|
|||||||
`logprobs+1` elements in the response.
|
`logprobs+1` elements in the response.
|
||||||
prompt_logprobs: Number of log probabilities to return per prompt token.
|
prompt_logprobs: Number of log probabilities to return per prompt token.
|
||||||
skip_special_tokens: Whether to skip special tokens in the output.
|
skip_special_tokens: Whether to skip special tokens in the output.
|
||||||
|
spaces_between_special_tokens: Whether to add spaces between special
|
||||||
|
tokens in the output. Defaults to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -93,6 +95,7 @@ class SamplingParams:
|
|||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
prompt_logprobs: Optional[int] = None,
|
prompt_logprobs: Optional[int] = None,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
|
spaces_between_special_tokens: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.n = n
|
self.n = n
|
||||||
self.best_of = best_of if best_of is not None else n
|
self.best_of = best_of if best_of is not None else n
|
||||||
@ -120,6 +123,7 @@ class SamplingParams:
|
|||||||
self.logprobs = logprobs
|
self.logprobs = logprobs
|
||||||
self.prompt_logprobs = prompt_logprobs
|
self.prompt_logprobs = prompt_logprobs
|
||||||
self.skip_special_tokens = skip_special_tokens
|
self.skip_special_tokens = skip_special_tokens
|
||||||
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||||
|
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
if self.use_beam_search:
|
if self.use_beam_search:
|
||||||
@ -222,4 +226,6 @@ class SamplingParams:
|
|||||||
f"max_tokens={self.max_tokens}, "
|
f"max_tokens={self.max_tokens}, "
|
||||||
f"logprobs={self.logprobs}, "
|
f"logprobs={self.logprobs}, "
|
||||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||||
f"skip_special_tokens={self.skip_special_tokens})")
|
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||||
|
"spaces_between_special_tokens="
|
||||||
|
f"{self.spaces_between_special_tokens})")
|
||||||
|
|||||||
@ -73,6 +73,7 @@ def _convert_tokens_to_string_with_added_encoders(
|
|||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
output_tokens: List[str],
|
output_tokens: List[str],
|
||||||
skip_special_tokens: bool,
|
skip_special_tokens: bool,
|
||||||
|
spaces_between_special_tokens: bool,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
|
||||||
@ -96,7 +97,10 @@ def _convert_tokens_to_string_with_added_encoders(
|
|||||||
if current_sub_text:
|
if current_sub_text:
|
||||||
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
||||||
sub_texts.append(sub_text)
|
sub_texts.append(sub_text)
|
||||||
return " ".join(sub_texts)
|
if spaces_between_special_tokens:
|
||||||
|
return " ".join(sub_texts)
|
||||||
|
else:
|
||||||
|
return "".join(sub_texts)
|
||||||
|
|
||||||
|
|
||||||
# Based on
|
# Based on
|
||||||
@ -109,6 +113,7 @@ def detokenize_incrementally(
|
|||||||
prefix_offset: int = 0,
|
prefix_offset: int = 0,
|
||||||
read_offset: int = 0,
|
read_offset: int = 0,
|
||||||
skip_special_tokens: bool = False,
|
skip_special_tokens: bool = False,
|
||||||
|
spaces_between_special_tokens: bool = True,
|
||||||
) -> Tuple[List[str], str, int, int]:
|
) -> Tuple[List[str], str, int, int]:
|
||||||
new_token_id = all_input_ids[-1]
|
new_token_id = all_input_ids[-1]
|
||||||
# This is the first iteration for this sequence
|
# This is the first iteration for this sequence
|
||||||
@ -143,11 +148,15 @@ def detokenize_incrementally(
|
|||||||
prefix_text = _convert_tokens_to_string_with_added_encoders(
|
prefix_text = _convert_tokens_to_string_with_added_encoders(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
output_tokens[prefix_offset:read_offset],
|
output_tokens[prefix_offset:read_offset],
|
||||||
skip_special_tokens=skip_special_tokens)
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||||
|
)
|
||||||
new_text = _convert_tokens_to_string_with_added_encoders(
|
new_text = _convert_tokens_to_string_with_added_encoders(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
output_tokens[prefix_offset:],
|
output_tokens[prefix_offset:],
|
||||||
skip_special_tokens=skip_special_tokens)
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
||||||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user