Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

126 apw missing energy parameters #139

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 115 additions & 62 deletions electronicparsers/exciting/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from nomad.datamodel.metainfo.simulation.method import (
Method, DFT, Electronic, Smearing, XCFunctional, Functional, Scf, BasisSet, KMesh,
FrequencyMesh, Screening, GW, Photon, BSE, CoreHole, BasisSetContainer,
OrbitalAPW, AtomParameters,
AtomParameters,
)
from nomad.datamodel.metainfo.simulation.system import (
System, Atoms
Expand All @@ -51,9 +51,9 @@
x_exciting_scrcoul_parameters
)
from ..utils import (
get_files, BeyondDFTWorkflowsParser
get_files, BeyondDFTWorkflowsParser, OrbitalAPWConstructor,
)
from typing import Any, Iterable
from typing import Any, Iterable, Callable


re_float = r'[-+]?\d+\.\d*(?:[Ee][-+]\d+)?'
Expand Down Expand Up @@ -1550,77 +1550,130 @@ def _parse_band_out(self, sec_scc):
sec_k_band_segment.value = band_energies[nb] + energy_fermi

def _parse_species(self, sec_method):
def _set_orbital(source: TextParser, l_quantum_number: int,
order: int, type: str = '') -> OrbitalAPW:
if not type:
type = re.sub(r'\+lo', '', source.get('type', 'lapw')).upper()
return OrbitalAPW(
l_quantum_number=l_quantum_number,
type=type,
order=order,
energy_parameter=source['trialEnergy'] * ureg.hartree,
update=source['searchE'],
)

type_order_mapping = {' ': 0, 'apw': 1, 'lap': 2}
self.species_parser.parse()
species_data = self.species_parser.to_dict()

# muffin-tin valence
radius = float(species_data['muffinTin']['radius'])
radialmeshPoints = int(species_data['muffinTin']['radialmeshPoints'])
radial_spacing = radius / radialmeshPoints
bs_val = BasisSet(
scope=['muffin-tin'],
radius=radius * ureg.bohr,
radius_lin_spacing=radial_spacing * ureg.bohr,
)
lo_samplings = {lo['l']: lo.get('wf', []) for lo in species_data.get('lo', [])}
lmax = self.input_xml_parser.get('xs/lmaxapw', 10)

for l_n in range(lmax + 1):
source = species_data.get('default', {})
for custom_settings in species_data.get('custom', []):
if custom_settings['l'] == l_n:
source = custom_settings
break
for order in range(type_order_mapping[source.get('type', 3 * ' ')[:3]]):
bs_val.orbital.append(_set_orbital(source, l_n, order))

# Add lo's
if source.get('type', 2 * ' ')[-2:] == 'lo':
wfs = lo_samplings[l_n] if l_n in lo_samplings else [source]
for wf in wfs:
for order in range(wf.get('matchingOrder', 0), 2):
bs_val.orbital.append(_set_orbital(wf, l_n, order, type='LO'))

# manage atom parameters
if not sec_method.atom_parameters:
sec_method.atom_parameters = []
sp = species_data.get('sp', {})
sec_method.atom_parameters.append(
AtomParameters(
atom_number=abs(sp.get('z')),
label=sp.get('chemicalSymbol'),
mass=sp.get('mass') * ureg.amu if sp.get('mass') else None,
)
)
bs_val.atom_parameters = sec_method.atom_parameters[-1]

# process basis set data
if not sec_method.electrons_representation:
sec_method.electrons_representation = [
BasisSetContainer(
scope=['wavefunction'],
basis_set=[
BasisSet(
type='plane waves',
scope=['valence'],
scope=['valence', 'interstitial'],
cutoff_fractional=self.input_xml_parser.get('xs/cutoffapw', 7.),
),
]
)
]
sec_method.electrons_representation[0].basis_set.append(bs_val)

# muffin-tin
if not (species_data := self.species_parser.results):
self.logger.warning(f'No species data found in {self.species_parser.filepath}')
return

