[Gen] Minor tweak to allocate_inference_cache
This commit is contained in:
parent
ba2fe7f378
commit
fcab93b43a
@ -158,6 +158,9 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
|||||||
|
|
||||||
class GenerationMixin:
|
class GenerationMixin:
|
||||||
|
|
||||||
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||||
return_dict_in_generate=False, output_scores=False, **kwargs):
|
return_dict_in_generate=False, output_scores=False, **kwargs):
|
||||||
output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
|
output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
|
||||||
@ -224,11 +227,11 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
cache.device, cache.dtype = device, dtype
|
cache.device, cache.dtype = device, dtype
|
||||||
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
|
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
|
||||||
headdim = getattr(model.config, 'head_dim',
|
|
||||||
model.config.hidden_size // model.config.num_attention_heads)
|
|
||||||
if hasattr(model, 'allocate_inference_cache'):
|
if hasattr(model, 'allocate_inference_cache'):
|
||||||
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
|
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
|
||||||
else:
|
else:
|
||||||
|
headdim = getattr(model.config, 'head_dim',
|
||||||
|
model.config.hidden_size // model.config.num_attention_heads)
|
||||||
inf_cache = allocate_inference_cache(
|
inf_cache = allocate_inference_cache(
|
||||||
batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim,
|
batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim,
|
||||||
model.config.num_hidden_layers, device, dtype
|
model.config.num_hidden_layers, device, dtype
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user