diff --git a/src/nomad_simulations/schema_packages/basis_set.py b/src/nomad_simulations/schema_packages/basis_set.py index d6e0875d..51f302a0 100644 --- a/src/nomad_simulations/schema_packages/basis_set.py +++ b/src/nomad_simulations/schema_packages/basis_set.py @@ -212,10 +212,43 @@ class APWBaseOrbital(ArchiveSection): """, ) # TODO: add check non-negative # ? to remove + def _get_lengths(self, quantities: set[str]) -> list[int]: + """Extract the lengths of the `quantities` contained in the set.""" + present_quantities = set(quantities) & self.m_quantities + lengths: list[int] = [] + for quant in present_quantities: + length = len(getattr(self, quant)) + if length > 0: # empty lists are exempt + lengths.append(length) + return lengths + + def _of_equal_length(lengths: list[int]) -> bool: + """Check if all elements in the list are of equal length.""" + if len(lengths) == 0: + return True + else: + ref_length = lengths[0] + return all(length == ref_length for length in lengths) + + def get_n_terms(self) -> Optional[int]: + """Determine the value of `n_terms` based on the lengths of the representative quantities.""" + rep_quant = {'energy_parameter', 'energy_parameter_n', 'differential_order'} + lengths = self._get_lengths(rep_quant) + if not self._of_equal_length(lengths) or len(lengths) == 0: + return None + else: + return lengths[0] + def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None: super().normalize(archive, logger) - if (self.differential_order < 0).any(): - logger.error('`APWBaseOrbital.differential_order` must be non-negative.') + new_n_terms = self.get_n_terms() + if self.n_terms is None: + self.n_terms = new_n_terms + elif self.n_terms != new_n_terms: + logger.error( + f'Inconsistent lengths of `APWBaseOrbital` quantities: {self.m_quantities}. Setting back to `None`.' + ) + self.n_terms = None class APWOrbital(APWBaseOrbital): @@ -228,7 +261,7 @@ class APWOrbital(APWBaseOrbital): """ type = Quantity( - type=MEnum('apw', 'lapw', 'slapw', 'spherical_dirac'), + type=MEnum('apw', 'lapw', 'slapw'), # ? where to put 'spherical_dirac' description=""" Type of augmentation contribution. Abbreviations stand for: | name | description | radial product | @@ -242,6 +275,25 @@ class APWOrbital(APWBaseOrbital): """, ) + def get_type(self, logger: BoundLogger) -> Optional[str]: + """ + Set the type of the APW orbital based on the differential order. + """ + if self.n_terms is None: + logger.error('`APWOrbital.n_terms` must be defined before setting the type.') + return None + if self.n_terms == 0: + return 'apw' + elif self.n_terms == 1: + return 'lapw' + else: + return 'slapw' + + def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None: + super().normalize(archive, logger) + if self.type is None: + self.type = self.get_type(logger) + class APWLocalOrbital(APWBaseOrbital): """ @@ -274,6 +326,7 @@ class APWLocalOrbital(APWBaseOrbital): """, ) + def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: super().normalize(archive, logger) if np.any(np.isneginf(self.differential_orders)): @@ -310,19 +363,19 @@ def _determine_apw(self, logger: BoundLogger) -> dict[str, int]: Produce a count of the APW components in the l-channel. """ count = {'apw': 0, 'lapw': 0, 'slapw': 0, 'lo': 0, 'other': 0} - order_map = {0: 'apw', 1: 'lapw', 2: 'slapw'} # TODO: add dirac? + order_map = {0: 'apw', 1: 'lapw'} # TODO: add dirac? for orb in self.orbitals: err_msg = f'Unknown APW orbital type {orb.type} and order {orb.differential_order}.' if isinstance(orb, APWOrbital): - if orb.type in order_map.values(): + if orb.type.lower() in order_map.values(): count[orb.type] += 1 - elif orb.differential_order in order_map: + elif orb.n_terms in order_map: count[order_map[orb.differential_order]] += 1 - elif orb.differential_order > 2: + elif orb.n_terms > 2: count['slapw'] += 1 else: - logger.warning(err_msg) + logger.warning(err_msg) # TODO: rewrite using `type normalization` elif isinstance(orb, APWLocalOrbital): count['lo'] += 1 else: