Skip to content

Commit

Permalink
Add test template for the APW structure
Browse files Browse the repository at this point in the history
  • Loading branch information
ndaelman committed Aug 16, 2024
1 parent 519cc38 commit 9b61bdf
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 21 deletions.
5 changes: 2 additions & 3 deletions src/nomad_simulations/schema_packages/basis_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ class BasisSet(ArchiveSection):

def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwargs):
super().__init__(m_def, m_context, **kwargs)
# Set the name of the section
self.name = self.m_def.name


class PlaneWaveBasisSet(BasisSet, Mesh):
Expand Down Expand Up @@ -235,6 +233,7 @@ class APWOrbital(APWBaseOrbital):
""",
)


class APWLocalOrbital(APWBaseOrbital):
"""
Implementation of `APWWavefunction` capturing a local orbital extending a foundational APW basis set.
Expand Down Expand Up @@ -286,7 +285,7 @@ class APWLChannel(BasisSet):
description="""
Angular momentum quantum number of the local orbital.
""",
)
) # TODO: add `l` as a quantity

n_wavefunctions = Quantity(
type=np.int32,
Expand Down
68 changes: 68 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,71 @@ def k_line_path() -> KLinePathSettings:
@pytest.fixture(scope='session')
def k_space_simulation() -> Simulation:
return generate_k_space_simulation()


apw = {
'basis_set_components': [
{
'm_def': 'nomad_simulations.schema_packages.basis_set.APWPlaneWaveBasisSet',
'cutoff_energy': 500.0,
},
{
'm_def': 'nomad_simulations.schema_packages.basis_set.MuffinTinRegion',
'radius': 1.823,
'l_max': 2,
'l_channels': [
{
'name': 0,
'orbitals': [
{
'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital',
'type': 'apw',
},
{
'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital',
'type': 'lapw',
},
{
'm_def': 'nomad_simulations.schema_packages.basis_set.APWLocalOrbital',
'type': 'lo',
},
],
},
{
'name': 1,
'orbitals': [
{
'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital',
'type': 'apw',
},
{
'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital',
'type': 'lapw',
},
{
'm_def': 'nomad_simulations.schema_packages.basis_set.APWLocalOrbital',
'type': 'lo',
},
],
},
{
'name': 2,
'orbitals': [
{
'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital',
'type': 'apw',
},
{
'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital',
'type': 'lapw',
},
{
'm_def': 'nomad_simulations.schema_packages.basis_set.APWLocalOrbital',
'type': 'lo',
},
],
},
],
},
]
}
55 changes: 37 additions & 18 deletions tests/test_basis_set.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import itertools
from . import logger
from nomad.units import ureg
import numpy as np
import pytest
from typing import Optional
from typing import Optional, Any

from tests.conftest import apw

from nomad_simulations.schema_packages.basis_set import (
APWBaseOrbital,
Expand Down Expand Up @@ -47,41 +50,57 @@ def test_cutoff_failure():
assert pw.cutoff_fractional == 1


@pytest.mark.skip(reason="This function is not meant to be tested directly")
@pytest.mark.skip(reason='This function is not meant to be tested directly')
def generate_apw(
species: dict[str, int | APWBaseOrbital],
cutoff: Optional[float] = None
species: dict[str, dict[str, Any]], cutoff: Optional[float] = None
) -> BasisSetContainer:
"""
Generate a mock APW basis set with the following structure:
.
├── plane-wave basis set
└── muffin-tin regions
└── l-channels
├── (orbitals)
│ └── wavefunctions
├── 1 x plane-wave basis set
└── n x muffin-tin regions
└── l_max x l-channels
├── orbitals
└── local orbitals
"""
basis_set_components: list[BasisSet] = []
if cutoff is not None:
pw = APWPlaneWaveBasisSet(cutoff_energy=cutoff)
basis_set_components.append(pw)

mts: list[MuffinTinRegion] = []
for sp in species:
for sp_name, sp in species.items():
l_max = sp['l_max']
mt = MuffinTinRegion(
radius=sp['r'],
l_max=l_max,
l_channels=[
APWLChannel(
l=l,
orbitals=[APWOrbital(type=orb) for orb in sp['orb_type']] +\
[APWLocalOrbital(type=lo) for lo in sp['lo_type']],
) for l in range(l_max)
]
name=l,
orbitals=list(
itertools.chain(
(APWOrbital(type=orb) for orb in sp['orb_type']),
(APWLocalOrbital(type=lo) for lo in sp['lo_type']),
)
),
)
for l in range(l_max + 1)
],
)
mts.append(mt)
basis_set_components.append(mts)
basis_set_components.append(mt)

return BasisSetContainer(basis_set_components=basis_set_components)


def test_full_apw():
ref_apw = generate_apw(
{
'A': {
'r': 1.823,
'l_max': 2,
'orb_type': ['apw', 'lapw'],
'lo_type': ['lo'],
}
},
cutoff=500,
)
assert ref_apw.m_to_dict() == apw

0 comments on commit 9b61bdf

Please sign in to comment.