# helper functions
def _convert_keyval(key_vals: list[list]) -> dict[str, Any]:
'''Helper function to convert key-value pairs to a dictionary'''
bool_map = {'false': False, 'true': True}
key_vals_converted: dict[str, Any] = {}
for key_val in key_vals:
if len(key_val) != 2:
raise ValueError(f'Invalid key-value pair: {key_val}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace raise with a warning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that the basis_set output will be complete bull. When should an error be fatal?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we do not want to terminate parsing due to this. simply do not write to archive if deemed incorrect.

if key_val[1] in bool_map:
key_vals_converted[key_val[0]] = bool_map[key_val[1]]
else:
key_vals_converted[key_val[0]] = key_val[1]
return key_vals_converted

def map_to_metainfo(settings: dict[str, Any], **kwargs) -> dict[str, Any]:
''''''
new_settings: dict[str, Any] = {}
mapping = {
'matchingOrder': 'order', 'n': 'energy_parameter_n',
'trialEnergy': 'energy_parameter', 'kappa': 'kappa_quantum_number',
'searchE': 'update',
}
for key, val in settings.items():
if key in mapping:
new_settings[mapping[key]] = val
if key == 'trialEnergy':
new_settings[mapping[key]] *= ureg.hartree
elif key == 'type':
new_settings[key] = val.upper()
else:
new_settings[key] = val
return {'type': 'LAPW', **new_settings, **kwargs}

def _unroll_lo(orbital: dict[str, Any]) -> list[dict[str, Any]]:
'''Helper function to unroll local orbitals'''
if (full_type := orbital.get('type').upper()).endswith('+LO'):
# Note: the input variable is modified in-place here
# This shouldn't be a problem when done in the right order
unrolled_lines = [{**orbital, 'type': full_type[:-3]}]
for order in range(2):
unrolled_lines.append({**orbital, 'type': 'LO', 'matchingOrder': order})
return unrolled_lines
elif full_type is not None:
return [orbital]
else:
return []

orb_constr = OrbitalAPWConstructor()
order_map = {'APW': 1, 'LAPW': 2}
# read default settings
if line_data := _convert_keyval(species_data.get('default', {}).get('key_val', [])):
for l in range(self.input_xml_parser.get('xs/lmaxapw', 10)):
for orbital in _unroll_lo(line_data):
orb = map_to_metainfo(orbital, l_quantum_number=l)
if (orb_type := orb['type']) in order_map:
orb_constr.unroll_orbital(order_map[orb_type], orb)
orb_constr.append_orbital()
elif orb_type == 'LO':
orb_constr.append_wavefunction(orb)
orb_constr.append_orbital()
# read custom settings
if lines_data := species_data.get('custom', []):
for line_data in lines_data:
for orbital in _unroll_lo(_convert_keyval(line_data.get('key_val', []))):
orb = map_to_metainfo(orbital, l_quantum_number=orbital.get('l'))
if (orb_type := orb['type']) in order_map:
orb_constr.unroll_orbital(order_map[orb_type], orb)
orb_constr.overwrite_orbital()
elif orb_type == 'LO':
orb_constr.append_wavefunction(orb)
orb_constr.append_orbital()
# read in local orbitals
if lines_data := species_data.get('lo', []):
for line_data in lines_data:
if (l := line_data.get('l')) is not None:
for wf in line_data.get('wf', []):
wf = map_to_metainfo(_convert_keyval(wf.get('key_val', [])))
orb_constr.append_wavefunction(wf, l_quantum_number=l, type='LO')
orb_constr.append_orbital()
# write out the orbitals
bs = BasisSet(
scope=['muffin-tin'],
orbital=orb_constr.get_orbitals(),
)
if mt := species_data.get('muffinTin', {}):
if radius := mt.get('radius'):
radius = float(radius) * ureg.bohr
bs.radius = radius
if rmp := mt.get('radialmeshPoints'):
bs.radial_spacing = radius / int(rmp)
sec_method.electrons_representation[0].basis_set.append(bs)

# manage atom parameters
if not sec_method.atom_parameters:
sec_method.atom_parameters = []
if sp := _convert_keyval(species_data.get('sp', {}).get('key_val', [])):
sec_method.atom_parameters.append(
AtomParameters(
atom_number=abs(sp.get('z')) if sp.get('z') else None,
label=sp.get('chemicalSymbol'),
mass=sp.get('mass') * ureg.amu if sp.get('mass') else None,
)
)
sec_method.electrons_representation[0].basis_set[-1].atom_parameters = sec_method.atom_parameters[-1]

def parse_file(self, name, section, filepath=None):
# TODO add support for info.xml, wannier.out
Expand Down
2 changes: 1 addition & 1 deletion electronicparsers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
# limitations under the License.

from .utils import (
extract_section, get_files, BeyondDFTWorkflowsParser
extract_section, get_files, BeyondDFTWorkflowsParser, OrbitalAPWConstructor,
)
136 changes: 136 additions & 0 deletions electronicparsers/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import os
from glob import glob
import numpy as np

from nomad.datamodel import EntryArchive
from nomad.datamodel.metainfo.simulation.run import Run
Expand All @@ -28,6 +29,8 @@
ParticleHoleExcitationsMethod, ParticleHoleExcitationsResults,
PhotonPolarization, PhotonPolarizationMethod, PhotonPolarizationResults
)
from nomad.datamodel.metainfo.simulation.method import OrbitalAPW
from typing import Any, Union, Iterable


