From 299563626fbfcd8345e7da2f4e1bb93886b58341 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 23 Jul 2024 02:04:15 -0700 Subject: [PATCH] Fix test with alibi and cache_leftpad --- tests/test_flash_attn.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 5e3b126..72d5513 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -27,7 +27,7 @@ is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0) def attn_bias_from_alibi_slopes( - slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False + slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None ): batch, nheads = slopes.shape device = slopes.device @@ -37,6 +37,10 @@ def attn_bias_from_alibi_slopes( else: row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) sk = ( seqlen_k if key_padding_mask is None @@ -1993,7 +1997,7 @@ def test_flash_attn_kvcache( if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes( - alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal + alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad ) else: alibi_slopes, attn_bias = None, None