diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 0853c4a..ae49728 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -1,6 +1,7 @@ +from typing import Optional, Union + import torch import torch.nn as nn -from einops import rearrange # isort: off # We need to import the CUDA kernels after importing torch @@ -799,7 +800,7 @@ def flash_attn_with_kvcache( v_cache, k=None, v=None, - cache_seqlens=None, + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, softmax_scale=None, causal=False, num_splits=0, @@ -840,7 +841,8 @@ def flash_attn_with_kvcache( k [optional]: (batch_size, seqlen, nheads_k, headdim). If not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens. v [optional]: (batch_size, seqlen, nheads_k, headdim). Similar to k. - cache_seqlens: (batch_size,), dtype torch.int32. The sequence lengths of the KV cache. + cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the + KV cache. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). @@ -858,6 +860,10 @@ def flash_attn_with_kvcache( q, k, v = [maybe_contiguous(x) for x in (q, k, v)] if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + if cache_seqlens is not None and isinstance(cache_seqlens, int): + cache_seqlens = torch.full( + (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + ) out, softmax_lse = flash_attn_cuda.fwd_kvcache( q, k_cache, v_cache, k, v, cache_seqlens, None, softmax_scale, causal, num_splits ) diff --git a/tests/models/test_gpt.py b/tests/models/test_gpt.py index 8f74e93..a9a827d 100644 --- a/tests/models/test_gpt.py +++ b/tests/models/test_gpt.py @@ -3,7 +3,12 @@ import re import pytest import torch from einops import rearrange -from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2, shard_state_dict_tp, combine_state_dicts_tp +from flash_attn.models.gpt import ( + GPTLMHeadModel, + remap_state_dict_hf_gpt2, + shard_state_dict_tp, + combine_state_dicts_tp, +) from flash_attn.utils.generation import InferenceParams from flash_attn.utils.pretrained import state_dict_from_pretrained from transformers import GPT2Config, GPT2Tokenizer @@ -130,9 +135,9 @@ def test_gpt2_optimized(model_name): @pytest.mark.parametrize("fused_ft_kernel", [False, True]) -# @pytest.mark.parametrize('fused_ft_kernel', [True]) +# @pytest.mark.parametrize('fused_ft_kernel', [False]) @pytest.mark.parametrize("optimized", [False, True]) -# @pytest.mark.parametrize('optimized', [False]) +# @pytest.mark.parametrize('optimized', [True]) @pytest.mark.parametrize("rotary", [False, True]) # @pytest.mark.parametrize('rotary', [False]) @pytest.mark.parametrize("model_name", ["gpt2"]) @@ -204,7 +209,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel): ) print(out.sequences) print(tokenizer.batch_decode(out.sequences.tolist())) - if fused_ft_kernel: + if fused_ft_kernel or config.use_flash_attn: out_cg = model.generate( input_ids=input_ids, max_length=max_length, @@ -263,7 +268,6 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs): out = model.generate( input_ids=input_ids, max_length=max_length, - fused_ft_kernel=True, teacher_outputs=teacher_outputs, return_dict_in_generate=True, output_scores=True, @@ -277,8 +281,9 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs): # @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)]) @pytest.mark.parametrize("rotary", [None, "interleaved", "block"]) # @pytest.mark.parametrize('rotary', [None]) +@pytest.mark.parametrize("fused_ft_kernel", [False, True]) @pytest.mark.parametrize("model_name", ["gpt2"]) -def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen): +def test_gpt2_generation_cg(model_name, fused_ft_kernel, rotary, seqlen, maxlen): """Check that decoding with CUDA graph is the same as decoding without CUDA graph.""" dtype = torch.float16 device = "cuda" @@ -308,8 +313,17 @@ def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen): 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device ) - logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) - logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) + logits = get_logits( + model, input_ids, maxlen, teacher_outputs=teacher_outputs, fused_ft_kernel=fused_ft_kernel + ) + logits_cg = get_logits( + model, + input_ids, + maxlen, + teacher_outputs=teacher_outputs, + fused_ft_kernel=fused_ft_kernel, + cg=True, + ) assert torch.equal(logits, logits_cg) # Try increasing batch size and seqlen, then decrease them to see if it's still correct @@ -446,11 +460,14 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg): print(tokenizer.batch_decode(out_og.sequences)) -@pytest.mark.parametrize("n_heads_q_kv", [ - (8, 8), # Regular attention - (8, 4), # GQA - (8, 2), # MQA -]) +@pytest.mark.parametrize( + "n_heads_q_kv", + [ + (8, 8), # Regular attention + (8, 4), # GQA + (8, 2), # MQA + ], +) def test_gpt2_shard_unshard(n_heads_q_kv): world_size = 2