Skip to content

Commit

Permalink
Added get_chemical_symbols method and testing
Browse files Browse the repository at this point in the history
Added from_ase_atoms method and testing
  • Loading branch information
JosePizarro3 committed Sep 27, 2024
1 parent 0b6c6cd commit 8f65497
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 2 deletions.
42 changes: 41 additions & 1 deletion src/nomad_simulations/schema_packages/model_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,17 @@ def is_equal_cell(self, other) -> bool:
return False
return True

def get_chemical_symbols(self) -> list:
"""
Get the chemical symbols of the atoms in the atomic cell. These are defined on `atoms_state[*].chemical_symbol`.
Returns:
list: The list of chemical symbols of the atoms in the atomic cell.
"""
if not self.atoms_state:
return []
return [atom_state.chemical_symbol for atom_state in self.atoms_state]

def to_ase_atoms(self, logger: 'BoundLogger') -> Optional[ase.Atoms]:
"""
Generates an ASE Atoms object with the most basic information from the parsed `AtomicCell`
Expand All @@ -403,7 +414,7 @@ def to_ase_atoms(self, logger: 'BoundLogger') -> Optional[ase.Atoms]:
(Optional[ase.Atoms]): The ASE Atoms object with the basic information from the `AtomicCell`.
"""
# Initialize ase.Atoms object with labels
atoms_labels = [atom_state.chemical_symbol for atom_state in self.atoms_state]
atoms_labels = self.get_chemical_symbols()
ase_atoms = ase.Atoms(symbols=atoms_labels)

# PBC
Expand Down Expand Up @@ -436,6 +447,35 @@ def to_ase_atoms(self, logger: 'BoundLogger') -> Optional[ase.Atoms]:

return ase_atoms

def from_ase_atoms(self, ase_atoms: ase.Atoms, logger: 'BoundLogger') -> None:
"""
Parses the information from an ASE Atoms object to the `AtomicCell` section.
Args:
ase_atoms (ase.Atoms): The ASE Atoms object to parse.
logger (BoundLogger): The logger to log messages.
"""
# `AtomsState[*].chemical_symbol`
chemical_symbols = ase_atoms.get_chemical_symbols()
for symbol in chemical_symbols:
atom_state = AtomsState(chemical_symbol=symbol)
self.atoms_state.append(atom_state)

# `periodic_boundary_conditions`
self.periodic_boundary_conditions = ase_atoms.get_pbc()

# `lattice_vectors`
cell = ase_atoms.get_cell()
self.lattice_vectors = ase.geometry.complete_cell(cell) * ureg('angstrom')

# `positions`
positions = ase_atoms.get_positions()
if (
not positions.tolist()
): # ASE assigns a shape=(0, 3) array if no positions are found
return None
self.positions = positions * ureg('angstrom')

def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
super().normalize(archive, logger)

Expand Down
91 changes: 90 additions & 1 deletion tests/test_model_system.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

