diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 43c55ac4355a46..017eed79d567d0 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -432,7 +432,8 @@ class CopyClass(model_class): if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor): max_diff = (model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]).sum().item() else: - max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() + diffs = model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key] + max_diff = np.abs(diffs).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") def test_fast_init_context_manager(self):