diff --git a/src/nomad_simulations/schema_packages/model_system.py b/src/nomad_simulations/schema_packages/model_system.py index 0a1897d1..29e70324 100644 --- a/src/nomad_simulations/schema_packages/model_system.py +++ b/src/nomad_simulations/schema_packages/model_system.py @@ -391,6 +391,17 @@ def is_equal_cell(self, other) -> bool: return False return True + def get_chemical_symbols(self) -> list: + """ + Get the chemical symbols of the atoms in the atomic cell. These are defined on `atoms_state[*].chemical_symbol`. + + Returns: + list: The list of chemical symbols of the atoms in the atomic cell. + """ + if not self.atoms_state: + return [] + return [atom_state.chemical_symbol for atom_state in self.atoms_state] + def to_ase_atoms(self, logger: 'BoundLogger') -> Optional[ase.Atoms]: """ Generates an ASE Atoms object with the most basic information from the parsed `AtomicCell` @@ -403,7 +414,7 @@ def to_ase_atoms(self, logger: 'BoundLogger') -> Optional[ase.Atoms]: (Optional[ase.Atoms]): The ASE Atoms object with the basic information from the `AtomicCell`. """ # Initialize ase.Atoms object with labels - atoms_labels = [atom_state.chemical_symbol for atom_state in self.atoms_state] + atoms_labels = self.get_chemical_symbols() ase_atoms = ase.Atoms(symbols=atoms_labels) # PBC @@ -436,6 +447,35 @@ def to_ase_atoms(self, logger: 'BoundLogger') -> Optional[ase.Atoms]: return ase_atoms + def from_ase_atoms(self, ase_atoms: ase.Atoms, logger: 'BoundLogger') -> None: + """ + Parses the information from an ASE Atoms object to the `AtomicCell` section. + + Args: + ase_atoms (ase.Atoms): The ASE Atoms object to parse. + logger (BoundLogger): The logger to log messages. + """ + # `AtomsState[*].chemical_symbol` + chemical_symbols = ase_atoms.get_chemical_symbols() + for symbol in chemical_symbols: + atom_state = AtomsState(chemical_symbol=symbol) + self.atoms_state.append(atom_state) + + # `periodic_boundary_conditions` + self.periodic_boundary_conditions = ase_atoms.get_pbc() + + # `lattice_vectors` + cell = ase_atoms.get_cell() + self.lattice_vectors = ase.geometry.complete_cell(cell) * ureg('angstrom') + + # `positions` + positions = ase_atoms.get_positions() + if ( + not positions.tolist() + ): # ASE assigns a shape=(0, 3) array if no positions are found + return None + self.positions = positions * ureg('angstrom') + def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: super().normalize(archive, logger) diff --git a/tests/test_model_system.py b/tests/test_model_system.py index 87ecd33b..a32efb7a 100644 --- a/tests/test_model_system.py +++ b/tests/test_model_system.py @@ -1,5 +1,6 @@ from typing import Optional +import ase import numpy as np import pytest from nomad.datamodel import EntryArchive @@ -171,6 +172,29 @@ def test_is_equal_cell(self, cell_1: Cell, cell_2: Cell, result: bool): """ assert cell_1.is_equal_cell(other=cell_2) == result + @pytest.mark.parametrize( + 'atomic_cell, result', + [ + (AtomicCell(), []), + (AtomicCell(atoms_state=[AtomsState(chemical_symbol='H')]), ['H']), + ( + AtomicCell( + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='Fe'), + AtomsState(chemical_symbol='O'), + ] + ), + ['H', 'Fe', 'O'], + ), + ], + ) + def test_get_chemical_symbols(self, atomic_cell: AtomicCell, result: list[str]): + """ + Test the `get_chemical_symbols` method of `AtomicCell`. + """ + assert atomic_cell.get_chemical_symbols() == result + @pytest.mark.parametrize( 'chemical_symbols, atomic_numbers, formula, lattice_vectors, positions, periodic_boundary_conditions', [ @@ -216,7 +240,7 @@ def test_is_equal_cell(self, cell_1: Cell, cell_2: Cell, result: bool): ), # missing lattice_vectors ], ) - def test_generate_ase_atoms( + def test_to_ase_atoms( self, chemical_symbols: list[str], atomic_numbers: list[int], @@ -258,6 +282,71 @@ def test_generate_ase_atoms( assert (ase_atoms.symbols.numbers == atomic_numbers).all() assert ase_atoms.symbols.get_chemical_formula() == formula + @pytest.mark.parametrize( + 'ase_atoms, chemical_symbols, pbc, lattice_vectors, positions', + [ + ( + ase.Atoms(), + [], + [False, False, False], + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + None, + ), + ( + ase.Atoms(symbols='CO'), + ['C', 'O'], + [False, False, False], + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 0, 0], [0, 0, 0]], + ), + ( + ase.Atoms(symbols='CO', pbc=True), + ['C', 'O'], + [True, True, True], + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 0, 0], [0, 0, 0]], + ), + ( + ase.Atoms(symbols='CO', positions=[[0, 0, 0], [0, 0, 1.1]]), + ['C', 'O'], + [False, False, False], + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 0, 0], [0, 0, 1.1]], + ), + ( + ase.Atoms( + symbols='Au', positions=[[0, 5, 5]], cell=[2.9, 5, 5], pbc=[1, 0, 0] + ), + ['Au'], + [True, False, False], + [[2.9, 0, 0], [0, 5, 0], [0, 0, 5]], + [[0, 5, 5]], + ), + ], + ) + def test_from_ase_atoms( + self, + ase_atoms: ase.Atoms, + chemical_symbols: list[str], + pbc: list[bool], + lattice_vectors: list, + positions: list, + ): + atomic_cell = AtomicCell() + atomic_cell.from_ase_atoms(ase_atoms=ase_atoms, logger=logger) + assert atomic_cell.get_chemical_symbols() == chemical_symbols + assert atomic_cell.periodic_boundary_conditions == pbc + assert ( + atomic_cell.lattice_vectors.to('angstrom').magnitude + == np.array(lattice_vectors) + ).all() + if positions is None: + assert atomic_cell.positions is None + else: + assert ( + atomic_cell.positions.to('angstrom').magnitude == np.array(positions) + ).all() + @pytest.mark.parametrize( 'chemical_symbols, atomic_numbers, lattice_vectors, positions, vectors_results, angles_results, volume', [