[Gen] Minor tweak to allocate_inference_cache
This commit is contained in:
parent
ba2fe7f378
commit
fcab93b43a
@ -158,6 +158,9 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
|
||||
class GenerationMixin:
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
return_dict_in_generate=False, output_scores=False, **kwargs):
|
||||
output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
|
||||
@ -224,11 +227,11 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
|
||||
gc.collect()
|
||||
cache.device, cache.dtype = device, dtype
|
||||
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
|
||||
headdim = getattr(model.config, 'head_dim',
|
||||
model.config.hidden_size // model.config.num_attention_heads)
|
||||
if hasattr(model, 'allocate_inference_cache'):
|
||||
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
|
||||
else:
|
||||
headdim = getattr(model.config, 'head_dim',
|
||||
model.config.hidden_size // model.config.num_attention_heads)
|
||||
inf_cache = allocate_inference_cache(
|
||||
batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim,
|
||||
model.config.num_hidden_layers, device, dtype
|
||||
|
||||
Loading…
Reference in New Issue
Block a user