[Rotary] Don't store inv_freq in state_dict
This commit is contained in:
parent
a157cc8c9b
commit
ec9f74ab9a
@ -191,12 +191,12 @@ class RotaryEmbedding(torch.nn.Module):
|
|||||||
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
||||||
# Generate and save the inverse frequency buffer (non trainable)
|
# Generate and save the inverse frequency buffer (non trainable)
|
||||||
inv_freq = self._compute_inv_freq(device)
|
inv_freq = self._compute_inv_freq(device)
|
||||||
self.register_buffer("inv_freq", inv_freq)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.interleaved = interleaved
|
self.interleaved = interleaved
|
||||||
self.scale_base = scale_base
|
self.scale_base = scale_base
|
||||||
scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
|
scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
|
||||||
/ (1.4 * dim) if scale_base is not None else None)
|
/ (1.4 * dim) if scale_base is not None else None)
|
||||||
self.register_buffer("scale", scale)
|
self.register_buffer("scale", scale, persistent=False)
|
||||||
|
|
||||||
self._seq_len_cached = 0
|
self._seq_len_cached = 0
|
||||||
self._cos_cached = None
|
self._cos_cached = None
|
||||||
|
|||||||
@ -237,7 +237,6 @@ class GPTPreTrainedModel(nn.Module):
|
|||||||
state_dict = remap_state_dict_hf_opt(state_dict, config)
|
state_dict = remap_state_dict_hf_opt(state_dict, config)
|
||||||
elif model_name.startswith('EleutherAI/gpt-j-'):
|
elif model_name.startswith('EleutherAI/gpt-j-'):
|
||||||
state_dict = remap_state_dict_hf_gptj(state_dict, config)
|
state_dict = remap_state_dict_hf_gptj(state_dict, config)
|
||||||
strict = False # We have rotary_emb.inf_freq buffers not in the GPT-J checkpoint
|
|
||||||
elif model_name.startswith('EleutherAI/gpt-neox-'):
|
elif model_name.startswith('EleutherAI/gpt-neox-'):
|
||||||
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
|
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user