Remove xformers
This commit is contained in:
parent
afdbe5d373
commit
7f22f90e8c
@ -2,7 +2,6 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import xformers.ops as xops
|
||||
|
||||
from cacheflow import ops
|
||||
from cacheflow.models import InputMetadata
|
||||
@ -14,8 +13,20 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
|
||||
# Shape-agnostic attention mask.
|
||||
self.attention_mask = xops.LowerTriangularMask()
|
||||
def _masked_attention(
|
||||
self,
|
||||
query: torch.Tensor, # [num_queries, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_keys, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_keys, num_heads, head_size]
|
||||
attn_mask: Optional[torch.Tensor] = None, # [num_queries, num_keys]
|
||||
) -> torch.Tensor: # [num_queries, num_heads, head_size]
|
||||
query = query * self.scale
|
||||
attn = torch.einsum('qhd,khd->hqk', query, key)
|
||||
if attn_mask is not None:
|
||||
attn = attn + attn_mask
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
out = torch.einsum('hqk,khd->qhd', attn, value)
|
||||
return out
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
@ -24,13 +35,11 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
) -> None:
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
out = xops.memory_efficient_attention(
|
||||
query, key, value, attn_bias=self.attention_mask, scale=self.scale)
|
||||
out = out.squeeze(0)
|
||||
# FIXME(woosuk): Directly write the attention output.
|
||||
# FIXME(woosuk): Replace this with a custom op call.
|
||||
attention_mask = torch.triu(
|
||||
torch.ones(query.shape[0], key.shape[0]), diagonal=1) * -1e5
|
||||
attention_mask = attention_mask.to(dtype=query.dtype, device=query.device)
|
||||
out = self._masked_attention(query, key, value, attention_mask)
|
||||
output.copy_(out, non_blocking=True)
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
@ -64,15 +73,10 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
|
||||
v = value_cache[block_number, :, block_offset, :]
|
||||
values.append(v)
|
||||
|
||||
keys = torch.stack(keys, dim=0)
|
||||
values = torch.stack(values, dim=0)
|
||||
|
||||
q = q.unsqueeze(0)
|
||||
keys = keys.unsqueeze(0)
|
||||
values = values.unsqueeze(0)
|
||||
out = xops.memory_efficient_attention(
|
||||
q, keys, values, scale=self.scale)
|
||||
out = self._masked_attention(q, keys, values)
|
||||
out = out.view(num_heads, head_size)
|
||||
output[i].copy_(out, non_blocking=True)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user