Skip to content

Commit

Permalink
Move some functionalities of APWLChannel._determine_apw() down the …
Browse files Browse the repository at this point in the history
…composition tree
  • Loading branch information
ndaelman committed Aug 19, 2024
1 parent 713b904 commit 7050071
Showing 1 changed file with 61 additions and 8 deletions.
69 changes: 61 additions & 8 deletions src/nomad_simulations/schema_packages/basis_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,43 @@ class APWBaseOrbital(ArchiveSection):
""",
) # TODO: add check non-negative # ? to remove

def _get_lengths(self, quantities: set[str]) -> list[int]:
"""Extract the lengths of the `quantities` contained in the set."""
present_quantities = set(quantities) & self.m_quantities
lengths: list[int] = []
for quant in present_quantities:
length = len(getattr(self, quant))
if length > 0: # empty lists are exempt
lengths.append(length)
return lengths

def _of_equal_length(lengths: list[int]) -> bool:
"""Check if all elements in the list are of equal length."""
if len(lengths) == 0:
return True
else:
ref_length = lengths[0]
return all(length == ref_length for length in lengths)

def get_n_terms(self) -> Optional[int]:
"""Determine the value of `n_terms` based on the lengths of the representative quantities."""
rep_quant = {'energy_parameter', 'energy_parameter_n', 'differential_order'}
lengths = self._get_lengths(rep_quant)
if not self._of_equal_length(lengths) or len(lengths) == 0:
return None
else:
return lengths[0]

def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None:
super().normalize(archive, logger)
if (self.differential_order < 0).any():
logger.error('`APWBaseOrbital.differential_order` must be non-negative.')
new_n_terms = self.get_n_terms()
if self.n_terms is None:
self.n_terms = new_n_terms
elif self.n_terms != new_n_terms:
logger.error(
f'Inconsistent lengths of `APWBaseOrbital` quantities: {self.m_quantities}. Setting back to `None`.'
)
self.n_terms = None


class APWOrbital(APWBaseOrbital):
Expand All @@ -228,7 +261,7 @@ class APWOrbital(APWBaseOrbital):
"""

type = Quantity(
type=MEnum('apw', 'lapw', 'slapw', 'spherical_dirac'),
type=MEnum('apw', 'lapw', 'slapw'), # ? where to put 'spherical_dirac'
description="""
Type of augmentation contribution. Abbreviations stand for:
| name | description | radial product |
Expand All @@ -242,6 +275,25 @@ class APWOrbital(APWBaseOrbital):
""",
)

def get_type(self, logger: BoundLogger) -> Optional[str]:
"""
Set the type of the APW orbital based on the differential order.
"""
if self.n_terms is None:
logger.error('`APWOrbital.n_terms` must be defined before setting the type.')
return None
if self.n_terms == 0:
return 'apw'
elif self.n_terms == 1:
return 'lapw'
else:
return 'slapw'

def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None:
super().normalize(archive, logger)
if self.type is None:
self.type = self.get_type(logger)


class APWLocalOrbital(APWBaseOrbital):
"""
Expand Down Expand Up @@ -274,6 +326,7 @@ class APWLocalOrbital(APWBaseOrbital):
""",
)


def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
super().normalize(archive, logger)
if np.any(np.isneginf(self.differential_orders)):
Expand Down Expand Up @@ -310,19 +363,19 @@ def _determine_apw(self, logger: BoundLogger) -> dict[str, int]:
Produce a count of the APW components in the l-channel.
"""
count = {'apw': 0, 'lapw': 0, 'slapw': 0, 'lo': 0, 'other': 0}
order_map = {0: 'apw', 1: 'lapw', 2: 'slapw'} # TODO: add dirac?
order_map = {0: 'apw', 1: 'lapw'} # TODO: add dirac?

for orb in self.orbitals:
err_msg = f'Unknown APW orbital type {orb.type} and order {orb.differential_order}.'
if isinstance(orb, APWOrbital):
if orb.type in order_map.values():
if orb.type.lower() in order_map.values():
count[orb.type] += 1
elif orb.differential_order in order_map:
elif orb.n_terms in order_map:
count[order_map[orb.differential_order]] += 1
elif orb.differential_order > 2:
elif orb.n_terms > 2:
count['slapw'] += 1
else:
logger.warning(err_msg)
logger.warning(err_msg) # TODO: rewrite using `type normalization`
elif isinstance(orb, APWLocalOrbital):
count['lo'] += 1
else:
Expand Down

0 comments on commit 7050071

Please sign in to comment.