[Gen] Adjust shape of kv_cache when using FT
This commit is contained in:
parent
e02fd588aa
commit
0938298e4c
@ -359,7 +359,7 @@ class MHA(nn.Module):
|
||||
self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)
|
||||
|
||||
def _update_kv_cache(self, kv, inference_params):
|
||||
"""kv: (batch_size, 1, nheads, head_dim)
|
||||
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
||||
"""
|
||||
assert not self.dwconv, 'Generation does not support dwconv yet'
|
||||
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
|
||||
@ -371,27 +371,46 @@ class MHA(nn.Module):
|
||||
)
|
||||
inference_params.key_value_memory_dict[self.layer_idx] = kv_cache
|
||||
else:
|
||||
assert not inference_params.fused_ft_kernel, 'fused_ft_kernel should not take this path'
|
||||
kv_cache = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
if not inference_params.fused_ft_kernel:
|
||||
kv_cache = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
else:
|
||||
# For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
|
||||
# where packsize = 4 if fp32, 8 if fp16 or bf16.
|
||||
# v_cache has shape (b, h, s, headdim)
|
||||
k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
kv_cache = None
|
||||
# Adjust key and value for inference
|
||||
batch_start = inference_params.batch_size_offset
|
||||
batch_end = batch_start + kv.shape[0]
|
||||
assert batch_end <= kv_cache.shape[0]
|
||||
sequence_start = inference_params.sequence_len_offset
|
||||
sequence_end = sequence_start + kv.shape[1]
|
||||
assert sequence_end <= kv_cache.shape[1]
|
||||
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
|
||||
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
|
||||
# Copy key and values.
|
||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
||||
if inference_params.fused_ft_kernel:
|
||||
if not inference_params.fused_ft_kernel:
|
||||
assert kv_cache is not None
|
||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
||||
return kv
|
||||
else:
|
||||
assert inference_params.sequence_len_offset == 0
|
||||
# FT kernel requires different layouts for the k_cache and v_cache.
|
||||
assert kv_cache.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
packsize = 4 if kv_cache.dtype == torch.float32 else 8
|
||||
k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize',
|
||||
packsize=packsize).contiguous()
|
||||
v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous()
|
||||
inference_params.key_value_memory_dict[self.layer_idx] = (k_cache, v_cache)
|
||||
return kv
|
||||
assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
packsize = 4 if kv.dtype == torch.float32 else 8
|
||||
if kv_cache is not None:
|
||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize',
|
||||
packsize=packsize).contiguous()
|
||||
v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous()
|
||||
inference_params.key_value_memory_dict[self.layer_idx] = (k_cache, v_cache)
|
||||
else:
|
||||
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
|
||||
kv[:, :, 0], 'b s h (d packsize) -> b h d s packsize', packsize=packsize
|
||||
)
|
||||
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
|
||||
kv[:, :, 1], 'b s h d -> b h s d'
|
||||
)
|
||||
return kv
|
||||
|
||||
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
|
||||
inference_params=None, **kwargs):
|
||||
|
||||
@ -14,10 +14,11 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
|
||||
|
||||
@pytest.mark.parametrize('fused_ft_kernel', [False, True])
|
||||
# @pytest.mark.parametrize('fused_ft_kernel', [True])
|
||||
@pytest.mark.parametrize('optimized', [False, True])
|
||||
# @pytest.mark.parametrize('fused_ft_kernel', [False])
|
||||
# @pytest.mark.parametrize('optimized', [True])
|
||||
# @pytest.mark.parametrize('optimized', [False])
|
||||
@pytest.mark.parametrize('rotary', [False, True])
|
||||
# @pytest.mark.parametrize('rotary', [False])
|
||||
@pytest.mark.parametrize('model_name', ["gpt2"])
|
||||
def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
|
||||
"""Check that our implementation of GPT2 generation matches the HF implementation:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user