[Gen] Remove minor dead code
This commit is contained in:
parent
e4f726fc44
commit
0a146185d6
@ -27,7 +27,7 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
|
||||
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
if getattr(config, "tie_word_embeddings"):
|
||||
if getattr(config, "tie_word_embeddings", False):
|
||||
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
||||
else:
|
||||
output_embeddings = state_dict.pop("embed_out.weight")
|
||||
|
||||
@ -591,10 +591,6 @@ def allocate_inference_cache(
|
||||
dtype=torch.float16,
|
||||
):
|
||||
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
packsize = 4 if dtype == torch.float32 else 8
|
||||
assert headdim % packsize == 0
|
||||
k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize)
|
||||
v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
|
||||
kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim)
|
||||
if isinstance(layers, int):
|
||||
layers = range(layers)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user