diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index c9fb5a0..e6477cc 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -359,7 +359,7 @@ class MHA(nn.Module): self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs) def _update_kv_cache(self, kv, inference_params): - """kv: (batch_size, 1, nheads, head_dim) + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) """ assert not self.dwconv, 'Generation does not support dwconv yet' assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor' @@ -371,27 +371,46 @@ class MHA(nn.Module): ) inference_params.key_value_memory_dict[self.layer_idx] = kv_cache else: - assert not inference_params.fused_ft_kernel, 'fused_ft_kernel should not take this path' - kv_cache = inference_params.key_value_memory_dict[self.layer_idx] + if not inference_params.fused_ft_kernel: + kv_cache = inference_params.key_value_memory_dict[self.layer_idx] + else: + # For FT, k_cache has shape (b, h, headdim / packsize, s, packsize) + # where packsize = 4 if fp32, 8 if fp16 or bf16. + # v_cache has shape (b, h, s, headdim) + k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx] + kv_cache = None # Adjust key and value for inference batch_start = inference_params.batch_size_offset batch_end = batch_start + kv.shape[0] - assert batch_end <= kv_cache.shape[0] sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + kv.shape[1] - assert sequence_end <= kv_cache.shape[1] + assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) + assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) # Copy key and values. - kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv - kv = kv_cache[batch_start:batch_end, :sequence_end, ...] - if inference_params.fused_ft_kernel: + if not inference_params.fused_ft_kernel: + assert kv_cache is not None + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + kv = kv_cache[batch_start:batch_end, :sequence_end, ...] + return kv + else: + assert inference_params.sequence_len_offset == 0 # FT kernel requires different layouts for the k_cache and v_cache. - assert kv_cache.dtype in [torch.float16, torch.bfloat16, torch.float32] - packsize = 4 if kv_cache.dtype == torch.float32 else 8 - k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize', - packsize=packsize).contiguous() - v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous() - inference_params.key_value_memory_dict[self.layer_idx] = (k_cache, v_cache) - return kv + assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32] + packsize = 4 if kv.dtype == torch.float32 else 8 + if kv_cache is not None: + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize', + packsize=packsize).contiguous() + v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous() + inference_params.key_value_memory_dict[self.layer_idx] = (k_cache, v_cache) + else: + k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange( + kv[:, :, 0], 'b s h (d packsize) -> b h d s packsize', packsize=packsize + ) + v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange( + kv[:, :, 1], 'b s h d -> b h s d' + ) + return kv def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, inference_params=None, **kwargs): diff --git a/tests/models/test_gpt_generation.py b/tests/models/test_gpt_generation.py index 793d247..dd15bda 100644 --- a/tests/models/test_gpt_generation.py +++ b/tests/models/test_gpt_generation.py @@ -14,10 +14,11 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained @pytest.mark.parametrize('fused_ft_kernel', [False, True]) +# @pytest.mark.parametrize('fused_ft_kernel', [True]) @pytest.mark.parametrize('optimized', [False, True]) -# @pytest.mark.parametrize('fused_ft_kernel', [False]) -# @pytest.mark.parametrize('optimized', [True]) +# @pytest.mark.parametrize('optimized', [False]) @pytest.mark.parametrize('rotary', [False, True]) +# @pytest.mark.parametrize('rotary', [False]) @pytest.mark.parametrize('model_name', ["gpt2"]) def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): """Check that our implementation of GPT2 generation matches the HF implementation: