[Gen] Remove minor dead code

This commit is contained in:
Tri Dao 2023-12-19 22:57:05 -08:00
parent e4f726fc44
commit 0a146185d6
2 changed files with 1 additions and 5 deletions

View File

@ -27,7 +27,7 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, "tie_word_embeddings"):
if getattr(config, "tie_word_embeddings", False):
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
else:
output_embeddings = state_dict.pop("embed_out.weight")

View File

@ -591,10 +591,6 @@ def allocate_inference_cache(
dtype=torch.float16,
):
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8
assert headdim % packsize == 0
k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize)
v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim)
if isinstance(layers, int):
layers = range(layers)