Fix attention

This commit is contained in:
Woosuk Kwon 2023-02-23 22:29:46 +00:00
parent 87e0bcd426
commit ba84b8728a

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.nn as nn
@ -24,8 +24,12 @@ class OPTCacheFlowAttention(nn.Module):
key: torch.Tensor,
value: torch.Tensor,
) -> None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention(
query, key, value, attn_bias=self.attention_mask, scale=self.scale)
out = out.squeeze(0)
# FIXME(woosuk): Directly write the attention output.
output.copy_(out, non_blocking=True)