Skip to content

Commit

Permalink
- Flatten out APWOrbital structure
Browse files Browse the repository at this point in the history
- Add tests
  • Loading branch information
ndaelman committed Aug 14, 2024
1 parent b32ec5b commit 560cd5f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 81 deletions.
98 changes: 19 additions & 79 deletions src/nomad_simulations/schema_packages/basis_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,20 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
# ? use naming BSE


class APWWavefunction(ArchiveSection):
class APWBaseOrbital(ArchiveSection):
"""Abstract base section for (S)(L)APW and local orbital component wavefunctions.
It helps defining the interface with `APWLChannel`."""

n_terms = Quantity(
type=np.int32,
description="""
Number of terms in the local orbital.
""",
)

energy_parameter = Quantity(
type=np.float64,
shape=['n_terms'],
unit='joule',
description="""
Reference energy parameter for the augmented plane wave (APW) basis set.
Expand All @@ -174,6 +182,7 @@ class APWWavefunction(ArchiveSection):

energy_parameter_n = Quantity(
type=np.int32,
shape=['n_terms'],
description="""
Reference number of radial nodes for the augmented plane wave (APW) basis set.
This is used to derive the `energy_parameter`.
Expand All @@ -190,13 +199,19 @@ class APWWavefunction(ArchiveSection):

differential_order = Quantity(
type=np.int32,
shape=['n_terms'],
description="""
Derivative order of the radial wavefunction term.
""",
) # TODO: add check non-negative # ? to remove

def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None:
super().normalize(archive, logger)
if (self.differential_order < 0).any():
logger.error('`APWBaseOrbital.differential_order` must be non-negative.')


class APWOrbital(APWWavefunction):
class APWOrbital(APWBaseOrbital):
"""
Implementation of `APWWavefunction` capturing the foundational (S)(L)APW basis sets, all of the form $\sum_{lm} \left[ \sum_o c_{lmo} \frac{\partial}{\partial r}u_l(r, \epsilon_l) \right] Y_lm$.
The energy parameter $\epsilon_l$ is always considered fixed during diagonalization, opposed to the original APW formulation.
Expand All @@ -220,63 +235,7 @@ class APWOrbital(APWWavefunction):
""",
)

def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
super().normalize(archive, logger)
if self.differential_order < 0:
logger.error('`APWLOrbital.differential_order` must be non-negative.')


def generate_apworbitals(
name: str,
params: Union[list[pint.Quantity], list[int]],
logger: BoundLogger,
) -> list[APWOrbital]: # TODO: move to utils
"""
Generate the APW orbitals encoding `apw` or `lapw`.
Accepts as `params` either energies (`energy_parameter`) or integers (`energy_parameter_n`).
"""
# check edge cases
# no parameters
if len(params) == 0:
logger.error('No parameters provided for APW orbital generation.')
return []

# incorrect type in the list
if isinstance(params[0], pint.Quantity):
try:
params[0].to('joule')
param_type = 'energy_parameter'
except pint.errors.DimensionalityError:
logger.error('Unknown parameter dimensionality: not energy.')
return []
elif isinstance(params[0], int):
param_type = 'energy_parameter_n'
else:
logger.error('Unknown parameter type.')
return []

# define order
if name == 'apw':
order = 0
elif name == 'lapw':
order = 1
else:
raise ValueError('Unknown APW orbital type.')

# pad out the parameters
if (len_diff := len(params) - order) < 0:
for _ in range(-len_diff):
params.append(params[-1])

return [
APWOrbital(
type=name,
differential_order=i,
**{param_type: param}
) for i, param in enumerate(params)
]

class APWLocalOrbital(APWWavefunction):
class APWLocalOrbital(APWBaseOrbital):
"""
Implementation of `APWWavefunction` capturing a local orbital extending a foundational APW basis set.
Local orbitals allow for flexible additions to an `APWOrbital` specification.
Expand All @@ -299,25 +258,6 @@ class APWLocalOrbital(APWWavefunction):
""",
)

n_terms = Quantity(
type=np.int32,
description="""
Number of terms in the local orbital.
""",
)

energy_parameter = APWWavefunction.energy_parameter.m_copy()
energy_parameter.shape = ['n_terms']

energy_parameter_n = APWWavefunction.energy_parameter_n.m_copy()
energy_parameter_n.shape = ['n_terms']

energy_status = APWWavefunction.energy_status.m_copy()
energy_status.shape = ['n_terms']

differential_order = APWWavefunction.differential_order.m_copy()
differential_order.shape = ['n_terms']

boundary_order = Quantity(
type=np.int32,
shape=['n_terms'],
Expand Down Expand Up @@ -355,7 +295,7 @@ class APWLChannel(BasisSet):
""",
)

orbitals = SubSection(sub_section=APWWavefunction.m_def, repeats=True)
orbitals = SubSection(sub_section=APWBaseOrbital.m_def, repeats=True)

def _determine_apw(self, logger: BoundLogger) -> dict[str, int]:
"""
Expand Down
55 changes: 53 additions & 2 deletions tests/test_basis_set.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
from . import logger
from nomad.units import ureg
import numpy as np
from . import logger
import pytest
from typing import Optional

from nomad_simulations.schema_packages.basis_set import APWPlaneWaveBasisSet
from nomad_simulations.schema_packages.basis_set import (
APWBaseOrbital,
APWOrbital,
APWLocalOrbital,
APWLChannel,
APWPlaneWaveBasisSet,
BasisSet,
BasisSetContainer,
MuffinTinRegion,
)


def test_cutoff():
Expand Down Expand Up @@ -34,3 +45,43 @@ def test_cutoff_failure():
pw = APWPlaneWaveBasisSet(cutoff_energy=500 * ureg('eV'), cutoff_fractional=1)
pw.set_cutoff_fractional(ureg.angstrom, logger)
assert pw.cutoff_fractional == 1


@pytest.mark.skip(reason="This function is not meant to be tested directly")
def generate_apw(
species: dict[str, int | APWBaseOrbital],
cutoff: Optional[float] = None
) -> BasisSetContainer:
"""
Generate a mock APW basis set with the following structure:
.
├── plane-wave basis set
└── muffin-tin regions
└── l-channels
├── (orbitals)
│ └── wavefunctions
└── local orbitals
"""
basis_set_components: list[BasisSet] = []
if cutoff is not None:
pw = APWPlaneWaveBasisSet(cutoff_energy=cutoff)
basis_set_components.append(pw)

mts: list[MuffinTinRegion] = []
for sp in species:
l_max = sp['l_max']
mt = MuffinTinRegion(
radius=sp['r'],
l_max=l_max,
l_channels=[
APWLChannel(
l=l,
orbitals=[APWOrbital(type=orb) for orb in sp['orb_type']] +\
[APWLocalOrbital(type=lo) for lo in sp['lo_type']],
) for l in range(l_max)
]
)
mts.append(mt)
basis_set_components.append(mts)

return BasisSetContainer(basis_set_components=basis_set_components)

0 comments on commit 560cd5f

Please sign in to comment.