Fix return type error

This commit is contained in:
Woosuk Kwon 2023-02-16 01:33:03 +00:00
parent 8edcabc737
commit 9e68a6827e

View File

@ -42,7 +42,7 @@ class CacheEngine:
# Initialize the events for stream synchronization.
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
def get_key_block_shape(self) -> Tuple[int, int, int, int, int]:
def get_key_block_shape(self) -> Tuple[int, int, int, int]:
element_size = torch.tensor([], dtype=self.dtype).element_size()
x = 16 // element_size
return (
@ -52,7 +52,7 @@ class CacheEngine:
x,
)
def get_value_block_shape(self) -> Tuple[int, int, int, int]:
def get_value_block_shape(self) -> Tuple[int, int, int]:
return (
self.num_heads,
self.block_size,