Move max_context_len

This commit is contained in:
Woosuk Kwon 2023-02-23 04:57:46 +00:00
parent 4b1ac23f53
commit d094512296
2 changed files with 5 additions and 4 deletions

View File

@ -1,4 +1,4 @@
from typing import List from typing import Optional, Set
class DecodingParams: class DecodingParams:
@ -9,7 +9,8 @@ class DecodingParams:
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
use_beam_search: bool = False, use_beam_search: bool = False,
stop_token_ids: List[int] = [], stop_token_ids: Set[int] = [],
max_context_len: Optional[int] = None,
) -> None: ) -> None:
assert n >= 1 assert n >= 1
assert temperature >= 0.0 assert temperature >= 0.0
@ -22,9 +23,11 @@ class DecodingParams:
# Zero temperature means greedy decoding. # Zero temperature means greedy decoding.
assert n == 1 assert n == 1
assert top_p == 1.0 assert top_p == 1.0
assert max_context_len is None or max_context_len >= 0
self.n = n self.n = n
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
self.use_beam_search = use_beam_search self.use_beam_search = use_beam_search
self.stop_token_ids = stop_token_ids self.stop_token_ids = stop_token_ids
self.max_context_len = max_context_len

View File

@ -17,11 +17,9 @@ class Sequence:
self, self,
seq_id: int, seq_id: int,
token_ids: List[int], token_ids: List[int],
max_context_len: int,
block_size: int, block_size: int,
) -> None: ) -> None:
self.seq_id = seq_id self.seq_id = seq_id
self.max_context_len = max_context_len
self.block_size = block_size self.block_size = block_size
self.logical_token_blocks: List[LogicalTokenBlock] = [] self.logical_token_blocks: List[LogicalTokenBlock] = []