Skip to content

Commit

Permalink
add stress to test_cueq.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Nov 22, 2024
1 parent 214ac8f commit 38834dc
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tests/test_cueq.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from typing import Any, Dict

import pytest
Expand Down Expand Up @@ -111,9 +112,11 @@ def test_bidirectional_conversion(
# model_e3nn_back = model_e3nn_back.to(device)

# Test forward pass equivalence
out_e3nn = model_e3nn(batch, training=True)
out_cueq = model_cueq(batch, training=True)
out_e3nn_back = model_e3nn_back(batch, training=True)
out_e3nn = model_e3nn(deepcopy(batch), training=True, compute_stress=True)
out_cueq = model_cueq(deepcopy(batch), training=True, compute_stress=True)
out_e3nn_back = model_e3nn_back(
deepcopy(batch), training=True, compute_stress=True
)

# Check outputs match for both conversions
torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"])
Expand Down

0 comments on commit 38834dc

Please sign in to comment.