Support cache_seqlens being integer
This commit is contained in:
parent
913922cac5
commit
fd20f16a4e
@ -1,6 +1,7 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
# isort: off
|
||||
# We need to import the CUDA kernels after importing torch
|
||||
@ -799,7 +800,7 @@ def flash_attn_with_kvcache(
|
||||
v_cache,
|
||||
k=None,
|
||||
v=None,
|
||||
cache_seqlens=None,
|
||||
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
num_splits=0,
|
||||
@ -840,7 +841,8 @@ def flash_attn_with_kvcache(
|
||||
k [optional]: (batch_size, seqlen, nheads_k, headdim). If not None, we concatenate k with
|
||||
k_cache, starting at the indices specified by cache_seqlens.
|
||||
v [optional]: (batch_size, seqlen, nheads_k, headdim). Similar to k.
|
||||
cache_seqlens: (batch_size,), dtype torch.int32. The sequence lengths of the KV cache.
|
||||
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
|
||||
KV cache.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
@ -858,6 +860,10 @@ def flash_attn_with_kvcache(
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
if cache_seqlens is not None and isinstance(cache_seqlens, int):
|
||||
cache_seqlens = torch.full(
|
||||
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
|
||||
)
|
||||
out, softmax_lse = flash_attn_cuda.fwd_kvcache(
|
||||
q, k_cache, v_cache, k, v, cache_seqlens, None, softmax_scale, causal, num_splits
|
||||
)
|
||||
|
||||
@ -3,7 +3,12 @@ import re
|
||||
import pytest
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2, shard_state_dict_tp, combine_state_dicts_tp
|
||||
from flash_attn.models.gpt import (
|
||||
GPTLMHeadModel,
|
||||
remap_state_dict_hf_gpt2,
|
||||
shard_state_dict_tp,
|
||||
combine_state_dicts_tp,
|
||||
)
|
||||
from flash_attn.utils.generation import InferenceParams
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from transformers import GPT2Config, GPT2Tokenizer
|
||||
@ -130,9 +135,9 @@ def test_gpt2_optimized(model_name):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
|
||||
# @pytest.mark.parametrize('fused_ft_kernel', [True])
|
||||
# @pytest.mark.parametrize('fused_ft_kernel', [False])
|
||||
@pytest.mark.parametrize("optimized", [False, True])
|
||||
# @pytest.mark.parametrize('optimized', [False])
|
||||
# @pytest.mark.parametrize('optimized', [True])
|
||||
@pytest.mark.parametrize("rotary", [False, True])
|
||||
# @pytest.mark.parametrize('rotary', [False])
|
||||
@pytest.mark.parametrize("model_name", ["gpt2"])
|
||||
@ -204,7 +209,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
|
||||
)
|
||||
print(out.sequences)
|
||||
print(tokenizer.batch_decode(out.sequences.tolist()))
|
||||
if fused_ft_kernel:
|
||||
if fused_ft_kernel or config.use_flash_attn:
|
||||
out_cg = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
@ -263,7 +268,6 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
|
||||
out = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
fused_ft_kernel=True,
|
||||
teacher_outputs=teacher_outputs,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
@ -277,8 +281,9 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
|
||||
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
|
||||
@pytest.mark.parametrize("rotary", [None, "interleaved", "block"])
|
||||
# @pytest.mark.parametrize('rotary', [None])
|
||||
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
|
||||
@pytest.mark.parametrize("model_name", ["gpt2"])
|
||||
def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen):
|
||||
def test_gpt2_generation_cg(model_name, fused_ft_kernel, rotary, seqlen, maxlen):
|
||||
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
|
||||
dtype = torch.float16
|
||||
device = "cuda"
|
||||
@ -308,8 +313,17 @@ def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen):
|
||||
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
|
||||
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
|
||||
logits = get_logits(
|
||||
model, input_ids, maxlen, teacher_outputs=teacher_outputs, fused_ft_kernel=fused_ft_kernel
|
||||
)
|
||||
logits_cg = get_logits(
|
||||
model,
|
||||
input_ids,
|
||||
maxlen,
|
||||
teacher_outputs=teacher_outputs,
|
||||
fused_ft_kernel=fused_ft_kernel,
|
||||
cg=True,
|
||||
)
|
||||
assert torch.equal(logits, logits_cg)
|
||||
|
||||
# Try increasing batch size and seqlen, then decrease them to see if it's still correct
|
||||
@ -446,11 +460,14 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
|
||||
print(tokenizer.batch_decode(out_og.sequences))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_heads_q_kv", [
|
||||
(8, 8), # Regular attention
|
||||
(8, 4), # GQA
|
||||
(8, 2), # MQA
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"n_heads_q_kv",
|
||||
[
|
||||
(8, 8), # Regular attention
|
||||
(8, 4), # GQA
|
||||
(8, 2), # MQA
|
||||
],
|
||||
)
|
||||
def test_gpt2_shard_unshard(n_heads_q_kv):
|
||||
world_size = 2
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user