[Rotary] Don't store inv_freq in state_dict
This commit is contained in:
parent
a157cc8c9b
commit
ec9f74ab9a
@ -191,12 +191,12 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
||||
# Generate and save the inverse frequency buffer (non trainable)
|
||||
inv_freq = self._compute_inv_freq(device)
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.interleaved = interleaved
|
||||
self.scale_base = scale_base
|
||||
scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
|
||||
/ (1.4 * dim) if scale_base is not None else None)
|
||||
self.register_buffer("scale", scale)
|
||||
self.register_buffer("scale", scale, persistent=False)
|
||||
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
|
||||
@ -237,7 +237,6 @@ class GPTPreTrainedModel(nn.Module):
|
||||
state_dict = remap_state_dict_hf_opt(state_dict, config)
|
||||
elif model_name.startswith('EleutherAI/gpt-j-'):
|
||||
state_dict = remap_state_dict_hf_gptj(state_dict, config)
|
||||
strict = False # We have rotary_emb.inf_freq buffers not in the GPT-J checkpoint
|
||||
elif model_name.startswith('EleutherAI/gpt-neox-'):
|
||||
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user