diff --git a/cacheflow/worker/cache_engine.py b/cacheflow/worker/cache_engine.py index 7c63c563..fe7b562e 100644 --- a/cacheflow/worker/cache_engine.py +++ b/cacheflow/worker/cache_engine.py @@ -42,7 +42,7 @@ class CacheEngine: # Initialize the events for stream synchronization. self.events = [torch.cuda.Event() for _ in range(self.num_layers)] - def get_key_block_shape(self) -> Tuple[int, int, int, int, int]: + def get_key_block_shape(self) -> Tuple[int, int, int, int]: element_size = torch.tensor([], dtype=self.dtype).element_size() x = 16 // element_size return ( @@ -52,7 +52,7 @@ class CacheEngine: x, ) - def get_value_block_shape(self) -> Tuple[int, int, int, int]: + def get_value_block_shape(self) -> Tuple[int, int, int]: return ( self.num_heads, self.block_size,