-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
126 apw missing energy parameters #139
base: develop
Are you sure you want to change the base?
Changes from 4 commits
6260099
048374f
e7a4bff
e20a2c9
6f13113
bf59758
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
|
||
import os | ||
from glob import glob | ||
import numpy as np | ||
|
||
from nomad.datamodel import EntryArchive | ||
from nomad.datamodel.metainfo.simulation.run import Run | ||
|
@@ -28,6 +29,8 @@ | |
ParticleHoleExcitationsMethod, ParticleHoleExcitationsResults, | ||
PhotonPolarization, PhotonPolarizationMethod, PhotonPolarizationResults | ||
) | ||
from nomad.datamodel.metainfo.simulation.method import OrbitalAPW | ||
from typing import Any, Union, Iterable | ||
|
||
|
||
def extract_section(source: EntryArchive, path: str): | ||
|
@@ -283,3 +286,136 @@ def extract_polarization_outputs(): | |
workflow.m_add_sub_section(ParticleHoleExcitations.tasks, task) | ||
|
||
xs_workflow_archive.workflow2 = workflow | ||
|
||
|
||
class OrbitalAPWConstructor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add test |
||
''' | ||
Class for storing and sorting the orbitals in a APW basis set. | ||
''' | ||
def __init__(self, *args, **kwargs): | ||
''' | ||
Initializer for the OrbitalAPWConstructor class. Accept: | ||
- args: list of strings defining the input format for the orbitals, should match the quantity names in OrbitalAPW | ||
default: ['l_quantum_number', 'energy_parameter', 'type', 'updated'] | ||
- order: list of strings defining the order in which the orbitals are sorted, in descending order of relevance | ||
- comparison: list of strings defining the keys used for comparison of orbitals in `overwrite_orbital` | ||
''' | ||
self.orbitals: list[dict[str, Any]] = [] | ||
self.wavefunctions: dict[str, Any] = {} | ||
if args: | ||
self.input_format = args | ||
self.term_order = kwargs.get( | ||
'order', | ||
[ | ||
'l_quantum_number', 'j_quantum_number', 'k_quantum_number', | ||
'energy_parameter_n', 'energy_parameter', 'order', 'updated', | ||
] | ||
) | ||
self.term_order.reverse() | ||
self.comparison_keys = kwargs.get( | ||
'comparison', | ||
['l_quantum_number', 'j_quantum_number', 'k_quantum_number'] | ||
) | ||
|
||
def _convert(self, settings: dict[str, Any], **kwargs) -> dict[str, Any]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this could actually live out of |
||
''' | ||
Convert the input arguments (`args`) to a `dict` as defined by `input_format`. | ||
Keys outside of `input_format` can be added via kwargs. | ||
''' | ||
return dict(sorted({**settings, **kwargs}.items())) | ||
|
||
def append_wavefunction(self, settings, **kwargs): | ||
''' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please add descriptions here and below in other functions? |
||
''' | ||
if self.wavefunctions: | ||
for key, val in self.wavefunctions.items(): | ||
converted = self._convert(settings, **kwargs) | ||
if getattr(OrbitalAPW, key, {}).get('shape'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ndaelman-hu @ladinesa is this correct? I remember Theodore saying that one can safely use getattr for MSections, but just wanted to make sure 🙂 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. getattr is fine, just avoid setattr. |
||
try: | ||
self.wavefunctions[key].append(converted[key]) | ||
except AttributeError: | ||
self.wavefunctions[key] = np.append(self.wavefunctions[key], converted[key]) | ||
elif val != converted[key]: | ||
raise ValueError(f'Wavefunction {key} does not match previous value') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. avoid raising exceptions |
||
else: | ||
for key, val in self._convert(settings, **kwargs).items(): | ||
if getattr(OrbitalAPW, key, {}).get('shape'): | ||
if hasattr(val, 'units'): # check if quantity | ||
self.wavefunctions[key] = np.array([val.magnitude]) * val.units | ||
else: | ||
self.wavefunctions[key] = [val] | ||
else: | ||
self.wavefunctions[key] = val | ||
|
||
def unroll_orbital(self, orders: Union[int, list[int]], settings, **kwargs): | ||
''' | ||
''' | ||
if isinstance(orders, int): | ||
orders = list(range(orders)) | ||
self.wavefunctions = {} # reset wavefunctions | ||
for order in orders: | ||
new_kwargs = {**kwargs, 'order': order} | ||
self.append_wavefunction(settings, **new_kwargs) | ||
|
||
def append_orbital(self, *settings, **kwargs): | ||
''' | ||
''' | ||
if self.wavefunctions: | ||
self.orbitals.append(self.wavefunctions) | ||
self.wavefunctions = {} | ||
return | ||
if len(settings) == 1: | ||
if new_orbital := self._convert(settings[0], **kwargs): | ||
self.orbitals.append(new_orbital) | ||
|
||
def _extract_comparison(self, orbital: dict[str, Any]) -> dict[str, Any]: | ||
''' | ||
Extract the keys used for comparison from the orbital, as defined in `comparison_keys`. | ||
''' | ||
return {k: v for k, v in orbital.items() if k in self.comparison_keys} | ||
|
||
def overwrite_orbital(self, *settings, **kwargs): | ||
''' | ||
''' | ||
if not(converted := self.wavefunctions): | ||
if len(settings) == 1: | ||
converted = self._convert(settings[0], **kwargs) | ||
else: | ||
return | ||
converted_ids = self._extract_comparison(converted) | ||
new_orbitals = [] | ||
for orbital in self.orbitals: | ||
orbital_ids = self._extract_comparison(orbital) | ||
if orbital_ids != converted_ids: | ||
new_orbitals.append(orbital) | ||
self.orbitals = new_orbitals | ||
self.append_orbital(*settings, **kwargs) | ||
|
||
def get_orbitals(self) -> list[OrbitalAPW]: | ||
''' | ||
Return the stored orbitals as a sorted list of `OrbitalAPW` sections. | ||
''' | ||
def _sort_func(orbital: dict[str, Any], term: str): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure what this function does. |
||
''' | ||
Helper function to guide the sorting process. | ||
''' | ||
if term in orbital: | ||
if isinstance(val := orbital[term], Iterable): | ||
try: | ||
return sum(val) | ||
except TypeError: | ||
return tuple(val) | ||
return orbital[term] | ||
return 0 | ||
|
||
# sort out the orbitals by id_keys | ||
for term in self.term_order: | ||
orbitals = sorted(self.orbitals, key=lambda orbital: _sort_func(orbital, term)) | ||
formatted_orbitals: list[OrbitalAPW] = [] | ||
for orbital in orbitals: | ||
# convert to NOMAD section | ||
formatted_orbital = OrbitalAPW() | ||
for term, val in orbital.items(): | ||
setattr(formatted_orbital, term, val) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Coming back to my comment above on You have an example in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can do formatted_orbital.m_set(formatted_orbital.m_get_quantity_defintion(tem), val) |
||
formatted_orbitals.append(formatted_orbital) | ||
return formatted_orbitals |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace raise with a warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means that the basis_set output will be complete bull. When should an error be fatal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but we do not want to terminate parsing due to this. simply do not write to archive if deemed incorrect.