diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 56e07375..fc88c051 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -400,24 +400,34 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1): atoms = self.atoms if self.model_type != "MACE": raise NotImplementedError("Only implemented for MACE models") + num_interactions = int(self.models[0].num_interactions) if num_layers == -1: - num_layers = int(self.models[0].num_interactions) + num_layers = num_interactions batch = self._atoms_to_batch(atoms) descriptors = [model(batch.to_dict())["node_feats"] for model in self.models] + + irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"] + l_max = irreps_out.lmax + num_invariant_features = irreps_out.dim // (l_max + 1) ** 2 + per_layer_features = [irreps_out.dim for _ in range(num_interactions)] + per_layer_features[-1] = ( + num_invariant_features # Equivariant features not created for the last layer + ) + if invariants_only: - irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"] - l_max = irreps_out.lmax - num_features = irreps_out.dim // (l_max + 1) ** 2 descriptors = [ extract_invariant( descriptor, num_layers=num_layers, - num_features=num_features, + num_features=num_invariant_features, l_max=l_max, ) for descriptor in descriptors ] - descriptors = [descriptor.detach().cpu().numpy() for descriptor in descriptors] + to_keep = np.sum(per_layer_features[:num_layers]) + descriptors = [ + descriptor[:, :to_keep].detach().cpu().numpy() for descriptor in descriptors + ] if self.num_models == 1: return descriptors[0] diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 74a0ffa3..158cad64 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -481,24 +481,42 @@ def test_calculator_descriptor(fitting_configs, trained_equivariant_model): desc_invariant = calc.get_descriptors(at, invariants_only=True) desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True) - desc_single_layer = calc.get_descriptors(at, invariants_only=True, num_layers=1) - desc_single_layer_rotated = calc.get_descriptors( + desc_invariant_single_layer = calc.get_descriptors( + at, invariants_only=True, num_layers=1 + ) + desc_invariant_single_layer_rotated = calc.get_descriptors( at_rotated, invariants_only=True, num_layers=1 ) desc = calc.get_descriptors(at, invariants_only=False) + desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1) desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False) + desc_rotated_single_layer = calc.get_descriptors( + at_rotated, invariants_only=False, num_layers=1 + ) assert desc_invariant.shape[0] == 3 assert desc_invariant.shape[1] == 32 - assert desc_single_layer.shape[0] == 3 - assert desc_single_layer.shape[1] == 16 + assert desc_invariant_single_layer.shape[0] == 3 + assert desc_invariant_single_layer.shape[1] == 16 assert desc.shape[0] == 3 assert desc.shape[1] == 80 + assert desc_single_layer.shape[0] == 3 + assert desc_single_layer.shape[1] == 16 * 4 + assert desc_rotated_single_layer.shape[0] == 3 + assert desc_rotated_single_layer.shape[1] == 16 * 4 np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6) - np.testing.assert_allclose(desc_single_layer, desc_invariant[:, :16], atol=1e-6) np.testing.assert_allclose( - desc_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 + desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6 + ) + np.testing.assert_allclose( + desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 + ) + np.testing.assert_allclose( + desc_single_layer[:, :16], desc_rotated_single_layer[:, :16], atol=1e-6 + ) + assert not np.allclose( + desc_single_layer[:, 16:], desc_rotated_single_layer[:, 16:], atol=1e-6 ) assert not np.allclose(desc, desc_rotated, atol=1e-6)