From 7f22f90e8cb423fdaa35203d41badd734d9c2e86 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 24 Feb 2023 08:36:16 +0000 Subject: [PATCH] Remove xformers --- cacheflow/models/attention.py | 36 +++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 6e6b8e98..068c7d49 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -2,7 +2,6 @@ from typing import Optional import torch import torch.nn as nn -import xformers.ops as xops from cacheflow import ops from cacheflow.models import InputMetadata @@ -14,8 +13,20 @@ class OPTCacheFlowAttention(nn.Module): super().__init__() self.scale = scale - # Shape-agnostic attention mask. - self.attention_mask = xops.LowerTriangularMask() + def _masked_attention( + self, + query: torch.Tensor, # [num_queries, num_heads, head_size] + key: torch.Tensor, # [num_keys, num_heads, head_size] + value: torch.Tensor, # [num_keys, num_heads, head_size] + attn_mask: Optional[torch.Tensor] = None, # [num_queries, num_keys] + ) -> torch.Tensor: # [num_queries, num_heads, head_size] + query = query * self.scale + attn = torch.einsum('qhd,khd->hqk', query, key) + if attn_mask is not None: + attn = attn + attn_mask + attn = torch.softmax(attn, dim=-1) + out = torch.einsum('hqk,khd->qhd', attn, value) + return out def multi_query_kv_attention( self, @@ -24,13 +35,11 @@ class OPTCacheFlowAttention(nn.Module): key: torch.Tensor, value: torch.Tensor, ) -> None: - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - out = xops.memory_efficient_attention( - query, key, value, attn_bias=self.attention_mask, scale=self.scale) - out = out.squeeze(0) - # FIXME(woosuk): Directly write the attention output. + # FIXME(woosuk): Replace this with a custom op call. + attention_mask = torch.triu( + torch.ones(query.shape[0], key.shape[0]), diagonal=1) * -1e5 + attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) + out = self._masked_attention(query, key, value, attention_mask) output.copy_(out, non_blocking=True) def single_query_cached_kv_attention( @@ -64,15 +73,10 @@ class OPTCacheFlowAttention(nn.Module): v = value_cache[block_number, :, block_offset, :] values.append(v) - keys = torch.stack(keys, dim=0) values = torch.stack(values, dim=0) - q = q.unsqueeze(0) - keys = keys.unsqueeze(0) - values = values.unsqueeze(0) - out = xops.memory_efficient_attention( - q, keys, values, scale=self.scale) + out = self._masked_attention(q, keys, values) out = out.view(num_heads, head_size) output[i].copy_(out, non_blocking=True)