Change causal for CrossAttention in mha.py to align to bottom right

This commit is contained in:
Tri Dao 2023-08-26 12:57:33 -07:00
parent 9b713872ea
commit a2974e850a

View File

@ -271,13 +271,18 @@ class CrossAttention(nn.Module):
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(
torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1
# causal mask needs to take into account the difference between seqlen_q and seqlen_k
row_idx = rearrange(
torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
causal_mask = col_idx > row_idx + sk - seqlen_q
scores = scores.masked_fill(causal_mask, -10000.0)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
attention_drop = self.drop(attention)
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
@ -627,10 +632,7 @@ class MHA(nn.Module):
else:
q = qkv[:, :, 0]
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
context = self.inner_cross_attn(q, kv)
else:
context = self._apply_rotary_single_query_attention(qkv, inference_params)
else:
@ -677,10 +679,7 @@ class MHA(nn.Module):
)
else:
kv = self._update_kv_cache(kv, inference_params)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
context = self.inner_cross_attn(q, kv)
else:
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
@ -869,10 +868,7 @@ class ParallelMHA(nn.Module):
else:
q = qkv[:, :, 0]
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
context = self.inner_cross_attn(q, kv)
else:
context = self._apply_rotary_single_query_attention(qkv, inference_params)
else:
@ -903,10 +899,7 @@ class ParallelMHA(nn.Module):
)
else:
kv = self._update_kv_cache(kv, inference_params)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
context = self.inner_cross_attn(q, kv)
else:
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
context = rearrange(context, "b s h d -> b s (h d)")