diff --git a/tests/test_cueq.py b/tests/test_cueq.py index 5f47cbd9..7317d62c 100644 --- a/tests/test_cueq.py +++ b/tests/test_cueq.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import Any, Dict import pytest @@ -111,9 +112,11 @@ def test_bidirectional_conversion( # model_e3nn_back = model_e3nn_back.to(device) # Test forward pass equivalence - out_e3nn = model_e3nn(batch, training=True) - out_cueq = model_cueq(batch, training=True) - out_e3nn_back = model_e3nn_back(batch, training=True) + out_e3nn = model_e3nn(deepcopy(batch), training=True, compute_stress=True) + out_cueq = model_cueq(deepcopy(batch), training=True, compute_stress=True) + out_e3nn_back = model_e3nn_back( + deepcopy(batch), training=True, compute_stress=True + ) # Check outputs match for both conversions torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"])