From 560cd5fedcb953e117c136ba6a2823dc1229dc59 Mon Sep 17 00:00:00 2001 From: ndaelman Date: Wed, 14 Aug 2024 20:00:06 +0200 Subject: [PATCH] - Flatten out `APWOrbital` structure - Add tests --- .../schema_packages/basis_set.py | 98 ++++--------------- tests/test_basis_set.py | 55 ++++++++++- 2 files changed, 72 insertions(+), 81 deletions(-) diff --git a/src/nomad_simulations/schema_packages/basis_set.py b/src/nomad_simulations/schema_packages/basis_set.py index 97c1a27e..225fb8f1 100644 --- a/src/nomad_simulations/schema_packages/basis_set.py +++ b/src/nomad_simulations/schema_packages/basis_set.py @@ -159,12 +159,20 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: # ? use naming BSE -class APWWavefunction(ArchiveSection): +class APWBaseOrbital(ArchiveSection): """Abstract base section for (S)(L)APW and local orbital component wavefunctions. It helps defining the interface with `APWLChannel`.""" + n_terms = Quantity( + type=np.int32, + description=""" + Number of terms in the local orbital. + """, + ) + energy_parameter = Quantity( type=np.float64, + shape=['n_terms'], unit='joule', description=""" Reference energy parameter for the augmented plane wave (APW) basis set. @@ -174,6 +182,7 @@ class APWWavefunction(ArchiveSection): energy_parameter_n = Quantity( type=np.int32, + shape=['n_terms'], description=""" Reference number of radial nodes for the augmented plane wave (APW) basis set. This is used to derive the `energy_parameter`. @@ -190,13 +199,19 @@ class APWWavefunction(ArchiveSection): differential_order = Quantity( type=np.int32, + shape=['n_terms'], description=""" Derivative order of the radial wavefunction term. """, ) # TODO: add check non-negative # ? to remove + def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None: + super().normalize(archive, logger) + if (self.differential_order < 0).any(): + logger.error('`APWBaseOrbital.differential_order` must be non-negative.') + -class APWOrbital(APWWavefunction): +class APWOrbital(APWBaseOrbital): """ Implementation of `APWWavefunction` capturing the foundational (S)(L)APW basis sets, all of the form $\sum_{lm} \left[ \sum_o c_{lmo} \frac{\partial}{\partial r}u_l(r, \epsilon_l) \right] Y_lm$. The energy parameter $\epsilon_l$ is always considered fixed during diagonalization, opposed to the original APW formulation. @@ -220,63 +235,7 @@ class APWOrbital(APWWavefunction): """, ) - def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: - super().normalize(archive, logger) - if self.differential_order < 0: - logger.error('`APWLOrbital.differential_order` must be non-negative.') - - -def generate_apworbitals( - name: str, - params: Union[list[pint.Quantity], list[int]], - logger: BoundLogger, - ) -> list[APWOrbital]: # TODO: move to utils - """ - Generate the APW orbitals encoding `apw` or `lapw`. - Accepts as `params` either energies (`energy_parameter`) or integers (`energy_parameter_n`). - """ - # check edge cases - # no parameters - if len(params) == 0: - logger.error('No parameters provided for APW orbital generation.') - return [] - - # incorrect type in the list - if isinstance(params[0], pint.Quantity): - try: - params[0].to('joule') - param_type = 'energy_parameter' - except pint.errors.DimensionalityError: - logger.error('Unknown parameter dimensionality: not energy.') - return [] - elif isinstance(params[0], int): - param_type = 'energy_parameter_n' - else: - logger.error('Unknown parameter type.') - return [] - - # define order - if name == 'apw': - order = 0 - elif name == 'lapw': - order = 1 - else: - raise ValueError('Unknown APW orbital type.') - - # pad out the parameters - if (len_diff := len(params) - order) < 0: - for _ in range(-len_diff): - params.append(params[-1]) - - return [ - APWOrbital( - type=name, - differential_order=i, - **{param_type: param} - ) for i, param in enumerate(params) - ] - -class APWLocalOrbital(APWWavefunction): +class APWLocalOrbital(APWBaseOrbital): """ Implementation of `APWWavefunction` capturing a local orbital extending a foundational APW basis set. Local orbitals allow for flexible additions to an `APWOrbital` specification. @@ -299,25 +258,6 @@ class APWLocalOrbital(APWWavefunction): """, ) - n_terms = Quantity( - type=np.int32, - description=""" - Number of terms in the local orbital. - """, - ) - - energy_parameter = APWWavefunction.energy_parameter.m_copy() - energy_parameter.shape = ['n_terms'] - - energy_parameter_n = APWWavefunction.energy_parameter_n.m_copy() - energy_parameter_n.shape = ['n_terms'] - - energy_status = APWWavefunction.energy_status.m_copy() - energy_status.shape = ['n_terms'] - - differential_order = APWWavefunction.differential_order.m_copy() - differential_order.shape = ['n_terms'] - boundary_order = Quantity( type=np.int32, shape=['n_terms'], @@ -355,7 +295,7 @@ class APWLChannel(BasisSet): """, ) - orbitals = SubSection(sub_section=APWWavefunction.m_def, repeats=True) + orbitals = SubSection(sub_section=APWBaseOrbital.m_def, repeats=True) def _determine_apw(self, logger: BoundLogger) -> dict[str, int]: """ diff --git a/tests/test_basis_set.py b/tests/test_basis_set.py index b3fd817c..15c7b14a 100644 --- a/tests/test_basis_set.py +++ b/tests/test_basis_set.py @@ -1,8 +1,19 @@ +from . import logger from nomad.units import ureg import numpy as np -from . import logger +import pytest +from typing import Optional -from nomad_simulations.schema_packages.basis_set import APWPlaneWaveBasisSet +from nomad_simulations.schema_packages.basis_set import ( + APWBaseOrbital, + APWOrbital, + APWLocalOrbital, + APWLChannel, + APWPlaneWaveBasisSet, + BasisSet, + BasisSetContainer, + MuffinTinRegion, +) def test_cutoff(): @@ -34,3 +45,43 @@ def test_cutoff_failure(): pw = APWPlaneWaveBasisSet(cutoff_energy=500 * ureg('eV'), cutoff_fractional=1) pw.set_cutoff_fractional(ureg.angstrom, logger) assert pw.cutoff_fractional == 1 + + +@pytest.mark.skip(reason="This function is not meant to be tested directly") +def generate_apw( + species: dict[str, int | APWBaseOrbital], + cutoff: Optional[float] = None +) -> BasisSetContainer: + """ + Generate a mock APW basis set with the following structure: + . + ├── plane-wave basis set + └── muffin-tin regions + └── l-channels + ├── (orbitals) + │ └── wavefunctions + └── local orbitals + """ + basis_set_components: list[BasisSet] = [] + if cutoff is not None: + pw = APWPlaneWaveBasisSet(cutoff_energy=cutoff) + basis_set_components.append(pw) + + mts: list[MuffinTinRegion] = [] + for sp in species: + l_max = sp['l_max'] + mt = MuffinTinRegion( + radius=sp['r'], + l_max=l_max, + l_channels=[ + APWLChannel( + l=l, + orbitals=[APWOrbital(type=orb) for orb in sp['orb_type']] +\ + [APWLocalOrbital(type=lo) for lo in sp['lo_type']], + ) for l in range(l_max) + ] + ) + mts.append(mt) + basis_set_components.append(mts) + + return BasisSetContainer(basis_set_components=basis_set_components)