Minor
This commit is contained in:
parent
9e68a6827e
commit
a1c67e6db8
@ -20,7 +20,8 @@ class CacheEngine:
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
) -> None:
|
) -> None:
|
||||||
if head_size % 16 != 0:
|
if head_size % 16 != 0:
|
||||||
raise ValueError(f'head_size ({head_size}) must be a multiple of 16.')
|
raise ValueError(
|
||||||
|
f'head_size ({head_size}) must be a multiple of 16.')
|
||||||
|
|
||||||
self.worker_id = worker_id
|
self.worker_id = worker_id
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
@ -40,7 +41,7 @@ class CacheEngine:
|
|||||||
self.cache_stream = torch.cuda.Stream(device=gpu_id)
|
self.cache_stream = torch.cuda.Stream(device=gpu_id)
|
||||||
assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
|
assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
|
||||||
# Initialize the events for stream synchronization.
|
# Initialize the events for stream synchronization.
|
||||||
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
|
self.events = [torch.cuda.Event() for _ in range(num_layers)]
|
||||||
|
|
||||||
def get_key_block_shape(self) -> Tuple[int, int, int, int]:
|
def get_key_block_shape(self) -> Tuple[int, int, int, int]:
|
||||||
element_size = torch.tensor([], dtype=self.dtype).element_size()
|
element_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user