From d094512296ad18968efffd925c372533e9dd12e3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 23 Feb 2023 04:57:46 +0000 Subject: [PATCH] Move max_context_len --- cacheflow/decoding.py | 7 +++++-- cacheflow/sequence.py | 2 -- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cacheflow/decoding.py b/cacheflow/decoding.py index f8d5980d..c4c4a13b 100644 --- a/cacheflow/decoding.py +++ b/cacheflow/decoding.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Optional, Set class DecodingParams: @@ -9,7 +9,8 @@ class DecodingParams: temperature: float = 1.0, top_p: float = 1.0, use_beam_search: bool = False, - stop_token_ids: List[int] = [], + stop_token_ids: Set[int] = [], + max_context_len: Optional[int] = None, ) -> None: assert n >= 1 assert temperature >= 0.0 @@ -22,9 +23,11 @@ class DecodingParams: # Zero temperature means greedy decoding. assert n == 1 assert top_p == 1.0 + assert max_context_len is None or max_context_len >= 0 self.n = n self.temperature = temperature self.top_p = top_p self.use_beam_search = use_beam_search self.stop_token_ids = stop_token_ids + self.max_context_len = max_context_len diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index 8a6f2908..a4fd551c 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -17,11 +17,9 @@ class Sequence: self, seq_id: int, token_ids: List[int], - max_context_len: int, block_size: int, ) -> None: self.seq_id = seq_id - self.max_context_len = max_context_len self.block_size = block_size self.logical_token_blocks: List[LogicalTokenBlock] = []