Support eos_token_id from generation_config.json (#4182)
This commit is contained in:
parent
8a7a3e4436
commit
a134ef6f5e
@ -1,7 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Iterable, List, Optional, Type, Union
|
from typing import Iterable, List, Optional, Type, Union
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import GenerationConfig, PreTrainedTokenizer
|
||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
||||||
@ -34,6 +34,17 @@ logger = init_logger(__name__)
|
|||||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||||
|
|
||||||
|
|
||||||
|
def _load_generation_config_dict(model_config: ModelConfig):
|
||||||
|
try:
|
||||||
|
return GenerationConfig.from_pretrained(
|
||||||
|
model_config.model,
|
||||||
|
revision=model_config.revision,
|
||||||
|
).to_diff_dict()
|
||||||
|
except OSError:
|
||||||
|
# Not found.
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class LLMEngine:
|
class LLMEngine:
|
||||||
"""An LLM engine that receives requests and generates texts.
|
"""An LLM engine that receives requests and generates texts.
|
||||||
|
|
||||||
@ -124,6 +135,8 @@ class LLMEngine:
|
|||||||
self._init_tokenizer()
|
self._init_tokenizer()
|
||||||
self.detokenizer = Detokenizer(self.tokenizer)
|
self.detokenizer = Detokenizer(self.tokenizer)
|
||||||
self.seq_counter = Counter()
|
self.seq_counter = Counter()
|
||||||
|
self.generation_config_fields = _load_generation_config_dict(
|
||||||
|
model_config)
|
||||||
|
|
||||||
self.model_executor = executor_class(
|
self.model_executor = executor_class(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@ -391,6 +404,8 @@ class LLMEngine:
|
|||||||
# inject the eos token id into the sampling_params to support min_tokens
|
# inject the eos token id into the sampling_params to support min_tokens
|
||||||
# processing
|
# processing
|
||||||
sampling_params.eos_token_id = seq.eos_token_id
|
sampling_params.eos_token_id = seq.eos_token_id
|
||||||
|
sampling_params.update_from_generation_config(
|
||||||
|
self.generation_config_fields)
|
||||||
|
|
||||||
# Create the sequence group.
|
# Create the sequence group.
|
||||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||||
@ -435,7 +450,7 @@ class LLMEngine:
|
|||||||
scheduled_seq_groups: List[SequenceGroup],
|
scheduled_seq_groups: List[SequenceGroup],
|
||||||
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
|
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
|
||||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||||
|
|
||||||
Returns RequestOutputs that can be returned to the client.
|
Returns RequestOutputs that can be returned to the client.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
import copy
|
import copy
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -271,6 +271,18 @@ class SamplingParams:
|
|||||||
raise ValueError("best_of must be 1 when using greedy sampling."
|
raise ValueError("best_of must be 1 when using greedy sampling."
|
||||||
f"Got {self.best_of}.")
|
f"Got {self.best_of}.")
|
||||||
|
|
||||||
|
def update_from_generation_config(
|
||||||
|
self, generation_config: Dict[str, Any]) -> None:
|
||||||
|
"""Update if there are non-default values from generation_config"""
|
||||||
|
# Update eos_token_id for generation
|
||||||
|
if eos_ids := generation_config.get("eos_token_id"):
|
||||||
|
# it can be either int or list of int
|
||||||
|
if isinstance(eos_ids, int):
|
||||||
|
eos_ids = [eos_ids]
|
||||||
|
original_stop_token_ids = set(self.stop_token_ids)
|
||||||
|
original_stop_token_ids.update(eos_ids)
|
||||||
|
self.stop_token_ids = list(original_stop_token_ids)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def sampling_type(self) -> SamplingType:
|
def sampling_type(self) -> SamplingType:
|
||||||
if self.use_beam_search:
|
if self.use_beam_search:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user