From ba2fe7f378c938263e8b5eeeac0fb2766c754551 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 20 Apr 2023 18:15:12 -0700 Subject: [PATCH] [Gen] Move allocate_inference_cache to within the model --- flash_attn/models/gpt.py | 8 ++++++++ flash_attn/modules/block.py | 3 +++ flash_attn/modules/mha.py | 16 ++++++++++++++++ flash_attn/utils/generation.py | 17 ++++++++++------- 4 files changed, 37 insertions(+), 7 deletions(-) diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 77744b4..5b165b7 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -335,6 +335,10 @@ class GPTModel(GPTPreTrainedModel): if self.process_group is not None: sync_shared_params(self, self.process_group) + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return {i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + for i, layer in enumerate(self.layers)} + def forward(self, input_ids, position_ids=None, inference_params=None): # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen # dimensions so that we can split on it easily, in case of small batch size. @@ -426,6 +430,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): if self.process_group is not None: sync_shared_params(self, self.process_group) + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.transformer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, + **kwargs) + def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False): """ inference_params: for generation. Adapted from Megatron-LM (and Apex) diff --git a/flash_attn/modules/block.py b/flash_attn/modules/block.py index b181c22..7ea2c2f 100644 --- a/flash_attn/modules/block.py +++ b/flash_attn/modules/block.py @@ -105,6 +105,9 @@ class Block(nn.Module): for p in self.norm2.parameters(): p._shared_params = True + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None): r"""Pass the input through the encoder layer. diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index f4d7d64..098d982 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -416,6 +416,22 @@ class MHA(nn.Module): attention_dropout=dropout) self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True): + dtype = self.out_proj.weight.dtype if dtype is None else dtype + device = self.out_proj.weight.device + if not fused_ft_kernel: + return torch.empty(batch_size, max_seqlen, 2, self.num_heads, self.head_dim, + dtype=dtype, device=device) + else: + assert dtype in [torch.float16, torch.bfloat16, torch.float32] + packsize = 4 if dtype == torch.float32 else 8 + assert self.head_dim % packsize == 0 + k_cache = torch.empty(batch_size, self.num_heads, self.head_dim // packsize, max_seqlen, + packsize, dtype=dtype, device=device) + v_cache = torch.empty(batch_size, self.num_heads, max_seqlen, self.head_dim, + dtype=dtype, device=device) + return k_cache, v_cache + def _update_kv_cache(self, kv, inference_params): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) """ diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index 6527534..b45fba5 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -167,8 +167,8 @@ class GenerationMixin: return output if return_dict_in_generate else output.sequences -def allocate_kv_cache(max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence], - device, dtype=torch.float16): +def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence], + device, dtype=torch.float16): assert dtype in [torch.float16, torch.bfloat16, torch.float32] packsize = 4 if dtype == torch.float32 else 8 assert headdim % packsize == 0 @@ -226,14 +226,17 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen headdim = getattr(model.config, 'head_dim', model.config.hidden_size // model.config.num_attention_heads) - kv_cache = allocate_kv_cache( - batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim, - model.config.num_hidden_layers, device, dtype - ) + if hasattr(model, 'allocate_inference_cache'): + inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) + else: + inf_cache = allocate_inference_cache( + batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim, + model.config.num_hidden_layers, device, dtype + ) lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) cache.inference_params = InferenceParams( max_sequence_len=max_seqlen, max_batch_size=batch_size, - sequence_len_offset=seqlen_og, key_value_memory_dict=kv_cache, fused_ft_kernel=True, + sequence_len_offset=seqlen_og, key_value_memory_dict=inf_cache, fused_ft_kernel=True, lengths_per_sample=lengths_per_sample ) cache.mempool = torch.cuda.graphs.graph_pool_handle()