Skip to content

Commit

Permalink
Fix low cpu mem usage tests (#30808)
Browse files Browse the repository at this point in the history
* Fix tests

* fix udop failing test

* remove skip

* style
  • Loading branch information
SunMarc authored May 22, 2024
1 parent 934e1b8 commit 5c18600
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/udop/modeling_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,7 +1297,7 @@ def __init__(self, config, embed_tokens=None, embed_patches=None):
# get weights from encoder position bias
self.relative_bias = self._get_relative_bias(config)

# tie weights of original position bias of encoder
def _tie_weights(self):
for bias in self.relative_bias.biases:
if isinstance(bias, RelativePositionBias1D):
self._tie_or_clone_weights(
Expand Down
15 changes: 5 additions & 10 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import random
import re
import tempfile
import unittest
import warnings
from collections import defaultdict
from typing import Dict, List, Tuple
Expand Down Expand Up @@ -444,7 +443,6 @@ class CopyClass(model_class):
@slow
@require_accelerate
@mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as saved_model_path:
Expand All @@ -457,7 +455,6 @@ def test_save_load_low_cpu_mem_usage(self):
@slow
@require_accelerate
@mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage_checkpoints(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as saved_model_path:
Expand All @@ -471,7 +468,6 @@ def test_save_load_low_cpu_mem_usage_checkpoints(self):
@slow
@require_accelerate
@mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
with tempfile.TemporaryDirectory() as saved_model_path:
for model_class in self.all_model_classes:
Expand All @@ -482,6 +478,8 @@ def test_save_load_low_cpu_mem_usage_no_safetensors(self):
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)

def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path):
from accelerate.utils.modeling import named_module_tensors

# Load the low usage and the normal models.
model_low_usage, loading_info = model_class.from_pretrained(
saved_model_path,
Expand All @@ -496,16 +494,13 @@ def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path):
# The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then
# subsequently loaded with the correct values and onto the correct device. We check if there are any
# remaining params that were not properly loaded.
for name, param in model_low_usage.named_parameters():
for name, tensor in named_module_tensors(model_low_usage, recurse=True):
self.assertNotEqual(
param.device,
tensor.device,
torch.device("meta"),
"Parameter '" + name + "' has not been properly loaded and has device=meta.",
"Tensor '" + name + "' has not been properly loaded and has device=meta.",
)

# Tests moving the model to a device other than meta.
model_low_usage.to(torch_device)

# Check that the parameters are equal.
for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()):
self.assertEquals(p1.data.ne(p2.data).sum(), 0)
Expand Down

0 comments on commit 5c18600

Please sign in to comment.