Skip to content

Commit

Permalink
fix cueq calc descriptors
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Dec 5, 2024
1 parent efaca2c commit ae2d461
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from ase.calculators.calculator import Calculator, all_changes
from ase.stress import full_3x3_to_voigt_6_stress
from e3nn import o3

from mace import data
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
Expand Down Expand Up @@ -406,7 +407,7 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1):
batch = self._atoms_to_batch(atoms)
descriptors = [model(batch.to_dict())["node_feats"] for model in self.models]

irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"]
irreps_out = o3.Irreps(str(self.models[0].products[0].linear.irreps_out))
l_max = irreps_out.lmax
num_invariant_features = irreps_out.dim // (l_max + 1) ** 2
per_layer_features = [irreps_out.dim for _ in range(num_interactions)]
Expand Down
48 changes: 48 additions & 0 deletions tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,54 @@ def test_calculator_descriptor(fitting_configs, trained_equivariant_model):
assert not np.allclose(desc, desc_rotated, atol=1e-6)


def test_calculator_descriptor_cueq(fitting_configs, trained_equivariant_model):
at = fitting_configs[2].copy()
at_rotated = fitting_configs[2].copy()
at_rotated.rotate(90, "x")
calc = trained_equivariant_model

desc_invariant = calc.get_descriptors(at, invariants_only=True, enable_cueq=True)
desc_invariant_rotated = calc.get_descriptors(
at_rotated, invariants_only=True, enable_cueq=True
)
desc_invariant_single_layer = calc.get_descriptors(
at, invariants_only=True, num_layers=1, enable_cueq=True
)
desc_invariant_single_layer_rotated = calc.get_descriptors(
at_rotated, invariants_only=True, num_layers=1, enable_cueq=True
)
desc = calc.get_descriptors(at, invariants_only=False, enable_cueq=True)
desc_single_layer = calc.get_descriptors(
at, invariants_only=False, num_layers=1, enable_cueq=True
)
desc_rotated = calc.get_descriptors(
at_rotated, invariants_only=False, enable_cueq=True
)
desc_rotated_single_layer = calc.get_descriptors(
at_rotated, invariants_only=False, num_layers=1, enable_cueq=True
)

assert desc_invariant.shape[0] == 3
assert desc_invariant.shape[1] == 32
assert desc_invariant_single_layer.shape[0] == 3
assert desc_invariant_single_layer.shape[1] == 16
assert desc.shape[0] == 3
assert desc.shape[1] == 80
assert desc_single_layer.shape[0] == 3
assert desc_single_layer.shape[1] == 16 * 4
assert desc_rotated_single_layer.shape[0] == 3
assert desc_rotated_single_layer.shape[1] == 16 * 4

np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6)
np.testing.assert_allclose(
desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6
)
np.testing.assert_allclose(
desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6
)
np.testing.assert_allclose(desc, desc_rotated, atol=1e-6)


def test_mace_mp(capsys: pytest.CaptureFixture):
mp_mace = mace_mp()
assert isinstance(mp_mace, MACECalculator)
Expand Down

0 comments on commit ae2d461

Please sign in to comment.