Support eos_token_id from generation_config.json (#4182)
This commit is contained in:
parent
8a7a3e4436
commit
a134ef6f5e
@ -1,7 +1,7 @@
|
||||
import time
|
||||
from typing import Iterable, List, Optional, Type, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers import GenerationConfig, PreTrainedTokenizer
|
||||
|
||||
import vllm
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
||||
@ -34,6 +34,17 @@ logger = init_logger(__name__)
|
||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||
|
||||
|
||||
def _load_generation_config_dict(model_config: ModelConfig):
|
||||
try:
|
||||
return GenerationConfig.from_pretrained(
|
||||
model_config.model,
|
||||
revision=model_config.revision,
|
||||
).to_diff_dict()
|
||||
except OSError:
|
||||
# Not found.
|
||||
return {}
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""An LLM engine that receives requests and generates texts.
|
||||
|
||||
@ -124,6 +135,8 @@ class LLMEngine:
|
||||
self._init_tokenizer()
|
||||
self.detokenizer = Detokenizer(self.tokenizer)
|
||||
self.seq_counter = Counter()
|
||||
self.generation_config_fields = _load_generation_config_dict(
|
||||
model_config)
|
||||
|
||||
self.model_executor = executor_class(
|
||||
model_config=model_config,
|
||||
@ -391,6 +404,8 @@ class LLMEngine:
|
||||
# inject the eos token id into the sampling_params to support min_tokens
|
||||
# processing
|
||||
sampling_params.eos_token_id = seq.eos_token_id
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields)
|
||||
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||
@ -435,7 +450,7 @@ class LLMEngine:
|
||||
scheduled_seq_groups: List[SequenceGroup],
|
||||
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
|
||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||
|
||||
|
||||
Returns RequestOutputs that can be returned to the client.
|
||||
"""
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
import copy
|
||||
from enum import IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
@ -271,6 +271,18 @@ class SamplingParams:
|
||||
raise ValueError("best_of must be 1 when using greedy sampling."
|
||||
f"Got {self.best_of}.")
|
||||
|
||||
def update_from_generation_config(
|
||||
self, generation_config: Dict[str, Any]) -> None:
|
||||
"""Update if there are non-default values from generation_config"""
|
||||
# Update eos_token_id for generation
|
||||
if eos_ids := generation_config.get("eos_token_id"):
|
||||
# it can be either int or list of int
|
||||
if isinstance(eos_ids, int):
|
||||
eos_ids = [eos_ids]
|
||||
original_stop_token_ids = set(self.stop_token_ids)
|
||||
original_stop_token_ids.update(eos_ids)
|
||||
self.stop_token_ids = list(original_stop_token_ids)
|
||||
|
||||
@cached_property
|
||||
def sampling_type(self) -> SamplingType:
|
||||
if self.use_beam_search:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user