diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index cb277c4..129d233 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -30,6 +30,11 @@ try: except ImportError: RotaryEmbedding = None +try: + import ft_attention +except ImportError: + ft_attention = None + class FlashSelfAttention(nn.Module): """Implement the scaled dot product attention with softmax. @@ -360,23 +365,32 @@ class MHA(nn.Module): assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor' # Pre-allocate memory for key-values for inference. if self.layer_idx not in inference_params.key_value_memory_dict: - inference_kv_cache = torch.empty( + kv_cache = torch.empty( inference_params.max_batch_size, inference_params.max_sequence_len, 2, self.num_heads, self.head_dim, dtype=kv.dtype, device=kv.device ) - inference_params.key_value_memory_dict[self.layer_idx] = inference_kv_cache + inference_params.key_value_memory_dict[self.layer_idx] = kv_cache else: - inference_kv_cache = inference_params.key_value_memory_dict[self.layer_idx] + assert not inference_params.fused_ft_kernel, 'fused_ft_kernel should not take this path' + kv_cache = inference_params.key_value_memory_dict[self.layer_idx] # Adjust key and value for inference batch_start = inference_params.batch_size_offset batch_end = batch_start + kv.shape[0] - assert batch_end <= inference_kv_cache.shape[0] + assert batch_end <= kv_cache.shape[0] sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + kv.shape[1] - assert sequence_end <= inference_kv_cache.shape[1] + assert sequence_end <= kv_cache.shape[1] # Copy key and values. - inference_kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv - kv = inference_kv_cache[batch_start:batch_end, :sequence_end, ...] + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + kv = kv_cache[batch_start:batch_end, :sequence_end, ...] + if inference_params.fused_ft_kernel: + # FT kernel requires different layouts for the k_cache and v_cache. + assert kv_cache.dtype in [torch.float16, torch.bfloat16, torch.float32] + packsize = 4 if kv_cache.dtype == torch.float32 else 8 + k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize', + packsize=packsize).contiguous() + v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous() + inference_params.key_value_memory_dict[self.layer_idx] = (k_cache, v_cache) return kv def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, @@ -430,14 +444,24 @@ class MHA(nn.Module): else: context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) else: - if self.rotary_emb_dim > 0: - qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset) - q = qkv[:, :, 0] - kv = self._update_kv_cache(qkv[:, :, 1:], inference_params) - # If we're processing the prompt, causal=None (use self.causal). - # If we're decoding, then causal=False. - causal = None if inference_params.sequence_len_offset == 0 else False - context = self.inner_cross_attn(q, kv, causal=causal) + if (not inference_params.fused_ft_kernel) or inference_params.sequence_len_offset == 0: + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset) + q = qkv[:, :, 0] + kv = self._update_kv_cache(qkv[:, :, 1:], inference_params) + # If we're processing the prompt, causal=None (use self.causal). + # If we're decoding, then causal=False. + causal = None if inference_params.sequence_len_offset == 0 else False + context = self.inner_cross_attn(q, kv, causal=causal) + else: + assert ft_attention is not None + context = ft_attention.single_query_attention( + *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), + *inference_params.key_value_memory_dict[self.layer_idx], + inference_params.lengths_per_sample, inference_params.sequence_len_offset, + self.rotary_emb_dim + ) + context = rearrange(context, 'b h d -> b 1 h d') else: if not self.return_residual: q = self.Wq(x) diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index c578f37..e50fd27 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -1,7 +1,10 @@ # Copyright (c) 2022, Tri Dao. # Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31 +from typing import Optional + from dataclasses import dataclass, field import torch +from torch import Tensor from einops import rearrange @@ -17,9 +20,11 @@ class InferenceParams: sequence_len_offset: int = 0 batch_size_offset: int = 0 key_value_memory_dict: dict = field(default_factory=dict) + fused_ft_kernel: bool = False + lengths_per_sample: Optional[Tensor] = None -def greedy_decode(input_ids, model, max_length): +def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True): """Greedy decoding. This is a very simple implementation. We assume that all sequences in the same batch have the same length. Arguments: @@ -30,7 +35,8 @@ def greedy_decode(input_ids, model, max_length): scores: tuples of (batch, vocab_size) """ batch_size, seqlen_og = input_ids.shape - inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size) + inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size, + fused_ft_kernel=fused_ft_kernel) scores = [] with torch.inference_mode(): logits = model(input_ids, inference_params=inference_params).logits[:, -1] @@ -57,8 +63,9 @@ def greedy_decode(input_ids, model, max_length): class GenerationMixin: - def generate(self, input_ids, max_length, return_dict_in_generate=False, output_scores=False): - output = greedy_decode(input_ids, self, max_length) + def generate(self, input_ids, max_length, return_dict_in_generate=False, output_scores=False, + **kwargs): + output = greedy_decode(input_ids, self, max_length, **kwargs) if not output_scores: output.scores = None return output if return_dict_in_generate else output.sequences diff --git a/tests/models/test_gpt_generation.py b/tests/models/test_gpt_generation.py index 0aba58f..1a28687 100644 --- a/tests/models/test_gpt_generation.py +++ b/tests/models/test_gpt_generation.py @@ -15,10 +15,11 @@ from flash_attn.utils.generation import greedy_decode # TODO: test with rotary embedding +@pytest.mark.parametrize('fused_ft_kernel', [False, True]) @pytest.mark.parametrize('optimized', [False, True]) -# @pytest.mark.parametrize('optimized', [False]) +# @pytest.mark.parametrize('optimized', [True]) @pytest.mark.parametrize('model_name', ["gpt2"]) -def test_greedy_decode(model_name, optimized): +def test_greedy_decode(model_name, optimized, fused_ft_kernel): """Check that our implementation of GPT2 generation matches the HF implementation: the scores in fp16 should be around the same as the HF scores in fp16, when compared to the HF scores in fp32. @@ -62,6 +63,7 @@ def test_greedy_decode(model_name, optimized): scores = tuple(scores) out = model.generate(input_ids=input_ids, max_length=max_length, + fused_ft_kernel=fused_ft_kernel, return_dict_in_generate=True, output_scores=True) out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,