diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index bfda386..8773beb 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -191,12 +191,12 @@ class RotaryEmbedding(torch.nn.Module): self.pos_idx_in_fp32 = pos_idx_in_fp32 # Generate and save the inverse frequency buffer (non trainable) inv_freq = self._compute_inv_freq(device) - self.register_buffer("inv_freq", inv_freq) + self.register_buffer("inv_freq", inv_freq, persistent=False) self.interleaved = interleaved self.scale_base = scale_base scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) if scale_base is not None else None) - self.register_buffer("scale", scale) + self.register_buffer("scale", scale, persistent=False) self._seq_len_cached = 0 self._cos_cached = None diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 0b6472b..cebf843 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -237,7 +237,6 @@ class GPTPreTrainedModel(nn.Module): state_dict = remap_state_dict_hf_opt(state_dict, config) elif model_name.startswith('EleutherAI/gpt-j-'): state_dict = remap_state_dict_hf_gptj(state_dict, config) - strict = False # We have rotary_emb.inf_freq buffers not in the GPT-J checkpoint elif model_name.startswith('EleutherAI/gpt-neox-'): state_dict = remap_state_dict_hf_gpt_neox(state_dict, config) else: