diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index b45fba5..adaa0b5 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -158,6 +158,9 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, class GenerationMixin: + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + raise NotImplementedError + def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0, return_dict_in_generate=False, output_scores=False, **kwargs): output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p, @@ -224,11 +227,11 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p gc.collect() cache.device, cache.dtype = device, dtype cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen - headdim = getattr(model.config, 'head_dim', - model.config.hidden_size // model.config.num_attention_heads) if hasattr(model, 'allocate_inference_cache'): inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) else: + headdim = getattr(model.config, 'head_dim', + model.config.hidden_size // model.config.num_attention_heads) inf_cache = allocate_inference_cache( batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim, model.config.num_hidden_layers, device, dtype