Minor
This commit is contained in:
parent
9e68a6827e
commit
a1c67e6db8
@ -20,7 +20,8 @@ class CacheEngine:
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
if head_size % 16 != 0:
|
||||
raise ValueError(f'head_size ({head_size}) must be a multiple of 16.')
|
||||
raise ValueError(
|
||||
f'head_size ({head_size}) must be a multiple of 16.')
|
||||
|
||||
self.worker_id = worker_id
|
||||
self.gpu_id = gpu_id
|
||||
@ -40,7 +41,7 @@ class CacheEngine:
|
||||
self.cache_stream = torch.cuda.Stream(device=gpu_id)
|
||||
assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
|
||||
# Initialize the events for stream synchronization.
|
||||
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
|
||||
self.events = [torch.cuda.Event() for _ in range(num_layers)]
|
||||
|
||||
def get_key_block_shape(self) -> Tuple[int, int, int, int]:
|
||||
element_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user