diff --git a/flash_attn/models/gpt_neox.py b/flash_attn/models/gpt_neox.py index 3a8fa07..c389404 100644 --- a/flash_attn/models/gpt_neox.py +++ b/flash_attn/models/gpt_neox.py @@ -27,7 +27,7 @@ def remap_state_dict_hf_gpt_neox(state_dict, config): state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) ) - if getattr(config, "tie_word_embeddings"): + if getattr(config, "tie_word_embeddings", False): state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] else: output_embeddings = state_dict.pop("embed_out.weight") diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index bbbe34b..d5d1139 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -591,10 +591,6 @@ def allocate_inference_cache( dtype=torch.float16, ): assert dtype in [torch.float16, torch.bfloat16, torch.float32] - packsize = 4 if dtype == torch.float32 else 8 - assert headdim % packsize == 0 - k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize) - v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim) kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim) if isinstance(layers, int): layers = range(layers)