diff --git a/cacheflow/worker/cache_engine.py b/cacheflow/worker/cache_engine.py index 4a2fc350..7c63c563 100644 --- a/cacheflow/worker/cache_engine.py +++ b/cacheflow/worker/cache_engine.py @@ -19,6 +19,9 @@ class CacheEngine: num_cpu_blocks: int, dtype: torch.dtype, ) -> None: + if head_size % 16 != 0: + raise ValueError(f'head_size ({head_size}) must be a multiple of 16.') + self.worker_id = worker_id self.gpu_id = gpu_id self.num_layers = num_layers