diff --git a/README.md b/README.md index 9d81bab..9f458c4 100644 --- a/README.md +++ b/README.md @@ -212,7 +212,7 @@ def flash_attn_with_kvcache( rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the KV cache. - block_table [optional]: (num_blocks, max_num_blocks_per_seq), dtype torch.int32. + block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. If the indices are not distinct, and k and v are provided, the values updated in the cache diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 2116d59..a1ef865 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -1149,7 +1149,7 @@ def flash_attn_with_kvcache( rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the KV cache. - block_table [optional]: (num_blocks, max_num_blocks_per_seq), dtype torch.int32. + block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. If the indices are not distinct, and k and v are provided, the values updated in the cache