Skip to content

Commit

Permalink
Add more APW tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ndaelman committed Aug 20, 2024
1 parent 4271eb4 commit b99c577
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 19 deletions.
26 changes: 17 additions & 9 deletions src/nomad_simulations/schema_packages/basis_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,7 @@ def _get_open_quantities(self) -> set[str]:
def _get_lengths(self, quantities: set[str]) -> list[int]:
"""Extract the lengths of the `quantities` contained in the set."""
present_quantities = set(quantities) & self._get_open_quantities()
lengths: list[int] = []
for quant in present_quantities:
length = len(getattr(self, quant))
if length > 0: # empty lists are exempt
lengths.append(length)
return lengths
return [len(getattr(self, quant)) for quant in present_quantities]

def _of_equal_length(self, lengths: list[int]) -> bool:
"""Check if all elements in the list are of equal length."""
Expand Down Expand Up @@ -320,11 +315,11 @@ def n_terms_to_type(self, n_terms: Optional[int]) -> Optional[str]:
"""
Set the type of the APW orbital based on the differential order.
"""
if n_terms is None:
if n_terms is None or n_terms == 0:
return None
if n_terms == 0:
if n_terms == 1:
return 'apw'
elif n_terms == 1:
elif n_terms == 2:
return 'lapw'
else:
return 'slapw'
Expand Down Expand Up @@ -386,6 +381,19 @@ def get_n_terms(self) -> Optional[int]:
}
)

def bo_terms_to_type(self, bo_terms: Optional[int]) -> Optional[str]:
"""
Set the type of the local orbital based on the boundary order.
""" # ? include differential_order
if bo_terms is None or len(bo_terms) == 0:
return None
if sorted(bo_terms) == [0, 1]:
return 'lo'
elif sorted(bo_terms) == [0, 0, 1]: # ! double-check
return 'LO'
else:
return 'custom'

@check_normalized
def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
super().normalize(archive, logger)
Expand Down
51 changes: 41 additions & 10 deletions tests/test_basis_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,22 +100,20 @@ def test_full_apw(


@pytest.mark.parametrize(
'ref_n_terms, e, e_n, d_o',
'ref_n_terms, e, d_o',
[
(None, [0.0], [0, 0], []), # logically inconsistent
(1, [0.0], [0], [0]), # apw
(2, [0.0, 0.0], [0, 0], [0, 1]), # lapw
(None, None, None), # unset
(0, [], []), # empty
(None, [0.0], []), # logically inconsistent
(1, [0.0], [0]), # apw
(2, 2 * [0.0], [0, 1]), # lapw
],
)
def test_apw_base_orbital(
ref_n_terms: Optional[int], e: list[float], e_n: list[int], d_o: list[int]
):
def test_apw_base_orbital(ref_n_terms: Optional[int], e: list[float], d_o: list[int]):
orb = APWBaseOrbital(
energy_parameter=e,
energy_parameter_n=e_n,
differential_order=d_o,
)

assert orb.get_n_terms() == ref_n_terms


Expand All @@ -124,8 +122,41 @@ def test_apw_base_orbital_normalize(n_terms: Optional[int], ref_n_terms: Optiona
orb = APWBaseOrbital(
n_terms=n_terms,
energy_parameter=[0],
energy_parameter_n=[0],
differential_order=[1],
)
orb.normalize(None, logger)
assert orb.n_terms == ref_n_terms


@pytest.mark.parametrize(
'ref_type, n_terms',
[(None, None), (None, 0), ('apw', 1), ('lapw', 2), ('slapw', 3)],
)
def test_apw_orbital(ref_type: Optional[str], n_terms: Optional[int]):
orb = APWOrbital(n_terms=n_terms)
assert orb.n_terms_to_type(orb.n_terms) == ref_type


@pytest.mark.parametrize(
'ref_n_terms, ref_type, e, d_o, b_o',
[
(None, None, [0.0], [], []), # logically inconsistent
(1, 'custom', [0.0], [0], [0]), # custom
(2, 'lo', 2 * [0.0], [0, 1], [0, 1]), # lo
(3, 'LO', 3 * [0.0], [0, 1, 0], [0, 1, 0]), # LO
],
)
def test_apw_local_orbital(
ref_n_terms: Optional[int],
ref_type: str,
e: list[float],
d_o: list[int],
b_o: list[int],
):
orb = APWLocalOrbital(
energy_parameter=e,
differential_order=d_o,
boundary_order=d_o,
)
assert orb.get_n_terms() == ref_n_terms
assert orb.bo_terms_to_type(orb.boundary_order) == ref_type

0 comments on commit b99c577

Please sign in to comment.