Skip to content

Commit

Permalink
Fix KMeshBase and changed name to KSpaceFunctionalities
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed May 27, 2024
1 parent 0113056 commit c75baeb
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 90 deletions.
135 changes: 70 additions & 65 deletions src/nomad_simulations/numerical_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,31 +160,43 @@ def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)


class KMeshBase(Mesh):
class KSpaceFunctionalities:
"""
A base section used for abstraction for `KMesh` and `KLinePath` sections. It contains the methods
`_check_reciprocal_lattice_vectors` and `resolve_high_symmetry_points` that are used in both sections.
A functionality class useful for defining methods shared between `KSpace`, `KMesh`, and `KLinePath`.
"""

def _check_reciprocal_lattice_vectors(
self, reciprocal_lattice_vectors: Optional[pint.Quantity], logger: BoundLogger
self,
reciprocal_lattice_vectors: Optional[pint.Quantity],
logger: BoundLogger,
check_grid: bool = False,
grid: Optional[List[int]] = [],
) -> bool:
"""
Check if the `reciprocal_lattice_vectors` exist and if they have the same dimensionality as `grid`.
Args:
reciprocal_lattice_vectors (Optional[pint.Quantity]): The reciprocal lattice vectors of the atomic cell.
logger (BoundLogger): The logger to log messages.
check_grid (bool, optional): Flag to check the `grid` is set to True. Defaults to False.
grid (Optional[List[int]], optional): The grid of the `KMesh`. Defaults to [].
Returns:
(bool): True if the `reciprocal_lattice_vectors` exist and have the same dimensionality as `grid`, False otherwise.
(bool): True if the `reciprocal_lattice_vectors` exist. If `check_grid_too` is set to True, it also checks if the
`reciprocal_lattice_vectors` and the `grid` have the same dimensionality. False otherwise.
"""
if reciprocal_lattice_vectors is None or self.grid is None:
logger.warning(
'Could not find `reciprocal_lattice_vectors` from parent `KSpace` or could not find `KMesh.grid`.'
)
if reciprocal_lattice_vectors is None:
logger.warning('Could not find `reciprocal_lattice_vectors`.')
return False
if len(reciprocal_lattice_vectors) != 3 or len(self.grid) != 3:
# Only checking the `reciprocal_lattice_vectors`
if not check_grid:
return True

# Checking the `grid` too
if grid is None:
logger.warning('Could not find `grid`.')
return False
if len(reciprocal_lattice_vectors) != 3 or len(grid) != 3:
logger.warning(
'The `reciprocal_lattice_vectors` and the `grid` should have the same dimensionality.'
)
Expand Down Expand Up @@ -284,9 +296,6 @@ def resolve_high_symmetry_points(
high_symmetry_points[key] = list(value)
return high_symmetry_points

def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)


class KMesh(Mesh):
"""
Expand Down Expand Up @@ -395,8 +404,11 @@ def get_k_line_density(
(np.float64): The k-line density of the `KMesh`.
"""
# Initial check
if not self._check_reciprocal_lattice_vectors(
reciprocal_lattice_vectors, logger
if not KSpaceFunctionalities._check_reciprocal_lattice_vectors(
reciprocal_lattice_vectors=reciprocal_lattice_vectors,
logger=logger,
check_grid=True,
grid=self.grid,
):
return None

Expand Down Expand Up @@ -426,8 +438,11 @@ def resolve_k_line_density(
(Optional[pint.Quantity]): The resolved `k_line_density` of the `KMesh`.
"""
# Initial check
if not self._check_reciprocal_lattice_vectors(
reciprocal_lattice_vectors, logger
if not KSpaceFunctionalities._check_reciprocal_lattice_vectors(
reciprocal_lattice_vectors=reciprocal_lattice_vectors,
logger=logger,
check_grid=True,
grid=self.grid,
):
return None

Expand Down Expand Up @@ -473,8 +488,10 @@ def normalize(self, archive, logger) -> None:

# Resolve `high_symmetry_points`
if self.high_symmetry_points is None:
self.high_symmetry_points = self.resolve_high_symmetry_points(
model_systems, logger
self.high_symmetry_points = (
KSpaceFunctionalities.resolve_high_symmetry_points(
model_systems=model_systems, logger=logger
)
)


Expand All @@ -485,21 +502,6 @@ class KLinePath(ArchiveSection):
value, one should multiply them by the reciprocal lattice vectors (`points_cartesian = points @ reciprocal_lattice_vectors`).
"""

# high_symmetry_path = Quantity(
# type=JSON,
# shape=['*'],
# description="""
# List of dictionaries containing the high-symmetry path (in units of the `reciprocal_lattice_vectors`) followed in
# the k-line path. E.g., in a cubic lattice:
# high_symmetry_path = [
# {'Gamma': [0, 0, 0]},
# {'X': [0.5, 0, 0]},
# {'Y': [0, 0.5, 0]},
# {'Gamma': [0, 0, 0]},
# ]
# """,
# )

high_symmetry_path_names = Quantity(
type=str,
shape=['*'],
Expand Down Expand Up @@ -533,34 +535,6 @@ class KLinePath(ArchiveSection):
""",
)

def _check_high_symmetry_path(self, logger: BoundLogger) -> bool:
"""
Check if the `high_symmetry_path_names` and `high_symmetry_path_values` are defined and have the same length.
Args:
logger (BoundLogger): The logger to log messages.
Returns:
(bool): True if the `high_symmetry_path_names` and `high_symmetry_path_values` are defined and have the same length, False otherwise.
"""
if (
self.high_symmetry_path_names is None
or self.high_symmetry_path_values is None
) or (
len(self.high_symmetry_path_names) == 0
or len(self.high_symmetry_path_values) == 0
):
logger.warning(
'Could not find `KLinePath.high_symmetry_path_names` or `KLinePath.high_symmetry_path_values`.'
)
return False
if len(self.high_symmetry_path_names) != len(self.high_symmetry_path_values):
logger.warning(
'The length of `KLinePath.high_symmetry_path_names` and `KLinePath.high_symmetry_path_values` should coincide.'
)
return False
return True

def resolve_high_symmetry_path_values(
self,
model_systems: List[ModelSystem],
Expand All @@ -579,13 +553,15 @@ def resolve_high_symmetry_path_values(
(Optional[List[float]]): The resolved `high_symmetry_path_values`.
"""
# Initial check on the `reciprocal_lattice_vectors`
if not self._check_reciprocal_lattice_vectors(
reciprocal_lattice_vectors, logger
if not KSpaceFunctionalities._check_reciprocal_lattice_vectors(
reciprocal_lattice_vectors=reciprocal_lattice_vectors, logger=logger
):
return []

# Resolving the dictionary containing the `high_symmetry_points` for the given ModelSystem symmetry
high_symmetry_points = self.resolve_high_symmetry_points(model_systems, logger)
high_symmetry_points = KSpaceFunctionalities.resolve_high_symmetry_points(
model_systems=model_systems, logger=logger
)
if high_symmetry_points is None:
return []

Expand All @@ -598,6 +574,34 @@ def resolve_high_symmetry_path_values(
]
return high_symmetry_path_values

def _check_high_symmetry_path(self, logger: BoundLogger) -> bool:
"""
Check if the `high_symmetry_path_names` and `high_symmetry_path_values` are defined and have the same length.
Args:
logger (BoundLogger): The logger to log messages.
Returns:
(bool): True if the `high_symmetry_path_names` and `high_symmetry_path_values` are defined and have the same length, False otherwise.
"""
if (
self.high_symmetry_path_names is None
or self.high_symmetry_path_values is None
) or (
len(self.high_symmetry_path_names) == 0
or len(self.high_symmetry_path_values) == 0
):
logger.warning(
'Could not find `KLinePath.high_symmetry_path_names` or `KLinePath.high_symmetry_path_values`.'
)
return False
if len(self.high_symmetry_path_names) != len(self.high_symmetry_path_values):
logger.warning(
'The length of `KLinePath.high_symmetry_path_names` and `KLinePath.high_symmetry_path_values` should coincide.'
)
return False
return True

def get_high_symmetry_path_norms(
self,
reciprocal_lattice_vectors: Optional[pint.Quantity],
Expand Down Expand Up @@ -759,6 +763,7 @@ class KSpace(NumericalSettings):
depending on the k-space sampling: `k_mesh` or `k_line_path`.
"""

# ! This needs to be normalized first in order to extract the `reciprocal_lattice_vectors` from the `ModelSystem.cell` information
reciprocal_lattice_vectors = Quantity(
type=np.float64,
shape=[3, 3],
Expand Down
129 changes: 104 additions & 25 deletions tests/test_numerical_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@

import pytest
import numpy as np
from typing import Optional, List, Dict
from typing import Optional, List

from nomad.units import ureg
from nomad.datamodel import EntryArchive

from nomad_simulations.numerical_settings import KMesh, KLinePath
from nomad_simulations.numerical_settings import KMesh, KLinePath, KSpaceFunctionalities

from . import logger
from .conftest import generate_k_line_path, generate_k_space_simulation
Expand Down Expand Up @@ -75,7 +75,64 @@ def test_normalize(
assert k_space.reciprocal_lattice_vectors == result


# TODO add testing for KMesh
class TestKSpaceFunctionalities:
"""
Test the `KSpaceFunctionalities` class defined in `numerical_settings.py`.
"""

@pytest.mark.parametrize(
'reciprocal_lattice_vectors, check_grid, grid, result',
[
(None, None, None, False),
([[1, 0, 0], [0, 1, 0], [0, 0, 1]], False, None, True),
([[1, 0, 0], [0, 1, 0], [0, 0, 1]], True, None, False),
([[1, 0, 0], [0, 1, 0], [0, 0, 1]], True, [6, 6, 6, 4], False),
([[1, 0, 0], [0, 1, 0], [0, 0, 1]], True, [6, 6, 6], True),
],
)
def test_check_reciprocal_lattice_vectors(
self,
reciprocal_lattice_vectors: Optional[List[List[float]]],
check_grid: bool,
grid: Optional[List[int]],
result: bool,
):
"""
Test the `_check_reciprocal_lattice_vectors` private method.
"""
check = KSpaceFunctionalities._check_reciprocal_lattice_vectors(
reciprocal_lattice_vectors=reciprocal_lattice_vectors,
logger=logger,
check_grd=check_grid,
grid=grid,
)
assert check == result

def test_resolve_high_symmetry_points(self):
"""
Test the `resolve_high_symmetry_points` method. Only testing the valid situation in which the `ModelSystem` normalization worked.
"""
# `ModelSystem.normalize()` need to extract `bulk` as a type.
simulation = generate_k_space_simulation(
pbc=[True, True, True],
)
model_systems = simulation.model_system
# normalize to extract symmetry
simulation.model_system[0].normalize(EntryArchive(), logger)

# Testing the functionality method
high_symmetry_points = KSpaceFunctionalities.resolve_high_symmetry_points(
model_systems=model_systems, logger=logger
)
assert len(high_symmetry_points) == 4
assert high_symmetry_points == {
'Gamma': [0, 0, 0],
'M': [0.5, 0.5, 0],
'R': [0.5, 0.5, 0.5],
'X': [0, 0.5, 0],
}


class TestKMesh:
"""
Test the `KMesh` class defined in `numerical_settings.py`.
Expand Down Expand Up @@ -231,28 +288,6 @@ def test_resolve_k_line_density(
else:
assert k_line_density == result_k_line_density

def test_resolve_high_symmetry_points(self):
"""
Test the `resolve_high_symmetry_points` method. Only testing the valid situation in which the `ModelSystem` normalization worked.
"""
# `ModelSystem.normalize()` need to extract `bulk` as a type.
simulation = generate_k_space_simulation(
pbc=[True, True, True],
)
model_system = simulation.model_system[0]
model_system.normalize(EntryArchive(), logger) # normalize to extract symmetry
k_mesh = simulation.model_method[0].numerical_settings[0].k_mesh[0]
high_symmetry_points = k_mesh.resolve_high_symmetry_points(
simulation.model_system, logger
)
assert len(high_symmetry_points) == 4
assert high_symmetry_points == {
'Gamma': [0, 0, 0],
'M': [0.5, 0.5, 0],
'R': [0.5, 0.5, 0.5],
'X': [0, 0.5, 0],
}


class TestKLinePath:
"""
Expand Down Expand Up @@ -284,6 +319,50 @@ def test_check_high_symmetry_path(
)
assert k_line_path._check_high_symmetry_path(logger) == result

@pytest.mark.parametrize(
'high_symmetry_path_names, result',
[
(['Gamma', 'X', 'R'], [[0, 0, 0], [0, 0.5, 0], [0.5, 0.5, 0.5]]),
],
)
def test_resolve_high_symmetry_path_values(
self,
high_symmetry_path_names: List[str],
result: bool,
):
"""
Test the `resolve_high_symmetry_path_values` method. Only testing the valid situation in which the `ModelSystem` normalization worked.
"""
# `ModelSystem.normalize()` need to extract `bulk` as a type.
simulation = generate_k_space_simulation(
pbc=[True, True, True],
high_symmetry_path_names=high_symmetry_path_names,
high_symmetry_path_values=None,
)
model_system = simulation.model_system[0]
model_system.normalize(EntryArchive(), logger) # normalize to extract symmetry
# getting `reciprocal_lattice_vectors`
reciprocal_lattice_vectors = (
simulation.model_method[0].numerical_settings[0].reciprocal_lattice_vectors
)

# `KLinePath` can be understood as a `KMeshBase` section
k_line_path = simulation.model_method[0].numerical_settings[0].k_line_path
k_line_path.normalize(EntryArchive(), logger)
high_symmetry_points_values = k_line_path.resolve_high_symmetry_path_values(
simulation.model_system, reciprocal_lattice_vectors, logger
)
# high_symmetry_points = k_mesh_base.resolve_high_symmetry_points(
# simulation.model_system, logger
# )
# assert len(high_symmetry_points) == 4
# assert high_symmetry_points == {
# 'Gamma': [0, 0, 0],
# 'M': [0.5, 0.5, 0],
# 'R': [0.5, 0.5, 0.5],
# 'X': [0, 0.5, 0],
# }

def test_get_high_symmetry_path_norm(self, k_line_path: KLinePath):
"""
Test the `get_high_symmetry_path_norm` method.
Expand Down

0 comments on commit c75baeb

Please sign in to comment.