Remove xformers

This commit is contained in:
Woosuk Kwon 2023-02-24 08:36:16 +00:00
parent afdbe5d373
commit 7f22f90e8c

View File

@ -2,7 +2,6 @@ from typing import Optional
import torch
import torch.nn as nn
import xformers.ops as xops
from cacheflow import ops
from cacheflow.models import InputMetadata
@ -14,8 +13,20 @@ class OPTCacheFlowAttention(nn.Module):
super().__init__()
self.scale = scale
# Shape-agnostic attention mask.
self.attention_mask = xops.LowerTriangularMask()
def _masked_attention(
self,
query: torch.Tensor, # [num_queries, num_heads, head_size]
key: torch.Tensor, # [num_keys, num_heads, head_size]
value: torch.Tensor, # [num_keys, num_heads, head_size]
attn_mask: Optional[torch.Tensor] = None, # [num_queries, num_keys]
) -> torch.Tensor: # [num_queries, num_heads, head_size]
query = query * self.scale
attn = torch.einsum('qhd,khd->hqk', query, key)
if attn_mask is not None:
attn = attn + attn_mask
attn = torch.softmax(attn, dim=-1)
out = torch.einsum('hqk,khd->qhd', attn, value)
return out
def multi_query_kv_attention(
self,
@ -24,13 +35,11 @@ class OPTCacheFlowAttention(nn.Module):
key: torch.Tensor,
value: torch.Tensor,
) -> None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention(
query, key, value, attn_bias=self.attention_mask, scale=self.scale)
out = out.squeeze(0)
# FIXME(woosuk): Directly write the attention output.
# FIXME(woosuk): Replace this with a custom op call.
attention_mask = torch.triu(
torch.ones(query.shape[0], key.shape[0]), diagonal=1) * -1e5
attention_mask = attention_mask.to(dtype=query.dtype, device=query.device)
out = self._masked_attention(query, key, value, attention_mask)
output.copy_(out, non_blocking=True)
def single_query_cached_kv_attention(
@ -64,15 +73,10 @@ class OPTCacheFlowAttention(nn.Module):
v = value_cache[block_number, :, block_offset, :]
values.append(v)
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)
q = q.unsqueeze(0)
keys = keys.unsqueeze(0)
values = values.unsqueeze(0)
out = xops.memory_efficient_attention(
q, keys, values, scale=self.scale)
out = self._masked_attention(q, keys, values)
out = out.view(num_heads, head_size)
output[i].copy_(out, non_blocking=True)