Support eos_token_id from generation_config.json (#4182)

This commit is contained in:
Simon Mo 2024-04-18 21:13:36 -07:00 committed by GitHub
parent 8a7a3e4436
commit a134ef6f5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 3 deletions

View File

@ -1,7 +1,7 @@
import time import time
from typing import Iterable, List, Optional, Type, Union from typing import Iterable, List, Optional, Type, Union
from transformers import PreTrainedTokenizer from transformers import GenerationConfig, PreTrainedTokenizer
import vllm import vllm
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
@ -34,6 +34,17 @@ logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5 _LOCAL_LOGGING_INTERVAL_SEC = 5
def _load_generation_config_dict(model_config: ModelConfig):
try:
return GenerationConfig.from_pretrained(
model_config.model,
revision=model_config.revision,
).to_diff_dict()
except OSError:
# Not found.
return {}
class LLMEngine: class LLMEngine:
"""An LLM engine that receives requests and generates texts. """An LLM engine that receives requests and generates texts.
@ -124,6 +135,8 @@ class LLMEngine:
self._init_tokenizer() self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer) self.detokenizer = Detokenizer(self.tokenizer)
self.seq_counter = Counter() self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.model_executor = executor_class( self.model_executor = executor_class(
model_config=model_config, model_config=model_config,
@ -391,6 +404,8 @@ class LLMEngine:
# inject the eos token id into the sampling_params to support min_tokens # inject the eos token id into the sampling_params to support min_tokens
# processing # processing
sampling_params.eos_token_id = seq.eos_token_id sampling_params.eos_token_id = seq.eos_token_id
sampling_params.update_from_generation_config(
self.generation_config_fields)
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params, seq_group = SequenceGroup(request_id, [seq], sampling_params,

View File

@ -2,7 +2,7 @@
import copy import copy
from enum import IntEnum from enum import IntEnum
from functools import cached_property from functools import cached_property
from typing import Callable, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from pydantic import Field from pydantic import Field
@ -271,6 +271,18 @@ class SamplingParams:
raise ValueError("best_of must be 1 when using greedy sampling." raise ValueError("best_of must be 1 when using greedy sampling."
f"Got {self.best_of}.") f"Got {self.best_of}.")
def update_from_generation_config(
self, generation_config: Dict[str, Any]) -> None:
"""Update if there are non-default values from generation_config"""
# Update eos_token_id for generation
if eos_ids := generation_config.get("eos_token_id"):
# it can be either int or list of int
if isinstance(eos_ids, int):
eos_ids = [eos_ids]
original_stop_token_ids = set(self.stop_token_ids)
original_stop_token_ids.update(eos_ids)
self.stop_token_ids = list(original_stop_token_ids)
@cached_property @cached_property
def sampling_type(self) -> SamplingType: def sampling_type(self) -> SamplingType:
if self.use_beam_search: if self.use_beam_search: