Change causal for CrossAttention in mha.py to align to bottom right
This commit is contained in:
parent
9b713872ea
commit
a2974e850a
@ -271,13 +271,18 @@ class CrossAttention(nn.Module):
|
||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
||||
if causal:
|
||||
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
||||
# So we have to construct the mask in float
|
||||
causal_mask = torch.triu(
|
||||
torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1
|
||||
# causal mask needs to take into account the difference between seqlen_q and seqlen_k
|
||||
row_idx = rearrange(
|
||||
torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
|
||||
)
|
||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||
scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||
col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
|
||||
sk = (
|
||||
seqlen_k
|
||||
if key_padding_mask is None
|
||||
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
||||
)
|
||||
causal_mask = col_idx > row_idx + sk - seqlen_q
|
||||
scores = scores.masked_fill(causal_mask, -10000.0)
|
||||
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
||||
attention_drop = self.drop(attention)
|
||||
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
|
||||
@ -627,10 +632,7 @@ class MHA(nn.Module):
|
||||
else:
|
||||
q = qkv[:, :, 0]
|
||||
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = None if inference_params.sequence_len_offset == 0 else False
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
context = self.inner_cross_attn(q, kv)
|
||||
else:
|
||||
context = self._apply_rotary_single_query_attention(qkv, inference_params)
|
||||
else:
|
||||
@ -677,10 +679,7 @@ class MHA(nn.Module):
|
||||
)
|
||||
else:
|
||||
kv = self._update_kv_cache(kv, inference_params)
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = None if inference_params.sequence_len_offset == 0 else False
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
context = self.inner_cross_attn(q, kv)
|
||||
else:
|
||||
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
|
||||
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
||||
@ -869,10 +868,7 @@ class ParallelMHA(nn.Module):
|
||||
else:
|
||||
q = qkv[:, :, 0]
|
||||
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = None if inference_params.sequence_len_offset == 0 else False
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
context = self.inner_cross_attn(q, kv)
|
||||
else:
|
||||
context = self._apply_rotary_single_query_attention(qkv, inference_params)
|
||||
else:
|
||||
@ -903,10 +899,7 @@ class ParallelMHA(nn.Module):
|
||||
)
|
||||
else:
|
||||
kv = self._update_kv_cache(kv, inference_params)
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = None if inference_params.sequence_len_offset == 0 else False
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
context = self.inner_cross_attn(q, kv)
|
||||
else:
|
||||
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
|
||||
context = rearrange(context, "b s h d -> b s (h d)")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user