def extract_section(source: EntryArchive, path: str):
Expand Down Expand Up @@ -283,3 +286,136 @@ def extract_polarization_outputs():
workflow.m_add_sub_section(ParticleHoleExcitations.tasks, task)

xs_workflow_archive.workflow2 = workflow


class OrbitalAPWConstructor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add test

'''
Class for storing and sorting the orbitals in a APW basis set.
'''
def __init__(self, *args, **kwargs):
'''
Initializer for the OrbitalAPWConstructor class. Accept:
- args: list of strings defining the input format for the orbitals, should match the quantity names in OrbitalAPW
default: ['l_quantum_number', 'energy_parameter', 'type', 'updated']
- order: list of strings defining the order in which the orbitals are sorted, in descending order of relevance
- comparison: list of strings defining the keys used for comparison of orbitals in `overwrite_orbital`
'''
self.orbitals: list[dict[str, Any]] = []
self.wavefunctions: dict[str, Any] = {}
if args:
self.input_format = args
self.term_order = kwargs.get(
'order',
[
'l_quantum_number', 'j_quantum_number', 'k_quantum_number',
'energy_parameter_n', 'energy_parameter', 'order', 'updated',
]
)
self.term_order.reverse()
self.comparison_keys = kwargs.get(
'comparison',
['l_quantum_number', 'j_quantum_number', 'k_quantum_number']
)