import ase
import numpy as np
import pytest
from nomad.datamodel import EntryArchive
Expand Down Expand Up @@ -171,6 +172,29 @@ def test_is_equal_cell(self, cell_1: Cell, cell_2: Cell, result: bool):
"""
assert cell_1.is_equal_cell(other=cell_2) == result

@pytest.mark.parametrize(
'atomic_cell, result',
[
(AtomicCell(), []),
(AtomicCell(atoms_state=[AtomsState(chemical_symbol='H')]), ['H']),
(
AtomicCell(
atoms_state=[
AtomsState(chemical_symbol='H'),
AtomsState(chemical_symbol='Fe'),
AtomsState(chemical_symbol='O'),
]
),
['H', 'Fe', 'O'],
),
],
)
def test_get_chemical_symbols(self, atomic_cell: AtomicCell, result: list[str]):
"""
Test the `get_chemical_symbols` method of `AtomicCell`.
"""
assert atomic_cell.get_chemical_symbols() == result

@pytest.mark.parametrize(
'chemical_symbols, atomic_numbers, formula, lattice_vectors, positions, periodic_boundary_conditions',
[
Expand Down Expand Up @@ -216,7 +240,7 @@ def test_is_equal_cell(self, cell_1: Cell, cell_2: Cell, result: bool):
), # missing lattice_vectors
],
)
def test_generate_ase_atoms(
def test_to_ase_atoms(
self,
chemical_symbols: list[str],
atomic_numbers: list[int],
Expand Down Expand Up @@ -258,6 +282,71 @@ def test_generate_ase_atoms(
assert (ase_atoms.symbols.numbers == atomic_numbers).all()
assert ase_atoms.symbols.get_chemical_formula() == formula

@pytest.mark.parametrize(
'ase_atoms, chemical_symbols, pbc, lattice_vectors, positions',
[
(
ase.Atoms(),
[],
[False, False, False],
[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
None,
),
(
ase.Atoms(symbols='CO'),
['C', 'O'],
[False, False, False],
[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
[[0, 0, 0], [0, 0, 0]],
),
(
ase.Atoms(symbols='CO', pbc=True),
['C', 'O'],
[True, True, True],
[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
[[0, 0, 0], [0, 0, 0]],
),
(
ase.Atoms(symbols='CO', positions=[[0, 0, 0], [0, 0, 1.1]]),
['C', 'O'],
[False, False, False],
[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
[[0, 0, 0], [0, 0, 1.1]],
),
(
ase.Atoms(
symbols='Au', positions=[[0, 5, 5]], cell=[2.9, 5, 5], pbc=[1, 0, 0]
),
['Au'],
[True, False, False],
[[2.9, 0, 0], [0, 5, 0], [0, 0, 5]],
[[0, 5, 5]],
),
],
)
def test_from_ase_atoms(
self,
ase_atoms: ase.Atoms,
chemical_symbols: list[str],
pbc: list[bool],
lattice_vectors: list,
positions: list,
):
atomic_cell = AtomicCell()
atomic_cell.from_ase_atoms(ase_atoms=ase_atoms, logger=logger)
assert atomic_cell.get_chemical_symbols() == chemical_symbols
assert atomic_cell.periodic_boundary_conditions == pbc
assert (
atomic_cell.lattice_vectors.to('angstrom').magnitude
== np.array(lattice_vectors)
).all()
if positions is None:
assert atomic_cell.positions is None
else:
assert (
atomic_cell.positions.to('angstrom').magnitude == np.array(positions)
).all()

@pytest.mark.parametrize(
'chemical_symbols, atomic_numbers, lattice_vectors, positions, vectors_results, angles_results, volume',
[
Expand Down

1 comment on commit 8f65497

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src/nomad_simulations
   __init__.py4250%3–4
   _version.py11282%5–6
src/nomad_simulations/schema_packages
   __init__.py15287%39–41
   atoms_state.py1902189%13–15, 201–204, 228, 283–284, 352–353, 355, 537, 549–550, 611–615, 630–634, 641
   basis_set.py2402888%8–9, 122–133, 172–185, 208, 391–395, 417–418, 462–465, 584, 615, 617
   general.py89891%4–7, 121, 185, 295–296, 306
   model_method.py2697871%10–12, 171–174, 177–184, 276–277, 297, 318–339, 355–381, 384–401, 587, 780, 791, 833–840, 878, 897, 977, 1034, 1109, 1223
   model_system.py3122393%25–27, 378, 612–615, 662–669, 843–844, 1065–1069, 1075–1076, 1084–1085, 1090, 1113
   numerical_settings.py2596176%12–14, 217, 219–220, 223–226, 230–231, 238–241, 250–253, 257–260, 262–265, 270–273, 279–282, 469–496, 571, 606–609, 633, 636, 681, 683–686, 690, 694, 741, 745–766, 821–822, 889
   outputs.py1201092%9–10, 252–255, 295–298, 323, 325, 362, 381
   physical_property.py102793%20–22, 202, 331–333
   variables.py861286%8–10, 98, 121, 145, 167, 189, 211, 233, 256, 276
src/nomad_simulations/schema_packages/properties
   band_gap.py51590%8–10, 135–136
   band_structure.py1232580%9–11, 232–265, 278, 285, 321–322, 325, 372–373, 378
   energies.py42979%7–9, 36, 57, 82, 103, 119, 134
   fermi_surface.py17476%7–9, 40
   forces.py22673%7–9, 36, 56, 79
   greens_function.py991387%7–9, 210–211, 214, 235–236, 239, 260–261, 264, 400
   hopping_matrix.py29583%7–9, 58, 94
   permittivity.py48883%7–9, 97–105
   spectral_profile.py26012851%9–11, 57–60, 95–98, 199–300, 356–368, 393–396, 416, 421–424, 466–502, 526, 573–576, 592–593, 598–604
   thermodynamics.py752764%7–9, 35, 56, 72, 81, 90, 101, 110, 137, 147, 157, 172–174, 177, 193, 213–215, 218, 234, 254–256, 259
src/nomad_simulations/schema_packages/utils
   utils.py701480%8–11, 65–74, 83–84, 89, 92
TOTAL254449880% 

Tests Skipped Failures Errors Time
409 0 💤 0 ❌ 0 🔥 5.769s ⏱️

Please sign in to comment.