diff --git a/src/nomad_simulations/schema_packages/basis_set.py b/src/nomad_simulations/schema_packages/basis_set.py index c66df084..54b19cc6 100644 --- a/src/nomad_simulations/schema_packages/basis_set.py +++ b/src/nomad_simulations/schema_packages/basis_set.py @@ -225,9 +225,15 @@ class APWBaseOrbital(ArchiveSection): """, ) # TODO: add check non-negative # ? to remove + def _get_open_quantities(self) -> set[str]: + """Extract the open quantities of the `APWBaseOrbital`.""" + return { + k for k, v in self.m_def.all_quantities.items() if self.m_get(v) is not None + } + def _get_lengths(self, quantities: set[str]) -> list[int]: """Extract the lengths of the `quantities` contained in the set.""" - present_quantities = set(quantities) & self.m_quantities + present_quantities = set(quantities) & self._get_open_quantities() lengths: list[int] = [] for quant in present_quantities: length = len(getattr(self, quant)) @@ -235,7 +241,7 @@ def _get_lengths(self, quantities: set[str]) -> list[int]: lengths.append(length) return lengths - def _of_equal_length(lengths: list[int]) -> bool: + def _of_equal_length(self, lengths: list[int]) -> bool: """Check if all elements in the list are of equal length.""" if len(lengths) == 0: return True @@ -266,9 +272,10 @@ def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None: if self.n_terms is None: self.n_terms = new_n_terms elif self.n_terms != new_n_terms: - logger.error( - f'Inconsistent lengths of `APWBaseOrbital` quantities: {self.m_quantities}. Setting back to `None`.' - ) + if logger is not None: + logger.error( + f'Inconsistent lengths of `APWBaseOrbital` quantities: {self.m_def.quantities}. Setting back to `None`.' + ) self.n_terms = None # enforce differential order constraints diff --git a/tests/test_basis_set.py b/tests/test_basis_set.py index 51f02302..dc7ee6ec 100644 --- a/tests/test_basis_set.py +++ b/tests/test_basis_set.py @@ -93,6 +93,39 @@ def test_full_apw( numerical_settings = entry.data.model_method[0].numerical_settings numerical_settings.append(generate_apw(species_def, cutoff=cutoff)) + # test structure assert ( numerical_settings[0].m_to_dict() == refs_apw[ref_index] ) # TODO: add normalization? + + +@pytest.mark.parametrize( + 'ref_n_terms, e, e_n, d_o', + [ + (None, [0.0], [0, 0], []), # logically inconsistent + (1, [0.0], [0], [0]), # apw + (2, [0.0, 0.0], [0, 0], [0, 1]), # lapw + ], +) +def test_apw_base_orbital( + ref_n_terms: Optional[int], e: list[float], e_n: list[int], d_o: list[int] +): + orb = APWBaseOrbital( + energy_parameter=e, + energy_parameter_n=e_n, + differential_order=d_o, + ) + + assert orb.get_n_terms() == ref_n_terms + + +@pytest.mark.parametrize('n_terms, ref_n_terms', [(None, 1), (1, 1), (2, None)]) +def test_apw_base_orbital_normalize(n_terms: Optional[int], ref_n_terms: Optional[int]): + orb = APWBaseOrbital( + n_terms=n_terms, + energy_parameter=[0], + energy_parameter_n=[0], + differential_order=[1], + ) + orb.normalize(None, logger) + assert orb.n_terms == ref_n_terms