Fix attention

This commit is contained in:
Woosuk Kwon 2023-02-23 23:02:25 +00:00
parent ba84b8728a
commit 932844f1cd
2 changed files with 21 additions and 6 deletions

View File

@ -53,20 +53,19 @@ class OPTCacheFlowAttention(nn.Module):
context_len = int(input_metadata.context_lens[i])
keys = []
values = []
for j in range(context_len):
block_number = block_table[j // block_size]
block_number = int(block_table[j // block_size])
block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_heads, head_size)
keys.append(k)
keys = torch.stack(keys, dim=0)
values = []
for j in range(context_len):
block_number = block_table[j // block_size]
block_offset = j % block_size
v = value_cache[block_number, :, block_offset, :]
values.append(v)
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)
q = q.unsqueeze(0)
@ -87,6 +86,11 @@ class OPTCacheFlowAttention(nn.Module):
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Prune out invalid tokens.
query = query[:input_metadata.num_valid_tokens]
key = key[:input_metadata.num_valid_tokens]
value = value[:input_metadata.num_valid_tokens]
# Reshape the input tensors.
num_heads = value_cache.shape[1]
head_size = value_cache.shape[3]

View File

@ -11,6 +11,7 @@ class InputMetadata:
prompt_lens: List[int],
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
# FIXME: Rename
max_context_len: int,
block_tables: torch.Tensor,
) -> None:
@ -23,9 +24,19 @@ class InputMetadata:
self.num_prompts = len(prompt_lens)
self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = len(slot_mapping)
if block_tables.numel() > 0:
self.max_num_blocks_per_seq = block_tables.shape[1]
else:
self.max_num_blocks_per_seq = 0
assert self.num_generation_tokens == block_tables.shape[0]
assert self.num_prompts + self.num_generation_tokens == len(seq_ids)
def __repr__(self) -> str:
return (f'InputMetadata('
f'seq_ids={self.seq_ids}, '
f'num_prompts={self.num_prompts}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'num_valid_tokens={self.num_valid_tokens}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'max_context_len={self.max_context_len})')