diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index c36f06c7..ac02c2fd 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -44,19 +44,18 @@ class OPTCacheFlowAttention(nn.Module): # FIXME(woosuk): Replace the following with a custom op. for i in range(input_metadata.num_generation_tokens): - q = query[i] + q = query[i].unsqueeze(0) block_table = block_tables[i] context_len = int(input_metadata.context_lens[i]) + keys = [] for j in range(context_len): block_number = block_table[j // block_size] block_offset = j % block_size k = key_cache[block_number, :, :, block_offset, :] - k = k.view(num_heads, head_size) + k = k.reshape(num_heads, head_size) keys.append(k) - keys = torch.stack(keys, dim=-1) - logits = q @ keys - attention_weights = torch.softmax(logits, dim=-1) + keys = torch.stack(keys, dim=0) values = [] for j in range(context_len): @@ -64,8 +63,14 @@ class OPTCacheFlowAttention(nn.Module): block_offset = j % block_size v = value_cache[block_number, :, block_offset, :] values.append(v) - values = torch.stack(values, dim=-1) - out = attention_weights @ values + values = torch.stack(values, dim=0) + + q = q.unsqueeze(0) + keys = keys.unsqueeze(0) + values = values.unsqueeze(0) + out = xops.memory_efficient_attention( + q, keys, values, scale=self.scale) + out = out.view(num_heads, head_size) output[i].copy_(out, non_blocking=True) def forward(