This commit is contained in:
Woosuk Kwon 2023-02-16 01:42:53 +00:00
parent 9e68a6827e
commit a1c67e6db8

View File

@ -20,7 +20,8 @@ class CacheEngine:
dtype: torch.dtype,
) -> None:
if head_size % 16 != 0:
raise ValueError(f'head_size ({head_size}) must be a multiple of 16.')
raise ValueError(
f'head_size ({head_size}) must be a multiple of 16.')
self.worker_id = worker_id
self.gpu_id = gpu_id
@ -40,7 +41,7 @@ class CacheEngine:
self.cache_stream = torch.cuda.Stream(device=gpu_id)
assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
# Initialize the events for stream synchronization.
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
self.events = [torch.cuda.Event() for _ in range(num_layers)]
def get_key_block_shape(self) -> Tuple[int, int, int, int]:
element_size = torch.tensor([], dtype=self.dtype).element_size()