Fix test with alibi and cache_leftpad
This commit is contained in:
parent
4488acee8d
commit
299563626f
@ -27,7 +27,7 @@ is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
|
|||||||
|
|
||||||
|
|
||||||
def attn_bias_from_alibi_slopes(
|
def attn_bias_from_alibi_slopes(
|
||||||
slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False
|
slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None
|
||||||
):
|
):
|
||||||
batch, nheads = slopes.shape
|
batch, nheads = slopes.shape
|
||||||
device = slopes.device
|
device = slopes.device
|
||||||
@ -37,6 +37,10 @@ def attn_bias_from_alibi_slopes(
|
|||||||
else:
|
else:
|
||||||
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
||||||
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
||||||
|
if key_leftpad is not None:
|
||||||
|
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
|
||||||
|
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
|
||||||
|
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
|
||||||
sk = (
|
sk = (
|
||||||
seqlen_k
|
seqlen_k
|
||||||
if key_padding_mask is None
|
if key_padding_mask is None
|
||||||
@ -1993,7 +1997,7 @@ def test_flash_attn_kvcache(
|
|||||||
if alibi:
|
if alibi:
|
||||||
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
|
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
|
||||||
attn_bias = attn_bias_from_alibi_slopes(
|
attn_bias = attn_bias_from_alibi_slopes(
|
||||||
alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal
|
alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
alibi_slopes, attn_bias = None, None
|
alibi_slopes, attn_bias = None, None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user