Skip to content

Commit

Permalink
Added explicit input and fix testing for atoms_state.py
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed Jun 5, 2024
1 parent 908a34b commit 79af7d0
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 17 deletions.
18 changes: 9 additions & 9 deletions src/nomad_simulations/atoms_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from nomad.datamodel.metainfo.basesections import Entity
from nomad.datamodel.metainfo.annotations import ELNAnnotation

from .utils import RussellSaundersState
from nomad_simulations.utils import RussellSaundersState


class OrbitalsState(Entity):
Expand Down Expand Up @@ -280,7 +280,7 @@ def resolve_degeneracy(self) -> Optional[int]:
for jj in self.j_quantum_number:
if self.mj_quantum_number is not None:
mjs = RussellSaundersState.generate_MJs(
self.j_quantum_number[0], rising=True
J=self.j_quantum_number[0], rising=True
)
degeneracy += len(
[mj for mj in mjs if mj in self.mj_quantum_number]
Expand All @@ -293,15 +293,15 @@ def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)

# General checks for physical quantum numbers and symbols
if not self.validate_quantum_numbers(logger):
if not self.validate_quantum_numbers(logger=logger):
logger.error('The quantum numbers are not physical.')
return

# Resolving the quantum numbers and symbols if not available
for quantum_name in ['l', 'ml', 'ms']:
for quantum_type in ['number', 'symbol']:
quantity = self.resolve_number_and_symbol(
quantum_name, quantum_type, logger
quantum_name=quantum_name, quantum_type=quantum_type, logger=logger
)
if getattr(self, f'{quantum_name}_quantum_{quantum_type}') is None:
setattr(self, f'{quantum_name}_quantum_{quantum_type}', quantity)
Expand Down Expand Up @@ -383,7 +383,7 @@ def normalize(self, archive, logger) -> None:
self.n_excited_electrons = None
self.orbital_ref.degeneracy = 1
if self.orbital_ref.occupation is None:
self.orbital_ref.occupation = self.resolve_occupation(logger)
self.orbital_ref.occupation = self.resolve_occupation(logger=logger)


class HubbardInteractions(ArchiveSection):
Expand Down Expand Up @@ -552,11 +552,11 @@ def normalize(self, archive, logger) -> None:
self.u_interaction,
self.u_interorbital_interaction,
self.j_hunds_coupling,
) = self.resolve_u_interactions(logger)
) = self.resolve_u_interactions(logger=logger)

# If u_effective is not available, calculate it
if self.u_effective is None:
self.u_effective = self.resolve_u_effective(logger)
self.u_effective = self.resolve_u_effective(logger=logger)

# Check if length of `orbitals_ref` is the same as the length of `umn`:
if self.u_matrix is not None and self.orbitals_ref is not None:
Expand Down Expand Up @@ -652,6 +652,6 @@ def normalize(self, archive, logger) -> None:

# Get chemical_symbol from atomic_number and viceversa
if self.chemical_symbol is None:
self.chemical_symbol = self.resolve_chemical_symbol(logger)
self.chemical_symbol = self.resolve_chemical_symbol(logger=logger)
if self.atomic_number is None:
self.atomic_number = self.resolve_atomic_number(logger)
self.atomic_number = self.resolve_atomic_number(logger=logger)
59 changes: 51 additions & 8 deletions tests/test_atoms_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,18 @@ def test_validate_quantum_numbers(
):
"""
Test the `validate_quantum_numbers` method.
Args:
number_label (str): The quantum number string to be tested.
values (List[int]): The values stored in `OrbitalState`.
results (List[bool]): The expected results after validation.
"""
orbital_state = OrbitalsState(n_quantum_number=2)
for val, res in zip(values, results):
if number_label == 'ml_quantum_number':
orbital_state.l_quantum_number = 2
setattr(orbital_state, number_label, val)
assert orbital_state.validate_quantum_numbers(logger) == res
assert orbital_state.validate_quantum_numbers(logger=logger) == res

