Skip to content

Commit

Permalink
Add apw_base_orbital tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ndaelman committed Aug 20, 2024
1 parent bc56ffe commit 4271eb4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 5 deletions.
17 changes: 12 additions & 5 deletions src/nomad_simulations/schema_packages/basis_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,17 +225,23 @@ class APWBaseOrbital(ArchiveSection):
""",
) # TODO: add check non-negative # ? to remove

def _get_open_quantities(self) -> set[str]:
"""Extract the open quantities of the `APWBaseOrbital`."""
return {
k for k, v in self.m_def.all_quantities.items() if self.m_get(v) is not None
}

def _get_lengths(self, quantities: set[str]) -> list[int]:
"""Extract the lengths of the `quantities` contained in the set."""
present_quantities = set(quantities) & self.m_quantities
present_quantities = set(quantities) & self._get_open_quantities()
lengths: list[int] = []
for quant in present_quantities:
length = len(getattr(self, quant))
if length > 0: # empty lists are exempt
lengths.append(length)
return lengths

def _of_equal_length(lengths: list[int]) -> bool:
def _of_equal_length(self, lengths: list[int]) -> bool:
"""Check if all elements in the list are of equal length."""
if len(lengths) == 0:
return True
Expand Down Expand Up @@ -266,9 +272,10 @@ def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None:
if self.n_terms is None:
self.n_terms = new_n_terms
elif self.n_terms != new_n_terms:
logger.error(
f'Inconsistent lengths of `APWBaseOrbital` quantities: {self.m_quantities}. Setting back to `None`.'
)
if logger is not None:
logger.error(
f'Inconsistent lengths of `APWBaseOrbital` quantities: {self.m_def.quantities}. Setting back to `None`.'
)
self.n_terms = None

# enforce differential order constraints
Expand Down
33 changes: 33 additions & 0 deletions tests/test_basis_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,39 @@ def test_full_apw(
numerical_settings = entry.data.model_method[0].numerical_settings
numerical_settings.append(generate_apw(species_def, cutoff=cutoff))

# test structure
assert (
numerical_settings[0].m_to_dict() == refs_apw[ref_index]
) # TODO: add normalization?


@pytest.mark.parametrize(
'ref_n_terms, e, e_n, d_o',
[
(None, [0.0], [0, 0], []), # logically inconsistent
(1, [0.0], [0], [0]), # apw
(2, [0.0, 0.0], [0, 0], [0, 1]), # lapw
],
)
def test_apw_base_orbital(
ref_n_terms: Optional[int], e: list[float], e_n: list[int], d_o: list[int]
):
orb = APWBaseOrbital(
energy_parameter=e,
energy_parameter_n=e_n,
differential_order=d_o,
)

assert orb.get_n_terms() == ref_n_terms


@pytest.mark.parametrize('n_terms, ref_n_terms', [(None, 1), (1, 1), (2, None)])
def test_apw_base_orbital_normalize(n_terms: Optional[int], ref_n_terms: Optional[int]):
orb = APWBaseOrbital(
n_terms=n_terms,
energy_parameter=[0],
energy_parameter_n=[0],
differential_order=[1],
)
orb.normalize(None, logger)
assert orb.n_terms == ref_n_terms

0 comments on commit 4271eb4

Please sign in to comment.