[Rotary] Don't store inv_freq in state_dict

This commit is contained in:
Tri Dao 2023-07-22 23:52:42 -07:00
parent a157cc8c9b
commit ec9f74ab9a
2 changed files with 2 additions and 3 deletions

View File

@ -191,12 +191,12 @@ class RotaryEmbedding(torch.nn.Module):
self.pos_idx_in_fp32 = pos_idx_in_fp32
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = self._compute_inv_freq(device)
self.register_buffer("inv_freq", inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.interleaved = interleaved
self.scale_base = scale_base
scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
/ (1.4 * dim) if scale_base is not None else None)
self.register_buffer("scale", scale)
self.register_buffer("scale", scale, persistent=False)
self._seq_len_cached = 0
self._cos_cached = None

View File

@ -237,7 +237,6 @@ class GPTPreTrainedModel(nn.Module):
state_dict = remap_state_dict_hf_opt(state_dict, config)
elif model_name.startswith('EleutherAI/gpt-j-'):
state_dict = remap_state_dict_hf_gptj(state_dict, config)
strict = False # We have rotary_emb.inf_freq buffers not in the GPT-J checkpoint
elif model_name.startswith('EleutherAI/gpt-neox-'):
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
else: