Skip to content

Commit

Permalink
Rework _determine_apw in a more OOP + schema fashion
Browse files Browse the repository at this point in the history
  • Loading branch information
ndaelman committed Aug 20, 2024
1 parent 7050071 commit bc56ffe
Showing 1 changed file with 126 additions and 59 deletions.
185 changes: 126 additions & 59 deletions src/nomad_simulations/schema_packages/basis_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
import pint
from scipy import constants as const
from structlog.stdlib import BoundLogger
from typing import TYPE_CHECKING, Optional, Any

if TYPE_CHECKING:
from nomad.metainfo import Context, Section
from typing import Optional, Any, Callable

from nomad_simulations.schema_packages.atoms_state import AtomsState
from nomad_simulations.schema_packages.numerical_settings import (
Expand All @@ -21,6 +18,20 @@
from nomad_simulations.schema_packages.properties.energies import EnergyContribution


def check_normalized(func: Callable):
"""
Decorator to check if the section is already normalized.
"""

def wrapper(self, archive: EntryArchive, logger: BoundLogger) -> None:
if self._is_normalized:
return None
func(self, archive, logger)
self._is_normalized = True

return wrapper


class BasisSet(NumericalSettings):
"""A type section denoting a basis set component of a simulation.
Should be used as a base section for more specialized sections.
Expand Down Expand Up @@ -167,8 +178,10 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:


class APWBaseOrbital(ArchiveSection):
"""Abstract base section for (S)(L)APW and local orbital component wavefunctions.
It helps defining the interface with `APWLChannel`."""
"""
Abstract base section for (S)(L)APW and local orbital component wavefunctions.
It helps defining the interface with `APWLChannel`.
"""

n_terms = Quantity(
type=np.int32,
Expand Down Expand Up @@ -230,17 +243,25 @@ def _of_equal_length(lengths: list[int]) -> bool:
ref_length = lengths[0]
return all(length == ref_length for length in lengths)

def get_n_terms(self) -> Optional[int]:
def get_n_terms(
self,
representative_quantities: set[str] = {
'energy_parameter',
'energy_parameter_n',
'differential_order',
},
) -> Optional[int]:
"""Determine the value of `n_terms` based on the lengths of the representative quantities."""
rep_quant = {'energy_parameter', 'energy_parameter_n', 'differential_order'}
lengths = self._get_lengths(rep_quant)
lengths = self._get_lengths(representative_quantities)
if not self._of_equal_length(lengths) or len(lengths) == 0:
return None
else:
return lengths[0]

def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None:
super().normalize(archive, logger)

# enforce quantity length (will be used for type assignment)
new_n_terms = self.get_n_terms()
if self.n_terms is None:
self.n_terms = new_n_terms
Expand All @@ -250,6 +271,19 @@ def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None:
)
self.n_terms = None

# enforce differential order constraints
if np.any(np.isneginf(self.differential_order)):
logger.error(
'`APWBaseOrbital.differential_order` must be completely non-negative.'
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# it's hard to enforce commutative diagrams between `_determine_apw` and `normalize`
# instead, make all `_determine_apw` soft-coupled and dependent on the normalized state
# leverage normalize ∘ normalize = normalize
self._is_normalized = False


class APWOrbital(APWBaseOrbital):
"""
Expand All @@ -275,24 +309,32 @@ class APWOrbital(APWBaseOrbital):
""",
)

def get_type(self, logger: BoundLogger) -> Optional[str]:
def n_terms_to_type(self, n_terms: Optional[int]) -> Optional[str]:
"""
Set the type of the APW orbital based on the differential order.
"""
if self.n_terms is None:
logger.error('`APWOrbital.n_terms` must be defined before setting the type.')
if n_terms is None:
return None
if self.n_terms == 0:
if n_terms == 0:
return 'apw'
elif self.n_terms == 1:
elif n_terms == 1:
return 'lapw'
else:
return 'slapw'

@check_normalized
def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None:
super().normalize(archive, logger)
# assign a APW orbital type
# this snippet works of the previous normalization
new_type = self.n_terms_to_type(self.n_terms)
if self.type is None:
self.type = self.get_type(logger)
self.type = new_type
elif self.type != new_type:
logger.error(
f'Inconsistent `APWOrbital` type: {self.type}. Setting back to `None`.'
)
self.type = None


class APWLocalOrbital(APWBaseOrbital):
Expand Down Expand Up @@ -326,13 +368,24 @@ class APWLocalOrbital(APWBaseOrbital):
""",
)

def get_n_terms(self) -> Optional[int]:
"""Determine the value of `n_terms` based on the lengths of the representative quantities."""
return super().get_n_terms(
representative_quantities={
'energy_parameter',
'energy_parameter_n',
'differential_order',
'boundary_order',
}
)

@check_normalized
def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
super().normalize(archive, logger)
if np.any(np.isneginf(self.differential_orders)):
logger.error('`APWLOrbital.differential_order` must be non-negative.')
if np.any(np.isneginf(self.boundary_orders)):
logger.error('`APWLOrbital.boundary_orders` must be non-negative.')
if np.any(np.isneginf(self.boundary_order)):
logger.error(
'`APWLOrbital.boundary_order` must be completely non-negative.'
)


class APWLChannel(BasisSet):
Expand All @@ -349,7 +402,7 @@ class APWLChannel(BasisSet):
""",
)

n_wavefunctions = Quantity(
n_orbitals = Quantity(
type=np.int32,
description="""
Number of wavefunctions in the l-channel, i.e. $(2l + 1) n_orbitals$.
Expand All @@ -358,33 +411,36 @@ class APWLChannel(BasisSet):

orbitals = SubSection(sub_section=APWBaseOrbital.m_def, repeats=True)

def _determine_apw(self, logger: BoundLogger) -> dict[str, int]:
def _determine_apw(self) -> dict[str, int]:
"""
Produce a count of the APW components in the l-channel.
Invokes `normalize` on `orbitals`.
"""
count = {'apw': 0, 'lapw': 0, 'slapw': 0, 'lo': 0, 'other': 0}
order_map = {0: 'apw', 1: 'lapw'} # TODO: add dirac?
for orb in self.orbitals:
orb.normalize(None, None)

type_count = {'apw': 0, 'lapw': 0, 'slapw': 0, 'lo': 0, 'other': 0}
for orb in self.orbitals:
err_msg = f'Unknown APW orbital type {orb.type} and order {orb.differential_order}.'
if isinstance(orb, APWOrbital):
if orb.type.lower() in order_map.values():
count[orb.type] += 1
elif orb.n_terms in order_map:
count[order_map[orb.differential_order]] += 1
elif orb.n_terms > 2:
count['slapw'] += 1
else:
logger.warning(err_msg) # TODO: rewrite using `type normalization`
if isinstance(orb, APWOrbital) and orb.type.lower() in type_count.keys():
type_count[orb.type] += 1
elif isinstance(orb, APWLocalOrbital):
count['lo'] += 1
type_count['lo'] += 1
else:
logger.warning(err_msg)
return count
type_count['other'] += 1 # other de facto operates as a catch-all
return type_count

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._is_normalized = False

@check_normalized
def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None:
super(BasisSet).normalize(archive, logger)
self.n_wavefunctions = len(self.orbitals) * (2 * self.name + 1)
self.n_orbitals = len(self.orbitals)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._is_normalized = False


class MuffinTinRegion(BasisSet, Mesh):
Expand Down Expand Up @@ -417,14 +473,26 @@ class MuffinTinRegion(BasisSet, Mesh):

l_channels = SubSection(sub_section=APWLChannel.m_def, repeats=True)

def _determine_apw(self, logger: BoundLogger) -> dict[str, int]:
def _determine_apw(self) -> dict[str, int]:
"""
Aggregate the APW component count in the muffin-tin region.
Invokes `normalize` on `l_channels`.
"""
count = {'apw': 0, 'lapw': 0, 'slapw': 0, 'lo': 0, 'other': 0}
for channel in self.l_channels:
count.update(channel._determine_apw(logger))
return count
for l_channel in self.l_channels:
l_channel.normalize(None, None)

type_count: dict[str, int]
if len(self.l_channels) > 0:
l_channel = self.l_channels[0]
# dynamically determine `type_count` structure
type_count = l_channel._determine_apw(l_channel.orbitals)
for channel in self.l_channels[1:]:
type_count.update(channel._determine_apw())
return type_count

@check_normalized
def normalize(self, archive, logger):
super().normalize(archive, logger)


class BasisSetContainer(NumericalSettings):
Expand All @@ -450,26 +518,28 @@ class BasisSetContainer(NumericalSettings):

basis_set_components = SubSection(sub_section=BasisSet.m_def, repeats=True)

def _determine_apw(self, logger: BoundLogger) -> str:
def _determine_apw(self) -> Optional[str]:
"""
Derive the basis set name for a (S)(L)APW case.
Derive the basis set name for a (S)(L)APW case, including local orbitals.
Invokes `normalize` on `basis_set_components`.
"""
answer, has_plane_wave = '', False
for comp in self.basis_set_components:
if isinstance(comp, MuffinTinRegion):
count = comp._determine_apw(logger)
if count['apw'] + count['lapw'] + count['slapw'] > 0:
if count['slapw'] > 0:
comp.normalize(None, None)
type_count = comp._determine_apw()
if sum([type_count[i] for i in ('apw', 'lapw', 'slapw')]) > 0:
if type_count['slapw'] > 0:
answer += 'slapw'.upper()
elif count['lapw'] > 0:
elif type_count['lapw'] > 0:
answer += 'lapw'.upper()
elif count['apw'] > 0:
elif type_count['apw'] > 0:
answer += 'apw'.upper()
if count['lo'] > 0:
if type_count['lo'] > 0:
answer += '+lo'
elif isinstance(comp, PlaneWaveBasisSet):
has_plane_wave = True
return answer if has_plane_wave else ''
return answer if has_plane_wave else None

def _smallest_mt(self) -> MuffinTinRegion:
"""
Expand All @@ -491,14 +561,11 @@ def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None:
]
if len(pws) > 1:
logger.warning('Multiple plane-wave basis sets found were found.')
if name := self._determine_apw(logger):
self.name = name # TODO: set name based on basis sets
try:
pws[0].set_cutoff_fractional(self._smallest_mt(), logger)
except (IndexError, AttributeError):
logger.error(
'Expected a `APWPlaneWaveBasisSet` instance, but found none.'
)
self.name = self._determine_apw(logger)
try:
pws[0].set_cutoff_fractional(self._smallest_mt(), logger)
except (IndexError, AttributeError):
logger.error('Expected a `APWPlaneWaveBasisSet` instance, but found none.')


def generate_apw(
Expand Down

0 comments on commit bc56ffe

Please sign in to comment.