Fix attention
This commit is contained in:
parent
1ce1333573
commit
87e0bcd426
@ -44,19 +44,18 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
|
||||
# FIXME(woosuk): Replace the following with a custom op.
|
||||
for i in range(input_metadata.num_generation_tokens):
|
||||
q = query[i]
|
||||
q = query[i].unsqueeze(0)
|
||||
block_table = block_tables[i]
|
||||
context_len = int(input_metadata.context_lens[i])
|
||||
|
||||
keys = []
|
||||
for j in range(context_len):
|
||||
block_number = block_table[j // block_size]
|
||||
block_offset = j % block_size
|
||||
k = key_cache[block_number, :, :, block_offset, :]
|
||||
k = k.view(num_heads, head_size)
|
||||
k = k.reshape(num_heads, head_size)
|
||||
keys.append(k)
|
||||
keys = torch.stack(keys, dim=-1)
|
||||
logits = q @ keys
|
||||
attention_weights = torch.softmax(logits, dim=-1)
|
||||
keys = torch.stack(keys, dim=0)
|
||||
|
||||
values = []
|
||||
for j in range(context_len):
|
||||
@ -64,8 +63,14 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
block_offset = j % block_size
|
||||
v = value_cache[block_number, :, block_offset, :]
|
||||
values.append(v)
|
||||
values = torch.stack(values, dim=-1)
|
||||
out = attention_weights @ values
|
||||
values = torch.stack(values, dim=0)
|
||||
|
||||
q = q.unsqueeze(0)
|
||||
keys = keys.unsqueeze(0)
|
||||
values = values.unsqueeze(0)
|
||||
out = xops.memory_efficient_attention(
|
||||
q, keys, values, scale=self.scale)
|
||||
out = out.view(num_heads, head_size)
|
||||
output[i].copy_(out, non_blocking=True)
|
||||
|
||||
def forward(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user