From 5601e9f9a4a41379da99568aa5bca493a648dfb9 Mon Sep 17 00:00:00 2001 From: JB Lau <1557853+hackyon@users.noreply.github.com> Date: Thu, 8 Feb 2024 15:49:41 -0500 Subject: [PATCH] Use absolute value for the comparison of values in test_save_load_fast_init_from_base() --- tests/test_modeling_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 95d8cb12a4267e..c4f56b14ab2746 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -430,7 +430,8 @@ class CopyClass(model_class): if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor): max_diff = (model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]).sum().item() else: - max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() + diffs = model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key] + max_diff = np.abs(diffs).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") def test_fast_init_context_manager(self):