From bc56ffe1b8aa62f37d43e7ce22be30f127fe5f74 Mon Sep 17 00:00:00 2001 From: ndaelman Date: Tue, 20 Aug 2024 12:56:01 +0200 Subject: [PATCH] Rework `_determine_apw` in a more OOP + schema fashion --- .../schema_packages/basis_set.py | 185 ++++++++++++------ 1 file changed, 126 insertions(+), 59 deletions(-) diff --git a/src/nomad_simulations/schema_packages/basis_set.py b/src/nomad_simulations/schema_packages/basis_set.py index 51f302a0..c66df084 100644 --- a/src/nomad_simulations/schema_packages/basis_set.py +++ b/src/nomad_simulations/schema_packages/basis_set.py @@ -8,10 +8,7 @@ import pint from scipy import constants as const from structlog.stdlib import BoundLogger -from typing import TYPE_CHECKING, Optional, Any - -if TYPE_CHECKING: - from nomad.metainfo import Context, Section +from typing import Optional, Any, Callable from nomad_simulations.schema_packages.atoms_state import AtomsState from nomad_simulations.schema_packages.numerical_settings import ( @@ -21,6 +18,20 @@ from nomad_simulations.schema_packages.properties.energies import EnergyContribution +def check_normalized(func: Callable): + """ + Decorator to check if the section is already normalized. + """ + + def wrapper(self, archive: EntryArchive, logger: BoundLogger) -> None: + if self._is_normalized: + return None + func(self, archive, logger) + self._is_normalized = True + + return wrapper + + class BasisSet(NumericalSettings): """A type section denoting a basis set component of a simulation. Should be used as a base section for more specialized sections. @@ -167,8 +178,10 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: class APWBaseOrbital(ArchiveSection): - """Abstract base section for (S)(L)APW and local orbital component wavefunctions. - It helps defining the interface with `APWLChannel`.""" + """ + Abstract base section for (S)(L)APW and local orbital component wavefunctions. + It helps defining the interface with `APWLChannel`. + """ n_terms = Quantity( type=np.int32, @@ -230,10 +243,16 @@ def _of_equal_length(lengths: list[int]) -> bool: ref_length = lengths[0] return all(length == ref_length for length in lengths) - def get_n_terms(self) -> Optional[int]: + def get_n_terms( + self, + representative_quantities: set[str] = { + 'energy_parameter', + 'energy_parameter_n', + 'differential_order', + }, + ) -> Optional[int]: """Determine the value of `n_terms` based on the lengths of the representative quantities.""" - rep_quant = {'energy_parameter', 'energy_parameter_n', 'differential_order'} - lengths = self._get_lengths(rep_quant) + lengths = self._get_lengths(representative_quantities) if not self._of_equal_length(lengths) or len(lengths) == 0: return None else: @@ -241,6 +260,8 @@ def get_n_terms(self) -> Optional[int]: def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None: super().normalize(archive, logger) + + # enforce quantity length (will be used for type assignment) new_n_terms = self.get_n_terms() if self.n_terms is None: self.n_terms = new_n_terms @@ -250,6 +271,19 @@ def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None: ) self.n_terms = None + # enforce differential order constraints + if np.any(np.isneginf(self.differential_order)): + logger.error( + '`APWBaseOrbital.differential_order` must be completely non-negative.' + ) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # it's hard to enforce commutative diagrams between `_determine_apw` and `normalize` + # instead, make all `_determine_apw` soft-coupled and dependent on the normalized state + # leverage normalize ∘ normalize = normalize + self._is_normalized = False + class APWOrbital(APWBaseOrbital): """ @@ -275,24 +309,32 @@ class APWOrbital(APWBaseOrbital): """, ) - def get_type(self, logger: BoundLogger) -> Optional[str]: + def n_terms_to_type(self, n_terms: Optional[int]) -> Optional[str]: """ Set the type of the APW orbital based on the differential order. """ - if self.n_terms is None: - logger.error('`APWOrbital.n_terms` must be defined before setting the type.') + if n_terms is None: return None - if self.n_terms == 0: + if n_terms == 0: return 'apw' - elif self.n_terms == 1: + elif n_terms == 1: return 'lapw' else: return 'slapw' + @check_normalized def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None: super().normalize(archive, logger) + # assign a APW orbital type + # this snippet works of the previous normalization + new_type = self.n_terms_to_type(self.n_terms) if self.type is None: - self.type = self.get_type(logger) + self.type = new_type + elif self.type != new_type: + logger.error( + f'Inconsistent `APWOrbital` type: {self.type}. Setting back to `None`.' + ) + self.type = None class APWLocalOrbital(APWBaseOrbital): @@ -326,13 +368,24 @@ class APWLocalOrbital(APWBaseOrbital): """, ) + def get_n_terms(self) -> Optional[int]: + """Determine the value of `n_terms` based on the lengths of the representative quantities.""" + return super().get_n_terms( + representative_quantities={ + 'energy_parameter', + 'energy_parameter_n', + 'differential_order', + 'boundary_order', + } + ) + @check_normalized def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: super().normalize(archive, logger) - if np.any(np.isneginf(self.differential_orders)): - logger.error('`APWLOrbital.differential_order` must be non-negative.') - if np.any(np.isneginf(self.boundary_orders)): - logger.error('`APWLOrbital.boundary_orders` must be non-negative.') + if np.any(np.isneginf(self.boundary_order)): + logger.error( + '`APWLOrbital.boundary_order` must be completely non-negative.' + ) class APWLChannel(BasisSet): @@ -349,7 +402,7 @@ class APWLChannel(BasisSet): """, ) - n_wavefunctions = Quantity( + n_orbitals = Quantity( type=np.int32, description=""" Number of wavefunctions in the l-channel, i.e. $(2l + 1) n_orbitals$. @@ -358,33 +411,36 @@ class APWLChannel(BasisSet): orbitals = SubSection(sub_section=APWBaseOrbital.m_def, repeats=True) - def _determine_apw(self, logger: BoundLogger) -> dict[str, int]: + def _determine_apw(self) -> dict[str, int]: """ Produce a count of the APW components in the l-channel. + Invokes `normalize` on `orbitals`. """ - count = {'apw': 0, 'lapw': 0, 'slapw': 0, 'lo': 0, 'other': 0} - order_map = {0: 'apw', 1: 'lapw'} # TODO: add dirac? + for orb in self.orbitals: + orb.normalize(None, None) + type_count = {'apw': 0, 'lapw': 0, 'slapw': 0, 'lo': 0, 'other': 0} for orb in self.orbitals: - err_msg = f'Unknown APW orbital type {orb.type} and order {orb.differential_order}.' - if isinstance(orb, APWOrbital): - if orb.type.lower() in order_map.values(): - count[orb.type] += 1 - elif orb.n_terms in order_map: - count[order_map[orb.differential_order]] += 1 - elif orb.n_terms > 2: - count['slapw'] += 1 - else: - logger.warning(err_msg) # TODO: rewrite using `type normalization` + if isinstance(orb, APWOrbital) and orb.type.lower() in type_count.keys(): + type_count[orb.type] += 1 elif isinstance(orb, APWLocalOrbital): - count['lo'] += 1 + type_count['lo'] += 1 else: - logger.warning(err_msg) - return count + type_count['other'] += 1 # other de facto operates as a catch-all + return type_count + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._is_normalized = False + + @check_normalized def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None: super(BasisSet).normalize(archive, logger) - self.n_wavefunctions = len(self.orbitals) * (2 * self.name + 1) + self.n_orbitals = len(self.orbitals) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._is_normalized = False class MuffinTinRegion(BasisSet, Mesh): @@ -417,14 +473,26 @@ class MuffinTinRegion(BasisSet, Mesh): l_channels = SubSection(sub_section=APWLChannel.m_def, repeats=True) - def _determine_apw(self, logger: BoundLogger) -> dict[str, int]: + def _determine_apw(self) -> dict[str, int]: """ Aggregate the APW component count in the muffin-tin region. + Invokes `normalize` on `l_channels`. """ - count = {'apw': 0, 'lapw': 0, 'slapw': 0, 'lo': 0, 'other': 0} - for channel in self.l_channels: - count.update(channel._determine_apw(logger)) - return count + for l_channel in self.l_channels: + l_channel.normalize(None, None) + + type_count: dict[str, int] + if len(self.l_channels) > 0: + l_channel = self.l_channels[0] + # dynamically determine `type_count` structure + type_count = l_channel._determine_apw(l_channel.orbitals) + for channel in self.l_channels[1:]: + type_count.update(channel._determine_apw()) + return type_count + + @check_normalized + def normalize(self, archive, logger): + super().normalize(archive, logger) class BasisSetContainer(NumericalSettings): @@ -450,26 +518,28 @@ class BasisSetContainer(NumericalSettings): basis_set_components = SubSection(sub_section=BasisSet.m_def, repeats=True) - def _determine_apw(self, logger: BoundLogger) -> str: + def _determine_apw(self) -> Optional[str]: """ - Derive the basis set name for a (S)(L)APW case. + Derive the basis set name for a (S)(L)APW case, including local orbitals. + Invokes `normalize` on `basis_set_components`. """ answer, has_plane_wave = '', False for comp in self.basis_set_components: if isinstance(comp, MuffinTinRegion): - count = comp._determine_apw(logger) - if count['apw'] + count['lapw'] + count['slapw'] > 0: - if count['slapw'] > 0: + comp.normalize(None, None) + type_count = comp._determine_apw() + if sum([type_count[i] for i in ('apw', 'lapw', 'slapw')]) > 0: + if type_count['slapw'] > 0: answer += 'slapw'.upper() - elif count['lapw'] > 0: + elif type_count['lapw'] > 0: answer += 'lapw'.upper() - elif count['apw'] > 0: + elif type_count['apw'] > 0: answer += 'apw'.upper() - if count['lo'] > 0: + if type_count['lo'] > 0: answer += '+lo' elif isinstance(comp, PlaneWaveBasisSet): has_plane_wave = True - return answer if has_plane_wave else '' + return answer if has_plane_wave else None def _smallest_mt(self) -> MuffinTinRegion: """ @@ -491,14 +561,11 @@ def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None: ] if len(pws) > 1: logger.warning('Multiple plane-wave basis sets found were found.') - if name := self._determine_apw(logger): - self.name = name # TODO: set name based on basis sets - try: - pws[0].set_cutoff_fractional(self._smallest_mt(), logger) - except (IndexError, AttributeError): - logger.error( - 'Expected a `APWPlaneWaveBasisSet` instance, but found none.' - ) + self.name = self._determine_apw(logger) + try: + pws[0].set_cutoff_fractional(self._smallest_mt(), logger) + except (IndexError, AttributeError): + logger.error('Expected a `APWPlaneWaveBasisSet` instance, but found none.') def generate_apw(