Fix test with alibi and cache_leftpad

This commit is contained in:
Tri Dao 2024-07-23 02:04:15 -07:00
parent 4488acee8d
commit 299563626f

View File

@ -27,7 +27,7 @@ is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
def attn_bias_from_alibi_slopes(
slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False
slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None
):
batch, nheads = slopes.shape
device = slopes.device
@ -37,6 +37,10 @@ def attn_bias_from_alibi_slopes(
else:
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
if key_leftpad is not None:
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
sk = (
seqlen_k
if key_padding_mask is None
@ -1993,7 +1997,7 @@ def test_flash_attn_kvcache(
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(
alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal
alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad
)
else:
alibi_slopes, attn_bias = None, None