[Gen] Minor tweak to allocate_inference_cache

This commit is contained in:
Tri Dao 2023-04-21 11:56:47 -07:00
parent ba2fe7f378
commit fcab93b43a

View File

@ -158,6 +158,9 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
class GenerationMixin:
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
raise NotImplementedError
def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0,
return_dict_in_generate=False, output_scores=False, **kwargs):
output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
@ -224,11 +227,11 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
gc.collect()
cache.device, cache.dtype = device, dtype
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
headdim = getattr(model.config, 'head_dim',
model.config.hidden_size // model.config.num_attention_heads)
if hasattr(model, 'allocate_inference_cache'):
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
else:
headdim = getattr(model.config, 'head_dim',
model.config.hidden_size // model.config.num_attention_heads)
inf_cache = allocate_inference_cache(
batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim,
model.config.num_hidden_layers, device, dtype