Add __repr__
This commit is contained in:
parent
2729087efe
commit
3363c27d19
@ -49,3 +49,8 @@ class PhysicalTokenBlock:
|
|||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
|
|
||||||
self.ref_count = 0
|
self.ref_count = 0
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f'PhysicalTokenBlock(device={self.device}, '
|
||||||
|
f'block_number={self.block_number}, '
|
||||||
|
f'ref_count={self.ref_count})')
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import enum
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from cacheflow.block import LogicalTokenBlock
|
from cacheflow.block import LogicalTokenBlock
|
||||||
from cacheflow.decoding import DecodingParams
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceStatus(enum.Enum):
|
class SequenceStatus(enum.Enum):
|
||||||
@ -58,6 +57,11 @@ class Sequence:
|
|||||||
token_ids.extend(block.get_token_ids())
|
token_ids.extend(block.get_token_ids())
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f'Sequence(seq_id={self.seq_id}, '
|
||||||
|
f'status={self.status.name}, '
|
||||||
|
f'num_blocks={len(self.logical_token_blocks)})')
|
||||||
|
|
||||||
|
|
||||||
class SequenceGroup:
|
class SequenceGroup:
|
||||||
|
|
||||||
@ -65,11 +69,9 @@ class SequenceGroup:
|
|||||||
self,
|
self,
|
||||||
group_id: int,
|
group_id: int,
|
||||||
seqs: List[Sequence],
|
seqs: List[Sequence],
|
||||||
decoding_params: DecodingParams,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.group_id = group_id
|
self.group_id = group_id
|
||||||
self.seqs = seqs
|
self.seqs = seqs
|
||||||
self.decoding_params = decoding_params
|
|
||||||
|
|
||||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||||
if status is None:
|
if status is None:
|
||||||
@ -82,3 +84,7 @@ class SequenceGroup:
|
|||||||
if seq.seq_id == seq_id:
|
if seq.seq_id == seq_id:
|
||||||
return seq
|
return seq
|
||||||
raise ValueError(f'Sequence {seq_id} not found.')
|
raise ValueError(f'Sequence {seq_id} not found.')
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f'SequenceGroup(group_id={self.group_id}, '
|
||||||
|
f'num_seqs={len(self.seqs)})')
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user