Fix attention
This commit is contained in:
parent
87e0bcd426
commit
ba84b8728a
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -24,8 +24,12 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
) -> None:
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
out = xops.memory_efficient_attention(
|
||||
query, key, value, attn_bias=self.attention_mask, scale=self.scale)
|
||||
out = out.squeeze(0)
|
||||
# FIXME(woosuk): Directly write the attention output.
|
||||
output.copy_(out, non_blocking=True)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user