From 862f5eee868496e32d7c489ac9690dd7f1268df1 Mon Sep 17 00:00:00 2001 From: ndaelman Date: Fri, 16 Aug 2024 22:47:54 +0200 Subject: [PATCH] - Extend APW structure test template - Cover more APW cases - Migrate and touch up `generate_apw` --- .../schema_packages/basis_set.py | 65 ++++++- tests/conftest.py | 181 +++++++++++------- tests/test_basis_set.py | 78 ++------ 3 files changed, 199 insertions(+), 125 deletions(-) diff --git a/src/nomad_simulations/schema_packages/basis_set.py b/src/nomad_simulations/schema_packages/basis_set.py index 6d6e8b86..7b300341 100644 --- a/src/nomad_simulations/schema_packages/basis_set.py +++ b/src/nomad_simulations/schema_packages/basis_set.py @@ -3,11 +3,12 @@ from nomad.datamodel.metainfo.annotations import ELNAnnotation from nomad.metainfo import MEnum, Quantity, SubSection from nomad.units import ureg +import itertools import numpy as np import pint from scipy import constants as const from structlog.stdlib import BoundLogger -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Optional, Any if TYPE_CHECKING: from nomad.metainfo import Context, Section @@ -442,3 +443,65 @@ def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None: logger.error( 'Expected a `APWPlaneWaveBasisSet` instance, but found none.' ) + + +def generate_apw( + species: dict[str, dict[str, Any]], cutoff: Optional[float] = None +) -> BasisSetContainer: + """ + Generate a mock APW basis set with the following structure: + . + ├── 1 x plane-wave basis set + └── n x muffin-tin regions + └── l_max x l-channels + ├── orbitals + └── local orbitals + + from a dictionary + { + : { + 'r': , + 'l_max': , + 'orb_type': [], + 'lo_type': [], + } + } + """ + + basis_set_components: list[BasisSet] = [] + if cutoff is not None: + pw = APWPlaneWaveBasisSet(cutoff_energy=cutoff) + basis_set_components.append(pw) + + for sp_name, sp in species.items(): + sp['r'] = sp.get('r', None) + sp['l_max'] = sp.get('l_max', 0) + sp['orb_type'] = sp.get('orb_type', []) + sp['lo_type'] = sp.get('lo_type', []) + + basis_set_components.extend( + [ + MuffinTinRegion( + species_scope=AtomsState( + chemical_symbol=sp_name + ), # TODO: extend to search through a model_system + radius=sp['r'], + l_max=sp['l_max'], + l_channels=[ + APWLChannel( + name=l, + orbitals=list( + itertools.chain( + (APWOrbital(type=orb) for orb in sp['orb_type']), + (APWLocalOrbital(type=lo) for lo in sp['lo_type']), + ) + ), + ) + for l in range(sp['l_max'] + 1) + ], + ) + + ] + ) + + return BasisSetContainer(basis_set_components=basis_set_components) diff --git a/tests/conftest.py b/tests/conftest.py index 9cc1d4f3..baf6d783 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -404,69 +404,118 @@ def k_space_simulation() -> Simulation: return generate_k_space_simulation() -apw = { - 'basis_set_components': [ - { - 'm_def': 'nomad_simulations.schema_packages.basis_set.APWPlaneWaveBasisSet', - 'cutoff_energy': 500.0, - }, - { - 'm_def': 'nomad_simulations.schema_packages.basis_set.MuffinTinRegion', - 'radius': 1.823, - 'l_max': 2, - 'l_channels': [ - { - 'name': 0, - 'orbitals': [ - { - 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', - 'type': 'apw', - }, - { - 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', - 'type': 'lapw', - }, - { - 'm_def': 'nomad_simulations.schema_packages.basis_set.APWLocalOrbital', - 'type': 'lo', - }, - ], - }, - { - 'name': 1, - 'orbitals': [ - { - 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', - 'type': 'apw', - }, - { - 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', - 'type': 'lapw', - }, - { - 'm_def': 'nomad_simulations.schema_packages.basis_set.APWLocalOrbital', - 'type': 'lo', - }, - ], - }, - { - 'name': 2, - 'orbitals': [ - { - 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', - 'type': 'apw', - }, - { - 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', - 'type': 'lapw', - }, - { - 'm_def': 'nomad_simulations.schema_packages.basis_set.APWLocalOrbital', - 'type': 'lo', - }, - ], - }, - ], - }, - ] -} +refs_apw = [ + {}, + { + 'basis_set_components': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWPlaneWaveBasisSet', + 'cutoff_energy': 500.0, + }, + ] + }, + { + 'basis_set_components': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.MuffinTinRegion', + 'radius': 1.823, + 'l_max': 2, + 'l_channels': [ + { + 'name': 0, + 'orbitals': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'type': 'apw', + }, + ], + }, + { + 'name': 1, + 'orbitals': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'type': 'apw', + }, + ], + }, + { + 'name': 2, + 'orbitals': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'type': 'apw', + }, + ], + }, + ], + }, + ] + }, + { + 'basis_set_components': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWPlaneWaveBasisSet', + 'cutoff_energy': 500.0, + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.MuffinTinRegion', + 'radius': 1.823, + 'l_max': 2, + 'l_channels': [ + { + 'name': 0, + 'orbitals': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'type': 'apw', + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'type': 'lapw', + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWLocalOrbital', + 'type': 'lo', + }, + ], + }, + { + 'name': 1, + 'orbitals': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'type': 'apw', + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'type': 'lapw', + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWLocalOrbital', + 'type': 'lo', + }, + ], + }, + { + 'name': 2, + 'orbitals': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'type': 'apw', + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'type': 'lapw', + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWLocalOrbital', + 'type': 'lo', + }, + ], + }, + ], + }, + ] + }, +] diff --git a/tests/test_basis_set.py b/tests/test_basis_set.py index 01bd1cdc..496a2b24 100644 --- a/tests/test_basis_set.py +++ b/tests/test_basis_set.py @@ -1,11 +1,11 @@ -import itertools +from typing import Any, Optional + +import pytest from . import logger from nomad.units import ureg import numpy as np -import pytest -from typing import Optional, Any -from tests.conftest import apw +from tests.conftest import refs_apw from nomad_simulations.schema_packages.basis_set import ( APWBaseOrbital, @@ -16,6 +16,7 @@ BasisSet, BasisSetContainer, MuffinTinRegion, + generate_apw, ) @@ -50,57 +51,18 @@ def test_cutoff_failure(): assert pw.cutoff_fractional == 1 -@pytest.mark.skip(reason='This function is not meant to be tested directly') -def generate_apw( - species: dict[str, dict[str, Any]], cutoff: Optional[float] = None -) -> BasisSetContainer: - """ - Generate a mock APW basis set with the following structure: - . - ├── 1 x plane-wave basis set - └── n x muffin-tin regions - └── l_max x l-channels - ├── orbitals - └── local orbitals - """ - basis_set_components: list[BasisSet] = [] - if cutoff is not None: - pw = APWPlaneWaveBasisSet(cutoff_energy=cutoff) - basis_set_components.append(pw) - - for sp_name, sp in species.items(): - l_max = sp['l_max'] - mt = MuffinTinRegion( - radius=sp['r'], - l_max=l_max, - l_channels=[ - APWLChannel( - name=l, - orbitals=list( - itertools.chain( - (APWOrbital(type=orb) for orb in sp['orb_type']), - (APWLocalOrbital(type=lo) for lo in sp['lo_type']), - ) - ), - ) - for l in range(l_max + 1) - ], - ) - basis_set_components.append(mt) - - return BasisSetContainer(basis_set_components=basis_set_components) - - -def test_full_apw(): - ref_apw = generate_apw( - { - 'A': { - 'r': 1.823, - 'l_max': 2, - 'orb_type': ['apw', 'lapw'], - 'lo_type': ['lo'], - } - }, - cutoff=500, - ) - assert ref_apw.m_to_dict() == apw +@pytest.mark.parametrize( + 'ref_index, species_def, cutoff', + [ + (0, {}, None), + (1, {}, 500.0), + (2, {'H': {'r': 1, 'l_max': 2, 'orb_type': ['apw']}}, 500.0), + ], +) +def test_full_apw( + ref_index: int, species_def: dict[str, dict[str, Any]], cutoff: Optional[float] +): + """Test the composite structure of APW basis sets.""" + assert ( + generate_apw(species_def, cutoff=cutoff).m_to_dict() == refs_apw[ref_index] + ) # TODO: add normalization?