Fix attention

This commit is contained in:
Woosuk Kwon 2023-02-23 21:32:02 +00:00
parent 1ce1333573
commit 87e0bcd426

View File

@ -44,19 +44,18 @@ class OPTCacheFlowAttention(nn.Module):
# FIXME(woosuk): Replace the following with a custom op.
for i in range(input_metadata.num_generation_tokens):
q = query[i]
q = query[i].unsqueeze(0)
block_table = block_tables[i]
context_len = int(input_metadata.context_lens[i])
keys = []
for j in range(context_len):
block_number = block_table[j // block_size]
block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :]
k = k.view(num_heads, head_size)
k = k.reshape(num_heads, head_size)
keys.append(k)
keys = torch.stack(keys, dim=-1)
logits = q @ keys
attention_weights = torch.softmax(logits, dim=-1)
keys = torch.stack(keys, dim=0)
values = []
for j in range(context_len):
@ -64,8 +63,14 @@ class OPTCacheFlowAttention(nn.Module):
block_offset = j % block_size
v = value_cache[block_number, :, block_offset, :]
values.append(v)
values = torch.stack(values, dim=-1)
out = attention_weights @ values
values = torch.stack(values, dim=0)
q = q.unsqueeze(0)
keys = keys.unsqueeze(0)
values = values.unsqueeze(0)
out = xops.memory_efficient_attention(
q, keys, values, scale=self.scale)
out = out.view(num_heads, head_size)
output[i].copy_(out, non_blocking=True)
def forward(