Skip to content

Commit

Permalink
Merge pull request #703 from ACEsuit/fix-extract-equivariant-features…
Browse files Browse the repository at this point in the history
…-with-num-layers-1

Fix-extract-equivariant-features-with-num-layers-1
  • Loading branch information
ilyes319 authored Nov 22, 2024
2 parents 3cd9bb8 + a6a729a commit 2f9a3a5
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 12 deletions.
22 changes: 16 additions & 6 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,24 +400,34 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1):
atoms = self.atoms
if self.model_type != "MACE":
raise NotImplementedError("Only implemented for MACE models")
num_interactions = int(self.models[0].num_interactions)
if num_layers == -1:
num_layers = int(self.models[0].num_interactions)
num_layers = num_interactions
batch = self._atoms_to_batch(atoms)
descriptors = [model(batch.to_dict())["node_feats"] for model in self.models]

irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"]
l_max = irreps_out.lmax
num_invariant_features = irreps_out.dim // (l_max + 1) ** 2
per_layer_features = [irreps_out.dim for _ in range(num_interactions)]
per_layer_features[-1] = (
num_invariant_features # Equivariant features not created for the last layer
)

if invariants_only:
irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"]
l_max = irreps_out.lmax
num_features = irreps_out.dim // (l_max + 1) ** 2
descriptors = [
extract_invariant(
descriptor,
num_layers=num_layers,
num_features=num_features,
num_features=num_invariant_features,
l_max=l_max,
)
for descriptor in descriptors
]
descriptors = [descriptor.detach().cpu().numpy() for descriptor in descriptors]
to_keep = np.sum(per_layer_features[:num_layers])
descriptors = [
descriptor[:, :to_keep].detach().cpu().numpy() for descriptor in descriptors
]

if self.num_models == 1:
return descriptors[0]
Expand Down
30 changes: 24 additions & 6 deletions tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,24 +481,42 @@ def test_calculator_descriptor(fitting_configs, trained_equivariant_model):

desc_invariant = calc.get_descriptors(at, invariants_only=True)
desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True)
desc_single_layer = calc.get_descriptors(at, invariants_only=True, num_layers=1)
desc_single_layer_rotated = calc.get_descriptors(
desc_invariant_single_layer = calc.get_descriptors(
at, invariants_only=True, num_layers=1
)
desc_invariant_single_layer_rotated = calc.get_descriptors(
at_rotated, invariants_only=True, num_layers=1
)
desc = calc.get_descriptors(at, invariants_only=False)
desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1)
desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False)
desc_rotated_single_layer = calc.get_descriptors(
at_rotated, invariants_only=False, num_layers=1
)

assert desc_invariant.shape[0] == 3
assert desc_invariant.shape[1] == 32
assert desc_single_layer.shape[0] == 3
assert desc_single_layer.shape[1] == 16
assert desc_invariant_single_layer.shape[0] == 3
assert desc_invariant_single_layer.shape[1] == 16
assert desc.shape[0] == 3
assert desc.shape[1] == 80
assert desc_single_layer.shape[0] == 3
assert desc_single_layer.shape[1] == 16 * 4
assert desc_rotated_single_layer.shape[0] == 3
assert desc_rotated_single_layer.shape[1] == 16 * 4

np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6)
np.testing.assert_allclose(desc_single_layer, desc_invariant[:, :16], atol=1e-6)
np.testing.assert_allclose(
desc_single_layer_rotated, desc_invariant[:, :16], atol=1e-6
desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6
)
np.testing.assert_allclose(
desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6
)
np.testing.assert_allclose(
desc_single_layer[:, :16], desc_rotated_single_layer[:, :16], atol=1e-6
)
assert not np.allclose(
desc_single_layer[:, 16:], desc_rotated_single_layer[:, 16:], atol=1e-6
)
assert not np.allclose(desc, desc_rotated, atol=1e-6)

Expand Down

0 comments on commit 2f9a3a5

Please sign in to comment.