Support eos_token_id from generation_config.json (#4182)

This commit is contained in:
Simon Mo 2024-04-18 21:13:36 -07:00 committed by GitHub
parent 8a7a3e4436
commit a134ef6f5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 3 deletions

View File

@ -1,7 +1,7 @@
import time
from typing import Iterable, List, Optional, Type, Union
from transformers import PreTrainedTokenizer
from transformers import GenerationConfig, PreTrainedTokenizer
import vllm
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
@ -34,6 +34,17 @@ logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
def _load_generation_config_dict(model_config: ModelConfig):
try:
return GenerationConfig.from_pretrained(
model_config.model,
revision=model_config.revision,
).to_diff_dict()
except OSError:
# Not found.
return {}
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
@ -124,6 +135,8 @@ class LLMEngine:
self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.model_executor = executor_class(
model_config=model_config,
@ -391,6 +404,8 @@ class LLMEngine:
# inject the eos token id into the sampling_params to support min_tokens
# processing
sampling_params.eos_token_id = seq.eos_token_id
sampling_params.update_from_generation_config(
self.generation_config_fields)
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
@ -435,7 +450,7 @@ class LLMEngine:
scheduled_seq_groups: List[SequenceGroup],
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
"""Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
"""

View File

@ -2,7 +2,7 @@
import copy
from enum import IntEnum
from functools import cached_property
from typing import Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from pydantic import Field
@ -271,6 +271,18 @@ class SamplingParams:
raise ValueError("best_of must be 1 when using greedy sampling."
f"Got {self.best_of}.")
def update_from_generation_config(
self, generation_config: Dict[str, Any]) -> None:
"""Update if there are non-default values from generation_config"""
# Update eos_token_id for generation
if eos_ids := generation_config.get("eos_token_id"):
# it can be either int or list of int
if isinstance(eos_ids, int):
eos_ids = [eos_ids]
original_stop_token_ids = set(self.stop_token_ids)
original_stop_token_ids.update(eos_ids)
self.stop_token_ids = list(original_stop_token_ids)
@cached_property
def sampling_type(self) -> SamplingType:
if self.use_beam_search: