diff --git a/cacheflow/decoding.py b/cacheflow/decoding.py new file mode 100644 index 00000000..f8d5980d --- /dev/null +++ b/cacheflow/decoding.py @@ -0,0 +1,30 @@ +from typing import List + + +class DecodingParams: + + def __init__( + self, + n: int = 1, + temperature: float = 1.0, + top_p: float = 1.0, + use_beam_search: bool = False, + stop_token_ids: List[int] = [], + ) -> None: + assert n >= 1 + assert temperature >= 0.0 + assert 0.0 < top_p <= 1.0 + if use_beam_search: + assert n > 1 + assert temperature > 0.0 + assert top_p == 1.0 + elif temperature == 0.0: + # Zero temperature means greedy decoding. + assert n == 1 + assert top_p == 1.0 + + self.n = n + self.temperature = temperature + self.top_p = top_p + self.use_beam_search = use_beam_search + self.stop_token_ids = stop_token_ids