Skip to content

Commit

Permalink
Use absolute value for the comparison of values in test_save_load_fas…
Browse files Browse the repository at this point in the history
…t_init_from_base()
  • Loading branch information
hackyon committed Feb 12, 2024
1 parent 84834ef commit b68240d
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ class CopyClass(model_class):
if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor):
max_diff = (model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]).sum().item()
else:
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
diffs = model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]
max_diff = np.abs(diffs).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

def test_fast_init_context_manager(self):
Expand Down

0 comments on commit b68240d

Please sign in to comment.