diff --git a/cacheflow/models/input_metadata.py b/cacheflow/models/input_metadata.py index 5963e67e..b311e203 100644 --- a/cacheflow/models/input_metadata.py +++ b/cacheflow/models/input_metadata.py @@ -23,6 +23,9 @@ class InputMetadata: self.num_prompts = len(prompt_lens) self.num_generation_tokens = context_lens.shape[0] - self.max_num_blocks_per_seq = block_tables.shape[1] + if block_tables.numel() > 0: + self.max_num_blocks_per_seq = block_tables.shape[1] + else: + self.max_num_blocks_per_seq = 0 assert self.num_generation_tokens == block_tables.shape[0] assert self.num_prompts + self.num_generation_tokens == len(seq_ids)