def _convert(self, settings: dict[str, Any], **kwargs) -> dict[str, Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could actually live out of OrbitalAPWConstructor, but you decide. Furthermore, the naming of the function could be more verbose: _convert_to_dict.

'''
Convert the input arguments (`args`) to a `dict` as defined by `input_format`.
Keys outside of `input_format` can be added via kwargs.
'''
return dict(sorted({**settings, **kwargs}.items()))

def append_wavefunction(self, settings, **kwargs):
'''
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add descriptions here and below in other functions?

'''
if self.wavefunctions:
for key, val in self.wavefunctions.items():
converted = self._convert(settings, **kwargs)
if getattr(OrbitalAPW, key, {}).get('shape'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ndaelman-hu @ladinesa is this correct? I remember Theodore saying that one can safely use getattr for MSections, but just wanted to make sure 🙂

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getattr is fine, just avoid setattr.

try:
self.wavefunctions[key].append(converted[key])
except AttributeError:
self.wavefunctions[key] = np.append(self.wavefunctions[key], converted[key])
elif val != converted[key]:
raise ValueError(f'Wavefunction {key} does not match previous value')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid raising exceptions

else:
for key, val in self._convert(settings, **kwargs).items():
if getattr(OrbitalAPW, key, {}).get('shape'):
if hasattr(val, 'units'): # check if quantity
self.wavefunctions[key] = np.array([val.magnitude]) * val.units
else:
self.wavefunctions[key] = [val]
else:
self.wavefunctions[key] = val

def unroll_orbital(self, orders: Union[int, list[int]], settings, **kwargs):
'''
'''
if isinstance(orders, int):
orders = list(range(orders))
self.wavefunctions = {} # reset wavefunctions
for order in orders:
new_kwargs = {**kwargs, 'order': order}
self.append_wavefunction(settings, **new_kwargs)

def append_orbital(self, *settings, **kwargs):
'''
'''
if self.wavefunctions:
self.orbitals.append(self.wavefunctions)
self.wavefunctions = {}
return
if len(settings) == 1:
if new_orbital := self._convert(settings[0], **kwargs):
self.orbitals.append(new_orbital)

def _extract_comparison(self, orbital: dict[str, Any]) -> dict[str, Any]:
'''
Extract the keys used for comparison from the orbital, as defined in `comparison_keys`.
'''
return {k: v for k, v in orbital.items() if k in self.comparison_keys}

def overwrite_orbital(self, *settings, **kwargs):
'''
'''
if not(converted := self.wavefunctions):
if len(settings) == 1:
converted = self._convert(settings[0], **kwargs)
else:
return
converted_ids = self._extract_comparison(converted)
new_orbitals = []
for orbital in self.orbitals:
orbital_ids = self._extract_comparison(orbital)
if orbital_ids != converted_ids:
new_orbitals.append(orbital)
self.orbitals = new_orbitals
self.append_orbital(*settings, **kwargs)

def get_orbitals(self) -> list[OrbitalAPW]:
'''
Return the stored orbitals as a sorted list of `OrbitalAPW` sections.
'''
def _sort_func(orbital: dict[str, Any], term: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what this function does.

'''
Helper function to guide the sorting process.
'''
if term in orbital:
if isinstance(val := orbital[term], Iterable):
try:
return sum(val)
except TypeError:
return tuple(val)
return orbital[term]
return 0

# sort out the orbitals by id_keys
for term in self.term_order:
orbitals = sorted(self.orbitals, key=lambda orbital: _sort_func(orbital, term))
formatted_orbitals: list[OrbitalAPW] = []
for orbital in orbitals:
# convert to NOMAD section
formatted_orbital = OrbitalAPW()
for term, val in orbital.items():
setattr(formatted_orbital, term, val)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coming back to my comment above on getattr (which again, it is just to confirm, so most probably you can use it) setattr should not be used and instead you should m_add_sub_section.

You have an example in nomad/normalizing/method.py:ExcitedStateMethod, at the end I was adding to simulation depending on the dict self._method_def (which is defined at the beginning of the MethodNormalizer().

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can do formatted_orbital.m_set(formatted_orbital.m_get_quantity_defintion(tem), val)

formatted_orbitals.append(formatted_orbital)
return formatted_orbitals
3 changes: 1 addition & 2 deletions electronicparsers/wien2k/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def parse_method(self):
em = BasisSetContainer(scope=['wavefunction'])
em.basis_set.append(
BasisSet(
scope=['intersitial', 'valence'],
scope=['valence', 'intersitial'],
type='plane waves',
cutoff_fractional=source.get('rkmax'),
)
Expand All @@ -828,7 +828,6 @@ def parse_method(self):
for l_n in range(source.get('lmax', -1) + 1):
orbital = OrbitalAPW()
orbital.l_quantum_number = l_n
orbital.core_level = False
e_param = mt.get('e_param', source.get('e_ref', .5) - .2) # TODO: check for +.2 case
update = False
apw_type = mt.get('type')
Expand Down