diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 33ddc2fbcc438b..40892f0cdc8d9a 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4479,6 +4479,11 @@ def _load_pretrained_model( model_buffers = {".".join([prefix, key]) for key in model_buffers} unexpected_keys = sorted(unexpected_keys - model_buffers) + # Clean up buffer for `inv-freq` because RoPE embedding moved under base model (https://github.com/huggingface/transformers/pull/34858) + has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers) + if has_inv_freq_buffers: + unexpected_keys = {k for k in unexpected_keys if "rotary_emb.inv_freq" not in k} + model.tie_weights() if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): ptrs = collections.defaultdict(list)