diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index ac02c2fd..f6d48a3a 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn as nn @@ -24,8 +24,12 @@ class OPTCacheFlowAttention(nn.Module): key: torch.Tensor, value: torch.Tensor, ) -> None: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) out = xops.memory_efficient_attention( query, key, value, attn_bias=self.attention_mask, scale=self.scale) + out = out.squeeze(0) # FIXME(woosuk): Directly write the attention output. output.copy_(out, non_blocking=True)