diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 02e7aac88b4432..4b170c1023a6ed 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1297,7 +1297,7 @@ def __init__(self, config, embed_tokens=None, embed_patches=None): # get weights from encoder position bias self.relative_bias = self._get_relative_bias(config) - # tie weights of original position bias of encoder + def _tie_weights(self): for bias in self.relative_bias.biases: if isinstance(bias, RelativePositionBias1D): self._tie_or_clone_weights( diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 5f3ac898daeea6..5480105054a909 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -21,7 +21,6 @@ import random import re import tempfile -import unittest import warnings from collections import defaultdict from typing import Dict, List, Tuple @@ -444,7 +443,6 @@ class CopyClass(model_class): @slow @require_accelerate @mark.accelerate_tests - @unittest.skip("Need to fix since we have a device mismatch") def test_save_load_low_cpu_mem_usage(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() with tempfile.TemporaryDirectory() as saved_model_path: @@ -457,7 +455,6 @@ def test_save_load_low_cpu_mem_usage(self): @slow @require_accelerate @mark.accelerate_tests - @unittest.skip("Need to fix since we have a device mismatch") def test_save_load_low_cpu_mem_usage_checkpoints(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() with tempfile.TemporaryDirectory() as saved_model_path: @@ -471,7 +468,6 @@ def test_save_load_low_cpu_mem_usage_checkpoints(self): @slow @require_accelerate @mark.accelerate_tests - @unittest.skip("Need to fix since we have a device mismatch") def test_save_load_low_cpu_mem_usage_no_safetensors(self): with tempfile.TemporaryDirectory() as saved_model_path: for model_class in self.all_model_classes: @@ -482,6 +478,8 @@ def test_save_load_low_cpu_mem_usage_no_safetensors(self): self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path) def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path): + from accelerate.utils.modeling import named_module_tensors + # Load the low usage and the normal models. model_low_usage, loading_info = model_class.from_pretrained( saved_model_path, @@ -496,16 +494,13 @@ def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path): # The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then # subsequently loaded with the correct values and onto the correct device. We check if there are any # remaining params that were not properly loaded. - for name, param in model_low_usage.named_parameters(): + for name, tensor in named_module_tensors(model_low_usage, recurse=True): self.assertNotEqual( - param.device, + tensor.device, torch.device("meta"), - "Parameter '" + name + "' has not been properly loaded and has device=meta.", + "Tensor '" + name + "' has not been properly loaded and has device=meta.", ) - # Tests moving the model to a device other than meta. - model_low_usage.to(torch_device) - # Check that the parameters are equal. for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()): self.assertEquals(p1.data.ne(p2.data).sum(), 0)