From a2974e850ae0782e125ebedfee55a4e6d210c620 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 26 Aug 2023 12:57:33 -0700 Subject: [PATCH] Change causal for CrossAttention in mha.py to align to bottom right --- flash_attn/modules/mha.py | 37 +++++++++++++++---------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 9a19006..d818ce4 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -271,13 +271,18 @@ class CrossAttention(nn.Module): # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") if causal: - # "triu_tril_cuda_template" not implemented for 'BFloat16' - # So we have to construct the mask in float - causal_mask = torch.triu( - torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1 + # causal mask needs to take into account the difference between seqlen_q and seqlen_k + row_idx = rearrange( + torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1" ) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + causal_mask.to(dtype=scores.dtype) + col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + causal_mask = col_idx > row_idx + sk - seqlen_q + scores = scores.masked_fill(causal_mask, -10000.0) attention = torch.softmax(scores, dim=-1, dtype=v.dtype) attention_drop = self.drop(attention) output = torch.einsum("bhts,bshd->bthd", attention_drop, v) @@ -627,10 +632,7 @@ class MHA(nn.Module): else: q = qkv[:, :, 0] kv = self._update_kv_cache(qkv[:, :, 1:], inference_params) - # If we're processing the prompt, causal=None (use self.causal). - # If we're decoding, then causal=False. - causal = None if inference_params.sequence_len_offset == 0 else False - context = self.inner_cross_attn(q, kv, causal=causal) + context = self.inner_cross_attn(q, kv) else: context = self._apply_rotary_single_query_attention(qkv, inference_params) else: @@ -677,10 +679,7 @@ class MHA(nn.Module): ) else: kv = self._update_kv_cache(kv, inference_params) - # If we're processing the prompt, causal=None (use self.causal). - # If we're decoding, then causal=False. - causal = None if inference_params.sequence_len_offset == 0 else False - context = self.inner_cross_attn(q, kv, causal=causal) + context = self.inner_cross_attn(q, kv) else: context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv) out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) @@ -869,10 +868,7 @@ class ParallelMHA(nn.Module): else: q = qkv[:, :, 0] kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) - # If we're processing the prompt, causal=None (use self.causal). - # If we're decoding, then causal=False. - causal = None if inference_params.sequence_len_offset == 0 else False - context = self.inner_cross_attn(q, kv, causal=causal) + context = self.inner_cross_attn(q, kv) else: context = self._apply_rotary_single_query_attention(qkv, inference_params) else: @@ -903,10 +899,7 @@ class ParallelMHA(nn.Module): ) else: kv = self._update_kv_cache(kv, inference_params) - # If we're processing the prompt, causal=None (use self.causal). - # If we're decoding, then causal=False. - causal = None if inference_params.sequence_len_offset == 0 else False - context = self.inner_cross_attn(q, kv, causal=causal) + context = self.inner_cross_attn(q, kv) else: context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv) context = rearrange(context, "b s h d -> b s (h d)")