Fix attention
This commit is contained in:
parent
87e0bcd426
commit
ba84b8728a
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -24,8 +24,12 @@ class OPTCacheFlowAttention(nn.Module):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
query = query.unsqueeze(0)
|
||||||
|
key = key.unsqueeze(0)
|
||||||
|
value = value.unsqueeze(0)
|
||||||
out = xops.memory_efficient_attention(
|
out = xops.memory_efficient_attention(
|
||||||
query, key, value, attn_bias=self.attention_mask, scale=self.scale)
|
query, key, value, attn_bias=self.attention_mask, scale=self.scale)
|
||||||
|
out = out.squeeze(0)
|
||||||
# FIXME(woosuk): Directly write the attention output.
|
# FIXME(woosuk): Directly write the attention output.
|
||||||
output.copy_(out, non_blocking=True)
|
output.copy_(out, non_blocking=True)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user