@pytest.mark.parametrize(
'quantum_name, value, expected_result',
Expand Down Expand Up @@ -103,6 +108,11 @@ def test_number_and_symbol(
):
"""
Test the number and symbol resolution for each of the quantum numbers defined in the parametrization.
Args:
quantum_name (str): The quantum number string to be tested.
value (Union[int, float]): The value stored in `OrbitalState`.
expected_result (Optional[str]): The expected result after resolving the counter-type.
"""
# Adding quantum numbers to the `OrbitalsState` section
orbital_state = OrbitalsState(n_quantum_number=2)
Expand All @@ -112,13 +122,13 @@ def test_number_and_symbol(

# Making sure that the `'number'` is assigned
resolved_type = orbital_state.resolve_number_and_symbol(
quantum_name, 'number', logger
quantum_name=quantum_name, quantum_type='number', logger=logger
)
assert resolved_type == value

# Resolving if the counter-type is assigned
resolved_countertype = orbital_state.resolve_number_and_symbol(
quantum_name, 'symbol', logger
quantum_name=quantum_name, quantum_type='symbol', logger=logger
)
assert resolved_countertype == expected_result

Expand Down Expand Up @@ -146,6 +156,14 @@ def test_degeneracy(
):
"""
Test the degeneracy of each orbital states defined in the parametrization.
Args:
l_quantum_number (int): The angular momentum quantum number.
ml_quantum_number (Optional[int]): The magnetic quantum number.
j_quantum_number (Optional[List[float]]): The total angular momentum quantum number.
mj_quantum_number (Optional[List[float]]): The magnetic quantum number for the total angular momentum.
ms_quantum_number (Optional[float]): The spin quantum number.
degeneracy (int): The expected degeneracy of the orbital state.
"""
orbital_state = OrbitalsState(n_quantum_number=2)
self.add_state(
Expand Down Expand Up @@ -195,13 +213,19 @@ def test_occupation(
):
"""
Test the occupation of a core hole for a given set of orbital reference and degeneracy.
Args:
orbital_ref (Optional[OrbitalsState]): The orbital reference of the core hole.
degeneracy (Optional[int]): The degeneracy of the orbital reference.
n_excited_electrons (float): The number of excited electrons.
occupation (Optional[float]): The expected occupation of the core hole.
"""
core_hole = CoreHole(
orbital_ref=orbital_ref, n_excited_electrons=n_excited_electrons
)
if orbital_ref is not None:
assert orbital_ref.resolve_degeneracy() == degeneracy
resolved_occupation = core_hole.resolve_occupation(logger)
resolved_occupation = core_hole.resolve_occupation(logger=logger)
if resolved_occupation is not None:
assert np.isclose(resolved_occupation, occupation)
else:
Expand Down Expand Up @@ -232,6 +256,12 @@ def test_normalize(
):
"""
Test the normalization of the `CoreHole`. Inputs are defined as the quantities of the `CoreHole` section.
Args:
orbital_ref (Optional[OrbitalsState]): The orbital reference of the core hole.
n_excited_electrons (Optional[float]): The number of excited electrons.
dscf_state (Optional[str]): The DSCF state of the core hole.
results (Tuple[Optional[float], Optional[float], Optional[float]]): The expected results after normalization.
"""
core_hole = CoreHole(
orbital_ref=orbital_ref,
Expand Down Expand Up @@ -265,6 +295,10 @@ def test_u_interactions(
):
"""
Test the Hubbard interactions `U`, `U'`, and `J` for a given set of Slater integrals.
Args:
slater_integrals (Optional[List[float]]): The Slater integrals of the Hubbard interactions.
results (Tuple[Optional[float], Optional[float], Optional[float]]): The expected results of the Hubbard interactions.
"""
# Adding `slater_integrals` to the `HubbardInteractions` section
hubbard_interactions = HubbardInteractions()
Expand All @@ -276,7 +310,7 @@ def test_u_interactions(
u_interaction,
u_interorbital_interaction,
j_hunds_coupling,
) = hubbard_interactions.resolve_u_interactions(logger)
) = hubbard_interactions.resolve_u_interactions(logger=logger)

if None not in (u_interaction, u_interorbital_interaction, j_hunds_coupling):
assert np.isclose(u_interaction.to('eV').magnitude, results[0])
Expand Down Expand Up @@ -306,6 +340,11 @@ def test_u_effective(
):
"""
Test the effective Hubbard interaction `U_eff` for a given set of Hubbard interactions `U` and `J`.
Args:
u_interaction (Optional[float]): The Hubbard interaction `U`.
j_local_exchange_interaction (Optional[float]): The Hubbard interaction `J`.
u_effective (Optional[float]): The expected effective Hubbard interaction `U_eff`.
"""
# Adding `u_interaction` and `j_local_exchange_interaction` to the `HubbardInteractions` section
hubbard_interactions = HubbardInteractions()
Expand All @@ -317,7 +356,7 @@ def test_u_effective(
)

# Resolving Ueff from class method
resolved_u_effective = hubbard_interactions.resolve_u_effective(logger)
resolved_u_effective = hubbard_interactions.resolve_u_effective(logger=logger)
if resolved_u_effective is not None:
assert np.isclose(resolved_u_effective.to('eV').magnitude, u_effective)
else:
Expand Down Expand Up @@ -358,10 +397,14 @@ def test_chemical_symbol_and_atomic_number(
):
"""
Test the `chemical_symbol` and `atomic_number` resolution for the `AtomsState` section.
Args:
chemical_symbol (str): The chemical symbol of the atom.
atomic_number (int): The atomic number of the atom.
"""
# Testing `chemical_symbol`
atom_state = AtomsState(chemical_symbol=chemical_symbol)
assert atom_state.resolve_atomic_number(logger) == atomic_number
assert atom_state.resolve_atomic_number(logger=logger) == atomic_number
# Testing `atomic_number`
atom_state.atomic_number = atomic_number
assert atom_state.resolve_chemical_symbol(logger) == chemical_symbol
assert atom_state.resolve_chemical_symbol(logger=logger) == chemical_symbol

0 comments on commit 79af7d0

Please sign in to comment.