From 9e68a6827ef7e08c71b9c4a92cf3c2be7b60fc84 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 16 Feb 2023 01:33:03 +0000 Subject: [PATCH] Fix return type error --- cacheflow/worker/cache_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cacheflow/worker/cache_engine.py b/cacheflow/worker/cache_engine.py index 7c63c563..fe7b562e 100644 --- a/cacheflow/worker/cache_engine.py +++ b/cacheflow/worker/cache_engine.py @@ -42,7 +42,7 @@ class CacheEngine: # Initialize the events for stream synchronization. self.events = [torch.cuda.Event() for _ in range(self.num_layers)] - def get_key_block_shape(self) -> Tuple[int, int, int, int, int]: + def get_key_block_shape(self) -> Tuple[int, int, int, int]: element_size = torch.tensor([], dtype=self.dtype).element_size() x = 16 // element_size return ( @@ -52,7 +52,7 @@ class CacheEngine: x, ) - def get_value_block_shape(self) -> Tuple[int, int, int, int]: + def get_value_block_shape(self) -> Tuple[int, int, int]: return ( self.num_heads, self.block_size,