From ba84b8728a8d0a766a636b30661836c30b17fbe6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 23 Feb 2023 22:29:46 +0000 Subject: [PATCH] Fix attention --- cacheflow/models/attention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index ac02c2fd..f6d48a3a 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn as nn @@ -24,8 +24,12 @@ class OPTCacheFlowAttention(nn.Module): key: torch.Tensor, value: torch.Tensor, ) -> None: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) out = xops.memory_efficient_attention( query, key, value, attn_bias=self.attention_mask, scale=self.scale) + out = out.squeeze(0) # FIXME(woosuk): Directly write the attention output. output.copy_(out, non_blocking=True)